aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java188
1 files changed, 113 insertions, 75 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
index f1e8303..307833f 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
@@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -53,8 +52,13 @@ import sun.misc.Unsafe;
* Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine)
* Improved layout/predictability: 1400 ms
* Delayed String creation again: 1350 ms
+ * Remove writing to buffer: 1335 ms
+ * Optimized collecting at the end: 1310 ms
+ * Adding a lot of comments: priceless
*
* Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai for ideas.
+ *
+ * Follow me at: @royvanrijn
*/
public class CalculateAverage_royvanrijn {
@@ -74,29 +78,24 @@ public class CalculateAverage_royvanrijn {
}
public static void main(String[] args) throws Exception {
-
// Calculate input segments.
final int numberOfChunks = Runtime.getRuntime().availableProcessors();
final long[] chunks = getSegments(numberOfChunks);
- final List<Entry[]> repositories = IntStream.range(0, chunks.length - 1)
+ final Map<String, Entry> measurements = HashMap.newHashMap(1 << 10);
+ IntStream.range(0, chunks.length - 1)
.mapToObj(chunkIndex -> processMemoryArea(chunks[chunkIndex], chunks[chunkIndex + 1]))
.parallel()
- .toList();
-
- // Sometimes simple is better:
- final HashMap<String, Entry> measurements = HashMap.newHashMap(1 << 10);
- for (Entry[] entries : repositories) {
- for (Entry entry : entries) {
- if (entry != null)
- measurements.merge(extractedCityFromLongArray(entry.data, entry.length), entry, Entry::mergeWith);
- }
- }
+ .forEachOrdered(repo -> { // make sure it's ordered, no concurrent map
+ for (Entry entry : repo) {
+ if (entry != null)
+ measurements.merge(turnLongArrayIntoString(entry.data, entry.length), entry, Entry::mergeWith);
+ }
+ });
System.out.print("{" +
measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", ")));
System.out.println("}");
-
}
/**
@@ -123,15 +122,20 @@ public class CalculateAverage_royvanrijn {
}
}
- private static final int TABLE_SIZE = 1 << 19; // large enough for the contest.
- private static final int TABLE_MASK = (TABLE_SIZE - 1);
-
+ // This is where I store the hashtable entry data in the "hot loop"
+ // The long[] contains the name in bytes (yeah, confusing)
+ // I've tried flyweight-ing, carrying all the data in a single byte[],
+ // where you offset type-indices: min:int,max:int,count:int,etc.
+ //
+ // The performance was just a little worse than this simple class.
static final class Entry {
- private final long[] data;
- private int min, max, count, length;
+
+ private int min, max, count;
+ private byte length;
private long sum;
+ private final long[] data;
- Entry(final long[] data, int length, int temp) {
+ Entry(final long[] data, byte length, int temp) {
this.data = data;
this.length = length;
this.min = temp;
@@ -164,127 +168,161 @@ public class CalculateAverage_royvanrijn {
}
}
- /**
- * Delay String creation until the end:
- * @param data
- * @param length
- * @return
- */
- private static String extractedCityFromLongArray(final long[] data, final int length) {
- // Initiate as late as possible:
+ // Only parse the String at the final end, when we have only the needed entries left that we need to output:
+ private static String turnLongArrayIntoString(final long[] data, final int length) {
+ // Create our target byte[]
final byte[] bytes = new byte[length];
+ // The power of magic allows us to just copy the memory in there.
UNSAFE.copyMemory(data, Unsafe.ARRAY_LONG_BASE_OFFSET, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
+ // And construct a String()
return new String(bytes, StandardCharsets.UTF_8);
}
- private static Entry createNewEntry(final long[] buffer, final int lengthLongs, final int lengthBytes, final int temp) {
-
+ private static Entry createNewEntry(final long fromAddress, final int lengthLongs, final byte lengthBytes, final int temp) {
+ // Make a copy of our working buffer, store this in a new Entry:
final long[] bufferCopy = new long[lengthLongs];
- System.arraycopy(buffer, 0, bufferCopy, 0, lengthLongs);
-
- // Add the entry:
+ // Just copy everything over, bytes into the long[]
+ UNSAFE.copyMemory(null, fromAddress, bufferCopy, Unsafe.ARRAY_BYTE_BASE_OFFSET, lengthBytes);
return new Entry(bufferCopy, lengthBytes, temp);
}
- private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) {
+ private static final int TABLE_SIZE = 1 << 19;
+ private static final int TABLE_MASK = (TABLE_SIZE - 1);
- final Entry[] table = new Entry[TABLE_SIZE];
- final long[] buffer = new long[16];
+ private static Entry[] processMemoryArea(final long fromAddress, final long toAddress) {
- long ptr = fromAddress;
- int bufferPtr;
+ int packedBytes;
long hash;
+ long ptr = fromAddress;
long word;
long mask;
+ final Entry[] table = new Entry[TABLE_SIZE];
+
+ // Go from start to finish address through the bytes:
while (ptr < toAddress) {
final long startAddress = ptr;
- bufferPtr = 0;
- hash = 1;
+ packedBytes = 1;
+ hash = 0;
word = UNSAFE.getLong(ptr);
mask = getDelimiterMask(word);
+ // Removed writing to a buffer here, why would we, we know the address and we'll need to check there anyway.
while (mask == 0) {
- buffer[bufferPtr++] = word;
+ // If the mask is zero, we have no ';'
+ packedBytes++;
+ // So we continue building the hash:
hash ^= word;
ptr += 8;
+ // And getting a new value and mask:
word = UNSAFE.getLong(ptr);
mask = getDelimiterMask(word);
}
+
// Found delimiter:
- final long delimiterAddress = ptr + (Long.numberOfTrailingZeros(mask) >> 3);
- final long numberBits = UNSAFE.getLong(delimiterAddress + 1);
+ final int delimiterByte = Long.numberOfTrailingZeros(mask);
+ final long delimiterAddress = ptr + (delimiterByte >> 3);
// Finish the masks and hash:
- word = word & ((mask >> 7) - 1);
- buffer[bufferPtr++] = word;
- hash ^= word;
+ final long partialWord = word & ((mask >>> 7) - 1);
+ hash ^= partialWord;
- final long invNumberBits = ~numberBits;
- final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBits & DOT_BITS);
+ // Read a long value from memory starting from the delimiter + 1, the number part:
+ final long numberBytes = UNSAFE.getLong(delimiterAddress + 1);
+ final long invNumberBytes = ~numberBytes;
- // Update counter asap, lets CPU predict.
+ // Adjust our pointer
+ final int decimalSepPos = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS);
ptr = delimiterAddress + (decimalSepPos >> 3) + 4;
- // Awesome idea of merykitty:
- final int temp = extractTemp(numberBits, invNumberBits, decimalSepPos);
-
- int intHash = (int) (hash ^ (hash >>> 33)); // offset for extra entropy
+ // Calculate the final hash and index of the table:
+ int intHash = (int) (hash ^ (hash >> 32));
+ intHash = intHash ^ (intHash >> 17);
int index = intHash & TABLE_MASK;
// Find or insert the entry:
while (true) {
Entry tableEntry = table[index];
if (tableEntry == null) {
- final int length = (int) (delimiterAddress - startAddress);
- table[index] = createNewEntry(buffer, bufferPtr, length, temp);
+ final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes);
+ // Create a new entry:
+ final byte length = (byte) (delimiterAddress - startAddress);
+ table[index] = createNewEntry(startAddress, packedBytes, length, temp);
break;
}
- else if (bufferPtr == tableEntry.data.length) {
- if (!arrayEquals(buffer, tableEntry.data, bufferPtr)) {
- index = (index + 1) & TABLE_MASK;
- continue;
- }
- // No differences in array
+ // Don't bother re-checking things here like hash or length.
+ // we'll need to check the content anyway if it's a hit, which is most times
+ else if (memoryEqualsEntry(startAddress, tableEntry.data, partialWord, packedBytes)) {
+ // temperature, you're not temporary my friend
+ final int temp = extractTemp(decimalSepPos, invNumberBytes, numberBytes);
+ // No differences, same entry:
tableEntry.updateWith(temp);
break;
}
- // Move to the next index
+ // Move to the next in the table, linear probing:
index = (index + 1) & TABLE_MASK;
}
}
return table;
}
- private static int extractTemp(final long numberBits, final long invNumberBits, final int decimalSepPos) {
+ /*
+ * `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___
+ * / ` \| _ \ __| \| \ \ / /_\ | | | | | | __|
+ * | () | _ / __|| . |\ V / _ \| |_| |_| | ._|
+ * \___/|_| |___|_|\_| \_/_/ \_\___|\___/|___|
+ * ---------------- BETTER SOFTWARE, FASTER --
+ *
+ * https://www.openvalue.eu/
+ *
+ * Made you look.
+ *
+ */
+
+ private static final long DOT_BITS = 0x10101000;
+ private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
+
+ private static int extractTemp(final int decimalSepPos, final long invNumberBits, final long numberBits) {
+ // Awesome idea of merykitty:
+ int min28 = (28 - decimalSepPos);
+ // Calculates the sign
final long signed = (invNumberBits << 59) >> 63;
final long minusFilter = ~(signed & 0xFF);
- final long digits = ((numberBits & minusFilter) << (28 - decimalSepPos)) & 0x0F000F0F00L;
- final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // filter just the result
+ // Use the pre-calculated decimal position to adjust the values
+ final long digits = ((numberBits & minusFilter) << min28) & 0x0F000F0F00L;
+ // Multiply by a magic (100 * 0x1000000 + 10 * 0x10000 + 1), to get the result
+ final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
+ // And perform abs()
final int temp = (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick
return temp;
}
+ private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;
+
+ // Takes a long and finds the bytes where this exact pattern is present.
+ // Cool bit manipulation technique: SWAR (SIMD as a Register).
private static long getDelimiterMask(final long word) {
- long match = word ^ SEPARATOR_PATTERN;
- return (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
+ final long match = word ^ SEPARATOR_PATTERN;
+ return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
+ // I've put some brackets separating the first and second part, this is faster.
+ // Now they run simultaneous after 'match' is altered, instead of waiting on each other.
}
- private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;
- private static final long DOT_BITS = 0x10101000;
- private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
-
/**
* For case multiple hashes are equal (however unlikely) check the actual key (using longs)
*/
- static boolean arrayEquals(final long[] a, final long[] b, final int length) {
- for (int i = 0; i < length; i++) {
- if (a[i] != b[i])
+ private static boolean memoryEqualsEntry(final long startAddress, final long[] entry, final long finalBytes, final int amountLong) {
+ for (int i = 0; i < (amountLong - 1); i++) {
+ int step = i << 3; // step by 8 bytes
+ if (UNSAFE.getLong(startAddress + step) != entry[i])
return false;
}
- return true;
+ // If all previous 'whole' 8-packed byte-long values are equal
+ // We still need to check the final bytes that don't fit.
+ // and we've already calculated them for the hash.
+ return finalBytes == entry[amountLong - 1];
}
}