diff options
| author | Roy van Rijn <roy.van.rijn@gmail.com> | 2024-01-05 16:38:40 +0100 |
|---|---|---|
| committer | Gunnar Morling <gunnar.morling@googlemail.com> | 2024-01-05 17:44:36 +0100 |
| commit | 3a2e0ed26746c912e47f8e55374ecbd0e1ccff7b (patch) | |
| tree | 9d2b561a15e6506137ca84655d5322154c1545ff /src | |
| parent | 631722158cc7a2de1fb3cfc06ca8fe1696f158a4 (diff) | |
Adding more speed improvements, going for first again.
Updating script
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java | 378 |
1 files changed, 164 insertions, 214 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java index cd4d572..c74415e 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java @@ -28,8 +28,10 @@ import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.TreeMap; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Changelog: @@ -47,6 +49,7 @@ import java.util.stream.Collectors; * Segmenting files: 3150 ms (based on spullara's code) * Not using SWAR for EOL: 2850 ms * Inlining hash calculation: 2450 ms + * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love) * * Best performing JVM on MacBook M2 Pro: 21.0.1-graal * `sdk use java 21.0.1-graal` @@ -55,8 +58,8 @@ import java.util.stream.Collectors; public class CalculateAverage_royvanrijn { private static final String FILE = "./measurements.txt"; + // private static final String FILE = "./src/test/resources/samples/measurements-10000-unique-keys.txt"; - // mutable state now instead of records, ugh, less instantiation. static final class Measurement { int min, max, count; long sum; @@ -96,29 +99,6 @@ public class CalculateAverage_royvanrijn { // new CalculateAverage_royvanrijn().runTests(); } - private void testInput(final String inputString, final int start, final boolean bigEndian, final int[] expectedDelimiterAndHash, final long[] expectedCityNameLong) { - - byte[] input = inputString.getBytes(StandardCharsets.UTF_8); - - ByteBuffer buffer = ByteBuffer.wrap(input).order(bigEndian ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN); - - int[] output = new int[2]; - long[] cityName = new long[128]; - findNextDelimiterAndCalculateHash(buffer, SEPARATOR_PATTERN, start, buffer.limit(), output, cityName, bigEndian); - - if (!Arrays.equals(output, expectedDelimiterAndHash)) { - System.out.println("Error in delimiter or hash"); - System.out.println("Expected: " + Arrays.toString(expectedDelimiterAndHash)); - System.out.println("Received: " + Arrays.toString(output)); - } - int amountLong = 1 + ((output[0] - start) >>> 3); - if (!Arrays.equals(cityName, 0, amountLong, expectedCityNameLong, 0, amountLong)) { - System.out.println("Error in long array"); - System.out.println("Expected: " + Arrays.toString(expectedCityNameLong)); - System.out.println("Received: " + Arrays.toString(cityName)); - } - } - private void run() throws Exception { var results = getFileSegments(new File(FILE)).stream().map(segment -> { @@ -128,8 +108,7 @@ public class CalculateAverage_royvanrijn { var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start()); // Work with any UTF-8 city name, up to 100 in length: - var buffer = new byte[106]; // 100 + ; + -XX.X + \n - var cityNameAsLongArray = new long[13]; // 13*8=104=kenough. + var cityNameAsLongArray = new long[16]; var delimiterPointerAndHash = new int[2]; // Calculate using native ordering (fastest?): @@ -143,39 +122,39 @@ public class CalculateAverage_royvanrijn { int limit = bb.limit(); while ((startPointer = bb.position()) < limit) { + int delimiterPointer, endPointer; + // SWAR method to find delimiter *and* record the cityname as long[] *and* calculate a hash: findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, startPointer, limit, delimiterPointerAndHash, cityNameAsLongArray, bufferIsBigEndian); - int delimiterPointer = delimiterPointerAndHash[0]; + delimiterPointer = delimiterPointerAndHash[0]; // Simple lookup is faster for '\n' (just three options) - int endPointer; - if (delimiterPointer >= limit) { - bb.position(limit); // skip to next line. return measurements; } + // Extract the measurement value (10x): + final int cityNameLength = delimiterPointer - startPointer; - if (bb.get(delimiterPointer + 4) == '\n') { - endPointer = delimiterPointer + 4; + int measuredValue; + int neg = 1; + if (bb.get(++delimiterPointer) == '-') { + neg = -1; + delimiterPointer++; } - else if (bb.get(delimiterPointer + 5) == '\n') { - endPointer = delimiterPointer + 5; + byte dot; + if ((dot = (bb.get(delimiterPointer + 1))) == '.') { + measuredValue = neg * ((bb.get(delimiterPointer)) * 10 + (bb.get(delimiterPointer + 2)) - 528); + endPointer = delimiterPointer + 3; } else { - endPointer = delimiterPointer + 6; + measuredValue = neg * (bb.get(delimiterPointer) * 100 + dot * 10 + bb.get(delimiterPointer + 3) - 5328); + endPointer = delimiterPointer + 4; } - // Read the entry in a single get(): - bb.get(buffer, 0, endPointer - startPointer); - bb.position(endPointer + 1); // skip to next line. - - // Extract the measurement value (10x): - final int cityNameLength = delimiterPointer - startPointer; - final int measuredValueLength = endPointer - delimiterPointer - 1; - final int measuredValue = branchlessParseInt(buffer, cityNameLength + 1, measuredValueLength); - // Store everything in a custom hashtable: - measurements.update(buffer, cityNameLength, delimiterPointerAndHash[1], cityNameAsLongArray).updateWith(measuredValue); + measurements.update(cityNameAsLongArray, bb, cityNameLength, delimiterPointerAndHash[1]).updateWith(measuredValue); + + bb.position(endPointer + 1); // skip to next line. } return measurements; } @@ -183,61 +162,18 @@ public class CalculateAverage_royvanrijn { throw new RuntimeException(e); } }).parallel() - .flatMap(v -> v.values.stream()) + .flatMap(v -> v.get()) .collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new)); System.out.println(results); + + // System.out.println("Processed: " + results.entrySet().stream().mapToLong(e -> e.getValue().count).sum()); } /** * -------- This section contains SWAR code (SIMD Within A Register) which processes a bytebuffer as longs to find values: */ private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); - private static final long[] PARTIAL_INDEX_MASKS = new long[]{ 0L, 255L, 65535L, 16777215L, 4294967295L, 1099511627775L, 281474976710655L, 72057594037927935L }; - - public void runTests() { - - // Method used for debugging purposes, easy to make mistakes with all the bit hacking. - - // These all have the same hashes: - testInput("Delft;-12.4", 0, true, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L }); - testInput("aDelft;-12.4", 1, true, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L }); - - testInput("Delft;-12.4", 0, false, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L }); - testInput("aDelft;-12.4", 1, false, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L }); - - testInput("Rotterdam;-12.4", 0, true, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, true, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, true, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - - testInput("Rotterdam;-12.4", 0, false, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, false, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, false, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - - // These have different hashes from the strings above: - testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, true, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, false, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L }); - - MeasurementRepository repository = new MeasurementRepository(); - - // Simulate adding two entries with the same hash: - byte[] b1 = "City1;10.0".getBytes(); - byte[] b2 = "City2;41.1".getBytes(); - repository.update(b1, 5, 1234, new long[]{ 1234L }); - repository.update(b2, 5, 1234, new long[]{ 4321L }); - // And update the same record shouldn't add a third (this happened): - repository.update(b1, 5, 1234, new long[]{ 1234L }); - - if (repository.values.size() != 2) { - System.out.println("Error, should have two entries:"); - System.out.println(repository.values); - } - - MeasurementRepository.Entry firstInserted = repository.values.getFirst(); - if (!firstInserted.cityName.equals("City1")) { - System.out.println("Error, should have correct name: " + firstInserted.cityName); - } - } /** * Already looping the longs here, lets shoehorn in making a hash @@ -249,35 +185,43 @@ public class CalculateAverage_royvanrijn { int lCnt = 0; for (i = start; i <= limit - 8; i += 8) { long word = bb.getLong(i); - if (bufferBigEndian) + if (bufferBigEndian) { word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this - int index = firstAnyPattern(word, pattern); - if (index < Long.BYTES) { - final long partialHash = word & PARTIAL_INDEX_MASKS[index]; - asLong[lCnt] = partialHash; - hash = 961 * hash + 31 * (int) (partialHash >>> 32) + (int) partialHash; + } + final long match = word ^ pattern; + long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; + + if (mask != 0) { + final int index = Long.numberOfTrailingZeros(mask) >> 3; output[0] = (i + index); - output[1] = hash; + + final long partialHash = word & ((mask >> 7) - 1); + asLong[lCnt] = partialHash; + output[1] = longHashStep(hash, partialHash); return; } asLong[lCnt++] = word; - hash = 961 * hash + 31 * (int) (word >>> 32) + (int) word; + hash = longHashStep(hash, word); } - // Handle remaining bytes + // Handle remaining bytes near the limit of the buffer: long partialHash = 0; + int len = 0; for (; i < limit; i++) { byte read; if ((read = bb.get(i)) == (byte) pattern) { asLong[lCnt] = partialHash; - hash = 961 * hash + 31 * (int) (partialHash >>> 32) + (int) partialHash; output[0] = i; - output[1] = hash; + output[1] = longHashStep(hash, partialHash); return; } - partialHash = partialHash << 8 | read; + partialHash = partialHash | ((long) read << (len << 3)); + len++; } output[0] = limit; // delimiter not found - output[1] = hash; + } + + private static int longHashStep(final int hash, final long word) { + return 31 * hash + (int) (word ^ (word >>> 32)); } private static long compilePattern(final byte value) { @@ -285,18 +229,9 @@ public class CalculateAverage_royvanrijn { ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; } - private static int firstAnyPattern(final long word, final long pattern) { - final long match = word ^ pattern; - long mask = match - 0x0101010101010101L; - mask &= ~match; - mask &= 0x8080808080808080L; - return Long.numberOfTrailingZeros(mask) >> 3; - } - record FileSegment(long start, long end) { } - /** Using this way to segment the file is much prettier, from spullara */ private static List<FileSegment> getFileSegments(final File file) throws IOException { final int numberOfSegments = Runtime.getRuntime().availableProcessors(); final long fileSize = file.length(); @@ -307,52 +242,28 @@ public class CalculateAverage_royvanrijn { return segments; } try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) { - for (int i = 0; i < numberOfSegments; i++) { - long segStart = i * segmentSize; - long segEnd = (i == numberOfSegments - 1) ? fileSize : segStart + segmentSize; - segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd); - segEnd = findSegment(i, numberOfSegments - 1, randomAccessFile, segEnd, fileSize); - + long segStart = 0; + long segEnd = segmentSize; + while (segStart < fileSize) { + segEnd = findSegment(randomAccessFile, segEnd, fileSize); segments.add(new FileSegment(segStart, segEnd)); + segStart = segEnd; // Just re-use the end and go from there. + segEnd = Math.min(fileSize, segEnd + segmentSize); } } return segments; } - private static long findSegment(final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize) throws IOException { - if (i != skipSegment) { - raf.seek(location); - while (location < fileSize) { - location++; - if (raf.read() == '\n') - return location; - } + private static long findSegment(RandomAccessFile raf, long location, final long fileSize) throws IOException { + raf.seek(location); + while (location < fileSize) { + location++; + if (raf.read() == '\n') + return location; } return location; } - /** - * Branchless parser, goes from String to int (10x): - * "-1.2" to -12 - * "40.1" to 401 - * etc. - * - * @param input - * @return int value x10 - */ - private static int branchlessParseInt(final byte[] input, final int start, final int length) { - // 0 if positive, 1 if negative - final int negative = ~(input[start] >> 4) & 1; - // 0 if nr length is 3, 1 if length is 4 - final int has4 = ((length - negative) >> 2) & 1; - - final int digit1 = input[start + negative] - '0'; - final int digit2 = input[start + negative + has4]; - final int digit3 = input[start + negative + has4 + 2]; - - return (-negative ^ (has4 * (digit1 * 100) + digit2 * 10 + digit3 - 528) - negative); // 528 == ('0' * 10 + '0') - } - // branchless max (unprecise for large numbers, but good enough) static int max(final int a, final int b) { final int diff = a - b; @@ -373,88 +284,81 @@ public class CalculateAverage_royvanrijn { * So I've written an extremely simple linear probing hashmap that should work well enough. */ class MeasurementRepository { - private int size = 16384;// 16384; // Much larger than the number of cities, needs power of two - private int[] indices = new int[size]; // Hashtable is just an int[] + private int tableSize = 1 << 20; // can grow in theory, made large enough not to (this is faster) + private int tableMask = (tableSize - 1); + private int tableLimit = (int) (tableSize * LOAD_FACTOR); + private int tableFilled = 0; + private static final float LOAD_FACTOR = 0.8f; - MeasurementRepository() { - populateEmptyIndices(indices); - } + private Entry[] table = new Entry[tableSize]; - private void populateEmptyIndices(int[] array) { - // Optimized fill with -1, fastest method: - int len = array.length; - array[0] = -1; - // Value of i will be [1, 2, 4, 8, 16, 32, ..., len] - for (int i = 1; i < len; i += i) { - System.arraycopy(array, 0, array, i, i); - } - } - - private final List<Entry> values = new ArrayList<>(512); - - record Entry(int hash, long[] cityNameAsLong, String cityName, Measurement measurement) { + record Entry(int hash, long[] nameBytesInLong, String cityName, Measurement measurement) { @Override public String toString() { return cityName + "=" + measurement; } } - public Measurement update(byte[] buffer, int length, int calculatedHash, long[] cityNameAsLongArray) { - - final int cityNameAsLongLength = 1 + (length >>> 3); // amount of longs that captures this cityname + public Measurement update(long[] nameBytesInLong, ByteBuffer bb, int length, int calculatedHash) { - int hashtableIndex = (size - 1) & calculatedHash; - int valueIndex; + final int nameBytesInLongLength = 1 + (length >>> 3); - Entry retrievedEntry = null; + int index = calculatedHash & tableMask; + Entry tableEntry; + while ((tableEntry = table[index]) != null + && (tableEntry.hash != calculatedHash || !arrayEquals(tableEntry.nameBytesInLong, nameBytesInLong, nameBytesInLongLength))) { // search for the right spot + index = (index + 1) & tableMask; + } - while (true) { // search for the right spot - if ((valueIndex = indices[hashtableIndex]) == -1) { - break; // Empty slot found, stop the loop - } - else { - // Non-empty slot, retrieve entry - if ((retrievedEntry = values.get(valueIndex)).hash == calculatedHash && - arrayEquals(retrievedEntry.cityNameAsLong, cityNameAsLongArray, cityNameAsLongLength)) { - break; // Both hash and cityname match, stop the loop - } - } - // Move to the next index - hashtableIndex = (hashtableIndex + 1) % size; + if (tableEntry != null) { + return tableEntry.measurement; } - if (valueIndex >= 0) { - return retrievedEntry.measurement; + // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. + Measurement measurement = new Measurement(); + + // Now create a string: + byte[] buffer = new byte[length]; + bb.get(buffer, 0, length); + String cityName = new String(buffer, 0, length); + + // Store the long[] for faster equals: + long[] nameBytesInLongCopy = new long[nameBytesInLongLength]; + System.arraycopy(nameBytesInLong, 0, nameBytesInLongCopy, 0, nameBytesInLongLength); + + // And add entry: + Entry toAdd = new Entry(calculatedHash, nameBytesInLongCopy, cityName, measurement); + table[index] = toAdd; + + // Resize the table if filled too much: + if (++tableFilled > tableLimit) { + resizeTable(); } - // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) - - // Keep the already processed longs for fast equals: - long[] cityNameAsLongArrayCopy = new long[cityNameAsLongLength]; - System.arraycopy(cityNameAsLongArray, 0, cityNameAsLongArrayCopy, 0, cityNameAsLongLength); - - Entry toAdd = new Entry(calculatedHash, cityNameAsLongArrayCopy, new String(buffer, 0, length), new Measurement()); - - // Code to regrow (if we get more unique entries): (not needed/not optimized yet) - // if (values.size() > size / 2) { - // // We probably don't want this... - // - // int newSize = size << 1; - // int[] newIndices = new int[newSize]; - // populateEmptyIndices(newIndices); - // for (int i = 0; i < values.size(); i++) { - // Entry e = values.get(i); - // int updatedIndex = (newSize - 1) & e.hash; - // newIndices[updatedIndex] = i; - // } - // indices = newIndices; - // size = newSize; - // } - indices[hashtableIndex] = values.size(); - - values.add(toAdd); return toAdd.measurement; } + + private void resizeTable() { + // Resize the table: + Entry[] oldEntries = table; + table = new Entry[tableSize <<= 2]; // x2 + tableMask = (tableSize - 1); + tableLimit = (int) (tableSize * LOAD_FACTOR); + + for (Entry entry : oldEntries) { + if (entry != null) { + int updatedTableIndex = entry.hash & tableMask; + while (table[updatedTableIndex] != null) { + updatedTableIndex = (updatedTableIndex + 1) & tableMask; + } + table[updatedTableIndex] = entry; + } + } + } + + public Stream<Entry> get() { + return Arrays.stream(table).filter(Objects::nonNull); + } } /** @@ -467,4 +371,50 @@ public class CalculateAverage_royvanrijn { } return true; } + + public void runTests() { + // Method used for debugging purposes, easy to make mistakes with all the bit hacking. + + // These all have the same hashes: + testInput("Delft;-12.4", 0, true, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L }); + testInput("aDelft;-12.4", 1, true, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L }); + + testInput("Delft;-12.4", 0, false, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L }); + testInput("aDelft;-12.4", 1, false, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L }); + + testInput("Rotterdam;-12.4", 0, true, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L }); + testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, true, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L }); + testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, true, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L }); + + testInput("Rotterdam;-12.4", 0, false, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L }); + testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, false, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L }); + testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, false, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L }); + + // These have different hashes from the strings above: + testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, true, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L }); + testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, false, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L }); + } + + private void testInput(final String inputString, final int start, final boolean bigEndian, final int[] expectedDelimiterAndHash, final long[] expectedCityNameLong) { + + byte[] input = inputString.getBytes(StandardCharsets.UTF_8); + + ByteBuffer buffer = ByteBuffer.wrap(input).order(bigEndian ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN); + + int[] output = new int[2]; + long[] cityName = new long[128]; + findNextDelimiterAndCalculateHash(buffer, SEPARATOR_PATTERN, start, buffer.limit(), output, cityName, bigEndian); + + if (!Arrays.equals(output, expectedDelimiterAndHash)) { + System.out.println("Error in delimiter or hash"); + System.out.println("Expected: " + Arrays.toString(expectedDelimiterAndHash)); + System.out.println("Received: " + Arrays.toString(output)); + } + int amountLong = 1 + ((output[0] - start) >>> 3); + if (!Arrays.equals(cityName, 0, amountLong, expectedCityNameLong, 0, amountLong)) { + System.out.println("Error in long array"); + System.out.println("Expected: " + Arrays.toString(expectedCityNameLong)); + System.out.println("Received: " + Arrays.toString(cityName)); + } + } } |
