aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
authorJaromir Hamala <jaromir.hamala@gmail.com>2024-01-15 18:55:22 +0100
committerGitHub <noreply@github.com>2024-01-15 18:55:22 +0100
commitdbdd89a84779761ca092e5aaeb6f6e92394a422d (patch)
tree0ea9c89a013a2efa42b8b658fe8e018e0a7c3901 /src/main/java
parent677d94e5cf9769d02843edf06686abfea7112d1d (diff)
jerrinot's initial submission (#424)
* initial version let's exploit that superscalar beauty! * give credits where credits is due also: added ideas I don't want to forget
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java482
1 files changed, 482 insertions, 0 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
new file mode 100644
index 0000000..6fb89bb
--- /dev/null
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
@@ -0,0 +1,482 @@
+/*
+ * Copyright 2023 The original authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package dev.morling.onebrc;
+
+import sun.misc.Unsafe;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.lang.foreign.Arena;
+import java.lang.reflect.Field;
+import java.nio.channels.FileChannel.MapMode;
+import java.util.Map;
+import java.util.TreeMap;
+
+public class CalculateAverage_jerrinot {
+ private static final Unsafe UNSAFE = unsafe();
+ private static final String MEASUREMENTS_TXT = "measurements.txt";
+ // todo: with hyper-threading enable we would be better of with availableProcessors / 2;
+ // todo: validate the testing env. params.
+ private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors();
+ private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;
+
+ private static Unsafe unsafe() {
+ try {
+ Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ theUnsafe.setAccessible(true);
+ return (Unsafe) theUnsafe.get(Unsafe.class);
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ calculate();
+ }
+
+ static void calculate() throws Exception {
+ final File file = new File(MEASUREMENTS_TXT);
+ final long length = file.length();
+ // final int chunkCount = Runtime.getRuntime().availableProcessors();
+ int chunkPerThread = 4;
+ final int chunkCount = THREAD_COUNT * chunkPerThread;
+ final var chunkStartOffsets = new long[chunkCount + 1];
+ try (var raf = new RandomAccessFile(file, "r")) {
+ // credit - chunking code: mtopolnik
+ final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address();
+ for (int i = 1; i < chunkStartOffsets.length - 1; i++) {
+ var start = length * i / (chunkStartOffsets.length - 1);
+ raf.seek(start);
+ while (raf.read() != (byte) '\n') {
+ }
+ start = raf.getFilePointer();
+ chunkStartOffsets[i] = start + inputBase;
+ }
+ chunkStartOffsets[0] = inputBase;
+ chunkStartOffsets[chunkCount] = inputBase + length;
+
+ Processor[] processors = new Processor[THREAD_COUNT];
+ Thread[] threads = new Thread[THREAD_COUNT];
+
+ for (int i = 0; i < THREAD_COUNT; i++) {
+ long startA = chunkStartOffsets[i * chunkPerThread];
+ long endA = chunkStartOffsets[i * chunkPerThread + 1];
+ long startB = chunkStartOffsets[i * chunkPerThread + 1];
+ long endB = chunkStartOffsets[i * chunkPerThread + 2];
+ long startC = chunkStartOffsets[i * chunkPerThread + 2];
+ long endC = chunkStartOffsets[i * chunkPerThread + 3];
+ long startD = chunkStartOffsets[i * chunkPerThread + 3];
+ long endD = chunkStartOffsets[i * chunkPerThread + 4];
+
+ Processor processor = new Processor(startA, endA, startB, endB, startC, endC, startD, endD);
+ processors[i] = processor;
+ Thread thread = new Thread(processor);
+ threads[i] = thread;
+ thread.start();
+ }
+
+ var accumulator = new TreeMap<String, Processor.StationStats>();
+ for (int i = 0; i < THREAD_COUNT; i++) {
+ Thread t = threads[i];
+ t.join();
+ processors[i].accumulateStatus(accumulator);
+ }
+
+ var sb = new StringBuilder();
+ boolean first = true;
+ for (Map.Entry<String, Processor.StationStats> statsEntry : accumulator.entrySet()) {
+ if (first) {
+ sb.append("{");
+ first = false;
+ }
+ else {
+ sb.append(", ");
+ }
+ var value = statsEntry.getValue();
+ var name = statsEntry.getKey();
+ int min = value.min;
+ int max = value.max;
+ int count = value.count;
+ long sum2 = value.sum;
+ sb.append(String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum2 / count) / 10.0, max / 10.0));
+ }
+ System.out.print(sb);
+ System.out.println('}');
+ }
+ }
+
+ public static int ceilPow2(int i) {
+ i--;
+ i |= i >> 1;
+ i |= i >> 2;
+ i |= i >> 4;
+ i |= i >> 8;
+ i |= i >> 16;
+ return i + 1;
+ }
+
+ private static class Processor implements Runnable {
+ private static final int MAP_SLOT_COUNT = ceilPow2(10000);
+ private static final int STATION_MAX_NAME_BYTES = 104;
+
+ private static final long COUNT_OFFSET = 0;
+ private static final long MIN_OFFSET = 4;
+ private static final long MAX_OFFSET = 8;
+ private static final long SUM_OFFSET = 12;
+ private static final long LEN_OFFSET = 20;
+ private static final long NAME_OFFSET = 24;
+
+ private static final int MAP_ENTRY_SIZE_BYTES = +Integer.BYTES // count // 0
+ + Integer.BYTES // min // +4
+ + Integer.BYTES // max // +8
+ + Long.BYTES // sum // +12
+ + Integer.BYTES // station name len // +20
+ + STATION_MAX_NAME_BYTES; // +24
+
+ private static final int MAP_SIZE_BYTES = MAP_SLOT_COUNT * MAP_ENTRY_SIZE_BYTES;
+ private static final long MAP_MASK = MAP_SLOT_COUNT - 1;
+
+ // todo: some fields could probably be converted to locals
+
+ private final long map;
+
+ private long cursorA;
+ private long endA;
+ private long cursorB;
+ private long endB;
+ private long cursorC;
+ private long endC;
+ private long cursorD;
+ private long endD;
+ private long maskA;
+ private long maskB;
+ private long maskC;
+ private long maskD;
+
+ // credit: merykitty
+ private long parseAndStoreTemperature(long startCursor, long baseEntryPtr) {
+ long word = UNSAFE.getLong(startCursor);
+ final long negateda = ~word;
+ final int dotPos = Long.numberOfTrailingZeros(negateda & 0x10101000);
+ final long signed = (negateda << 59) >> 63;
+ final long removeSignMask = ~(signed & 0xFF);
+ final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
+ final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ final int temperature = (int) ((absValue ^ signed) - signed);
+
+ long countPtr = baseEntryPtr + COUNT_OFFSET;
+ long minPtr = baseEntryPtr + MIN_OFFSET;
+ long maxPtr = baseEntryPtr + MAX_OFFSET;
+ long sumPtr = baseEntryPtr + SUM_OFFSET;
+
+ int min = UNSAFE.getInt(minPtr);
+ int max = UNSAFE.getInt(maxPtr);
+ long sum = UNSAFE.getLong(sumPtr);
+ // try if min/max intrinsics are paying off
+ // maybe braching is better? the branch is becoming more predictable with
+ // each new sample.
+ max = Math.max(max, temperature);
+ min = Math.min(min, temperature);
+ sum += temperature;
+ UNSAFE.putInt(countPtr, UNSAFE.getInt(countPtr) + 1);
+ UNSAFE.putInt(minPtr, min);
+ UNSAFE.putInt(maxPtr, max);
+ UNSAFE.putLong(sumPtr, sum);
+ return startCursor + (dotPos / 8) + 3;
+ }
+
+ private static long getDelimiterMask(final long word) {
+ // credit royvanrijn
+ final long match = word ^ SEPARATOR_PATTERN;
+ return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
+ }
+
+ // todo: immutability cost us in allocations, but that's probably peanuts in the grand scheme of things. still worth checking
+ // maybe JVM trusting Final in Records offsets it ..a test is needed
+ record StationStats(int min, int max, int count, long sum) {
+ }
+
+ void accumulateStatus(TreeMap<String, StationStats> accumulator) {
+ for (long baseAddress = map; baseAddress < map + MAP_SIZE_BYTES; baseAddress += MAP_ENTRY_SIZE_BYTES) {
+ long len = UNSAFE.getInt(baseAddress + LEN_OFFSET);
+ if (len == 0) {
+ continue;
+ }
+ byte[] nameArr = new byte[(int) len];
+ long baseNameAddr = baseAddress + NAME_OFFSET;
+ for (int i = 0; i < len; i++) {
+ nameArr[i] = UNSAFE.getByte(baseNameAddr + i);
+ }
+ String name = new String(nameArr);
+ int min = UNSAFE.getInt(baseAddress + MIN_OFFSET);
+ int max = UNSAFE.getInt(baseAddress + MAX_OFFSET);
+ int count = UNSAFE.getInt(baseAddress + COUNT_OFFSET);
+ long sum = UNSAFE.getLong(baseAddress + SUM_OFFSET);
+
+ // todo: lambdas bootstrap probably cost us
+ accumulator.compute(name, (_, v) -> {
+ if (v == null) {
+ return new StationStats(min, max, count, sum);
+ }
+ return new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum);
+ });
+ }
+ }
+
+ Processor(long startA, long endA, long startB, long endB, long startC, long endC, long startD, long endD) {
+ this.cursorA = startA;
+ this.cursorB = startB;
+ this.cursorC = startC;
+ this.cursorD = startD;
+ this.endA = endA;
+ this.endB = endB;
+ this.endC = endC;
+ this.endD = endD;
+ this.map = UNSAFE.allocateMemory(MAP_SIZE_BYTES);
+
+ int i;
+ for (i = 0; i < MAP_SIZE_BYTES; i += 8) {
+ UNSAFE.putLong(map + i, 0);
+ }
+ for (i = i - 8; i < MAP_SIZE_BYTES; i++) {
+ UNSAFE.putByte(map + i, (byte) 0);
+ }
+ }
+
+ private void doTail() {
+ // todo: we would be probably better of without all that code dup. ("compilers hates him!")
+ // System.out.println("done ILP");
+ while (cursorA < endA) {
+ long startA = cursorA;
+ long delimiterWordA = UNSAFE.getLong(cursorA);
+ long hashA = 0;
+ maskA = getDelimiterMask(delimiterWordA);
+ while (maskA == 0) {
+ hashA ^= delimiterWordA;
+ cursorA += 8;
+ delimiterWordA = UNSAFE.getLong(cursorA);
+ maskA = getDelimiterMask(delimiterWordA);
+ }
+ final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
+ final long semicolonA = cursorA + (delimiterByteA >> 3);
+ final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1);
+ hashA ^= maskedWordA;
+ int intHashA = (int) (hashA ^ (hashA >> 32));
+ intHashA = intHashA ^ (intHashA >> 17);
+
+ long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA);
+ cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA);
+ }
+ // System.out.println("done A");
+ while (cursorB < endB) {
+ long startB = cursorB;
+ long delimiterWordB = UNSAFE.getLong(cursorB);
+ long hashB = 0;
+ maskB = getDelimiterMask(delimiterWordB);
+ while (maskB == 0) {
+ hashB ^= delimiterWordB;
+ cursorB += 8;
+ delimiterWordB = UNSAFE.getLong(cursorB);
+ maskB = getDelimiterMask(delimiterWordB);
+ }
+ final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
+ final long semicolonB = cursorB + (delimiterByteB >> 3);
+ final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1);
+ hashB ^= maskedWordB;
+ int intHashB = (int) (hashB ^ (hashB >> 32));
+ intHashB = intHashB ^ (intHashB >> 17);
+
+ long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, intHashB, maskedWordB);
+ cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB);
+ }
+ // System.out.println("done B");
+ while (cursorC < endC) {
+ long startC = cursorC;
+ long delimiterWordC = UNSAFE.getLong(cursorC);
+ long hashC = 0;
+ maskC = getDelimiterMask(delimiterWordC);
+ while (maskC == 0) {
+ hashC ^= delimiterWordC;
+ cursorC += 8;
+ delimiterWordC = UNSAFE.getLong(cursorC);
+ maskC = getDelimiterMask(delimiterWordC);
+ }
+ final int delimiterByteC = Long.numberOfTrailingZeros(maskC);
+ final long semicolonC = cursorC + (delimiterByteC >> 3);
+ final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1);
+ hashC ^= maskedWordC;
+ int intHashC = (int) (hashC ^ (hashC >> 32));
+ intHashC = intHashC ^ (intHashC >> 17);
+
+ long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC);
+ cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC);
+ }
+ // System.out.println("done C");
+ while (cursorD < endD) {
+ long startD = cursorD;
+ long delimiterWordD = UNSAFE.getLong(cursorD);
+ long hashD = 0;
+ maskD = getDelimiterMask(delimiterWordD);
+ while (maskD == 0) {
+ hashD ^= delimiterWordD;
+ cursorD += 8;
+ delimiterWordD = UNSAFE.getLong(cursorD);
+ maskD = getDelimiterMask(delimiterWordD);
+ }
+ final int delimiterByteD = Long.numberOfTrailingZeros(maskD);
+ final long semicolonD = cursorD + (delimiterByteD >> 3);
+ final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1);
+ hashD ^= maskedWordD;
+ int intHashD = (int) (hashD ^ (hashD >> 32));
+ intHashD = intHashD ^ (intHashD >> 17);
+
+ long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD);
+ cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD);
+ }
+ // System.out.println("done D");
+ }
+
+ @Override
+ public void run() {
+ while (cursorA < endA && cursorB < endB && cursorC < endC && cursorD < endD) {
+ // todo: experiment with different inter-leaving
+ long startA = cursorA;
+ long startB = cursorB;
+ long startC = cursorC;
+ long startD = cursorD;
+
+ long delimiterWordA = UNSAFE.getLong(cursorA);
+ long delimiterWordB = UNSAFE.getLong(cursorB);
+ long delimiterWordC = UNSAFE.getLong(cursorC);
+ long delimiterWordD = UNSAFE.getLong(cursorD);
+
+ long hashA = 0;
+ long hashB = 0;
+ long hashC = 0;
+ long hashD = 0;
+
+ // credits for the hashing idea: royvanrijn
+ maskA = getDelimiterMask(delimiterWordA);
+ while (maskA == 0) {
+ hashA ^= delimiterWordA;
+ cursorA += 8;
+ delimiterWordA = UNSAFE.getLong(cursorA);
+ maskA = getDelimiterMask(delimiterWordA);
+ }
+ final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
+ final long semicolonA = cursorA + (delimiterByteA >> 3);
+ final long maskedWordA = delimiterWordA & ((maskA >>> 7) - 1);
+ hashA ^= maskedWordA;
+ int intHashA = (int) (hashA ^ (hashA >> 32));
+ intHashA = intHashA ^ (intHashA >> 17);
+
+ maskB = getDelimiterMask(delimiterWordB);
+ while (maskB == 0) {
+ hashB ^= delimiterWordB;
+ cursorB += 8;
+ delimiterWordB = UNSAFE.getLong(cursorB);
+ maskB = getDelimiterMask(delimiterWordB);
+ }
+ final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
+ final long semicolonB = cursorB + (delimiterByteB >> 3);
+ final long maskedWordB = delimiterWordB & ((maskB >>> 7) - 1);
+ hashB ^= maskedWordB;
+ int intHashB = (int) (hashB ^ (hashB >> 32));
+ intHashB = intHashB ^ (intHashB >> 17);
+
+ maskC = getDelimiterMask(delimiterWordC);
+ while (maskC == 0) {
+ hashC ^= delimiterWordC;
+ cursorC += 8;
+ delimiterWordC = UNSAFE.getLong(cursorC);
+ maskC = getDelimiterMask(delimiterWordC);
+ }
+ final int delimiterByteC = Long.numberOfTrailingZeros(maskC);
+ final long semicolonC = cursorC + (delimiterByteC >> 3);
+ final long maskedWordC = delimiterWordC & ((maskC >>> 7) - 1);
+ hashC ^= maskedWordC;
+ int intHashC = (int) (hashC ^ (hashC >> 32));
+ intHashC = intHashC ^ (intHashC >> 17);
+
+ maskD = getDelimiterMask(delimiterWordD);
+ while (maskD == 0) {
+ hashD ^= delimiterWordD;
+ cursorD += 8;
+ delimiterWordD = UNSAFE.getLong(cursorD);
+ maskD = getDelimiterMask(delimiterWordD);
+ }
+ final int delimiterByteD = Long.numberOfTrailingZeros(maskD);
+ final long semicolonD = cursorD + (delimiterByteD >> 3);
+ final long maskedWordD = delimiterWordD & ((maskD >>> 7) - 1);
+ hashD ^= maskedWordD;
+ int intHashD = (int) (hashD ^ (hashD >> 32));
+ intHashD = intHashD ^ (intHashD >> 17);
+
+ long baseEntryPtrA = getOrCreateEntryBaseOffset(semicolonA, startA, intHashA, maskedWordA);
+ long baseEntryPtrB = getOrCreateEntryBaseOffset(semicolonB, startB, intHashB, maskedWordB);
+ long baseEntryPtrC = getOrCreateEntryBaseOffset(semicolonC, startC, intHashC, maskedWordC);
+ long baseEntryPtrD = getOrCreateEntryBaseOffset(semicolonD, startD, intHashD, maskedWordD);
+
+ cursorA = parseAndStoreTemperature(semicolonA + 1, baseEntryPtrA);
+ cursorB = parseAndStoreTemperature(semicolonB + 1, baseEntryPtrB);
+ cursorC = parseAndStoreTemperature(semicolonC + 1, baseEntryPtrC);
+ cursorD = parseAndStoreTemperature(semicolonD + 1, baseEntryPtrD);
+ }
+ doTail();
+ }
+
+ private long getOrCreateEntryBaseOffset(long semicolonA, long startA, int intHashA, long maskedWordA) {
+ int lenA = (int) (semicolonA - startA);
+ long mapIndexA = intHashA & MAP_MASK;
+ for (;;) {
+ long basePtr = mapIndexA * MAP_ENTRY_SIZE_BYTES + map;
+ long lenPtr = basePtr + LEN_OFFSET;
+ int len = UNSAFE.getInt(lenPtr);
+ if (len == 0) {
+ // todo: uncommon branch maybe?
+ // empty slot
+ UNSAFE.copyMemory(semicolonA - lenA, basePtr + NAME_OFFSET, lenA);
+ UNSAFE.putInt(lenPtr, lenA);
+ UNSAFE.putInt(basePtr + MAX_OFFSET, Integer.MIN_VALUE);
+ UNSAFE.putInt(basePtr + MIN_OFFSET, Integer.MAX_VALUE);
+ return basePtr;
+ }
+ if (len == lenA) {
+ boolean match = true;
+ long namePtr = basePtr + NAME_OFFSET;
+ int fullLen = (len >> 3) << 3;
+ long offset;
+ // todo: this is worth exploring further.
+ // @mtopolnik has an interesting algo with 2 unconditioned long loads: this is sufficient
+ // for majority of names. so we would be left with just a single branch which is almost never taken?
+ for (offset = 0; offset < fullLen; offset += 8) {
+ match &= (UNSAFE.getLong(startA + offset) == UNSAFE.getLong(namePtr + offset));
+ }
+
+ long maskedWordInMap = UNSAFE.getLong(namePtr + offset);
+ match &= (maskedWordInMap == maskedWordA);
+
+ if (match) {
+ return basePtr;
+ }
+ }
+ mapIndexA = ++mapIndexA & MAP_MASK;
+ }
+ }
+ }
+
+}