aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java283
1 files changed, 215 insertions, 68 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
index 5fc38ae..cd4d572 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
@@ -21,10 +21,12 @@ import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
+import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
import java.util.stream.Collectors;
@@ -44,6 +46,7 @@ import java.util.stream.Collectors;
* Improved String skip: 3250 ms
* Segmenting files: 3150 ms (based on spullara's code)
* Not using SWAR for EOL: 2850 ms
+ * Inlining hash calculation: 2450 ms
*
* Best performing JVM on MacBook M2 Pro: 21.0.1-graal
* `sdk use java 21.0.1-graal`
@@ -59,8 +62,8 @@ public class CalculateAverage_royvanrijn {
long sum;
public Measurement() {
- this.min = 10000;
- this.max = -10000;
+ this.min = 1000;
+ this.max = -1000;
}
public Measurement updateWith(int measurement) {
@@ -88,8 +91,32 @@ public class CalculateAverage_royvanrijn {
}
}
- public static final void main(String[] args) throws Exception {
+ public static void main(String[] args) throws Exception {
new CalculateAverage_royvanrijn().run();
+ // new CalculateAverage_royvanrijn().runTests();
+ }
+
+ private void testInput(final String inputString, final int start, final boolean bigEndian, final int[] expectedDelimiterAndHash, final long[] expectedCityNameLong) {
+
+ byte[] input = inputString.getBytes(StandardCharsets.UTF_8);
+
+ ByteBuffer buffer = ByteBuffer.wrap(input).order(bigEndian ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN);
+
+ int[] output = new int[2];
+ long[] cityName = new long[128];
+ findNextDelimiterAndCalculateHash(buffer, SEPARATOR_PATTERN, start, buffer.limit(), output, cityName, bigEndian);
+
+ if (!Arrays.equals(output, expectedDelimiterAndHash)) {
+ System.out.println("Error in delimiter or hash");
+ System.out.println("Expected: " + Arrays.toString(expectedDelimiterAndHash));
+ System.out.println("Received: " + Arrays.toString(output));
+ }
+ int amountLong = 1 + ((output[0] - start) >>> 3);
+ if (!Arrays.equals(cityName, 0, amountLong, expectedCityNameLong, 0, amountLong)) {
+ System.out.println("Error in long array");
+ System.out.println("Expected: " + Arrays.toString(expectedCityNameLong));
+ System.out.println("Received: " + Arrays.toString(cityName));
+ }
}
private void run() throws Exception {
@@ -99,30 +126,43 @@ public class CalculateAverage_royvanrijn {
long segmentEnd = segment.end();
try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) {
var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start());
- var buffer = new byte[64];
- // Force little endian:
- bb.order(ByteOrder.LITTLE_ENDIAN);
+ // Work with any UTF-8 city name, up to 100 in length:
+ var buffer = new byte[106]; // 100 + ; + -XX.X + \n
+ var cityNameAsLongArray = new long[13]; // 13*8=104=kenough.
+ var delimiterPointerAndHash = new int[2];
+
+ // Calculate using native ordering (fastest?):
+ bb.order(ByteOrder.nativeOrder());
- BitTwiddledMap measurements = new BitTwiddledMap();
+ // Record the order it is and calculate accordingly:
+ final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN);
+ MeasurementRepository measurements = new MeasurementRepository();
int startPointer;
int limit = bb.limit();
while ((startPointer = bb.position()) < limit) {
- // SWAR is faster for ';'
- int separatorPointer = findNextSWAR(bb, SEPARATOR_PATTERN, startPointer + 3, limit);
+ // SWAR method to find delimiter *and* record the cityname as long[] *and* calculate a hash:
+ findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, startPointer, limit, delimiterPointerAndHash, cityNameAsLongArray, bufferIsBigEndian);
+ int delimiterPointer = delimiterPointerAndHash[0];
- // Simple is faster for '\n' (just three options)
+ // Simple lookup is faster for '\n' (just three options)
int endPointer;
- if (bb.get(separatorPointer + 4) == '\n') {
- endPointer = separatorPointer + 4;
+
+ if (delimiterPointer >= limit) {
+ bb.position(limit); // skip to next line.
+ return measurements;
+ }
+
+ if (bb.get(delimiterPointer + 4) == '\n') {
+ endPointer = delimiterPointer + 4;
}
- else if (bb.get(separatorPointer + 5) == '\n') {
- endPointer = separatorPointer + 5;
+ else if (bb.get(delimiterPointer + 5) == '\n') {
+ endPointer = delimiterPointer + 5;
}
else {
- endPointer = separatorPointer + 6;
+ endPointer = delimiterPointer + 6;
}
// Read the entry in a single get():
@@ -130,20 +170,22 @@ public class CalculateAverage_royvanrijn {
bb.position(endPointer + 1); // skip to next line.
// Extract the measurement value (10x):
- final int nameLength = separatorPointer - startPointer;
- final int valueLength = endPointer - separatorPointer - 1;
- final int measured = branchlessParseInt(buffer, nameLength + 1, valueLength);
- measurements.getOrCreate(buffer, nameLength).updateWith(measured);
+ final int cityNameLength = delimiterPointer - startPointer;
+ final int measuredValueLength = endPointer - delimiterPointer - 1;
+ final int measuredValue = branchlessParseInt(buffer, cityNameLength + 1, measuredValueLength);
+
+ // Store everything in a custom hashtable:
+ measurements.update(buffer, cityNameLength, delimiterPointerAndHash[1], cityNameAsLongArray).updateWith(measuredValue);
}
return measurements;
}
catch (IOException e) {
throw new RuntimeException(e);
}
- }).parallel().flatMap(v -> v.values.stream())
- .collect(Collectors.toMap(e -> new String(e.key), BitTwiddledMap.Entry::measurement, (m1, m2) -> m1.updateWith(m2), TreeMap::new));
+ }).parallel()
+ .flatMap(v -> v.values.stream())
+ .collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, Measurement::updateWith, TreeMap::new));
- // Seems to perform better than actually using a TreeMap:
System.out.println(results);
}
@@ -151,47 +193,119 @@ public class CalculateAverage_royvanrijn {
* -------- This section contains SWAR code (SIMD Within A Register) which processes a bytebuffer as longs to find values:
*/
private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');
+ private static final long[] PARTIAL_INDEX_MASKS = new long[]{ 0L, 255L, 65535L, 16777215L, 4294967295L, 1099511627775L, 281474976710655L, 72057594037927935L };
+
+ public void runTests() {
+
+ // Method used for debugging purposes, easy to make mistakes with all the bit hacking.
+
+ // These all have the same hashes:
+ testInput("Delft;-12.4", 0, true, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L });
+ testInput("aDelft;-12.4", 1, true, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L });
+
+ testInput("Delft;-12.4", 0, false, new int[]{ 5, 1718384401 }, new long[]{ 499934586180L });
+ testInput("aDelft;-12.4", 1, false, new int[]{ 6, 1718384401 }, new long[]{ 499934586180L });
+
+ testInput("Rotterdam;-12.4", 0, true, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L });
+ testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, true, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L });
+ testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, true, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L });
+
+ testInput("Rotterdam;-12.4", 0, false, new int[]{ 9, -784321989 }, new long[]{ 7017859899421126482L, 109L });
+ testInput("abcdefghijklmnpoqrstuvwxyzRotterdam;-12.4", 26, false, new int[]{ 35, -784321989 }, new long[]{ 7017859899421126482L, 109L });
+ testInput("abcdefghijklmnpoqrstuvwxyzARotterdam;-12.4", 27, false, new int[]{ 36, -784321989 }, new long[]{ 7017859899421126482L, 109L });
+
+ // These have different hashes from the strings above:
+ testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, true, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L });
+ testInput("abcdefghijklmnpoqrstuvwxyzAROtterdam;-12.4", 27, false, new int[]{ 36, -792194501 }, new long[]{ 7017859899421118290L, 109L });
+
+ MeasurementRepository repository = new MeasurementRepository();
+
+ // Simulate adding two entries with the same hash:
+ byte[] b1 = "City1;10.0".getBytes();
+ byte[] b2 = "City2;41.1".getBytes();
+ repository.update(b1, 5, 1234, new long[]{ 1234L });
+ repository.update(b2, 5, 1234, new long[]{ 4321L });
+ // And update the same record shouldn't add a third (this happened):
+ repository.update(b1, 5, 1234, new long[]{ 1234L });
+
+ if (repository.values.size() != 2) {
+ System.out.println("Error, should have two entries:");
+ System.out.println(repository.values);
+ }
+
+ MeasurementRepository.Entry firstInserted = repository.values.getFirst();
+ if (!firstInserted.cityName.equals("City1")) {
+ System.out.println("Error, should have correct name: " + firstInserted.cityName);
+ }
+ }
- private int findNextSWAR(ByteBuffer bb, long pattern, int start, int limit) {
+ /**
+ * Already looping the longs here, lets shoehorn in making a hash
+ */
+ private void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output,
+ final long[] asLong, final boolean bufferBigEndian) {
+ int hash = 1;
int i;
+ int lCnt = 0;
for (i = start; i <= limit - 8; i += 8) {
long word = bb.getLong(i);
+ if (bufferBigEndian)
+ word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this
int index = firstAnyPattern(word, pattern);
if (index < Long.BYTES) {
- return i + index;
+ final long partialHash = word & PARTIAL_INDEX_MASKS[index];
+ asLong[lCnt] = partialHash;
+ hash = 961 * hash + 31 * (int) (partialHash >>> 32) + (int) partialHash;
+ output[0] = (i + index);
+ output[1] = hash;
+ return;
}
+ asLong[lCnt++] = word;
+ hash = 961 * hash + 31 * (int) (word >>> 32) + (int) word;
}
// Handle remaining bytes
+ long partialHash = 0;
for (; i < limit; i++) {
- if (bb.get(i) == (byte) pattern) {
- return i;
+ byte read;
+ if ((read = bb.get(i)) == (byte) pattern) {
+ asLong[lCnt] = partialHash;
+ hash = 961 * hash + 31 * (int) (partialHash >>> 32) + (int) partialHash;
+ output[0] = i;
+ output[1] = hash;
+ return;
}
+ partialHash = partialHash << 8 | read;
}
- return limit; // delimiter not found
+ output[0] = limit; // delimiter not found
+ output[1] = hash;
}
- private static long compilePattern(byte value) {
+ private static long compilePattern(final byte value) {
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
}
- private static int firstAnyPattern(long word, long pattern) {
+ private static int firstAnyPattern(final long word, final long pattern) {
final long match = word ^ pattern;
long mask = match - 0x0101010101010101L;
mask &= ~match;
mask &= 0x8080808080808080L;
- return Long.numberOfTrailingZeros(mask) >>> 3;
+ return Long.numberOfTrailingZeros(mask) >> 3;
}
record FileSegment(long start, long end) {
}
/** Using this way to segment the file is much prettier, from spullara */
- private static List<FileSegment> getFileSegments(File file) throws IOException {
+ private static List<FileSegment> getFileSegments(final File file) throws IOException {
final int numberOfSegments = Runtime.getRuntime().availableProcessors();
final long fileSize = file.length();
final long segmentSize = fileSize / numberOfSegments;
final List<FileSegment> segments = new ArrayList<>();
+ if (segmentSize < 1000) {
+ segments.add(new FileSegment(0, fileSize));
+ return segments;
+ }
try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) {
for (int i = 0; i < numberOfSegments; i++) {
long segStart = i * segmentSize;
@@ -205,7 +319,7 @@ public class CalculateAverage_royvanrijn {
return segments;
}
- private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException {
+ private static long findSegment(final int i, final int skipSegment, RandomAccessFile raf, long location, final long fileSize) throws IOException {
if (i != skipSegment) {
raf.seek(location);
while (location < fileSize) {
@@ -226,7 +340,7 @@ public class CalculateAverage_royvanrijn {
* @param input
* @return int value x10
*/
- private static int branchlessParseInt(final byte[] input, int start, int length) {
+ private static int branchlessParseInt(final byte[] input, final int start, final int length) {
// 0 if positive, 1 if negative
final int negative = ~(input[start] >> 4) & 1;
// 0 if nr length is 3, 1 if length is 4
@@ -258,66 +372,99 @@ public class CalculateAverage_royvanrijn {
*
* So I've written an extremely simple linear probing hashmap that should work well enough.
*/
- class BitTwiddledMap {
- private static final int SIZE = 16384; // A bit larger than the number of keys, needs power of two
- private int[] indices = new int[SIZE]; // Hashtable is just an int[]
+ class MeasurementRepository {
+ private int size = 16384;// 16384; // Much larger than the number of cities, needs power of two
+ private int[] indices = new int[size]; // Hashtable is just an int[]
+
+ MeasurementRepository() {
+ populateEmptyIndices(indices);
+ }
- BitTwiddledMap() {
+ private void populateEmptyIndices(int[] array) {
// Optimized fill with -1, fastest method:
- int len = indices.length;
- if (len > 0) {
- indices[0] = -1;
- }
+ int len = array.length;
+ array[0] = -1;
// Value of i will be [1, 2, 4, 8, 16, 32, ..., len]
for (int i = 1; i < len; i += i) {
- System.arraycopy(indices, 0, indices, i, i);
+ System.arraycopy(array, 0, array, i, i);
}
}
- private List<Entry> values = new ArrayList<>(512);
+ private final List<Entry> values = new ArrayList<>(512);
- record Entry(int hash, byte[] key, Measurement measurement) {
+ record Entry(int hash, long[] cityNameAsLong, String cityName, Measurement measurement) {
@Override
public String toString() {
- return new String(key) + "=" + measurement;
+ return cityName + "=" + measurement;
}
}
- /**
- * Who needs methods like add(), merge(), compute() etc, we need one, getOrCreate.
- * @param key
- * @return
- */
- public Measurement getOrCreate(byte[] key, int length) {
- int inHash;
- int index = (SIZE - 1) & (inHash = hashCode(key, length));
+ public Measurement update(byte[] buffer, int length, int calculatedHash, long[] cityNameAsLongArray) {
+
+ final int cityNameAsLongLength = 1 + (length >>> 3); // amount of longs that captures this cityname
+
+ int hashtableIndex = (size - 1) & calculatedHash;
int valueIndex;
+
Entry retrievedEntry = null;
- while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != inHash) {
- index = (index + 1) % SIZE;
+
+ while (true) { // search for the right spot
+ if ((valueIndex = indices[hashtableIndex]) == -1) {
+ break; // Empty slot found, stop the loop
+ }
+ else {
+ // Non-empty slot, retrieve entry
+ if ((retrievedEntry = values.get(valueIndex)).hash == calculatedHash &&
+ arrayEquals(retrievedEntry.cityNameAsLong, cityNameAsLongArray, cityNameAsLongLength)) {
+ break; // Both hash and cityname match, stop the loop
+ }
+ }
+ // Move to the next index
+ hashtableIndex = (hashtableIndex + 1) % size;
}
+
if (valueIndex >= 0) {
return retrievedEntry.measurement;
}
- // New entry, insert into table and return.
- indices[index] = values.size();
- // Only parse this once:
- byte[] actualKey = new byte[length];
- System.arraycopy(key, 0, actualKey, 0, length);
+ // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!)
+
+ // Keep the already processed longs for fast equals:
+ long[] cityNameAsLongArrayCopy = new long[cityNameAsLongLength];
+ System.arraycopy(cityNameAsLongArray, 0, cityNameAsLongArrayCopy, 0, cityNameAsLongLength);
+
+ Entry toAdd = new Entry(calculatedHash, cityNameAsLongArrayCopy, new String(buffer, 0, length), new Measurement());
+
+ // Code to regrow (if we get more unique entries): (not needed/not optimized yet)
+ // if (values.size() > size / 2) {
+ // // We probably don't want this...
+ //
+ // int newSize = size << 1;
+ // int[] newIndices = new int[newSize];
+ // populateEmptyIndices(newIndices);
+ // for (int i = 0; i < values.size(); i++) {
+ // Entry e = values.get(i);
+ // int updatedIndex = (newSize - 1) & e.hash;
+ // newIndices[updatedIndex] = i;
+ // }
+ // indices = newIndices;
+ // size = newSize;
+ // }
+ indices[hashtableIndex] = values.size();
- Entry toAdd = new Entry(inHash, actualKey, new Measurement());
values.add(toAdd);
return toAdd.measurement;
}
+ }
- private static int hashCode(byte[] a, int length) {
- int result = 1;
- for (int i = 0; i < length; i++) {
- result = 31 * result + a[i];
- }
- return result;
+ /**
+ * For case multiple hashes are equal (however unlikely) check the actual key (using longs)
+ */
+ private boolean arrayEquals(final long[] a, final long[] b, final int length) {
+ for (int i = 0; i < length; i++) {
+ if (a[i] != b[i])
+ return false;
}
+ return true;
}
-
}