aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
authorJaromir Hamala <jaromir.hamala@gmail.com>2024-01-28 11:34:28 +0100
committerGitHub <noreply@github.com>2024-01-28 11:34:28 +0100
commitd9ab36a241e4f38e404b5fd5f92de86337dd459c (patch)
treeed86c57026e00c9f23e5a05da0198e0b25de8e16 /src/main/java/dev/morling/onebrc
parenta6cd83fc9817de787591d27b1fe5d6527bb3aebd (diff)
jerrinot's improvement (#607)
* some random changes with minimal, if any, effect * use munmap() trick credit: thomaswue * some smaller tweaks * use native image
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java278
1 files changed, 167 insertions, 111 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
index 2492c0f..36e3182 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
@@ -18,6 +18,7 @@ package dev.morling.onebrc;
import sun.misc.Unsafe;
import java.io.File;
+import java.io.IOException;
import java.io.RandomAccessFile;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
@@ -54,9 +55,29 @@ public class CalculateAverage_jerrinot {
}
public static void main(String[] args) throws Exception {
+ // credits for spawning new workers: thomaswue
+ if (args.length == 0 || !("--worker".equals(args[0]))) {
+ spawnWorker();
+ return;
+ }
calculate();
}
+ private static void spawnWorker() throws IOException {
+ ProcessHandle.Info info = ProcessHandle.current().info();
+ ArrayList<String> workerCommand = new ArrayList<>();
+ info.command().ifPresent(workerCommand::add);
+ info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
+ workerCommand.add("--worker");
+ new ProcessBuilder()
+ .command(workerCommand)
+ .inheritIO()
+ .redirectOutput(ProcessBuilder.Redirect.PIPE)
+ .start()
+ .getInputStream()
+ .transferTo(System.out);
+ }
+
static void calculate() throws Exception {
final File file = new File(MEASUREMENTS_TXT);
final long length = file.length();
@@ -140,6 +161,7 @@ public class CalculateAverage_jerrinot {
}
sb.append('}');
System.out.println(sb);
+ System.out.close();
}
public static int ceilPow2(int i) {
@@ -187,7 +209,7 @@ public class CalculateAverage_jerrinot {
private static final int SLOW_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * SLOW_MAP_ENTRY_SIZE_BYTES;
private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES;
private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES;
- private static final long MAP_MASK = MAPS_SLOT_COUNT - 1;
+ private static final int MAP_MASK = MAPS_SLOT_COUNT - 1;
private long slowMap;
private long slowMapNamesPtr;
@@ -281,9 +303,9 @@ public class CalculateAverage_jerrinot {
doOne(cursorC, endC);
transferToHeap();
- UNSAFE.freeMemory(fastMap);
- UNSAFE.freeMemory(slowMap);
- UNSAFE.freeMemory(slowMapNamesLo);
+ // UNSAFE.freeMemory(fastMap);
+ // UNSAFE.freeMemory(slowMap);
+ // UNSAFE.freeMemory(slowMapNamesLo);
}
private void transferToHeap() {
@@ -339,11 +361,11 @@ public class CalculateAverage_jerrinot {
long mask = getDelimiterMask(currentWord);
long firstWordMask = ((mask - 1) ^ mask) >>> 8;
final long isMaskZeroA = ((mask | -mask) >>> 63) ^ 1;
- long ext = -isMaskZeroA & 0xFF00_0000_0000_0000L;
+ long ext = -isMaskZeroA;
firstWordMask |= ext;
long maskedFirstWord = currentWord & firstWordMask;
- long hash = hash(maskedFirstWord);
+ int hash = hash(maskedFirstWord);
while (mask == 0) {
cursor += 8;
currentWord = UNSAFE.getLong(cursor);
@@ -353,22 +375,22 @@ public class CalculateAverage_jerrinot {
final long semicolon = cursor + (delimiterByte >> 3);
final long maskedWord = currentWord & ((mask - 1) ^ mask) >>> 8;
- long len = semicolon - start;
- long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, (int) hash, maskedWord);
+ int len = (int) (semicolon - start);
+ long baseEntryPtr = getOrCreateEntryBaseOffsetSlow(len, start, hash, maskedWord);
long temperatureWord = UNSAFE.getLong(semicolon + 1);
cursor = parseAndStoreTemperature(semicolon + 1, baseEntryPtr, temperatureWord);
}
}
- private static long hash(long word1) {
+ private static int hash(long word) {
// credit: mtopolnik
long seed = 0x51_7c_c1_b7_27_22_0a_95L;
int rotDist = 17;
-
- long hash = word1;
+ //
+ long hash = word;
hash *= seed;
hash = Long.rotateLeft(hash, rotDist);
- return hash;
+ return (int) hash;
}
@Override
@@ -382,69 +404,87 @@ public class CalculateAverage_jerrinot {
UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0);
while (cursorA < endA && cursorB < endB && cursorC < endC) {
+ long currentWordA = UNSAFE.getLong(cursorA);
+ long currentWordB = UNSAFE.getLong(cursorB);
+ long currentWordC = UNSAFE.getLong(cursorC);
+
long startA = cursorA;
long startB = cursorB;
long startC = cursorC;
- long currentWordA = UNSAFE.getLong(startA);
- long currentWordB = UNSAFE.getLong(startB);
- long currentWordC = UNSAFE.getLong(startC);
-
long maskA = getDelimiterMask(currentWordA);
long maskB = getDelimiterMask(currentWordB);
long maskC = getDelimiterMask(currentWordC);
- long firstWordMaskA = (maskA ^ (maskA - 1)) >>> 8;
- long firstWordMaskB = (maskB ^ (maskB - 1)) >>> 8;
- long firstWordMaskC = (maskC ^ (maskC - 1)) >>> 8;
-
- final long isMaskZeroA = ((maskA | -maskA) >>> 63) ^ 1;
- final long isMaskZeroB = ((maskB | -maskB) >>> 63) ^ 1;
- final long isMaskZeroC = ((maskC | -maskC) >>> 63) ^ 1;
-
- long extA = -isMaskZeroA & 0xFF00_0000_0000_0000L;
- long extB = -isMaskZeroB & 0xFF00_0000_0000_0000L;
- long extC = -isMaskZeroC & 0xFF00_0000_0000_0000L;
-
- firstWordMaskA |= extA;
- firstWordMaskB |= extB;
- firstWordMaskC |= extC;
-
- long maskedFirstWordA = currentWordA & firstWordMaskA;
- long maskedFirstWordB = currentWordB & firstWordMaskB;
- long maskedFirstWordC = currentWordC & firstWordMaskC;
-
- // assertMasks(isMaskZeroA, maskA);
-
- long hashA = hash(maskedFirstWordA);
- long hashB = hash(maskedFirstWordB);
- long hashC = hash(maskedFirstWordC);
-
- cursorA += isMaskZeroA * 8;
- cursorB += isMaskZeroB * 8;
- cursorC += isMaskZeroC * 8;
-
- currentWordA = UNSAFE.getLong(cursorA);
- currentWordB = UNSAFE.getLong(cursorB);
- currentWordC = UNSAFE.getLong(cursorC);
+ long maskComplementA = -maskA;
+ long maskComplementB = -maskB;
+ long maskComplementC = -maskC;
+
+ long maskWithDelimiterA = (maskA ^ (maskA - 1));
+ long maskWithDelimiterB = (maskB ^ (maskB - 1));
+ long maskWithDelimiterC = (maskC ^ (maskC - 1));
+
+ long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1);
+ long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1);
+ long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1);
+
+ cursorA += isMaskZeroA << 3;
+ cursorB += isMaskZeroB << 3;
+ cursorC += isMaskZeroC << 3;
+
+ long nextWordA = UNSAFE.getLong(cursorA);
+ long nextWordB = UNSAFE.getLong(cursorB);
+ long nextWordC = UNSAFE.getLong(cursorC);
+
+ long firstWordMaskA = maskWithDelimiterA >>> 8;
+ long firstWordMaskB = maskWithDelimiterB >>> 8;
+ long firstWordMaskC = maskWithDelimiterC >>> 8;
+
+ long nextMaskA = getDelimiterMask(nextWordA);
+ long nextMaskB = getDelimiterMask(nextWordB);
+ long nextMaskC = getDelimiterMask(nextWordC);
+
+ boolean slowA = nextMaskA == 0;
+ boolean slowB = nextMaskB == 0;
+ boolean slowC = nextMaskC == 0;
+ boolean slowSome = (slowA || slowB || slowC);
+
+ long extA = -isMaskZeroA;
+ long extB = -isMaskZeroB;
+ long extC = -isMaskZeroC;
+
+ long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA;
+ long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB;
+ long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC;
+
+ int hashA = hash(maskedFirstWordA);
+ int hashB = hash(maskedFirstWordB);
+ int hashC = hash(maskedFirstWordC);
+
+ currentWordA = nextWordA;
+ currentWordB = nextWordB;
+ currentWordC = nextWordC;
+
+ maskA = nextMaskA;
+ maskB = nextMaskB;
+ maskC = nextMaskC;
+ if (slowSome) {
+ while (maskA == 0) {
+ cursorA += 8;
+ currentWordA = UNSAFE.getLong(cursorA);
+ maskA = getDelimiterMask(currentWordA);
+ }
- maskA = getDelimiterMask(currentWordA);
- while (maskA == 0) {
- cursorA += 8;
- currentWordA = UNSAFE.getLong(cursorA);
- maskA = getDelimiterMask(currentWordA);
- }
- maskB = getDelimiterMask(currentWordB);
- while (maskB == 0) {
- cursorB += 8;
- currentWordB = UNSAFE.getLong(cursorB);
- maskB = getDelimiterMask(currentWordB);
- }
- maskC = getDelimiterMask(currentWordC);
- while (maskC == 0) {
- cursorC += 8;
- currentWordC = UNSAFE.getLong(cursorC);
- maskC = getDelimiterMask(currentWordC);
+ while (maskB == 0) {
+ cursorB += 8;
+ currentWordB = UNSAFE.getLong(cursorB);
+ maskB = getDelimiterMask(currentWordB);
+ }
+ while (maskC == 0) {
+ cursorC += 8;
+ currentWordC = UNSAFE.getLong(cursorC);
+ maskC = getDelimiterMask(currentWordC);
+ }
}
final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
@@ -458,40 +498,57 @@ public class CalculateAverage_jerrinot {
long digitStartA = semicolonA + 1;
long digitStartB = semicolonB + 1;
long digitStartC = semicolonC + 1;
+
long temperatureWordA = UNSAFE.getLong(digitStartA);
long temperatureWordB = UNSAFE.getLong(digitStartB);
long temperatureWordC = UNSAFE.getLong(digitStartC);
- final long maskedWordA = currentWordA & ((maskA - 1) ^ maskA) >>> 8;
- final long maskedWordB = currentWordB & ((maskB - 1) ^ maskB) >>> 8;
- final long maskedWordC = currentWordC & ((maskC - 1) ^ maskC) >>> 8;
+ long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8;
+ long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8;
+ long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8;
- long lenA = semicolonA - startA;
- long lenB = semicolonB - startB;
- long lenC = semicolonC - startC;
+ final long maskedLastWordA = currentWordA & lastWordMaskA;
+ final long maskedLastWordB = currentWordB & lastWordMaskB;
+ final long maskedLastWordC = currentWordC & lastWordMaskC;
- long baseEntryPtrA;
- if (lenA > 15) {
- baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, (int) hashA, maskedWordA);
- }
- else {
- baseEntryPtrA = getOrCreateEntryBaseOffsetFast(lenA, (int) hashA, maskedWordA, maskedFirstWordA);
- }
+ int lenA = (int) (semicolonA - startA);
+ int lenB = (int) (semicolonB - startB);
+ int lenC = (int) (semicolonC - startC);
- long baseEntryPtrB;
- if (lenB > 15) {
- baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, (int) hashB, maskedWordB);
- }
- else {
- baseEntryPtrB = getOrCreateEntryBaseOffsetFast(lenB, (int) hashB, maskedWordB, maskedFirstWordB);
- }
+ int mapIndexA = hashA & MAP_MASK;
+ int mapIndexB = hashB & MAP_MASK;
+ int mapIndexC = hashC & MAP_MASK;
+ long baseEntryPtrA;
+ long baseEntryPtrB;
long baseEntryPtrC;
- if (lenC > 15) {
- baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, (int) hashC, maskedWordC);
+
+ if (slowSome) {
+ if (slowA) {
+ baseEntryPtrA = getOrCreateEntryBaseOffsetSlow(lenA, startA, hashA, maskedLastWordA);
+ }
+ else {
+ baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
+ }
+
+ if (slowB) {
+ baseEntryPtrB = getOrCreateEntryBaseOffsetSlow(lenB, startB, hashB, maskedLastWordB);
+ }
+ else {
+ baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
+ }
+
+ if (slowC) {
+ baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC);
+ }
+ else {
+ baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
+ }
}
else {
- baseEntryPtrC = getOrCreateEntryBaseOffsetFast(lenC, (int) hashC, maskedWordC, maskedFirstWordC);
+ baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA);
+ baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB);
+ baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC);
}
cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA);
@@ -502,36 +559,35 @@ public class CalculateAverage_jerrinot {
// System.out.println("Longest chain: " + longestChain);
}
- private long getOrCreateEntryBaseOffsetFast(long lenLong, int hash, long maskedLastWord, long maskedFirstWord) {
- int lenA = (int) lenLong;
- long mapIndexA = hash & MAP_MASK;
+ private long getOrCreateEntryBaseOffsetFast(int mapIndexA, int lenA, long maskedLastWord, long maskedFirstWord) {
for (;;) {
long basePtr = mapIndexA * FAST_MAP_ENTRY_SIZE_BYTES + fastMap;
+ long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1);
+ long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2);
+ if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) {
+ return basePtr;
+ }
long lenPtr = basePtr + MAP_LEN_OFFSET;
int len = UNSAFE.getInt(lenPtr);
- if (len == lenA) {
- long namePart1 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART1);
- long namePart2 = UNSAFE.getLong(basePtr + FAST_MAP_NAME_PART2);
- if (namePart1 == maskedFirstWord && namePart2 == maskedLastWord) {
- return basePtr;
- }
- }
- else if (len == 0) {
- UNSAFE.putInt(lenPtr, lenA);
- // todo: this could be a single putLong()
- UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
- UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
- UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART1, maskedFirstWord);
- UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord);
- return basePtr;
+ if (len == 0) {
+ return newEntryFast(lenA, maskedLastWord, maskedFirstWord, lenPtr, basePtr);
}
mapIndexA = ++mapIndexA & MAP_MASK;
}
}
- private long getOrCreateEntryBaseOffsetSlow(long lenLong, long startPtr, int hash, long maskedLastWord) {
- long fullLen = lenLong & ~7L;
- int lenA = (int) lenLong;
+ private static long newEntryFast(int lenA, long maskedLastWord, long maskedFirstWord, long lenPtr, long basePtr) {
+ UNSAFE.putInt(lenPtr, lenA);
+ // todo: this could be a single putLong()
+ UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
+ UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
+ UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART1, maskedFirstWord);
+ UNSAFE.putLong(basePtr + FAST_MAP_NAME_PART2, maskedLastWord);
+ return basePtr;
+ }
+
+ private long getOrCreateEntryBaseOffsetSlow(int lenA, long startPtr, int hash, long maskedLastWord) {
+ long fullLen = lenA & ~7L;
long mapIndexA = hash & MAP_MASK;
for (;;) {
long basePtr = mapIndexA * SLOW_MAP_ENTRY_SIZE_BYTES + slowMap;
@@ -550,7 +606,7 @@ public class CalculateAverage_jerrinot {
UNSAFE.putInt(basePtr + MAP_MAX_OFFSET, Integer.MIN_VALUE);
UNSAFE.putInt(basePtr + MAP_MIN_OFFSET, Integer.MAX_VALUE);
UNSAFE.copyMemory(startPtr, slowMapNamesPtr, lenA);
- long alignedLen = (lenLong & ~7L) + 8;
+ long alignedLen = (lenA & ~7L) + 8;
slowMapNamesPtr += alignedLen;
return basePtr;
}