aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRoy van Rijn <roy.van.rijn@gmail.com>2024-01-11 11:12:05 +0100
committerGitHub <noreply@github.com>2024-01-11 11:12:05 +0100
commit8c248714061535a819181ca137876fa1435a7e7d (patch)
tree876b92713fea4cd172f8655702bdf9656f7cbc86 /src
parentb0c9952c082d2c9dc772328100b60ef0da82f0c9 (diff)
Fixing the off-by-one error and updating to native, redone layout of code. (#307)
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java368
1 files changed, 166 insertions, 202 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
index aa22bef..4cf9925 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
@@ -18,52 +18,52 @@ package dev.morling.onebrc;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
-import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
+import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
-import java.util.Arrays;
-import java.util.Objects;
-import java.util.TreeMap;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
-import java.util.stream.Stream;
import sun.misc.Unsafe;
/**
* Changelog:
*
- * Initial submission: 62000 ms
- * Chunked reader: 16000 ms
- * Optimized parser: 13000 ms
- * Branchless methods: 11000 ms
- * Adding memory mapped files: 6500 ms (based on bjhara's submission)
- * Skipping string creation: 4700 ms
- * Custom hashmap... 4200 ms
- * Added SWAR token checks: 3900 ms
- * Skipped String creation: 3500 ms (idea from kgonia)
- * Improved String skip: 3250 ms
- * Segmenting files: 3150 ms (based on spullara's code)
- * Not using SWAR for EOL: 2850 ms
- * Inlining hash calculation: 2450 ms
- * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love)
- * Added unsafe memory access: 1900 ms (keeping the long[] small and local)
- *
- * Best performing JVM on MacBook M2 Pro: 21.0.1-graal
- * `sdk use java 21.0.1-graal`
+ * Initial submission: 62000 ms
+ * Chunked reader: 16000 ms
+ * Optimized parser: 13000 ms
+ * Branchless methods: 11000 ms
+ * Adding memory mapped files: 6500 ms (based on bjhara's submission)
+ * Skipping string creation: 4700 ms
+ * Custom hashmap... 4200 ms
+ * Added SWAR token checks: 3900 ms
+ * Skipped String creation: 3500 ms (idea from kgonia)
+ * Improved String skip: 3250 ms
+ * Segmenting files: 3150 ms (based on spullara's code)
+ * Not using SWAR for EOL: 2850 ms
+ * Inlining hash calculation: 2450 ms
+ * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love)
+ * Added unsafe memory access: 1900 ms (keeping the long[] small and local)
+ * Fixed bug, UNSAFE bytes String: 1850 ms
+ * Separate hash from entries: 1550 ms
+ * Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine)
+ * Improved layout/predictability 1450 ms (on par with Thomas Wuerthinger)
*
+ * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas.
*/
public class CalculateAverage_royvanrijn {
private static final String FILE = "./measurements.txt";
private static final Unsafe UNSAFE = initUnsafe();
- private static final boolean isBigEndian = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN);
private static Unsafe initUnsafe() {
try {
- Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
@@ -73,32 +73,42 @@ public class CalculateAverage_royvanrijn {
}
public static void main(String[] args) throws Exception {
- new CalculateAverage_royvanrijn().run();
- }
-
- public void run() throws Exception {
// Calculate input segments.
- int numberOfChunks = Runtime.getRuntime().availableProcessors();
- long[] chunks = getSegments(numberOfChunks);
+ final int numberOfChunks = Runtime.getRuntime().availableProcessors();
+ final long[] chunks = getSegments(numberOfChunks);
+
+ final List<Entry[]> repositories = IntStream.range(0, chunks.length - 1)
+ .mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1]))
+ .parallel()
+ .toList();
+
+ // Sometimes simple is better:
+ final HashMap<String, Entry> measurements = HashMap.newHashMap(1 << 10);
+ for (Entry[] entries : repositories) {
+ for (Entry entry : entries) {
+ if (entry != null)
+ measurements.merge(entry.city, entry, Entry::mergeWith);
+ }
+ }
- // Parallel processing of segments.
- TreeMap<String, Measurement> results = IntStream.range(0, chunks.length - 1)
- .mapToObj(chunkIndex -> process(chunks[chunkIndex], chunks[chunkIndex + 1])).parallel()
- .flatMap(MeasurementRepository::get)
- .collect(Collectors.toMap(e -> e.city, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new));
+ System.out.print("{" +
+ measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", ")));
+ System.out.println("}");
- System.out.println(results);
}
- private static long[] getSegments(int numberOfChunks) throws IOException {
+ /**
+ * Simpler way to get the segments and launch parallel processing by thomaswue
+ */
+ private static long[] getSegments(final int numberOfChunks) throws IOException {
try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
- long fileSize = fileChannel.size();
- long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks;
- long[] chunks = new long[numberOfChunks + 1];
- long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
+ final long fileSize = fileChannel.size();
+ final long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks;
+ final long[] chunks = new long[numberOfChunks + 1];
+ final long mappedAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
chunks[0] = mappedAddress;
- long endAddress = mappedAddress + fileSize;
+ final long endAddress = mappedAddress + fileSize;
for (int i = 1; i < numberOfChunks; ++i) {
long chunkAddress = mappedAddress + i * segmentSize;
// Align to first row start.
@@ -112,108 +122,36 @@ public class CalculateAverage_royvanrijn {
}
}
- private MeasurementRepository process(long fromAddress, long toAddress) {
-
- MeasurementRepository repository = new MeasurementRepository();
- long ptr = fromAddress;
- long[] dataBuffer = new long[16];
- while ((ptr = processEntity(dataBuffer, ptr, toAddress, repository)) < toAddress)
- ;
-
- return repository;
- }
-
- private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');
-
- /**
- * Already looping the longs here, lets shoehorn in making a hash
- */
- private long processEntity(final long[] data, final long start, final long limit, final MeasurementRepository measurementRepository) {
- int hash = 1;
- long i;
- int dataPtr = 0;
- for (i = start; i <= limit - 8; i += 8) {
- long word = UNSAFE.getLong(i);
- if (isBigEndian) {
- word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this
- }
- final long match = word ^ SEPARATOR_PATTERN;
- long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L;
-
- if (mask != 0) {
-
- final long partialWord = word & ((mask >> 7) - 1);
- hash = longHashStep(hash, partialWord);
- data[dataPtr] = partialWord;
-
- final int index = Long.numberOfTrailingZeros(mask) >> 3;
- return process(start, i + index, hash, data, measurementRepository);
- }
- data[dataPtr++] = word;
- hash = longHashStep(hash, word);
- }
- // Handle remaining bytes near the limit of the buffer:
- long partialWord = 0;
- int len = 0;
- for (; i < limit; i++) {
- byte read;
- if ((read = UNSAFE.getByte(i)) == ';') {
- hash = longHashStep(hash, partialWord);
- data[dataPtr] = partialWord;
- return process(start, i, hash, data, measurementRepository);
- }
- partialWord = partialWord | ((long) read << (len << 3));
- len++;
+ private static final int TABLE_SIZE = 1 << 18; // large enough for the contest.
+ private static final int TABLE_MASK = (TABLE_SIZE - 1);
+
+ static final class Entry {
+ private final long[] data;
+ private final String city;
+ private int min, max, count;
+ private long sum;
+
+ Entry(final long[] data, String city, int temp) {
+ this.data = data;
+ this.city = city;
+ this.min = temp;
+ this.max = temp;
+ this.sum = temp;
+ this.count = 1;
}
- return limit;
- }
-
- private static final long DOT_BITS = 0x10101000;
- private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
- private long process(final long startAddress, final long delimiterAddress, final int hash, final long[] data, final MeasurementRepository measurementRepository) {
-
- long word = UNSAFE.getLong(delimiterAddress + 1);
- if (isBigEndian) {
- word = Long.reverseBytes(word);
- }
- final long invWord = ~word;
- final int decimalSepPos = Long.numberOfTrailingZeros(invWord & DOT_BITS);
- final long signed = (invWord << 59) >> 63;
- final long designMask = ~(signed & 0xFF);
- final long digits = ((word & designMask) << (28 - decimalSepPos)) & 0x0F000F0F00L;
- final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
- final int measurement = (int) ((absValue ^ signed) - signed);
-
- // Store:
- measurementRepository.update(startAddress, data, (int) (delimiterAddress - startAddress), hash, measurement);
-
- return delimiterAddress + (decimalSepPos >> 3) + 4; // Determine next start:
- // return nextAddress;
- }
-
- static final class Measurement {
- int min, max, count;
- long sum;
-
- public Measurement() {
- this.min = 1000;
- this.max = -1000;
- }
-
- public Measurement updateWith(int measurement) {
- min = min(min, measurement);
- max = max(max, measurement);
+ public void updateWith(int measurement) {
+ min = Math.min(min, measurement);
+ max = Math.max(max, measurement);
sum += measurement;
count++;
- return this;
}
- public Measurement updateWith(Measurement measurement) {
- min = min(min, measurement.min);
- max = max(max, measurement.max);
- sum += measurement.sum;
- count += measurement.count;
+ public Entry mergeWith(Entry entry) {
+ min = Math.min(min, entry.min);
+ max = Math.max(max, entry.max);
+ sum += entry.sum;
+ count += entry.count;
return this;
}
@@ -221,101 +159,127 @@ public class CalculateAverage_royvanrijn {
return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max);
}
- private double round(double value) {
+ private static double round(double value) {
return Math.round(value) / 10.0;
}
}
- // branchless max (unprecise for large numbers, but good enough)
- static int max(final int a, final int b) {
- final int diff = a - b;
- final int dsgn = diff >> 31;
- return a - (diff & dsgn);
- }
+ private static Entry createNewEntry(final long[] buffer, final long startAddress, final int lengthLongs, final int lengthBytes, final int temp) {
- // branchless min (unprecise for large numbers, but good enough)
- static int min(final int a, final int b) {
- final int diff = a - b;
- final int dsgn = diff >> 31;
- return b + (diff & dsgn);
- }
+ // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here.
+ final byte[] bytes = new byte[lengthBytes];
+ UNSAFE.copyMemory(null, startAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes);
+ final String city = new String(bytes, StandardCharsets.UTF_8);
- private static int longHashStep(final int hash, final long word) {
- return 31 * hash + (int) (word ^ (word >>> 32));
- }
+ final long[] bufferCopy = new long[lengthLongs];
+ System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs);
- private static long compilePattern(final byte value) {
- return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
- ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
+ // Add the entry:
+ return new Entry(bufferCopy, city, temp);
}
- /**
- * A normal Java HashMap does all these safety things like boundary checks... we don't need that, we need speeeed.
- *
- * So I've written an extremely simple linear probing hashmap that should work well enough.
- */
- class MeasurementRepository {
- private int tableSize = 1 << 20; // large enough for the contest.
- private int tableMask = (tableSize - 1);
+ private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) {
- private MeasurementRepository.Entry[] table = new MeasurementRepository.Entry[tableSize];
+ Entry[] table = new Entry[TABLE_SIZE];
- record Entry(long address, long[] data, int length, int hash, String city, Measurement measurement) {
+ long ptr = fromAddress;
+ long[] buffer = new long[14];
- @Override
- public String toString() {
- return city + "=" + measurement;
- }
- }
+ while (ptr < toAddress) {
- public void update(long address, long[] data, int length, int hash, int temperature) {
+ int bufferPtr = 0;
+ long startAddress = ptr;
+ long hash = 1;
- int dataLength = length >> 3;
- int index = hash & tableMask;
- MeasurementRepository.Entry tableEntry;
- while ((tableEntry = table[index]) != null
- && (tableEntry.hash != hash || tableEntry.length != length || !arrayEquals(tableEntry.data, data, dataLength))) { // search for the right spot
- index = (index + 1) & tableMask;
- }
+ long word = UNSAFE.getLong(ptr);
+ long mask = getDelimiterMask(word);
+
+ while (mask == 0) {
+ buffer[bufferPtr++] = word;
+ hash ^= word;
+ ptr += 8;
- if (tableEntry != null) {
- tableEntry.measurement.updateWith(temperature);
- return;
+ word = UNSAFE.getLong(ptr);
+ mask = getDelimiterMask(word);
}
- // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here.
- Measurement measurement = new Measurement();
+ // Found delimiter:
+ final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3);
+ final long numberBits = UNSAFE.getLong(delimiterAddress + 1);
- byte[] bytes = new byte[length];
- for (int i = 0; i < length; i++) {
- bytes[i] = UNSAFE.getByte(address + i);
- }
- String city = new String(bytes);
+ // Finish the masks and hash:
+ final long partialWord = word & ((mask >> 7) - 1);
+ buffer[bufferPtr++] = partialWord;
+ hash ^= partialWord;
- long[] dataCopy = new long[dataLength];
- System.arraycopy(data, 0, dataCopy, 0, dataLength);
+ final long invNumberBits = ~numberBits;
+ final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS);
- // And add entry:
- MeasurementRepository.Entry toAdd = new MeasurementRepository.Entry(address, dataCopy, length, hash, city, measurement);
- table[index] = toAdd;
+ // Update counter asap, lets CPU predict.
+ ptr = delimiterAddress + (decimalSepPos >> 3) + 4;
- toAdd.measurement.updateWith(temperature);
- }
+ int intHash = (int) (hash ^ (hash >>> 31)); // offset for extra entropy
+
+ // Awesome idea of merykitty:
+ final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos);
+
+ int index = intHash & TABLE_MASK;
- public Stream<MeasurementRepository.Entry> get() {
- return Arrays.stream(table).filter(Objects::nonNull);
+ // Find or insert the entry:
+ while (true) {
+ Entry tableEntry = table[index];
+ if (tableEntry == null) {
+ final int length = (int) (delimiterAddress - startAddress);
+ table[index] = createNewEntry(buffer, startAddress, bufferPtr, length, temp);
+ break;
+ }
+ else if (bufferPtr == tableEntry.data.length) {
+ if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) {
+ index = (index + 1) & TABLE_MASK;
+ continue;
+ }
+ // No differences in array
+ tableEntry.updateWith(temp);
+ break;
+ }
+ // Move to the next index
+ index = (index + 1) & TABLE_MASK;
+ }
}
+ return table;
+ }
+
+ private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) {
+ final long signed = (invNumberBits << 59) >> 63;
+ final long minusFilter = ~(signed & 0xFF);
+ final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L;
+ final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result
+ final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick
+ return temp;
+ }
+
+ private static long getDelimiterMask(final long word) {
+ long match = word ^ SEPARATOR_PATTERN;
+ return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
+ }
+
+ private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');
+ private static final long DOT_BITS = 0x10101000;
+ private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
+
+ private static long compilePattern(final byte value) {
+ return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
+ ((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
}
/**
* For case multiple hashes are equal (however unlikely) check the actual key (using longs)
*/
- private boolean arrayEquals(final long[] a, final long[] b, final int length) {
+ static boolean arrayEquals(final long[] a, final long[] b, final int length) {
for (int i = 0; i < length; i++) {
if (a[i] != b[i])
return false;
}
return true;
}
-
}