aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarko Topolnik <marko.topolnik@gmail.com>2024-01-11 20:02:14 +0100
committerGitHub <noreply@github.com>2024-01-11 20:02:14 +0100
commit95459f56407c5c75e7ef08500eb711e5c94654ab (patch)
tree760a4ae40d9968e3e8377a86dcac59af201fd381 /src
parent8ec9ba861a50f80be7606d374ce88653a8fd0040 (diff)
Entry into the contest, calculate_average_mtopolnik.sh (#246)
* calculate_average_mtopolnik * short hash (just first 8 bytes of name) * Remove unneeded checks * Remove archiving classes * 2x larger hashtable * Add "set" to setters * Simplify parsing temperature, remove newline search * Reduce the size of the name slot * Store name length and use to detect collision * Reduce memory loads in parseTemperature * Use short for min/max * Extract constant for semicolon * Fix script header * Explicit bash shell in shebang * Inline usage of broadcast semicolon * Try vectorization * Remove vectorization * Go Unsafe * Use SWAR temperature parsing by merykitty * Inline some things * Remove commented-out MemorySegment usage * Inline namesMem.asSlice() invocation * Try out JVM JIT flags * Implement strcmp * Remove unused instance variables * Optimize hashing * Put station name into hashtable * Reorder method * Remove usage of MemorySegment.getUtf8String Replace with UNSAFE.copyMemory() and new String() * Fix hashing bug * Remove outdated comments * Fix informative constants * Use broadcastByte() more * Improve method naming * More hashing * Revert more hashing * Add commented-out code to hash 16 bytes * Slight cleanup * Align hashtable at cacheline boundary * Add Graal Native image * Revert Graal Native image This reverts commit d916a42326d89bd1a841bbbecfae185adb8679d7. * Simplify shell script (no SDK selection) * Move a constant, zero out hashtable on start * Better name comparison * Add prepare_mtopolnik.sh * Cleaner idiom in name comparison * AND instead of MOD for hashtable indexing * Improve word masking code * Fix formatting * Reduce memory loads * Remove endianness checks * Avoid hash == 0 problem * Fix subtle bug * MergeSort of parellel results * Touch up perf * Touch up perf * Remove -Xmx256m * Extract result printing method * Print allocation details on OOME * Single mmap * Use global allocation arena
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java530
1 files changed, 530 insertions, 0 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
new file mode 100644
index 0000000..fe487fc
--- /dev/null
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
@@ -0,0 +1,530 @@
+/*
+ * Copyright 2023 The original authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package dev.morling.onebrc;
+
+import sun.misc.Unsafe;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.lang.reflect.Field;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.channels.FileChannel.MapMode;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+public class CalculateAverage_mtopolnik {
+ private static final Unsafe UNSAFE = unsafe();
+ private static final int MAX_NAME_LEN = 100;
+ private static final int STATS_TABLE_SIZE = 1 << 16;
+ private static final int TABLE_INDEX_MASK = STATS_TABLE_SIZE - 1;
+ private static final String MEASUREMENTS_TXT = "measurements.txt";
+ private static final byte SEMICOLON = ';';
+ private static final long BROADCAST_SEMICOLON = broadcastByte(SEMICOLON);
+
+ // These two are just informative, I let the IDE calculate them for me
+ private static final long NATIVE_MEM_PER_THREAD = StatsAccessor.SIZEOF * STATS_TABLE_SIZE;
+ private static final long NATIVE_MEM_ON_8_THREADS = 8 * NATIVE_MEM_PER_THREAD;
+
+ private static Unsafe unsafe() {
+ try {
+ Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ theUnsafe.setAccessible(true);
+ return (Unsafe) theUnsafe.get(Unsafe.class);
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static class StationStats implements Comparable<StationStats> {
+ String name;
+ long sum;
+ int count;
+ int min;
+ int max;
+
+ @Override
+ public String toString() {
+ return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0);
+ }
+
+ @Override
+ public boolean equals(Object that) {
+ return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name);
+ }
+
+ @Override
+ public int compareTo(StationStats that) {
+ return name.compareTo(that.name);
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ calculate();
+ }
+
+ static void calculate() throws Exception {
+ final File file = new File(MEASUREMENTS_TXT);
+ final long length = file.length();
+ final int chunkCount = Runtime.getRuntime().availableProcessors();
+ final var results = new StationStats[chunkCount][];
+ final var chunkStartOffsets = new long[chunkCount];
+ try (var raf = new RandomAccessFile(file, "r")) {
+ final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address();
+ for (int i = 1; i < chunkStartOffsets.length; i++) {
+ var start = length * i / chunkStartOffsets.length;
+ raf.seek(start);
+ while (raf.read() != (byte) '\n') {
+ }
+ start = raf.getFilePointer();
+ chunkStartOffsets[i] = start;
+ }
+ var threads = new Thread[chunkCount];
+ for (int i = 0; i < chunkCount; i++) {
+ final long chunkStart = chunkStartOffsets[i];
+ final long chunkLimit = (i + 1 < chunkCount) ? chunkStartOffsets[i + 1] : length;
+ threads[i] = new Thread(new ChunkProcessor(inputBase + chunkStart, inputBase + chunkLimit, results, i));
+ }
+ for (var thread : threads) {
+ thread.start();
+ }
+ for (var thread : threads) {
+ thread.join();
+ }
+ }
+ mergeSortAndPrint(results);
+ }
+
+ private static class ChunkProcessor implements Runnable {
+ private static final long NAMEBUF_SIZE = 2 * Long.BYTES;
+ private static final int CACHELINE_SIZE = 64;
+
+ private final long inputBase;
+ private final long inputSize;
+ private final StationStats[][] results;
+ private final int myIndex;
+
+ private StatsAccessor stats;
+ private long nameBufBase;
+ private long cursor;
+
+ ChunkProcessor(long chunkStart, long chunkLimit, StationStats[][] results, int myIndex) {
+ this.inputBase = chunkStart;
+ this.inputSize = chunkLimit - chunkStart;
+ this.results = results;
+ this.myIndex = myIndex;
+ }
+
+ @Override
+ public void run() {
+ try (Arena confinedArena = Arena.ofConfined()) {
+ long totalAllocated = 0;
+ String threadName = Thread.currentThread().getName();
+ long statsByteSize = STATS_TABLE_SIZE * StatsAccessor.SIZEOF;
+ var diagnosticString = String.format("Thread %s needs %,d bytes, managed to allocate before OOM: ",
+ threadName, statsByteSize + NAMEBUF_SIZE);
+ try {
+ stats = new StatsAccessor(confinedArena.allocate(statsByteSize, CACHELINE_SIZE));
+ totalAllocated = statsByteSize;
+ nameBufBase = confinedArena.allocate(NAMEBUF_SIZE).address();
+ }
+ catch (OutOfMemoryError e) {
+ System.err.print(diagnosticString);
+ System.err.println(totalAllocated);
+ throw e;
+ }
+ processChunk();
+ exportResults();
+ }
+ }
+
+ private void processChunk() {
+ while (cursor < inputSize) {
+ long word1;
+ long word2;
+ if (cursor + 2 * Long.BYTES <= inputSize) {
+ word1 = UNSAFE.getLong(inputBase + cursor);
+ word2 = UNSAFE.getLong(inputBase + cursor + Long.BYTES);
+ }
+ else {
+ UNSAFE.putLong(nameBufBase, 0);
+ UNSAFE.putLong(nameBufBase + Long.BYTES, 0);
+ UNSAFE.copyMemory(inputBase + cursor, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor));
+ word1 = UNSAFE.getLong(nameBufBase);
+ word2 = UNSAFE.getLong(nameBufBase + Long.BYTES);
+ }
+ long posOfSemicolon = posOfSemicolon(word1, word2);
+ word1 = maskWord(word1, posOfSemicolon - cursor);
+ word2 = maskWord(word2, posOfSemicolon - cursor - Long.BYTES);
+ long hash = hash(word1);
+ long namePos = cursor;
+ long nameLen = posOfSemicolon - cursor;
+ assert nameLen <= 100 : "nameLen > 100";
+ int temperature = parseTemperatureAndAdvanceCursor(posOfSemicolon);
+ updateStats(hash, namePos, nameLen, word1, word2, temperature);
+ }
+ }
+
+ private void updateStats(long hash, long namePos, long nameLen, long nameWord1, long nameWord2, int temperature) {
+ int tableIndex = (int) (hash & TABLE_INDEX_MASK);
+ while (true) {
+ stats.gotoIndex(tableIndex);
+ if (stats.hash() == hash && stats.nameLen() == nameLen
+ && nameEquals(stats.nameAddress(), inputBase + namePos, nameLen, nameWord1, nameWord2)) {
+ stats.setSum(stats.sum() + temperature);
+ stats.setCount(stats.count() + 1);
+ stats.setMin((short) Integer.min(stats.min(), temperature));
+ stats.setMax((short) Integer.max(stats.max(), temperature));
+ return;
+ }
+ if (stats.nameLen() != 0) {
+ tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK;
+ continue;
+ }
+ stats.setHash(hash);
+ stats.setNameLen((int) nameLen);
+ stats.setSum(temperature);
+ stats.setCount(1);
+ stats.setMin((short) temperature);
+ stats.setMax((short) temperature);
+ UNSAFE.copyMemory(inputBase + namePos, stats.nameAddress(), nameLen);
+ return;
+ }
+ }
+
+ private int parseTemperatureAndAdvanceCursor(long semicolonPos) {
+ long startOffset = semicolonPos + 1;
+ if (startOffset <= inputSize - Long.BYTES) {
+ return parseTemperatureSwarAndAdvanceCursor(startOffset);
+ }
+ return parseTemperatureSimpleAndAdvanceCursor(startOffset);
+ }
+
+ // Credit: merykitty
+ private int parseTemperatureSwarAndAdvanceCursor(long startOffset) {
+ long word = UNSAFE.getLong(inputBase + startOffset);
+ final long negated = ~word;
+ final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000);
+ final long signed = (negated << 59) >> 63;
+ final long removeSignMask = ~(signed & 0xFF);
+ final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
+ final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ final int temperature = (int) ((absValue ^ signed) - signed);
+ cursor = startOffset + (dotPos / 8) + 3;
+ return temperature;
+ }
+
+ private int parseTemperatureSimpleAndAdvanceCursor(long startOffset) {
+ final byte minus = (byte) '-';
+ final byte zero = (byte) '0';
+ final byte dot = (byte) '.';
+
+ // Temperature plus the following newline is at least 4 chars, so this is always safe:
+ int fourCh = UNSAFE.getInt(inputBase + startOffset);
+ final int mask = 0xFF;
+ byte ch = (byte) (fourCh & mask);
+ int shift = 0;
+ int temperature;
+ int sign;
+ if (ch == minus) {
+ sign = -1;
+ shift += 8;
+ ch = (byte) ((fourCh & (mask << shift)) >>> shift);
+ }
+ else {
+ sign = 1;
+ }
+ temperature = ch - zero;
+ shift += 8;
+ ch = (byte) ((fourCh & (mask << shift)) >>> shift);
+ if (ch == dot) {
+ shift += 8;
+ ch = (byte) ((fourCh & (mask << shift)) >>> shift);
+ }
+ else {
+ temperature = 10 * temperature + (ch - zero);
+ shift += 16;
+ // The last character may be past the four loaded bytes, load it from memory.
+ // Checking that with another `if` is self-defeating for performance.
+ ch = UNSAFE.getByte(inputBase + startOffset + (shift / 8));
+ }
+ temperature = 10 * temperature + (ch - zero);
+ // `shift` holds the number of bits in the temperature field.
+ // A newline character follows the temperature, and so we advance
+ // the cursor past the newline to the start of the next line.
+ cursor = startOffset + (shift / 8) + 2;
+ return sign * temperature;
+ }
+
+ private static long hash(long word1) {
+ long seed = 0x51_7c_c1_b7_27_22_0a_95L;
+ int rotDist = 17;
+
+ long hash = word1;
+ hash *= seed;
+ hash = Long.rotateLeft(hash, rotDist);
+ // hash ^= word2;
+ // hash *= seed;
+ // hash = Long.rotateLeft(hash, rotDist);
+ return hash;
+ }
+
+ private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2) {
+ boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr);
+ boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES);
+ if (mismatch1 | mismatch2) {
+ return false;
+ }
+ for (int i = 2 * Long.BYTES; i < len; i++) {
+ if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static long maskWord(long word, long len) {
+ long halfShiftDistance = Long.max(0, Long.BYTES - len) << 2;
+ long mask = (~0L >>> halfShiftDistance) >>> halfShiftDistance; // avoid Java trap of shiftDist % 64
+ return word & mask;
+ }
+
+ private static final long BROADCAST_0x01 = broadcastByte(0x01);
+ private static final long BROADCAST_0x80 = broadcastByte(0x80);
+
+ // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/
+ // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169
+ long posOfSemicolon(long word1, long word2) {
+ long diff = word1 ^ BROADCAST_SEMICOLON;
+ long matchBits1 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
+ diff = word2 ^ BROADCAST_SEMICOLON;
+ long matchBits2 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
+ if ((matchBits1 | matchBits2) != 0) {
+ int trailing1 = Long.numberOfTrailingZeros(matchBits1);
+ int match1IsNonZero = trailing1 & 63;
+ match1IsNonZero |= match1IsNonZero >>> 3;
+ match1IsNonZero |= match1IsNonZero >>> 1;
+ match1IsNonZero |= match1IsNonZero >>> 1;
+ // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to
+ // raise the lowest bit in traling2 if trailing1 is nonzero. This forces
+ // trailing2 to be zero if trailing1 is non-zero.
+ int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63;
+ return cursor + ((trailing1 | trailing2) >> 3);
+ }
+ long offset = cursor + 2 * Long.BYTES;
+ for (; offset <= inputSize - Long.BYTES; offset += Long.BYTES) {
+ var block = UNSAFE.getLong(inputBase + offset);
+ diff = block ^ BROADCAST_SEMICOLON;
+ long matchBits = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
+ if (matchBits != 0) {
+ return offset + Long.numberOfTrailingZeros(matchBits) / 8;
+ }
+ }
+ return posOfSemicolonSimple(offset);
+ }
+
+ private long posOfSemicolonSimple(long offset) {
+ for (; offset < inputSize; offset++) {
+ if (UNSAFE.getByte(inputBase + offset) == SEMICOLON) {
+ return offset;
+ }
+ }
+ throw new RuntimeException("Semicolon not found");
+ }
+
+ // Copies the results from native memory to Java heap and puts them into the results array.
+ private void exportResults() {
+ var exportedStats = new ArrayList<StationStats>(10_000);
+ for (int i = 0; i < STATS_TABLE_SIZE; i++) {
+ stats.gotoIndex(i);
+ if (stats.nameLen() == 0) {
+ continue;
+ }
+ var sum = stats.sum();
+ var count = stats.count();
+ var min = stats.min();
+ var max = stats.max();
+ var name = stats.exportNameString();
+ var stationStats = new StationStats();
+ stationStats.name = name;
+ stationStats.sum = sum;
+ stationStats.count = count;
+ stationStats.min = min;
+ stationStats.max = max;
+ exportedStats.add(stationStats);
+ }
+ StationStats[] exported = exportedStats.toArray(new StationStats[0]);
+ Arrays.sort(exported);
+ results[myIndex] = exported;
+ }
+
+ private final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder());
+
+ private String longToString(long word) {
+ buf.clear();
+ buf.putLong(word);
+ return new String(buf.array(), StandardCharsets.UTF_8); // + "|" + Arrays.toString(buf.array());
+ }
+ }
+
+ private static long broadcastByte(int b) {
+ long nnnnnnnn = b;
+ nnnnnnnn |= nnnnnnnn << 8;
+ nnnnnnnn |= nnnnnnnn << 16;
+ nnnnnnnn |= nnnnnnnn << 32;
+ return nnnnnnnn;
+ }
+
+ static class StatsAccessor {
+ static final int NAME_SLOT_SIZE = 104;
+ static final long HASH_OFFSET = 0;
+ static final long NAMELEN_OFFSET = HASH_OFFSET + Long.BYTES;
+ static final long SUM_OFFSET = NAMELEN_OFFSET + Integer.BYTES;
+ static final long COUNT_OFFSET = SUM_OFFSET + Integer.BYTES;
+ static final long MIN_OFFSET = COUNT_OFFSET + Integer.BYTES;
+ static final long MAX_OFFSET = MIN_OFFSET + Short.BYTES;
+ static final long NAME_OFFSET = MAX_OFFSET + Short.BYTES;
+ static final long SIZEOF = (NAME_OFFSET + NAME_SLOT_SIZE - 1) / 8 * 8 + 8;
+
+ static final int ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
+
+ private final long address;
+ private long slotBase;
+
+ StatsAccessor(MemorySegment memSeg) {
+ memSeg.fill((byte) 0);
+ this.address = memSeg.address();
+ }
+
+ void gotoIndex(int index) {
+ slotBase = address + index * SIZEOF;
+ }
+
+ long hash() {
+ return UNSAFE.getLong(slotBase + HASH_OFFSET);
+ }
+
+ int nameLen() {
+ return UNSAFE.getInt(slotBase + NAMELEN_OFFSET);
+ }
+
+ int sum() {
+ return UNSAFE.getInt(slotBase + SUM_OFFSET);
+ }
+
+ int count() {
+ return UNSAFE.getInt(slotBase + COUNT_OFFSET);
+ }
+
+ short min() {
+ return UNSAFE.getShort(slotBase + MIN_OFFSET);
+ }
+
+ short max() {
+ return UNSAFE.getShort(slotBase + MAX_OFFSET);
+ }
+
+ long nameAddress() {
+ return slotBase + NAME_OFFSET;
+ }
+
+ String exportNameString() {
+ final var bytes = new byte[nameLen()];
+ UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, nameLen());
+ return new String(bytes, StandardCharsets.UTF_8);
+ }
+
+ void setHash(long hash) {
+ UNSAFE.putLong(slotBase + HASH_OFFSET, hash);
+ }
+
+ void setNameLen(int nameLen) {
+ UNSAFE.putInt(slotBase + NAMELEN_OFFSET, nameLen);
+ }
+
+ void setSum(int sum) {
+ UNSAFE.putInt(slotBase + SUM_OFFSET, sum);
+ }
+
+ void setCount(int count) {
+ UNSAFE.putInt(slotBase + COUNT_OFFSET, count);
+ }
+
+ void setMin(short min) {
+ UNSAFE.putShort(slotBase + MIN_OFFSET, min);
+ }
+
+ void setMax(short max) {
+ UNSAFE.putShort(slotBase + MAX_OFFSET, max);
+ }
+ }
+
+ private static void mergeSortAndPrint(StationStats[][] results) {
+ var onFirst = true;
+ System.out.print('{');
+ var cursors = new int[results.length];
+ var indexOfMin = 0;
+ StationStats curr = null;
+ int exhaustedCount;
+ while (true) {
+ exhaustedCount = 0;
+ StationStats min = null;
+ for (int i = 0; i < cursors.length; i++) {
+ if (cursors[i] == results[i].length) {
+ exhaustedCount++;
+ continue;
+ }
+ StationStats candidate = results[i][cursors[i]];
+ if (min == null || min.compareTo(candidate) > 0) {
+ indexOfMin = i;
+ min = candidate;
+ }
+ }
+ if (exhaustedCount == cursors.length) {
+ if (!onFirst) {
+ System.out.print(", ");
+ }
+ System.out.print(curr);
+ break;
+ }
+ cursors[indexOfMin]++;
+ if (curr == null) {
+ curr = min;
+ }
+ else if (min.equals(curr)) {
+ curr.sum += min.sum;
+ curr.count += min.count;
+ curr.min = Integer.min(curr.min, min.min);
+ curr.max = Integer.max(curr.max, min.max);
+ }
+ else {
+ if (onFirst) {
+ onFirst = false;
+ }
+ else {
+ System.out.print(", ");
+ }
+ System.out.print(curr);
+ curr = min;
+ }
+ }
+ System.out.println('}');
+ }
+}