diff options
| author | Roy van Rijn <roy.van.rijn@gmail.com> | 2024-01-07 19:41:43 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-01-07 19:41:43 +0100 |
| commit | e665d715499b2f2cac765cd3314c948a56602061 (patch) | |
| tree | 78438a8b084342ac438b1e1a10f57e833e9dd43e /src/main/java/dev | |
| parent | ff7d4a1750cec14d889132124a8d09f4042c2ea6 (diff) | |
Roy: Adding a bit of unsafe...
Co-authored-by: Gunnar Morling <gunnar.morling@googlemail.com>
Diffstat (limited to 'src/main/java/dev')
| -rw-r--r-- | src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java | 409 |
1 files changed, 155 insertions, 254 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java index c74415e..aa22bef 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java @@ -15,24 +15,22 @@ */ package dev.morling.onebrc; -import java.io.File; import java.io.IOException; -import java.io.RandomAccessFile; -import java.nio.ByteBuffer; +import java.lang.foreign.Arena; +import java.lang.reflect.Field; import java.nio.ByteOrder; import java.nio.channels.FileChannel; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Objects; import java.util.TreeMap; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; +import sun.misc.Unsafe; + /** * Changelog: * @@ -50,6 +48,7 @@ import java.util.stream.Stream; * Not using SWAR for EOL: 2850 ms * Inlining hash calculation: 2450 ms * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love) + * Added unsafe memory access: 1900 ms (keeping the long[] small and local) * * Best performing JVM on MacBook M2 Pro: 21.0.1-graal * `sdk use java 21.0.1-graal` @@ -58,210 +57,173 @@ import java.util.stream.Stream; public class CalculateAverage_royvanrijn { private static final String FILE = "./measurements.txt"; - // private static final String FILE = "./src/test/resources/samples/measurements-10000-unique-keys.txt"; - - static final class Measurement { - int min, max, count; - long sum; - - public Measurement() { - this.min = 1000; - this.max = -1000; - } - - public Measurement updateWith(int measurement) { - min = min(min, measurement); - max = max(max, measurement); - sum += measurement; - count++; - return this; - } - public Measurement updateWith(Measurement measurement) { - min = min(min, measurement.min); - max = max(max, measurement.max); - sum += measurement.sum; - count += measurement.count; - return this; - } + private static final Unsafe UNSAFE = initUnsafe(); + private static final boolean isBigEndian = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); - public String toString() { - return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max); + private static Unsafe initUnsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); } - - private double round(double value) { - return Math.round(value) / 10.0; + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); } } public static void main(String[] args) throws Exception { new CalculateAverage_royvanrijn().run(); - // new CalculateAverage_royvanrijn().runTests(); } - private void run() throws Exception { - - var results = getFileSegments(new File(FILE)).stream().map(segment -> { - - long segmentEnd = segment.end(); - try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) { - var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start()); - - // Work with any UTF-8 city name, up to 100 in length: - var cityNameAsLongArray = new long[16]; - var delimiterPointerAndHash = new int[2]; - - // Calculate using native ordering (fastest?): - bb.order(ByteOrder.nativeOrder()); - - // Record the order it is and calculate accordingly: - final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN); - MeasurementRepository measurements = new MeasurementRepository(); - - int startPointer; - int limit = bb.limit(); - while ((startPointer = bb.position()) < limit) { - - int delimiterPointer, endPointer; - - // SWAR method to find delimiter *and* record the cityname as long[] *and* calculate a hash: - findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, startPointer, limit, delimiterPointerAndHash, cityNameAsLongArray, bufferIsBigEndian); - delimiterPointer = delimiterPointerAndHash[0]; - - // Simple lookup is faster for '\n' (just three options) - if (delimiterPointer >= limit) { - return measurements; - } - // Extract the measurement value (10x): - final int cityNameLength = delimiterPointer - startPointer; - - int measuredValue; - int neg = 1; - if (bb.get(++delimiterPointer) == '-') { - neg = -1; - delimiterPointer++; - } - byte dot; - if ((dot = (bb.get(delimiterPointer + 1))) == '.') { - measuredValue = neg * ((bb.get(delimiterPointer)) * 10 + (bb.get(delimiterPointer + 2)) - 528); - endPointer = delimiterPointer + 3; - } - else { - measuredValue = neg * (bb.get(delimiterPointer) * 100 + dot * 10 + bb.get(delimiterPointer + 3) - 5328); - endPointer = delimiterPointer + 4; - } - - // Store everything in a custom hashtable: - measurements.update(cityNameAsLongArray, bb, cityNameLength, delimiterPointerAndHash[1]).updateWith(measuredValue); - - bb.position(endPointer + 1); // skip to next line. + public void run() throws Exception { + + // Calculate input segments. + int numberOfChunks = Runtime.getRuntime().availableProcessors(); + long[] chunks = getSegments(numberOfChunks); + + // Parallel processing of segments. + TreeMap<String, Measurement> results = IntStream.range(0, chunks.length - 1) + .mapToObj(chunkIndex -> process(chunks[chunkIndex], chunks[chunkIndex + 1])).parallel() + .flatMap(MeasurementRepository::get) + .collect(Collectors.toMap(e -> e.city, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new)); + + System.out.println(results); + } + + private static long[] getSegments(int numberOfChunks) throws IOException { + try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { + long fileSize = fileChannel.size(); + long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks; + long[] chunks = new long[numberOfChunks + 1]; + long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); + chunks[0] = mappedAddress; + long endAddress = mappedAddress + fileSize; + for (int i = 1; i < numberOfChunks; ++i) { + long chunkAddress = mappedAddress + i * segmentSize; + // Align to first row start. + while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') { + // nop } - return measurements; + chunks[i] = Math.min(chunkAddress, endAddress); } - catch (IOException e) { - throw new RuntimeException(e); - } - }).parallel() - .flatMap(v -> v.get()) - .collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new)); + chunks[numberOfChunks] = endAddress; + return chunks; + } + } - System.out.println(results); + private MeasurementRepository process(long fromAddress, long toAddress) { - // System.out.println("Processed: " + results.entrySet().stream().mapToLong(e -> e.getValue().count).sum()); + MeasurementRepository repository = new MeasurementRepository(); + long ptr = fromAddress; + long[] dataBuffer = new long[16]; + while ((ptr = processEntity(dataBuffer, ptr, toAddress, repository)) < toAddress) + ; + + return repository; } - /** - * -------- This section contains SWAR code (SIMD Within A Register) which processes a bytebuffer as longs to find values: - */ private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); /** * Already looping the longs here, lets shoehorn in making a hash */ - private void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output, - final long[] asLong, final boolean bufferBigEndian) { + private long processEntity(final long[] data, final long start, final long limit, final MeasurementRepository measurementRepository) { int hash = 1; - int i; - int lCnt = 0; + long i; + int dataPtr = 0; for (i = start; i <= limit - 8; i += 8) { - long word = bb.getLong(i); - if (bufferBigEndian) { + long word = UNSAFE.getLong(i); + if (isBigEndian) { word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this } - final long match = word ^ pattern; + final long match = word ^ SEPARATOR_PATTERN; long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; if (mask != 0) { - final int index = Long.numberOfTrailingZeros(mask) >> 3; - output[0] = (i + index); - final long partialHash = word & ((mask >> 7) - 1); - asLong[lCnt] = partialHash; - output[1] = longHashStep(hash, partialHash); - return; + final long partialWord = word & ((mask >> 7) - 1); + hash = longHashStep(hash, partialWord); + data[dataPtr] = partialWord; + + final int index = Long.numberOfTrailingZeros(mask) >> 3; + return process(start, i + index, hash, data, measurementRepository); } - asLong[lCnt++] = word; + data[dataPtr++] = word; hash = longHashStep(hash, word); } // Handle remaining bytes near the limit of the buffer: - long partialHash = 0; + long partialWord = 0; int len = 0; for (; i < limit; i++) { byte read; - if ((read = bb.get(i)) == (byte) pattern) { - asLong[lCnt] = partialHash; - output[0] = i; - output[1] = longHashStep(hash, partialHash); - return; + if ((read = UNSAFE.getByte(i)) == ';') { + hash = longHashStep(hash, partialWord); + data[dataPtr] = partialWord; + return process(start, i, hash, data, measurementRepository); } - partialHash = partialHash | ((long) read << (len << 3)); + partialWord = partialWord | ((long) read << (len << 3)); len++; } - output[0] = limit; // delimiter not found + return limit; } - private static int longHashStep(final int hash, final long word) { - return 31 * hash + (int) (word ^ (word >>> 32)); - } + private static final long DOT_BITS = 0x10101000; + private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); - private static long compilePattern(final byte value) { - return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | - ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; - } + private long process(final long startAddress, final long delimiterAddress, final int hash, final long[] data, final MeasurementRepository measurementRepository) { - record FileSegment(long start, long end) { + long word = UNSAFE.getLong(delimiterAddress + 1); + if (isBigEndian) { + word = Long.reverseBytes(word); + } + final long invWord = ~word; + final int decimalSepPos = Long.numberOfTrailingZeros(invWord & DOT_BITS); + final long signed = (invWord << 59) >> 63; + final long designMask = ~(signed & 0xFF); + final long digits = ((word & designMask) << (28 - decimalSepPos)) & 0x0F000F0F00L; + final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; + final int measurement = (int) ((absValue ^ signed) - signed); + + // Store: + measurementRepository.update(startAddress, data, (int) (delimiterAddress - startAddress), hash, measurement); + + return delimiterAddress + (decimalSepPos >> 3) + 4; // Determine next start: + // return nextAddress; } - private static List<FileSegment> getFileSegments(final File file) throws IOException { - final int numberOfSegments = Runtime.getRuntime().availableProcessors(); - final long fileSize = file.length(); - final long segmentSize = fileSize / numberOfSegments; - final List<FileSegment> segments = new ArrayList<>(); - if (segmentSize < 1000) { - segments.add(new FileSegment(0, fileSize)); - return segments; + static final class Measurement { + int min, max, count; + long sum; + + public Measurement() { + this.min = 1000; + this.max = -1000; } - try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) { - long segStart = 0; - long segEnd = segmentSize; - while (segStart < fileSize) { - segEnd = findSegment(randomAccessFile, segEnd, fileSize); - segments.add(new FileSegment(segStart, segEnd)); - segStart = segEnd; // Just re-use the end and go from there. - segEnd = Math.min(fileSize, segEnd + segmentSize); - } + + public Measurement updateWith(int measurement) { + min = min(min, measurement); + max = max(max, measurement); + sum += measurement; + count++; + return this; + } + + public Measurement updateWith(Measurement measurement) { + min = min(min, measurement.min); + max = max(max, measurement.max); + sum += measurement.sum; + count += measurement.count; + return this; } - return segments; - } - private static long findSegment(RandomAccessFile raf, long location, final long fileSize) throws IOException { - raf.seek(location); - while (location < fileSize) { - location++; - if (raf.read() == '\n') - return location; + public String toString() { + return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max); + } + + private double round(double value) { + return Math.round(value) / 10.0; } - return location; } // branchless max (unprecise for large numbers, but good enough) @@ -278,85 +240,69 @@ public class CalculateAverage_royvanrijn { return b + (diff & dsgn); } + private static int longHashStep(final int hash, final long word) { + return 31 * hash + (int) (word ^ (word >>> 32)); + } + + private static long compilePattern(final byte value) { + return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | + ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value; + } + /** * A normal Java HashMap does all these safety things like boundary checks... we don't need that, we need speeeed. * * So I've written an extremely simple linear probing hashmap that should work well enough. */ class MeasurementRepository { - private int tableSize = 1 << 20; // can grow in theory, made large enough not to (this is faster) + private int tableSize = 1 << 20; // large enough for the contest. private int tableMask = (tableSize - 1); - private int tableLimit = (int) (tableSize * LOAD_FACTOR); - private int tableFilled = 0; - private static final float LOAD_FACTOR = 0.8f; - private Entry[] table = new Entry[tableSize]; + private MeasurementRepository.Entry[] table = new MeasurementRepository.Entry[tableSize]; + + record Entry(long address, long[] data, int length, int hash, String city, Measurement measurement) { - record Entry(int hash, long[] nameBytesInLong, String cityName, Measurement measurement) { @Override public String toString() { - return cityName + "=" + measurement; + return city + "=" + measurement; } } - public Measurement update(long[] nameBytesInLong, ByteBuffer bb, int length, int calculatedHash) { + public void update(long address, long[] data, int length, int hash, int temperature) { - final int nameBytesInLongLength = 1 + (length >>> 3); - - int index = calculatedHash & tableMask; - Entry tableEntry; + int dataLength = length >> 3; + int index = hash & tableMask; + MeasurementRepository.Entry tableEntry; while ((tableEntry = table[index]) != null - && (tableEntry.hash != calculatedHash || !arrayEquals(tableEntry.nameBytesInLong, nameBytesInLong, nameBytesInLongLength))) { // search for the right spot + && (tableEntry.hash != hash || tableEntry.length != length || !arrayEquals(tableEntry.data, data, dataLength))) { // search for the right spot index = (index + 1) & tableMask; } if (tableEntry != null) { - return tableEntry.measurement; + tableEntry.measurement.updateWith(temperature); + return; } // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. Measurement measurement = new Measurement(); - // Now create a string: - byte[] buffer = new byte[length]; - bb.get(buffer, 0, length); - String cityName = new String(buffer, 0, length); + byte[] bytes = new byte[length]; + for (int i = 0; i < length; i++) { + bytes[i] = UNSAFE.getByte(address + i); + } + String city = new String(bytes); - // Store the long[] for faster equals: - long[] nameBytesInLongCopy = new long[nameBytesInLongLength]; - System.arraycopy(nameBytesInLong, 0, nameBytesInLongCopy, 0, nameBytesInLongLength); + long[] dataCopy = new long[dataLength]; + System.arraycopy(data, 0, dataCopy, 0, dataLength); // And add entry: - Entry toAdd = new Entry(calculatedHash, nameBytesInLongCopy, cityName, measurement); + MeasurementRepository.Entry toAdd = new MeasurementRepository.Entry(address, dataCopy, length, hash, city, measurement); table[index] = toAdd; - // Resize the table if filled too much: - if (++tableFilled > tableLimit) { - resizeTable(); - } - - return toAdd.measurement; - } - - private void resizeTable() { - // Resize the table: - Entry[] oldEntries = table; - table = new Entry[tableSize <<= 2]; // x2 - tableMask = (tableSize - 1); - tableLimit = (int) (tableSize * LOAD_FACTOR); - - for (Entry entry : oldEntries) { - if (entry != null) { - int updatedTableIndex = entry.hash & tableMask; - while (table[updatedTableIndex] != null) { - updatedTableIndex = (updatedTableIndex + 1) & tableMask; - } - table[updatedTableIndex] = entry; - } - } + toAdd.measurement.updateWith(temperature); } - public Stream<Entry> get() { + public Stream<MeasurementRepository.Entry> get() { return Arrays.stream(table).filter(Objects::nonNull); } } @@ -372,49 +318,4 @@ public class CalculateAverage_royvanrijn { return true; } - public void runTests() { - // Method used for debugging purposes, easy to make mistakes with all the bit hacking. - - // These all have the same hashes: - testInput("Delft;-12.4", 0, true, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L }); - testInput("aDelft;-12.4", 1, true, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L }); - - testInput("Delft;-12.4", 0, false, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L }); - testInput("aDelft;-12.4", 1, false, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L }); - - testInput("Rotterdam;-12.4", 0, true, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, true, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, true, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - - testInput("Rotterdam;-12.4", 0, false, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, false, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, false, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L }); - - // These have different hashes from the strings above: - testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, true, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L }); - testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, false, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L }); - } - - private void testInput(final String inputString, final int start, final boolean bigEndian, final int[] expectedDelimiterAndHash, final long[] expectedCityNameLong) { - - byte[] input = inputString.getBytes(StandardCharsets.UTF_8); - - ByteBuffer buffer = ByteBuffer.wrap(input).order(bigEndian ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN); - - int[] output = new int[2]; - long[] cityName = new long[128]; - findNextDelimiterAndCalculateHash(buffer, SEPARATOR_PATTERN, start, buffer.limit(), output, cityName, bigEndian); - - if (!Arrays.equals(output, expectedDelimiterAndHash)) { - System.out.println("Error in delimiter or hash"); - System.out.println("Expected: " + Arrays.toString(expectedDelimiterAndHash)); - System.out.println("Received: " + Arrays.toString(output)); - } - int amountLong = 1 + ((output[0] - start) >>> 3); - if (!Arrays.equals(cityName, 0, amountLong, expectedCityNameLong, 0, amountLong)) { - System.out.println("Error in long array"); - System.out.println("Expected: " + Arrays.toString(expectedCityNameLong)); - System.out.println("Received: " + Arrays.toString(cityName)); - } - } } |
