aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java348
1 files changed, 230 insertions, 118 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java b/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java
index 0ec4856..5325816 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_serkan_ozal.java
@@ -59,7 +59,7 @@ public class CalculateAverage_serkan_ozal {
? ByteVector.SPECIES_128
: ByteVector.SPECIES_64;
private static final int BYTE_SPECIES_SIZE = BYTE_SPECIES.vectorByteSize();
- private static final MemorySegment ALL = MemorySegment.NULL.reinterpret(Long.MAX_VALUE);
+ private static final MemorySegment NULL = MemorySegment.NULL.reinterpret(Long.MAX_VALUE);
private static final ByteOrder NATIVE_BYTE_ORDER = ByteOrder.nativeOrder();
private static final char NEW_LINE_SEPARATOR = '\n';
@@ -290,7 +290,7 @@ public class CalculateAverage_serkan_ozal {
long regionStart = regionGiven ? (r.address() + task.start) : r.address();
long regionEnd = regionStart + task.size;
- doProcessRegion(r, r.address(), regionStart, regionEnd);
+ doProcessRegion(regionStart, regionEnd);
}
if (VERBOSE) {
@@ -334,105 +334,204 @@ public class CalculateAverage_serkan_ozal {
}
}
- private void doProcessRegion(MemorySegment region, long regionAddress, long regionStart, long regionEnd) {
+ private long findClosestLineEnd(long endPos, long minPos) {
+ int i = 0;
+ int maxI = Math.min(MAX_LINE_LENGTH, (int) (endPos - minPos));
+ while (i <= maxI && U.getByte(endPos - i) != NEW_LINE_SEPARATOR) {
+ i++;
+ }
+ return endPos - i + 1;
+ }
+
+ // Credits: merykitty
+ private long extractValue(long regionPtr, long word, OpenMap map, int entryOffset) {
+ // Parse and extract value
+ int decimalSepPos = Long.numberOfTrailingZeros(~word & 0x10101000);
+ int shift = 28 - decimalSepPos;
+ long signed = (~word << 59) >> 63;
+ long designMask = ~(signed & 0xFF);
+ long digits = ((word & designMask) << shift) & 0x0F000F0F00L;
+ long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ int value = (int) ((absValue ^ signed) - signed);
+
+ // Put extracted value into map
+ map.putValue(entryOffset, value);
+
+ // Return new position
+ return regionPtr + (decimalSepPos >>> 3) + 3;
+ }
+
+ private void doProcessRegion(long regionStart, long regionEnd) {
final int vectorSize = BYTE_SPECIES.vectorByteSize();
- final long regionMainLimit = regionEnd - BYTE_SPECIES_SIZE;
- long regionPtr;
+ final long size = regionEnd - regionStart;
+ final long segmentSize = size / 2;
+
+ final long regionStart1 = regionStart;
+ final long regionEnd1 = Math.max(regionStart1, findClosestLineEnd(regionStart1 + segmentSize, regionStart));
+
+ final long regionStart2 = regionEnd1;
+ final long regionEnd2 = regionEnd;
+
+ long regionPtr1, regionPtr2;
// Read and process region - main
- for (regionPtr = regionStart; regionPtr < regionMainLimit;) {
- regionPtr = doProcessLine(regionPtr, vectorSize);
- }
+ // Inspired by: @jerrinot
+ // - two lines at a time (according to my experiment, this is optimum value in terms of register spilling)
+ // - most of the implementation is inlined
+ // - so get the benefit of ILP (Instruction Level Parallelism) better
+ for (regionPtr1 = regionStart1, regionPtr2 = regionStart2; regionPtr1 < regionEnd1 && regionPtr2 < regionEnd2;) {
+ // Search key/value separators and find keys' start and end positions
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
+ long keyStartPtr1 = regionPtr1;
+ long keyStartPtr2 = regionPtr2;
+
+ ByteVector keyVector1 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr1, NATIVE_BYTE_ORDER);
+ ByteVector keyVector2 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr2, NATIVE_BYTE_ORDER);
+
+ int keyLength1 = keyVector1.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
+ int keyLength2 = keyVector2.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
+
+ if (keyLength1 != vectorSize && keyLength2 != vectorSize) {
+ regionPtr1 += (keyLength1 + 1);
+ regionPtr2 += (keyLength2 + 1);
+ }
+ else {
+ if (keyLength1 != vectorSize) {
+ regionPtr1 += (keyLength1 + 1);
+ }
+ else {
+ regionPtr1 += vectorSize;
+ for (; U.getByte(regionPtr1) != KEY_VALUE_SEPARATOR; regionPtr1++)
+ ;
+ keyLength1 = (int) (regionPtr1 - keyStartPtr1);
+ regionPtr1++;
+ }
+ if (keyLength2 != vectorSize) {
+ regionPtr2 += (keyLength2 + 1);
+ }
+ else {
+ regionPtr2 += vectorSize;
+ for (; U.getByte(regionPtr2) != KEY_VALUE_SEPARATOR; regionPtr2++)
+ ;
+ keyLength2 = (int) (regionPtr2 - keyStartPtr2);
+ regionPtr2++;
+ }
+ }
- // Read and process region - tail
- for (long i = regionPtr, j = regionPtr; i < regionEnd;) {
- byte b = U.getByte(i);
- if (b == KEY_VALUE_SEPARATOR) {
- long baseOffset = map.putKey(null, j, (int) (i - j));
- i = extractValue(i + 1, map, baseOffset);
- j = i;
+ // Read first words as they will be used while extracting values later
+ long word1 = U.getLong(regionPtr1);
+ long word2 = U.getLong(regionPtr2);
+ if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
+ word1 = Long.reverseBytes(word1);
+ word2 = Long.reverseBytes(word2);
+ }
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+ // Calculate key hashes and find entry indexes
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
+ int x1, y1, x2, y2;
+ if (keyLength1 >= Integer.BYTES && keyLength2 >= Integer.BYTES) {
+ x1 = U.getInt(keyStartPtr1);
+ y1 = U.getInt(keyStartPtr1 + keyLength1 - Integer.BYTES);
+ x2 = U.getInt(keyStartPtr2);
+ y2 = U.getInt(keyStartPtr2 + keyLength2 - Integer.BYTES);
}
else {
- i++;
+ if (keyLength1 >= Integer.BYTES) {
+ x1 = U.getInt(keyStartPtr1);
+ y1 = U.getInt(keyStartPtr1 + keyLength1 - Integer.BYTES);
+ }
+ else {
+ x1 = U.getByte(keyStartPtr1);
+ y1 = U.getByte(keyStartPtr1 + keyLength1 - Byte.BYTES);
+ }
+ if (keyLength2 >= Integer.BYTES) {
+ x2 = U.getInt(keyStartPtr2);
+ y2 = U.getInt(keyStartPtr2 + keyLength2 - Integer.BYTES);
+ }
+ else {
+ x2 = U.getByte(keyStartPtr2);
+ y2 = U.getByte(keyStartPtr2 + keyLength2 - Byte.BYTES);
+ }
}
- }
- }
- private long doProcessLine(long regionPtr, int vectorSize) {
- // Find key/value separator
- ////////////////////////////////////////////////////////////////////////////////////////////////////////
- long keyStartPtr = regionPtr;
+ int keyHash1 = (Integer.rotateLeft(x1 * OpenMap.HASH_SEED, OpenMap.HASH_ROTATE) ^ y1) * OpenMap.HASH_SEED;
+ int keyHash2 = (Integer.rotateLeft(x2 * OpenMap.HASH_SEED, OpenMap.HASH_ROTATE) ^ y2) * OpenMap.HASH_SEED;
- // Vectorized search for key/value separator
- ByteVector keyVector = ByteVector.fromMemorySegment(BYTE_SPECIES, ALL, regionPtr, NATIVE_BYTE_ORDER);
+ int entryIdx1 = (keyHash1 & OpenMap.ENTRY_HASH_MASK) << OpenMap.ENTRY_SIZE_SHIFT;
+ int entryIdx2 = (keyHash2 & OpenMap.ENTRY_HASH_MASK) << OpenMap.ENTRY_SIZE_SHIFT;
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
- int keyLength = keyVector.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
- // Check whether key/value separator is found in the first vector (city name is <= vector size)
- if (keyLength != vectorSize) {
- regionPtr += (keyLength + 1);
- }
- else {
- regionPtr += vectorSize;
- for (; U.getByte(regionPtr) != KEY_VALUE_SEPARATOR; regionPtr++)
- ;
- keyLength = (int) (regionPtr - keyStartPtr);
- regionPtr++;
- // I have tried vectorized search for key/value separator in the remaining part,
- // but since majority (99%) of the city names <= 16 bytes
- // and other a few longer city names (have length < 16 and <= 32) not close to 32 bytes,
- // byte by byte search is better in terms of performance (according to my experiments) and simplicity.
- }
- ////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // Put keys and calculate entry offsets to put values
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
+ int entryOffset1 = map.putKey(keyVector1, keyStartPtr1, keyLength1, entryIdx1);
+ int entryOffset2 = map.putKey(keyVector2, keyStartPtr2, keyLength2, entryIdx2);
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
- // Put key and get map offset to put value
- long entryOffset = map.putKey(keyVector, keyStartPtr, keyLength);
+ // Extract values by parsing and put them into map
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
+ regionPtr1 = extractValue(regionPtr1, word1, map, entryOffset1);
+ regionPtr2 = extractValue(regionPtr2, word2, map, entryOffset2);
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////
+ }
- // Extract value, put it into map and return next position in the region to continue processing from there
- return extractValue(regionPtr, map, entryOffset);
+ // Read and process region - tail
+ doProcessTail(regionPtr1, regionEnd1, regionPtr2, regionEnd2, vectorSize);
}
- }
- // Credits: merykitty
- private static long extractValue(long regionPtr, OpenMap map, long entryOffset) {
- long word = U.getLong(regionPtr);
- if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
- word = Long.reverseBytes(word);
+ private void doProcessTail(long regionPtr1, long regionEnd1, long regionPtr2, long regionEnd2, int vectorSize) {
+ while (regionPtr1 < regionEnd1) {
+ long keyStartPtr1 = regionPtr1;
+ ByteVector keyVector1 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr1, NATIVE_BYTE_ORDER);
+ int keyLength1 = keyVector1.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
+ if (keyLength1 != vectorSize) {
+ regionPtr1 += (keyLength1 + 1);
+ }
+ else {
+ regionPtr1 += vectorSize;
+ for (; U.getByte(regionPtr1) != KEY_VALUE_SEPARATOR; regionPtr1++)
+ ;
+ keyLength1 = (int) (regionPtr1 - keyStartPtr1);
+ regionPtr1++;
+ }
+ int entryIdx1 = map.calculateEntryIndex(keyStartPtr1, keyLength1);
+ int entryOffset1 = map.putKey(keyVector1, keyStartPtr1, keyLength1, entryIdx1);
+ long word1 = U.getLong(regionPtr1);
+ if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
+ word1 = Long.reverseBytes(word1);
+ }
+ regionPtr1 = extractValue(regionPtr1, word1, map, entryOffset1);
+ }
+ while (regionPtr2 < regionEnd2) {
+ long keyStartPtr2 = regionPtr2;
+ ByteVector keyVector2 = ByteVector.fromMemorySegment(BYTE_SPECIES, NULL, regionPtr2, NATIVE_BYTE_ORDER);
+ int keyLength2 = keyVector2.compare(VectorOperators.EQ, KEY_VALUE_SEPARATOR).firstTrue();
+ if (keyLength2 != vectorSize) {
+ regionPtr2 += (keyLength2 + 1);
+ }
+ else {
+ regionPtr2 += vectorSize;
+ for (; U.getByte(regionPtr2) != KEY_VALUE_SEPARATOR; regionPtr2++)
+ ;
+ keyLength2 = (int) (regionPtr2 - keyStartPtr2);
+ regionPtr2++;
+ }
+ int entryIdx2 = map.calculateEntryIndex(keyStartPtr2, keyLength2);
+ int entryOffset2 = map.putKey(keyVector2, keyStartPtr2, keyLength2, entryIdx2);
+ long word2 = U.getLong(regionPtr2);
+ if (NATIVE_BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
+ word2 = Long.reverseBytes(word2);
+ }
+ regionPtr2 = extractValue(regionPtr2, word2, map, entryOffset2);
+ }
}
- // Parse and extract value
- int decimalSepPos = Long.numberOfTrailingZeros(~word & 0x10101000);
- int shift = 28 - decimalSepPos;
- long signed = (~word << 59) >> 63;
- long designMask = ~(signed & 0xFF);
- long digits = ((word & designMask) << shift) & 0x0F000F0F00L;
- long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- int value = (int) ((absValue ^ signed) - signed);
-
- // Put extracted value into map
- map.putValue(entryOffset, value);
-
- // Return new position
- return regionPtr + (decimalSepPos >>> 3) + 3;
}
/**
- * Region processor request
+ * Region processor task
*/
- private static final class Request {
-
- private final Arena arena;
- private final Queue<Task> sharedTasks;
- private final Result result;
-
- private Request(Arena arena, Queue<Task> sharedTasks, Result result) {
- this.arena = arena;
- this.sharedTasks = sharedTasks;
- this.result = result;
- }
-
- }
-
private static final class Task {
private final FileChannel fileChannel;
@@ -452,6 +551,23 @@ public class CalculateAverage_serkan_ozal {
}
/**
+ * Region processor request
+ */
+ private static final class Request {
+
+ private final Arena arena;
+ private final Queue<Task> sharedTasks;
+ private final Result result;
+
+ private Request(Arena arena, Queue<Task> sharedTasks, Result result) {
+ this.arena = arena;
+ this.sharedTasks = sharedTasks;
+ this.result = result;
+ }
+
+ }
+
+ /**
* Region processor response
*/
private static final class Response {
@@ -555,6 +671,9 @@ public class CalculateAverage_serkan_ozal {
}
+ /**
+ * Custom map implementation to store results
+ */
private static final class OpenMap {
// Layout
@@ -585,21 +704,22 @@ public class CalculateAverage_serkan_ozal {
private static final int ENTRY_MASK = MAP_SIZE - 1;
private static final int KEY_ARRAY_OFFSET = KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET;
+ private static final int HASH_SEED = 0x9E3779B9;
+ private static final int HASH_ROTATE = 5;
+
private final byte[] data;
- private final long[] entryOffsets;
+ private final int[] entryOffsets;
private int entryOffsetIdx;
private OpenMap() {
this.data = new byte[MAP_SIZE];
// Max number of unique keys are 10K, so 1 << 14 (16384) is long enough to hold offsets for all of them
- this.entryOffsets = new long[1 << 14];
+ this.entryOffsets = new int[1 << 14];
this.entryOffsetIdx = 0;
}
// Credits: merykitty
- private static int calculateKeyHash(long address, int keyLength) {
- int seed = 0x9E3779B9;
- int rotate = 5;
+ private int calculateEntryIndex(long address, int keyLength) {
int x, y;
if (keyLength >= Integer.BYTES) {
x = U.getInt(address);
@@ -609,19 +729,17 @@ public class CalculateAverage_serkan_ozal {
x = U.getByte(address);
y = U.getByte(address + keyLength - Byte.BYTES);
}
- return (Integer.rotateLeft(x * seed, rotate) ^ y) * seed;
+ // Calculate key hash
+ int keyHash = (Integer.rotateLeft(x * HASH_SEED, HASH_ROTATE) ^ y) * HASH_SEED;
+ // Get the position of the entry in the linear map based on calculated hash
+ return (keyHash & ENTRY_HASH_MASK) << ENTRY_SIZE_SHIFT;
}
- private long putKey(ByteVector keyVector, long keyStartAddress, int keyLength) {
- // Calculate hash of key
- int keyHash = calculateKeyHash(keyStartAddress, keyLength);
- // and get the position of the entry in the linear map based on calculated hash
- int idx = (keyHash & ENTRY_HASH_MASK) << ENTRY_SIZE_SHIFT;
-
+ private int putKey(ByteVector keyVector, long keyStartAddress, int keyLength, int entryIdx) {
// Start searching from the calculated position
// and continue until find an available slot in case of hash collision
// TODO Prevent infinite loop if all the slots are in use for other keys
- for (long entryOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + idx;; entryOffset = (entryOffset + ENTRY_SIZE) & ENTRY_MASK) {
+ for (int entryOffset = Unsafe.ARRAY_BYTE_BASE_OFFSET + entryIdx;; entryOffset = (entryOffset + ENTRY_SIZE) & ENTRY_MASK) {
int keySize = U.getInt(data, entryOffset + KEY_SIZE_OFFSET);
// Check whether current index is empty (no another key is inserted yet)
if (keySize == 0) {
@@ -633,32 +751,26 @@ public class CalculateAverage_serkan_ozal {
entryOffsets[entryOffsetIdx++] = entryOffset;
return entryOffset;
}
- int keyStartArrayOffset = (int) entryOffset + KEY_ARRAY_OFFSET;
// Check for hash collision (hashes are same, but keys are different).
// If there is no collision (both hashes and keys are equals), return current slot's offset.
// Otherwise, continue iterating until find an available slot.
- if (keySize == keyLength && keysEqual(keyVector, keyStartAddress, keyLength, keyStartArrayOffset)) {
+ if (keySize == keyLength && keysEqual(keyVector, keyStartAddress, keyLength, entryOffset + KEY_ARRAY_OFFSET)) {
return entryOffset;
}
}
}
private boolean keysEqual(ByteVector keyVector, long keyStartAddress, int keyLength, int keyStartArrayOffset) {
- int keyCheckIdx = 0;
- if (keyVector != null) {
- // Use vectorized search for the comparison of keys.
- // Since majority of the city names >= 8 bytes and <= 16 bytes,
- // this way is more efficient (according to my experiments) than any other comparisons (byte by byte or 2 longs).
- ByteVector entryKeyVector = ByteVector.fromArray(BYTE_SPECIES, data, keyStartArrayOffset);
- long eqMask = keyVector.compare(VectorOperators.EQ, entryKeyVector).toLong();
- int eqCount = Long.numberOfTrailingZeros(~eqMask);
- if (eqCount >= keyLength) {
- return true;
- }
- else if (keyLength <= BYTE_SPECIES_SIZE) {
- return false;
- }
- keyCheckIdx = BYTE_SPECIES_SIZE;
+ // Use vectorized search for the comparison of keys.
+ // Since majority of the city names >= 8 bytes and <= 16 bytes,
+ // this way is more efficient (according to my experiments) than any other comparisons (byte by byte or 2 longs).
+ ByteVector entryKeyVector = ByteVector.fromArray(BYTE_SPECIES, data, keyStartArrayOffset);
+ int eqCount = keyVector.compare(VectorOperators.EQ, entryKeyVector).trueCount();
+ if (eqCount == keyLength) {
+ return true;
+ }
+ else if (keyLength <= BYTE_SPECIES_SIZE) {
+ return false;
}
// Compare remaining parts of the keys
@@ -671,7 +783,7 @@ public class CalculateAverage_serkan_ozal {
long keyStartOffset = keyStartArrayOffset + Unsafe.ARRAY_BYTE_BASE_OFFSET;
int alignedKeyLength = normalizedKeyLength & 0xFFFFFFF8;
int i;
- for (i = keyCheckIdx; i < alignedKeyLength; i += Long.BYTES) {
+ for (i = BYTE_SPECIES_SIZE; i < alignedKeyLength; i += Long.BYTES) {
if (U.getLong(keyStartAddress + i) != U.getLong(data, keyStartOffset + i)) {
return false;
}
@@ -690,18 +802,18 @@ public class CalculateAverage_serkan_ozal {
return wordA == wordB;
}
- private void putValue(long entryOffset, int value) {
- long countOffset = entryOffset + COUNT_OFFSET;
+ private void putValue(int entryOffset, int value) {
+ int countOffset = entryOffset + COUNT_OFFSET;
U.putInt(data, countOffset, U.getInt(data, countOffset) + 1);
- long minValueOffset = entryOffset + MIN_VALUE_OFFSET;
+ int minValueOffset = entryOffset + MIN_VALUE_OFFSET;
if (value < U.getShort(data, minValueOffset)) {
U.putShort(data, minValueOffset, (short) value);
}
- long maxValueOffset = entryOffset + MAX_VALUE_OFFSET;
+ int maxValueOffset = entryOffset + MAX_VALUE_OFFSET;
if (value > U.getShort(data, maxValueOffset)) {
U.putShort(data, maxValueOffset, (short) value);
}
- long sumOffset = entryOffset + VALUE_SUM_OFFSET;
+ int sumOffset = entryOffset + VALUE_SUM_OFFSET;
U.putLong(data, sumOffset, U.getLong(data, sumOffset) + value);
}
@@ -709,13 +821,13 @@ public class CalculateAverage_serkan_ozal {
// Merge this local map into global result map
Arrays.sort(entryOffsets, 0, entryOffsetIdx);
for (int i = 0; i < entryOffsetIdx; i++) {
- long entryOffset = entryOffsets[i];
+ int entryOffset = entryOffsets[i];
int keyLength = U.getInt(data, entryOffset + KEY_SIZE_OFFSET);
if (keyLength == 0) {
// No entry is available for this index, so continue iterating
continue;
}
- int entryArrayIdx = (int) (entryOffset + KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET);
+ int entryArrayIdx = entryOffset + KEY_OFFSET - Unsafe.ARRAY_BYTE_BASE_OFFSET;
String key = new String(data, entryArrayIdx, keyLength, StandardCharsets.UTF_8);
int count = U.getInt(data, entryOffset + COUNT_OFFSET);
short minValue = U.getShort(data, entryOffset + MIN_VALUE_OFFSET);