aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorJaromir Hamala <jaromir.hamala@gmail.com>2024-01-17 18:28:03 +0100
committerGitHub <noreply@github.com>2024-01-17 18:28:03 +0100
commit927880b97ec3c0ae354773ae645dc8c0bbec8345 (patch)
treedd661b7302d424fa768596e4b631e14503de52e8 /src/main
parent77872e197d1e6237c327bdbd5fbd925648ca4337 (diff)
edge-case in hashing fixed (#459)
also a bunch of smaller improvements
Diffstat (limited to 'src/main')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java303
1 files changed, 151 insertions, 152 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
index 6fb89bb..5373cb0 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
@@ -22,15 +22,24 @@ import java.io.RandomAccessFile;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
import java.nio.channels.FileChannel.MapMode;
-import java.util.Map;
-import java.util.TreeMap;
+import java.util.*;
+/**
+ * I figured out it would be very hard to win the main competition of the One Billion Rows Challenge.
+ * but I think this code has a good chance to win a special prize for the Ugliest Solution ever! :)
+ *
+ * Anyway, if you can make sense out of not exactly idiomatic Java code, and you enjoy pushing performance limits
+ * then QuestDB - the fastest open-source time-series database - is hiring: https://questdb.io/careers/core-database-engineer/
+ *
+ */
public class CalculateAverage_jerrinot {
private static final Unsafe UNSAFE = unsafe();
private static final String MEASUREMENTS_TXT = "measurements.txt";
// todo: with hyper-threading enable we would be better of with availableProcessors / 2;
// todo: validate the testing env. params.
private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors();
+ // private static final int THREAD_COUNT = 4;
+
private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;
private static Unsafe unsafe() {
@@ -72,7 +81,7 @@ public class CalculateAverage_jerrinot {
Processor[] processors = new Processor[THREAD_COUNT];
Thread[] threads = new Thread[THREAD_COUNT];
- for (int i = 0; i < THREAD_COUNT; i++) {
+ for (int i = 0; i < THREAD_COUNT - 1; i++) {
long startA = chunkStartOffsets[i * chunkPerThread];
long endA = chunkStartOffsets[i * chunkPerThread + 1];
long startB = chunkStartOffsets[i * chunkPerThread + 1];
@@ -89,8 +98,22 @@ public class CalculateAverage_jerrinot {
thread.start();
}
+ int ownIndex = THREAD_COUNT - 1;
+ long startA = chunkStartOffsets[ownIndex * chunkPerThread];
+ long endA = chunkStartOffsets[ownIndex * chunkPerThread + 1];
+ long startB = chunkStartOffsets[ownIndex * chunkPerThread + 1];
+ long endB = chunkStartOffsets[ownIndex * chunkPerThread + 2];
+ long startC = chunkStartOffsets[ownIndex * chunkPerThread + 2];
+ long endC = chunkStartOffsets[ownIndex * chunkPerThread + 3];
+ long startD = chunkStartOffsets[ownIndex * chunkPerThread + 3];
+ long endD = chunkStartOffsets[ownIndex * chunkPerThread + 4];
+ Processor processor = new Processor(startA, endA, startB, endB, startC, endC, startD, endD);
+ processor.run();
+
var accumulator = new TreeMap<String, Processor.StationStats>();
- for (int i = 0; i < THREAD_COUNT; i++) {
+ processor.accumulateStatus(accumulator);
+
+ for (int i = 0; i < THREAD_COUNT - 1; i++) {
Thread t = threads[i];
t.join();
processors[i].accumulateStatus(accumulator);
@@ -131,7 +154,7 @@ public class CalculateAverage_jerrinot {
private static class Processor implements Runnable {
private static final int MAP_SLOT_COUNT = ceilPow2(10000);
- private static final int STATION_MAX_NAME_BYTES = 104;
+ private static final int STATION_MAX_NAME_BYTES = 120;
private static final long COUNT_OFFSET = 0;
private static final long MIN_OFFSET = 4;
@@ -162,23 +185,16 @@ public class CalculateAverage_jerrinot {
private long endC;
private long cursorD;
private long endD;
- private long maskA;
- private long maskB;
- private long maskC;
- private long maskD;
- // credit: merykitty
- private long parseAndStoreTemperature(long startCursor, long baseEntryPtr) {
- long word = UNSAFE.getLong(startCursor);
- final long negateda = ~word;
- final int dotPos = Long.numberOfTrailingZeros(negateda & 0x10101000);
- final long signed = (negateda << 59) >> 63;
- final long removeSignMask = ~(signed & 0xFF);
- final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
- final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- final int temperature = (int) ((absValue ^ signed) - signed);
+ // private long maxClusterLen;
+ // credit: merykitty
+ private long parseAndStoreTemperature(long startCursor, long baseEntryPtr, long word) {
+ // long word = UNSAFE.getLong(startCursor);
long countPtr = baseEntryPtr + COUNT_OFFSET;
+ int cnt = UNSAFE.getInt(countPtr);
+ UNSAFE.putInt(countPtr, cnt + 1);
+
long minPtr = baseEntryPtr + MIN_OFFSET;
long maxPtr = baseEntryPtr + MAX_OFFSET;
long sumPtr = baseEntryPtr + SUM_OFFSET;
@@ -186,16 +202,23 @@ public class CalculateAverage_jerrinot {
int min = UNSAFE.getInt(minPtr);
int max = UNSAFE.getInt(maxPtr);
long sum = UNSAFE.getLong(sumPtr);
- // try if min/max intrinsics are paying off
- // maybe braching is better? the branch is becoming more predictable with
- // each new sample.
- max = Math.max(max, temperature);
- min = Math.min(min, temperature);
+
+ final long negateda = ~word;
+ final int dotPos = Long.numberOfTrailingZeros(negateda & 0x10101000);
+ final long signed = (negateda << 59) >> 63;
+ final long removeSignMask = ~(signed & 0xFF);
+ final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
+ final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ final int temperature = (int) ((absValue ^ signed) - signed);
sum += temperature;
- UNSAFE.putInt(countPtr, UNSAFE.getInt(countPtr) + 1);
- UNSAFE.putInt(minPtr, min);
- UNSAFE.putInt(maxPtr, max);
UNSAFE.putLong(sumPtr, sum);
+
+ if (temperature > max) {
+ UNSAFE.putInt(maxPtr, temperature);
+ }
+ if (temperature < min) {
+ UNSAFE.putInt(minPtr, temperature);
+ }
return startCursor + (dotPos / 8) + 3;
}
@@ -227,13 +250,13 @@ public class CalculateAverage_jerrinot {
int count = UNSAFE.getInt(baseAddress + COUNT_OFFSET);
long sum = UNSAFE.getLong(baseAddress + SUM_OFFSET);
- // todo: lambdas bootstrap probably cost us
- accumulator.compute(name, (_, v) -> {
- if (v == null) {
- return new StationStats(min, max, count, sum);
- }
- return new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum);
- });
+ var v = accumulator.get(name);
+ if (v == null) {
+ accumulator.put(name, new StationStats(min, max, count, sum));
+ }
+ else {
+ accumulator.put(name, new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum));
+ }
}
}
@@ -260,11 +283,22 @@ public class CalculateAverage_jerrinot {
private void doTail() {
// todo: we would be probably better of without all that code dup. ("compilers hates him!")
// System.out.println("done ILP");
+ doOne(cursorA, endA);
+ // System.out.println("done A");
+ doOne(cursorB, endB);
+ // System.out.println("done B");
+ doOne(cursorC, endC);
+ // System.out.println("done C");
+ doOne(cursorD, endD);
+ // System.out.println("done D");
+ }
+
+ private void doOne(long cursorA, long endA) {
while (cursorA < endA) {
long startA = cursorA;
long delimiterWordA = UNSAFE.getLong(cursorA);
long hashA = 0;
- maskA = getDelimiterMask(delimiterWordA);
+ long maskA = getDelimiterMask(delimiterWordA);
while (maskA == 0) {
hashA ^= delimiterWordA;
cursorA += 8;
@@ -273,81 +307,15 @@ public class CalculateAverage_jerrinot {
}
final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
final long semicolonA = cursorA + (delimiterByteA >> 3);
- final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1);
+ final long maskedWordA = delimiterWordA & ((maskA - 1) ^ maskA) >>> 8;
hashA ^= maskedWordA;
int intHashA = (int) (hashA ^ (hashA >> 32));
intHashA = intHashA ^ (intHashA >> 17);
long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA);
- cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA);
- }
- // System.out.println("done A");
- while (cursorB < endB) {
- long startB = cursorB;
- long delimiterWordB = UNSAFE.getLong(cursorB);
- long hashB = 0;
- maskB = getDelimiterMask(delimiterWordB);
- while (maskB == 0) {
- hashB ^= delimiterWordB;
- cursorB += 8;
- delimiterWordB = UNSAFE.getLong(cursorB);
- maskB = getDelimiterMask(delimiterWordB);
- }
- final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
- final long semicolonB = cursorB + (delimiterByteB >> 3);
- final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1);
- hashB ^= maskedWordB;
- int intHashB = (int) (hashB ^ (hashB >> 32));
- intHashB = intHashB ^ (intHashB >> 17);
-
- long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, intHashB, maskedWordB);
- cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB);
- }
- // System.out.println("done B");
- while (cursorC < endC) {
- long startC = cursorC;
- long delimiterWordC = UNSAFE.getLong(cursorC);
- long hashC = 0;
- maskC = getDelimiterMask(delimiterWordC);
- while (maskC == 0) {
- hashC ^= delimiterWordC;
- cursorC += 8;
- delimiterWordC = UNSAFE.getLong(cursorC);
- maskC = getDelimiterMask(delimiterWordC);
- }
- final int delimiterByteC = Long.numberOfTrailingZeros(maskC);
- final long semicolonC = cursorC + (delimiterByteC >> 3);
- final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1);
- hashC ^= maskedWordC;
- int intHashC = (int) (hashC ^ (hashC >> 32));
- intHashC = intHashC ^ (intHashC >> 17);
-
- long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC);
- cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC);
- }
- // System.out.println("done C");
- while (cursorD < endD) {
- long startD = cursorD;
- long delimiterWordD = UNSAFE.getLong(cursorD);
- long hashD = 0;
- maskD = getDelimiterMask(delimiterWordD);
- while (maskD == 0) {
- hashD ^= delimiterWordD;
- cursorD += 8;
- delimiterWordD = UNSAFE.getLong(cursorD);
- maskD = getDelimiterMask(delimiterWordD);
- }
- final int delimiterByteD = Long.numberOfTrailingZeros(maskD);
- final long semicolonD = cursorD + (delimiterByteD >> 3);
- final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1);
- hashD ^= maskedWordD;
- int intHashD = (int) (hashD ^ (hashD >> 32));
- intHashD = intHashD ^ (intHashD >> 17);
-
- long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD);
- cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD);
+ long temperatureWordA = UNSAFE.getLong(semicolonA + 1);
+ cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA, temperatureWordA);
}
- // System.out.println("done D");
}
@Override
@@ -359,10 +327,14 @@ public class CalculateAverage_jerrinot {
long startC = cursorC;
long startD = cursorD;
- long delimiterWordA = UNSAFE.getLong(cursorA);
- long delimiterWordB = UNSAFE.getLong(cursorB);
- long delimiterWordC = UNSAFE.getLong(cursorC);
- long delimiterWordD = UNSAFE.getLong(cursorD);
+ long currentWordA = UNSAFE.getLong(startA);
+ // long delimiterWordA2 = UNSAFE.getLong(startA + 8);
+ long currentWordB = UNSAFE.getLong(startB);
+ // long delimiterWordB2 = UNSAFE.getLong(startB + 8);
+ long currentWordC = UNSAFE.getLong(startC);
+ // long delimiterWordCa = UNSAFE.getLong(startC + 8);
+ long currentWordD = UNSAFE.getLong(startD);
+ // long delimiterWordD2 = UNSAFE.getLong(startD + 8);
long hashA = 0;
long hashB = 0;
@@ -370,58 +342,62 @@ public class CalculateAverage_jerrinot {
long hashD = 0;
// credits for the hashing idea: royvanrijn
- maskA = getDelimiterMask(delimiterWordA);
+ long maskA = getDelimiterMask(currentWordA);
while (maskA == 0) {
- hashA ^= delimiterWordA;
+ hashA ^= currentWordA;
cursorA += 8;
- delimiterWordA = UNSAFE.getLong(cursorA);
- maskA = getDelimiterMask(delimiterWordA);
+ currentWordA = UNSAFE.getLong(cursorA);
+ maskA = getDelimiterMask(currentWordA);
}
final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
final long semicolonA = cursorA + (delimiterByteA >> 3);
- final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1);
+ long temperatureWordA = UNSAFE.getLong(semicolonA + 1);
+ final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8;
hashA ^= maskedWordA;
int intHashA = (int) (hashA ^ (hashA >> 32));
intHashA = intHashA ^ (intHashA >> 17);
- maskB = getDelimiterMask(delimiterWordB);
+ long maskB = getDelimiterMask(currentWordB);
while (maskB == 0) {
- hashB ^= delimiterWordB;
+ hashB ^= currentWordB;
cursorB += 8;
- delimiterWordB = UNSAFE.getLong(cursorB);
- maskB = getDelimiterMask(delimiterWordB);
+ currentWordB = UNSAFE.getLong(cursorB);
+ maskB = getDelimiterMask(currentWordB);
}
final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
final long semicolonB = cursorB + (delimiterByteB >> 3);
- final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1);
+ long temperatureWordB = UNSAFE.getLong(semicolonB + 1);
+ final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8;
hashB ^= maskedWordB;
int intHashB = (int) (hashB ^ (hashB >> 32));
intHashB = intHashB ^ (intHashB >> 17);
- maskC = getDelimiterMask(delimiterWordC);
+ long maskC = getDelimiterMask(currentWordC);
while (maskC == 0) {
- hashC ^= delimiterWordC;
+ hashC ^= currentWordC;
cursorC += 8;
- delimiterWordC = UNSAFE.getLong(cursorC);
- maskC = getDelimiterMask(delimiterWordC);
+ currentWordC = UNSAFE.getLong(cursorC);
+ maskC = getDelimiterMask(currentWordC);
}
final int delimiterByteC = Long.numberOfTrailingZeros(maskC);
final long semicolonC = cursorC + (delimiterByteC >> 3);
- final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1);
+ long temperatureWordC = UNSAFE.getLong(semicolonC + 1);
+ final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8;
hashC ^= maskedWordC;
int intHashC = (int) (hashC ^ (hashC >> 32));
intHashC = intHashC ^ (intHashC >> 17);
- maskD = getDelimiterMask(delimiterWordD);
+ long maskD = getDelimiterMask(currentWordD);
while (maskD == 0) {
- hashD ^= delimiterWordD;
+ hashD ^= currentWordD;
cursorD += 8;
- delimiterWordD = UNSAFE.getLong(cursorD);
- maskD = getDelimiterMask(delimiterWordD);
+ currentWordD = UNSAFE.getLong(cursorD);
+ maskD = getDelimiterMask(currentWordD);
}
final int delimiterByteD = Long.numberOfTrailingZeros(maskD);
final long semicolonD = cursorD + (delimiterByteD >> 3);
- final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1);
+ long temperatureWordD = UNSAFE.getLong(semicolonD + 1);
+ final long maskedWordD = currentWordD & ((maskD - 1) ^ maskD) >>> 8;
hashD ^= maskedWordD;
int intHashD = (int) (hashD ^ (hashD >> 32));
intHashD = intHashD ^ (intHashD >> 17);
@@ -431,51 +407,74 @@ public class CalculateAverage_jerrinot {
long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC);
long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD);
- cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA);
- cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB);
- cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC);
- cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD);
+ cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA, temperatureWordA);
+ cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB, temperatureWordB);
+ cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC, temperatureWordC);
+ cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD, temperatureWordD);
}
doTail();
}
private long getOrCreateEntryBaseOffset(long semicolonA, long startA, int intHashA, long maskedWordA) {
- int lenA = (int) (semicolonA - startA);
+ // hashSet.add(intHashA);
+ long lenLong = semicolonA - startA;
+ int lenA = (int) lenLong;
+
+ // assert lenA != 0;
+ // byte[] nameArr = new byte[lenA];
+ // for (int i = 0; i < lenA; i++) {
+ // nameArr[i] = UNSAFE.getByte(startA + i);
+ // }
+ // String nameStr = new String(nameArr);
+ // Integer oldHash = nameToHash.put(nameStr, intHashA);
+ // assert oldHash == null || oldHash == intHashA : "name: " + nameStr + ", old hash = " + oldHash + ", new hash = " + intHashA;
+
long mapIndexA = intHashA & MAP_MASK;
+ // long clusterLen = 0;
for (;;) {
long basePtr = mapIndexA * MAP_ENTRY_SIZE_BYTES + map;
long lenPtr = basePtr + LEN_OFFSET;
int len = UNSAFE.getInt(lenPtr);
- if (len == 0) {
+ if (len == lenA) {
+ if (nameMatch(startA, maskedWordA, basePtr, lenLong)) {
+ // if (clusterLen > maxClusterLen) {
+ // maxClusterLen = clusterLen;
+ // System.out.println("max cluster len: " + clusterLen);
+ // }
+ return basePtr;
+ }
+ }
+ else if (len == 0) {
// todo: uncommon branch maybe?
// empty slot
UNSAFE.copyMemory(semicolonA - lenA, basePtr + NAME_OFFSET, lenA);
UNSAFE.putInt(lenPtr, lenA);
+ // todo: this could be a single putLong()
UNSAFE.putInt(basePtr + MAX_OFFSET, Integer.MIN_VALUE);
UNSAFE.putInt(basePtr + MIN_OFFSET, Integer.MAX_VALUE);
return basePtr;
}
- if (len == lenA) {
- boolean match = true;
- long namePtr = basePtr + NAME_OFFSET;
- int fullLen = (len >> 3) << 3;
- long offset;
- // todo: this is worth exploring further.
- // @mtopolnik has an interesting algo with 2 unconditioned long loads: this is sufficient
- // for majority of names. so we would be left with just a single branch which is almost never taken?
- for (offset = 0; offset < fullLen; offset += 8) {
- match &= (UNSAFE.getLong(startA + offset) == UNSAFE.getLong(namePtr + offset));
- }
-
- long maskedWordInMap = UNSAFE.getLong(namePtr + offset);
- match &= (maskedWordInMap == maskedWordA);
+ mapIndexA = ++mapIndexA & MAP_MASK;
+ // clusterLen++;
+ }
+ }
- if (match) {
- return basePtr;
- }
+ private static boolean nameMatch(long startA, long maskedWordA, long basePtr, long len) {
+ long namePtr = basePtr + NAME_OFFSET;
+ long fullLen = len & ~7L;
+ long offset;
+
+ // todo: this is worth exploring further.
+ // @mtopolnik has an interesting algo with 2 unconditioned long loads: this is sufficient
+ // for majority of names. so we would be left with just a single branch which is almost never taken?
+ for (offset = 0; offset < fullLen; offset += 8) {
+ if (UNSAFE.getLong(startA + offset) != UNSAFE.getLong(namePtr + offset)) {
+ return false;
}
- mapIndexA = ++mapIndexA & MAP_MASK;
}
+
+ long maskedWordInMap = UNSAFE.getLong(namePtr + fullLen);
+ return (maskedWordInMap == maskedWordA);
}
}