aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
authorParth Mudgal <artpar@gmail.com>2024-01-12 14:08:09 +0530
committerGitHub <noreply@github.com>2024-01-12 09:38:09 +0100
commitf37b304fc3006933dd1245218c41404c336df280 (patch)
tree5ebbe6b5e26302233bdb69a2e9fecddd875330ed /src/main/java/dev/morling/onebrc
parentaf1946fcb5468e85142a4eff7e954e7a6d980530 (diff)
inline hash calculation and number parsing (#200)
no number parsing with precalculated map verify tests better loop with direct hash to measurement mapping accept formatting changes Use unsafe
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java412
1 files changed, 209 insertions, 203 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java
index 835e65e..4faf322 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artpar.java
@@ -15,11 +15,16 @@
*/
package dev.morling.onebrc;
+import sun.misc.Unsafe;
+
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.RandomAccessFile;
-import java.nio.MappedByteBuffer;
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.lang.reflect.Field;
+import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
@@ -37,17 +42,22 @@ import java.util.stream.Collectors;
public class CalculateAverage_artpar {
public static final int N_THREADS = 8;
private static final String FILE = "./measurements.txt";
+ private static final int INT_MAP_SIZE = 8192; // from calculateIntegerByteMapTest()
+ final static int[] byteHashMapToInt = calculateIntegerByteMap();
+ private static final Unsafe UNSAFE = initUnsafe();
// private static final VectorSpecies<Integer> SPECIES = IntVector.SPECIES_PREFERRED;
// final int VECTOR_SIZE = 512;
// final int VECTOR_SIZE_1 = VECTOR_SIZE - 1;
- final int SIZE = 1024 * 1024;
+ final int AVERAGE_CHUNK_SIZE = 1024 * 64;
+ final int AVERAGE_CHUNK_SIZE_1 = AVERAGE_CHUNK_SIZE - 1;
public CalculateAverage_artpar() throws IOException {
long start = Instant.now().toEpochMilli();
Path measurementFile = Paths.get(FILE);
long fileSize = Files.size(measurementFile);
- long expectedChunkSize = Math.max(fileSize / 8, 1024);
+ // System.out.println("File size - " + fileSize);
+ int expectedChunkSize = Math.toIntExact(Math.min(fileSize / N_THREADS, Integer.MAX_VALUE / 2));
ExecutorService threadPool = Executors.newFixedThreadPool(N_THREADS);
@@ -56,52 +66,50 @@ public class CalculateAverage_artpar {
List<Future<Map<String, MeasurementAggregator>>> futures = new ArrayList<>();
long bytesReadCurrent = 0;
- try (FileChannel fileChannel = FileChannel.open(measurementFile, StandardOpenOption.READ)) {
- for (int i = 0; i < 8; i++) {
+ FileChannel fileChannel = FileChannel.open(measurementFile, StandardOpenOption.READ);
+ for (int i = 0; chunkStartPosition < fileSize; i++) {
- long chunkSize = expectedChunkSize;
- chunkSize = fis.skipBytes(Math.toIntExact(chunkSize));
+ int chunkSize = expectedChunkSize;
+ chunkSize = fis.skipBytes(chunkSize);
- bytesReadCurrent += chunkSize;
- while (((char) fis.read()) != '\n' && bytesReadCurrent < fileSize) {
- chunkSize++;
- bytesReadCurrent++;
- }
+ bytesReadCurrent += chunkSize;
+ while (((char) fis.read()) != '\n' && bytesReadCurrent < fileSize) {
+ chunkSize++;
+ bytesReadCurrent++;
+ }
- // System.out.println("[" + chunkStartPosition + "] - [" + (chunkStartPosition + chunkSize) + " bytes");
- if (chunkStartPosition + chunkSize >= fileSize) {
- chunkSize = fileSize - chunkStartPosition;
- }
- if (chunkSize < 1) {
- break;
- }
- if (chunkSize > Integer.MAX_VALUE) {
- chunkSize = Integer.MAX_VALUE;
- }
+ // System.out.println("[" + chunkStartPosition + "] - [" + (chunkStartPosition + chunkSize) + " bytes");
+ if (chunkStartPosition + chunkSize >= fileSize) {
+ chunkSize = (int) Math.min(fileSize - chunkStartPosition, Integer.MAX_VALUE);
+ }
+ if (chunkSize < 1) {
+ break;
+ }
+ if (chunkSize >= Integer.MAX_VALUE) {
+ throw new RuntimeException();
+ }
- MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, chunkStartPosition,
- chunkSize);
+ // MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, chunkStartPosition,
+ // chunkSize);
- ReaderRunnable readerRunnable = new ReaderRunnable(mappedByteBuffer);
- Future<Map<String, MeasurementAggregator>> future = threadPool.submit(readerRunnable::run);
- // System.out.println("Added future [" + chunkStartPosition + "][" + chunkSize + "]");
- futures.add(future);
- chunkStartPosition = chunkStartPosition + chunkSize + 1;
- }
+ ReaderRunnable readerRunnable = new ReaderRunnable(chunkStartPosition, chunkSize, fileChannel);
+ Future<Map<String, MeasurementAggregator>> future = threadPool.submit(readerRunnable::run);
+ // System.out.println("Added future [" + chunkStartPosition + "][" + chunkSize + "]");
+ futures.add(future);
+ chunkStartPosition = chunkStartPosition + chunkSize + 1;
}
+
fis.close();
- Map<String, MeasurementAggregator> globalMap = futures.parallelStream()
- .flatMap(future -> {
- try {
- return future.get().entrySet().stream();
- }
- catch (InterruptedException | ExecutionException e) {
- throw new RuntimeException(e);
- }
- }).parallel().collect(Collectors.toMap(
- Map.Entry::getKey, Map.Entry::getValue,
- MeasurementAggregator::combine));
+ Map<String, MeasurementAggregator> globalMap = futures.parallelStream().flatMap(future -> {
+ try {
+ return future.get().entrySet().stream();
+ }
+ catch (InterruptedException | ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+ }).parallel().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, MeasurementAggregator::combine));
+ fileChannel.close();
Map<String, ResultRow> results = globalMap.entrySet().stream().parallel()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().finish()));
@@ -136,33 +144,99 @@ public class CalculateAverage_artpar {
}
+ public static int[] calculateIntegerByteMapTest() {
+ int[] intToIntMap = null;
+ for (int j = 0; j < 10000; j++) {
+ int length = 2000 + j;
+ intToIntMap = new int[length];
+ boolean hasHashClash = false;
+ Map<Integer, Integer> byteHashToInt = new HashMap<>();
+ for (int i = -999; i < 1000; i++) {
+ int hashCode = hashInteger(i);
+
+ // String s = new String(value);
+ int position = hashCode & (length - 1);
+ // System.out.printf("%.1f => %s length [%d] hash [%d] => %d\n", number, s, s.length(), hashCode, position);
+ if (byteHashToInt.containsKey(hashCode) || intToIntMap[position] != 0) {
+ // System.err.println("HashClash [" + hashCode + "] -> " +
+ // byteHashToInt.get(
+ // hashCode) + " vs " + number + " == [" + position + "] =>" + intToIntMap[position]);
+ hasHashClash = true;
+ break;
+ }
+ else {
+ byteHashToInt.put(hashCode, i);
+ intToIntMap[position] = i;
+ }
+ }
+ if (!hasHashClash) {
+ // 8192
+ System.out.println("NoHash clash at [" + length + "]");
+ // throw new RuntimeException("clash");
+ return intToIntMap;
+ }
+
+ }
+ System.out.println("Fail");
+ return null;
+ }
+
+ private static int hashInteger(int i) {
+ float number = i / 10f;
+ String numberString = String.format("%.1f", number);
+ byte[] value = numberString.getBytes();
+
+ int hashCode = 1;
+ for (int k = 0; k < value.length; k++) {
+ hashCode = hashCode * 31 + value[k];
+ }
+ return hashCode;
+ }
+
+ public static int[] calculateIntegerByteMap() {
+ long start = System.currentTimeMillis();
+ int[] intToIntMap = new int[INT_MAP_SIZE];
+ for (int i = -999; i < 1000; i++) {
+ float number = i / 10f;
+ byte[] value = String.format("%.1f", number).getBytes();
+
+ int hashCode = 1;
+ for (byte b : value) {
+ hashCode = hashCode * 31 + b;
+ }
+ int position = hashCode & (INT_MAP_SIZE - 1);
+ intToIntMap[position] = i;
+ }
+ long end = System.currentTimeMillis();
+ // System.out.println("calculateIntegerByteMap " + (end - start) + " ms");
+ return intToIntMap;
+ }
+
public static void main(String[] args) throws IOException {
new CalculateAverage_artpar();
}
- public static int hashCode(byte[] array, int length) {
-
- int h = 1;
- int i = 0;
- for (; i + 7 < length; i += 8) {
- h = 31 * 31 * 31 * 31 * 31 * 31 * 31 * 31 * h + 31 * 31 * 31 * 31
- * 31 * 31 * 31 * array[i] + 31 * 31 * 31 * 31 * 31 * 31
- * array[i + 1]
- + 31 * 31 * 31 * 31 * 31 * array[i + 2] + 31
- * 31 * 31 * 31 * array[i + 3]
- + 31 * 31 * 31 * array[i + 4]
- + 31 * 31 * array[i + 5] + 31 * array[i + 6] + array[i + 7];
+ private static Unsafe initUnsafe() {
+ try {
+ Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ theUnsafe.setAccessible(true);
+ return (Unsafe) theUnsafe.get(Unsafe.class);
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
}
+ }
- for (; i + 3 < length; i += 4) {
- h = 31 * 31 * 31 * 31 * h + 31 * 31 * 31 * array[i] + 31 * 31
- * array[i + 1] + 31 * array[i + 2] + array[i + 3];
+ static boolean unsafeEquals(long aStart, long aLength, long bStart, long bLength) {
+ if (aLength != bLength) {
+ return false;
}
- for (; i < length; i++) {
- h = 31 * h + array[i];
+ for (int i = 0; i < aLength; ++i) {
+ if (UNSAFE.getByte(aStart + i) != UNSAFE.getByte(bStart + i)) {
+ return false;
+ }
}
-
- return h;
+ return true;
}
private record ResultRow(double min, double mean, double max) {
@@ -180,43 +254,25 @@ public class CalculateAverage_artpar {
}
private static class MeasurementAggregator {
- private double min = Double.POSITIVE_INFINITY;
- private double max = Double.NEGATIVE_INFINITY;
+ private int min = 999;
+ private int max = -999;
private double sum;
private long count;
- public MeasurementAggregator() {
- }
-
- // public MeasurementAggregator(double min, double max, double sum, long count) {
- // this.min = min;
- // this.max = max;
- // this.sum = sum;
- // this.count = count;
- // }
-
MeasurementAggregator combine(MeasurementAggregator other) {
- min = Math.min(min, other.min);
- max = Math.max(max, other.max);
+ min = other.min + ((min - other.min) & ((min - other.min) >> (32 * 8 - 1)));
+ max = max - ((max - other.max) & ((max - other.max) >> (32 * 8 - 1)));
sum += other.sum;
count += other.count;
return this;
}
- // MeasurementAggregator combine(double otherMin, double otherMax, double otherSum, long otherCount) {
- // min = Math.min(min, otherMin);
- // max = Math.max(max, otherMax);
- // sum += otherSum;
- // count += otherCount;
- // return this;
- // }
-
- MeasurementAggregator combine(double value) {
- min = Math.min(min, value);
- max = Math.max(max, value);
+ void combine(int value) {
sum += value;
- count += 1;
- return this;
+ count++;
+
+ min = value + ((min - value) & ((min - value) >> (32 * 8 - 1))); // min(x, y)
+ max = max - ((max - value) & ((max - value) >> (32 * 8 - 1))); // max(x, y)
}
ResultRow finish() {
@@ -227,150 +283,100 @@ public class CalculateAverage_artpar {
static class StationName {
public final int hash;
- private final String name;
- // private final int index;
+ private final ByteBuffer nameBytes;
+ private final MeasurementAggregator measurementAggregator = new MeasurementAggregator();
public int count = 0;
- // public int[] values = new int[VECTOR_SIZE];
- public MeasurementAggregator measurementAggregator = new MeasurementAggregator();
- public StationName(String name, int hash) {
- this.name = name;
- // this.index = index;
+ public StationName(ByteBuffer nameBytes, int hash) {
+ this.nameBytes = nameBytes;
this.hash = hash;
}
}
private class ReaderRunnable {
- private final MappedByteBuffer mappedByteBuffer;
+ private final long startPosition;
+ private final FileChannel fileChannel;
+ private final int chunkSize;
StationNameMap stationNameMap = new StationNameMap();
- // double[][] stationValueMap = new double[SIZE][];
- private ReaderRunnable(MappedByteBuffer mappedByteBuffer) {
- this.mappedByteBuffer = mappedByteBuffer;
+ private ReaderRunnable(long startPosition, int chunkSize, FileChannel fileChannel) throws IOException {
+ this.chunkSize = chunkSize;
+ this.startPosition = startPosition;
+ this.fileChannel = fileChannel;
+ // mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startPosition, chunkSize);
}
- public Map<String, MeasurementAggregator> run() {
- // System.out.println("Started future - " + mappedByteBuffer.position());
-
- int doubleValue;
- long start = Date.from(Instant.now()).getTime();
- // int totalBytesRead = 0;
+ public Map<String, MeasurementAggregator> run() throws IOException {
+ MemorySegment mappedSegment = fileChannel.map(FileChannel.MapMode.READ_ONLY,
+ startPosition, chunkSize, Arena.global());
- // ByteBuffer nameBuffer = ByteBuffer.allocate(128);
- int MAPPED_BYTE_BUFFER_SIZE = 8192;
- byte[] rawBuffer = new byte[32];
+ long rawBufferAddress = UNSAFE.allocateMemory(100);
int rawBufferReadIndex = 0;
- StationName matchedStation = null;
- boolean expectedName = true;
-
- byte[] mappedBytes = new byte[MAPPED_BYTE_BUFFER_SIZE];
- int mappedBytesReadIndex;
- boolean negative = false;
- int start1 = 0;
- int result = 0;
-
- while (mappedByteBuffer.hasRemaining()) {
- int remaining = mappedByteBuffer.remaining();
- int bytesToRead = Math.min(remaining, MAPPED_BYTE_BUFFER_SIZE);
- mappedByteBuffer.get(mappedBytes, 0, bytesToRead);
- remaining = mappedByteBuffer.remaining();
- mappedBytesReadIndex = 0;
-
- while (mappedBytesReadIndex < bytesToRead) {
- byte b = mappedBytes[mappedBytesReadIndex];
- mappedBytesReadIndex++;
-
- if (expectedName) {
- if (b != ';') {
- rawBuffer[rawBufferReadIndex] = b;
- rawBufferReadIndex++;
- continue;
- }
- else {
- expectedName = false;
- matchedStation = stationNameMap.getOrCreate(rawBuffer, rawBufferReadIndex);
- rawBufferReadIndex = 0;
- negative = false;
- start1 = 0;
- result = 0;
- continue;
- }
- }
-
- while (b != '\n') {
- rawBuffer[rawBufferReadIndex] = b;
- rawBufferReadIndex++;
-
- if (mappedBytesReadIndex < bytesToRead) {
- b = mappedBytes[mappedBytesReadIndex];
- mappedBytesReadIndex++;
- }
- else {
- break;
- }
- }
-
- if (b != '\n') {
- if (mappedBytesReadIndex == bytesToRead && remaining > 0) {
- continue;
- }
- }
-
- // Check for negative numbers
- if (rawBuffer[0] == '-') {
- negative = true;
- start1++;
- }
-
- for (int i = start1; i < rawBufferReadIndex; i++) {
- byte c = rawBuffer[i];
- if (c != '.') {
- result = result * 10 + (c - '0');
- }
- }
-
- doubleValue = negative ? -result : result;
- rawBufferReadIndex = 0;
- matchedStation.measurementAggregator.combine(doubleValue);
- matchedStation.count++;
- expectedName = true;
+ long position = mappedSegment.address();
+ long endPosition = position + chunkSize;
+ byte b;
+ int hash;
+ int nameHash;
+ hash = 1;
+
+ while (position < endPosition) {
+
+ while ((position < endPosition) &&
+ (b = UNSAFE.getByte(position++)) != ';') {
+ UNSAFE.putByte(rawBufferAddress + rawBufferReadIndex++, b);
+ hash = hash * 31 + b;
}
- }
+ nameHash = hash;
+ hash = 1;
- long end = Date.from(Instant.now()).getTime();
- // System.out.println("Took [" + ((end - start) / 1000) + "s for " + totalBytesRead / 1024 + " kb");
+ while ((position < endPosition) &&
+ (b = UNSAFE.getByte(position++)) != '\n') {
+ hash = hash * 31 + b;
+ }
+ stationNameMap.getOrCreate(rawBufferAddress, rawBufferReadIndex,
+ byteHashMapToInt[hash & (INT_MAP_SIZE - 1)], nameHash);
+ rawBufferReadIndex = 0;
+ hash = 1;
- return Arrays.stream(stationNameMap.names).parallel().filter(Objects::nonNull)
- .collect(Collectors.toMap(e -> e.name, e -> e.measurementAggregator));
- // return groupedMeasurements;
+ }
+ return Arrays.stream(stationNameMap.names).parallel().filter(Objects::nonNull).collect(
+ Collectors.toMap(e -> StandardCharsets.UTF_8.decode(e.nameBytes).toString(),
+ e -> e.measurementAggregator, MeasurementAggregator::combine));
}
}
class StationNameMap {
- int[] indexes = new int[SIZE];
- StationName[] names = new StationName[SIZE];
+ int[] indexes = new int[AVERAGE_CHUNK_SIZE];
+ StationName[] names = new StationName[AVERAGE_CHUNK_SIZE];
int currentIndex = 0;
+ ByteBuffer bytesForName = ByteBuffer.allocateDirect(1000 * 100);
+ int nameBufferIndex = 0;
- public StationName getOrCreate(byte[] stationNameBytes, int length) {
-
- int hash = CalculateAverage_artpar.hashCode(stationNameBytes, length);
-
- int position = Math.abs(hash) % SIZE;
- while (indexes[position] != 0 && names[indexes[position]].hash != hash) {
- position = ++position % SIZE;
+ public void getOrCreate(long stationNameBytesAddress, int length, int doubleValue, int hash) {
+ int position = hash & AVERAGE_CHUNK_SIZE_1;
+ while (indexes[position] != 0 && (names[indexes[position]].hash != hash)) {
+ position = ++position & AVERAGE_CHUNK_SIZE_1;
}
if (indexes[position] != 0) {
- return names[indexes[position]];
+ StationName stationName = names[indexes[position]];
+ stationName.measurementAggregator.combine(doubleValue);
+ }
+ else {
+ ByteBuffer nameSlice = bytesForName.slice(nameBufferIndex, length);
+ nameBufferIndex += length;
+ for (int i = 0; i < length; i++) {
+ nameSlice.put(UNSAFE.getByte(stationNameBytesAddress + i));
+ }
+ nameSlice.flip();
+ StationName stationName = new StationName(nameSlice, hash);
+ indexes[position] = ++currentIndex;
+ names[indexes[position]] = stationName;
+ stationName.measurementAggregator.combine(doubleValue);
}
- StationName stationName = new StationName(
- new String(stationNameBytes, 0, length, StandardCharsets.UTF_8), hash);
- indexes[position] = ++currentIndex;
- names[indexes[position]] = stationName;
- return stationName;
}
}
-}
+} \ No newline at end of file