aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java313
1 files changed, 138 insertions, 175 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
index c9b7144..4f6c8fd 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
@@ -35,7 +35,7 @@ public class CalculateAverage_artsiomkorzun {
private static final MemorySegment MAPPED_FILE = map(FILE);
private static final int PARALLELISM = Runtime.getRuntime().availableProcessors();
- private static final int SEGMENT_SIZE = 16 * 1024 * 1024;
+ private static final int SEGMENT_SIZE = 32 * 1024 * 1024;
private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
private static final int SEGMENT_OVERLAP = 1024;
private static final long COMMA_PATTERN = pattern(';');
@@ -100,16 +100,6 @@ public class CalculateAverage_artsiomkorzun {
return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56);
}
- private static long getLongBigEndian(long address) {
- long value = UNSAFE.getLong(address);
-
- if (BYTE_ORDER == ByteOrder.LITTLE_ENDIAN) {
- value = Long.reverseBytes(value);
- }
-
- return value;
- }
-
private static long getLongLittleEndian(long address) {
long value = UNSAFE.getLong(address);
@@ -144,98 +134,80 @@ public class CalculateAverage_artsiomkorzun {
return Math.round(v) / 10.0;
}
- private static class Row {
- long address;
- int length;
- int hash;
- int value;
- }
-
private record Aggregate(int min, int max, long sum, int cnt) {
}
private static class Aggregates {
- private static final int SIZE = 16 * 1024;
+ private static final int ENTRIES = 64 * 1024;
+ private static final int SIZE = 32 * ENTRIES;
+
private final long pointer;
public Aggregates() {
- int size = 32 * SIZE;
- long address = UNSAFE.allocateMemory(size + 8096);
+ long address = UNSAFE.allocateMemory(SIZE + 8096);
pointer = (address + 4095) & (~4095);
- UNSAFE.setMemory(pointer, size, (byte) 0);
-
- long word = pack(Short.MAX_VALUE, Short.MIN_VALUE, 0);
- for (int i = 0; i < SIZE; i++) {
- long entry = pointer + 32 * i;
- UNSAFE.putLong(entry + 24, word);
- }
+ UNSAFE.setMemory(pointer, SIZE, (byte) 0);
}
- public void add(Row row) {
- long index = index(row.hash);
- long header = ((long) row.hash << 32) | (row.length);
-
- while (true) {
- long address = pointer + (index << 5);
- long head = UNSAFE.getLong(address);
- long ref = UNSAFE.getLong(address + 8);
- boolean isHit = (head == 0) || (head == header && equal(ref, row.address, row.length));
-
- if (isHit) {
- long sum = UNSAFE.getLong(address + 16) + row.value;
- long word = UNSAFE.getLong(address + 24);
- int min = Math.min(min(word), row.value);
- int max = Math.max(max(word), row.value);
- int cnt = cnt(word) + 1;
-
- UNSAFE.putLong(address, header);
- UNSAFE.putLong(address + 8, row.address);
- UNSAFE.putLong(address + 16, sum);
- UNSAFE.putLong(address + 24, pack(min, max, cnt));
+ public void add(long reference, int length, int hash, int value) {
+ for (int offset = offset(hash);; offset = next(offset)) {
+ long address = pointer + offset;
+ long ref = UNSAFE.getLong(address);
+
+ if (ref == 0) {
+ alloc(reference, length, hash, value, address);
break;
}
- index = (index + 1) & (SIZE - 1);
+ if (equal(ref, reference, length)) {
+ long sum = UNSAFE.getLong(address + 16) + value;
+ int cnt = UNSAFE.getInt(address + 24) + 1;
+ short min = (short) Math.min(UNSAFE.getShort(address + 28), value);
+ short max = (short) Math.max(UNSAFE.getShort(address + 30), value);
+
+ UNSAFE.putLong(address + 16, sum);
+ UNSAFE.putInt(address + 24, cnt);
+ UNSAFE.putShort(address + 28, min);
+ UNSAFE.putShort(address + 30, max);
+ break;
+ }
}
}
public void merge(Aggregates rights) {
- for (int rightIndex = 0; rightIndex < SIZE; rightIndex++) {
- long rightAddress = rights.pointer + (rightIndex << 5);
- long header = UNSAFE.getLong(rightAddress);
- long reference = UNSAFE.getLong(rightAddress + 8);
+ for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) {
+ long rightAddress = rights.pointer + rightOffset;
+ long reference = UNSAFE.getLong(rightAddress);
- if (header == 0) {
+ if (reference == 0) {
continue;
}
- int hash = (int) (header >>> 32);
- int length = (int) (header);
- long index = index(hash);
+ int hash = UNSAFE.getInt(rightAddress + 8);
+ int length = UNSAFE.getInt(rightAddress + 12);
- while (true) {
- long address = pointer + (index << 5);
- long head = UNSAFE.getLong(address);
- long ref = UNSAFE.getLong(address + 8);
- boolean isHit = (head == 0) || (head == header && equal(ref, reference, length));
+ for (int offset = offset(hash);; offset = next(offset)) {
+ long address = pointer + offset;
+ long ref = UNSAFE.getLong(address);
- if (isHit) {
+ if (ref == 0) {
+ UNSAFE.copyMemory(rightAddress, address, 32);
+ break;
+ }
+
+ if (equal(ref, reference, length)) {
long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16);
- long left = UNSAFE.getLong(address + 24);
- long right = UNSAFE.getLong(rightAddress + 24);
- int min = Math.min(min(left), min(right));
- int max = Math.max(max(left), max(right));
- int cnt = cnt(left) + cnt(right);
-
- UNSAFE.putLong(address, header);
- UNSAFE.putLong(address + 8, reference);
+ int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24);
+ short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28));
+ short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30));
+
UNSAFE.putLong(address + 16, sum);
- UNSAFE.putLong(address + 24, pack(min, max, cnt));
+ UNSAFE.putInt(address + 24, cnt);
+ UNSAFE.putShort(address + 28, min);
+ UNSAFE.putShort(address + 30, max);
break;
}
-
- index = (index + 1) & (SIZE - 1);
}
}
}
@@ -243,68 +215,64 @@ public class CalculateAverage_artsiomkorzun {
public Map<String, Aggregate> aggregate() {
TreeMap<String, Aggregate> set = new TreeMap<>();
- for (int index = 0; index < SIZE; index++) {
- long address = pointer + (index << 5);
- long head = UNSAFE.getLong(address);
- long ref = UNSAFE.getLong(address + 8);
-
- if (head == 0) {
- continue;
- }
+ for (int offset = 0; offset < SIZE; offset += 32) {
+ long address = pointer + offset;
+ long ref = UNSAFE.getLong(address);
- int length = (int) (head);
- byte[] array = new byte[length];
- UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
- String key = new String(array);
+ if (ref != 0) {
+ int length = UNSAFE.getInt(address + 12) - 1;
+ byte[] array = new byte[length];
+ UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
+ String key = new String(array);
- long sum = UNSAFE.getLong(address + 16);
- long word = UNSAFE.getLong(address + 24);
+ long sum = UNSAFE.getLong(address + 16);
+ int cnt = UNSAFE.getInt(address + 24);
+ short min = UNSAFE.getShort(address + 28);
+ short max = UNSAFE.getShort(address + 30);
- Aggregate aggregate = new Aggregate(min(word), max(word), sum, cnt(word));
- set.put(key, aggregate);
+ Aggregate aggregate = new Aggregate(min, max, sum, cnt);
+ set.put(key, aggregate);
+ }
}
return set;
}
- private static long pack(int min, int max, int cnt) {
- return ((long) min << 48) | (((long) max & 0xFFFF) << 32) | cnt;
- }
-
- private static int cnt(long word) {
- return (int) word;
- }
-
- private static int max(long word) {
- return (short) (word >>> 32);
+ private static void alloc(long reference, int length, int hash, int value, long address) {
+ UNSAFE.putLong(address, reference);
+ UNSAFE.putInt(address + 8, hash);
+ UNSAFE.putInt(address + 12, length);
+ UNSAFE.putLong(address + 16, value);
+ UNSAFE.putInt(address + 24, 1);
+ UNSAFE.putShort(address + 28, (short) value);
+ UNSAFE.putShort(address + 30, (short) value);
}
- private static int min(long word) {
- return (short) (word >>> 48);
+ private static int offset(int hash) {
+ return ((hash) & (ENTRIES - 1)) << 5;
}
- private static long index(int hash) {
- return (hash ^ (hash >> 16)) & (SIZE - 1);
+ private static int next(int prev) {
+ return (prev + 32) & (SIZE - 1);
}
private static boolean equal(long leftAddress, long rightAddress, int length) {
- int index = 0;
-
while (length > 8) {
- long left = UNSAFE.getLong(leftAddress + index);
- long right = UNSAFE.getLong(rightAddress + index);
+ long left = UNSAFE.getLong(leftAddress);
+ long right = UNSAFE.getLong(rightAddress);
if (left != right) {
return false;
}
+ leftAddress += 8;
+ rightAddress += 8;
length -= 8;
- index += 8;
}
- int shift = 64 - (length << 3);
- long left = getLongBigEndian(leftAddress + index) >>> shift;
- long right = getLongBigEndian(rightAddress + index) >>> shift;
+ int shift = (8 - length) << 3;
+ long left = getLongLittleEndian(leftAddress) << shift;
+ long right = getLongLittleEndian(rightAddress) << shift;
return (left == right);
}
}
@@ -323,10 +291,18 @@ public class CalculateAverage_artsiomkorzun {
@Override
public void run() {
Aggregates aggregates = new Aggregates();
- Row row = new Row();
for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) {
- aggregate(aggregates, row, segment);
+ long position = (long) SEGMENT_SIZE * segment;
+ int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position);
+ long address = MAPPED_FILE.address() + position;
+ long limit = address + Math.min(SEGMENT_SIZE, size - 1);
+
+ if (segment > 0) {
+ address = next(address);
+ }
+
+ aggregate(aggregates, address, limit);
}
while (!result.compareAndSet(null, aggregates)) {
@@ -338,75 +314,62 @@ public class CalculateAverage_artsiomkorzun {
}
}
- private static void aggregate(Aggregates aggregates, Row row, int segment) {
- long position = (long) SEGMENT_SIZE * segment;
- int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position);
- long address = MAPPED_FILE.address() + position;
- long limit = address + Math.min(SEGMENT_SIZE, size - 1);
-
- if (segment > 0) {
- address = next(address);
- }
-
- while (address <= limit) {
- // this parsing can produce seg fault at page boundaries
- // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes
- // as a result a read will be split across pages, where one of them is not mapped
- // but for some reason it works on my machine, leaving to investigate
- address = parseKey(address, row);
- address = parseValue(address, row);
- aggregates.add(row);
- }
- }
+ private static void aggregate(Aggregates aggregates, long position, long limit) {
+ // this parsing can produce seg fault at page boundaries
+ // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes
+ // as a result a read will be split across pages, where one of them is not mapped
+ // but for some reason it works on my machine, leaving to investigate
+
+ for (long start = position, hash = 0; position <= limit;) {
+ int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes
+ {
+ long word = getLongLittleEndian(position);
+ long match = word ^ COMMA_PATTERN;
+ long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
+
+ if (mask == 0) {
+ hash ^= word;
+ position += 8;
+ continue;
+ }
- private static long next(long address) {
- while (UNSAFE.getByte(address++) != '\n') {
- // continue
- }
- return address;
- }
+ int bit = Long.numberOfTrailingZeros(mask);
+ position += (bit >>> 3) + 1; // +sep
+ hash ^= (word << (69 - bit));
+ length = (int) (position - start);
+ }
- // idea: royvanrijn
- // explanation: https://richardstartin.github.io/posts/finding-bytes
- private static long parseKey(long address, Row row) {
- int length = 0;
- long hash = 0;
- long word;
-
- while (true) {
- word = getLongLittleEndian(address + length);
- long match = word ^ COMMA_PATTERN;
- long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L;
-
- if (mask == 0) {
- hash = 71 * hash + word;
- length += 8;
- continue;
+ int value; // idea: merykitty
+ {
+ long word = getLongLittleEndian(position);
+ long inverted = ~word;
+ int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
+ long signed = (inverted << 59) >> 63;
+ long mask = ~(signed & 0xFF);
+ long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
+ long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
+ value = (int) ((abs ^ signed) - signed);
+ position += (dot >> 3) + 3;
}
- int bit = Long.numberOfTrailingZeros(mask);
- length += (bit >>> 3);
- hash = 71 * hash + (word & (0x00FFFFFFFFFFFFFFL >>> (63 - bit)));
+ aggregates.add(start, length, mix(hash), value);
- row.address = address;
- row.length = length;
- row.hash = Long.hashCode(hash);
+ start = position;
+ hash = 0;
+ }
+ }
- return address + length + 1;
+ private static long next(long position) {
+ while (UNSAFE.getByte(position++) != '\n') {
+ // continue
}
+ return position;
}
- // idea: merykitty
- private static long parseValue(long address, Row row) {
- long word = getLongLittleEndian(address);
- long inverted = ~word;
- int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
- long signed = (inverted << 59) >> 63;
- long mask = ~(signed & 0xFF);
- long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
- long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
- row.value = (int) ((abs ^ signed) - signed);
- return address + (dot >> 3) + 3;
+ private static int mix(long x) {
+ long h = x * -7046029254386353131L;
+ h ^= h >>> 32;
+ return (int) (h ^ h >>> 16);
}
}
}