aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java283
1 files changed, 176 insertions, 107 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
index 4f6c8fd..f92f414 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
@@ -20,7 +20,6 @@ import sun.misc.Unsafe;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.reflect.Field;
-import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
@@ -38,11 +37,10 @@ public class CalculateAverage_artsiomkorzun {
private static final int SEGMENT_SIZE = 32 * 1024 * 1024;
private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
private static final int SEGMENT_OVERLAP = 1024;
- private static final long COMMA_PATTERN = pattern(';');
+ private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL;
private static final long DOT_BITS = 0x10101000;
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
- private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder();
private static final Unsafe UNSAFE;
static {
@@ -95,19 +93,15 @@ public class CalculateAverage_artsiomkorzun {
}
}
- private static long pattern(char c) {
- long b = c & 0xFFL;
- return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56);
- }
-
- private static long getLongLittleEndian(long address) {
- long value = UNSAFE.getLong(address);
-
- if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
- value = Long.reverseBytes(value);
- }
-
- return value;
+ private static long word(long address) {
+ return UNSAFE.getLong(address);
+ /*
+ * if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
+ * value = Long.reverseBytes(value);
+ * }
+ *
+ * return value;
+ */
}
private static String text(Map<String, Aggregate> aggregates) {
@@ -140,7 +134,7 @@ public class CalculateAverage_artsiomkorzun {
private static class Aggregates {
private static final int ENTRIES = 64 * 1024;
- private static final int SIZE = 32 * ENTRIES;
+ private static final int SIZE = 128 * ENTRIES;
private final long pointer;
@@ -150,62 +144,82 @@ public class CalculateAverage_artsiomkorzun {
UNSAFE.setMemory(pointer, SIZE, (byte) 0);
}
- public void add(long reference, int length, int hash, int value) {
+ public long find(long word, int hash) {
+ long address = pointer + offset(hash);
+ long w = word(address + 24);
+ return (w == word) ? address : 0;
+ }
+
+ public long find(long word1, long word2, int hash) {
+ long address = pointer + offset(hash);
+ long w1 = word(address + 24);
+ long w2 = word(address + 32);
+ return (word1 == w1) && (word2 == w2) ? address : 0;
+ }
+
+ public long put(long reference, long word, int length, int hash) {
for (int offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset;
- long ref = UNSAFE.getLong(address);
-
- if (ref == 0) {
- alloc(reference, length, hash, value, address);
- break;
+ if (equal(reference, word, address + 24, length)) {
+ return address;
}
- if (equal(ref, reference, length)) {
- long sum = UNSAFE.getLong(address + 16) + value;
- int cnt = UNSAFE.getInt(address + 24) + 1;
- short min = (short) Math.min(UNSAFE.getShort(address + 28), value);
- short max = (short) Math.max(UNSAFE.getShort(address + 30), value);
-
- UNSAFE.putLong(address + 16, sum);
- UNSAFE.putInt(address + 24, cnt);
- UNSAFE.putShort(address + 28, min);
- UNSAFE.putShort(address + 30, max);
- break;
+ int len = UNSAFE.getInt(address);
+ if (len == 0) {
+ alloc(reference, length, hash, address);
+ return address;
}
}
}
+ public static void update(long address, int value) {
+ long sum = UNSAFE.getLong(address + 8) + value;
+ int cnt = UNSAFE.getInt(address + 16) + 1;
+ short min = UNSAFE.getShort(address + 20);
+ short max = UNSAFE.getShort(address + 22);
+
+ UNSAFE.putLong(address + 8, sum);
+ UNSAFE.putInt(address + 16, cnt);
+
+ if (value < min) {
+ UNSAFE.putShort(address + 20, (short) value);
+ }
+
+ if (value > max) {
+ UNSAFE.putShort(address + 22, (short) value);
+ }
+ }
+
public void merge(Aggregates rights) {
- for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 32) {
+ for (int rightOffset = 0; rightOffset < SIZE; rightOffset += 128) {
long rightAddress = rights.pointer + rightOffset;
- long reference = UNSAFE.getLong(rightAddress);
+ int length = UNSAFE.getInt(rightAddress);
- if (reference == 0) {
+ if (length == 0) {
continue;
}
- int hash = UNSAFE.getInt(rightAddress + 8);
- int length = UNSAFE.getInt(rightAddress + 12);
+ int hash = UNSAFE.getInt(rightAddress + 4);
for (int offset = offset(hash);; offset = next(offset)) {
long address = pointer + offset;
- long ref = UNSAFE.getLong(address);
+ int len = UNSAFE.getInt(address);
- if (ref == 0) {
- UNSAFE.copyMemory(rightAddress, address, 32);
+ if (len == 0) {
+ UNSAFE.copyMemory(rightAddress, address, 24 + length);
break;
}
- if (equal(ref, reference, length)) {
- long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16);
- int cnt = UNSAFE.getInt(address + 24) + UNSAFE.getInt(rightAddress + 24);
- short min = (short) Math.min(UNSAFE.getShort(address + 28), UNSAFE.getShort(rightAddress + 28));
- short max = (short) Math.max(UNSAFE.getShort(address + 30), UNSAFE.getShort(rightAddress + 30));
+ if (len == length && equal(address + 24, rightAddress + 24, length)) {
+ long sum = UNSAFE.getLong(address + 8) + UNSAFE.getLong(rightAddress + 8);
+ int cnt = UNSAFE.getInt(address + 16) + UNSAFE.getInt(rightAddress + 16);
+ short min = (short) Math.min(UNSAFE.getShort(address + 20), UNSAFE.getShort(rightAddress + 20));
+ short max = (short) Math.max(UNSAFE.getShort(address + 22), UNSAFE.getShort(rightAddress + 22));
- UNSAFE.putLong(address + 16, sum);
- UNSAFE.putInt(address + 24, cnt);
- UNSAFE.putShort(address + 28, min);
- UNSAFE.putShort(address + 30, max);
+ UNSAFE.putLong(address + 8, sum);
+ UNSAFE.putInt(address + 16, cnt);
+ UNSAFE.putShort(address + 20, min);
+ UNSAFE.putShort(address + 22, max);
break;
}
}
@@ -215,20 +229,19 @@ public class CalculateAverage_artsiomkorzun {
public Map<String, Aggregate> aggregate() {
TreeMap<String, Aggregate> set = new TreeMap<>();
- for (int offset = 0; offset < SIZE; offset += 32) {
+ for (int offset = 0; offset < SIZE; offset += 128) {
long address = pointer + offset;
- long ref = UNSAFE.getLong(address);
+ int length = UNSAFE.getInt(address);
- if (ref != 0) {
- int length = UNSAFE.getInt(address + 12) - 1;
+ if (length != 0) {
byte[] array = new byte[length];
- UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
+ UNSAFE.copyMemory(null, address + 24, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
String key = new String(array);
- long sum = UNSAFE.getLong(address + 16);
- int cnt = UNSAFE.getInt(address + 24);
- short min = UNSAFE.getShort(address + 28);
- short max = UNSAFE.getShort(address + 30);
+ long sum = UNSAFE.getLong(address + 8);
+ int cnt = UNSAFE.getInt(address + 16);
+ short min = UNSAFE.getShort(address + 20);
+ short max = UNSAFE.getShort(address + 22);
Aggregate aggregate = new Aggregate(min, max, sum, cnt);
set.put(key, aggregate);
@@ -238,26 +251,24 @@ public class CalculateAverage_artsiomkorzun {
return set;
}
- private static void alloc(long reference, int length, int hash, int value, long address) {
- UNSAFE.putLong(address, reference);
- UNSAFE.putInt(address + 8, hash);
- UNSAFE.putInt(address + 12, length);
- UNSAFE.putLong(address + 16, value);
- UNSAFE.putInt(address + 24, 1);
- UNSAFE.putShort(address + 28, (short) value);
- UNSAFE.putShort(address + 30, (short) value);
+ private static void alloc(long reference, int length, int hash, long address) {
+ UNSAFE.putInt(address, length);
+ UNSAFE.putInt(address + 4, hash);
+ UNSAFE.putShort(address + 20, Short.MAX_VALUE);
+ UNSAFE.putShort(address + 22, Short.MIN_VALUE);
+ UNSAFE.copyMemory(reference, address + 24, length);
}
private static int offset(int hash) {
- return ((hash) & (ENTRIES - 1)) << 5;
+ return ((hash) & (ENTRIES - 1)) << 7;
}
private static int next(int prev) {
- return (prev + 32) & (SIZE - 1);
+ return (prev + 128) & (SIZE - 1);
}
- private static boolean equal(long leftAddress, long rightAddress, int length) {
- while (length > 8) {
+ private static boolean equal(long leftAddress, long leftWord, long rightAddress, int length) {
+ while (length >= 8) {
long left = UNSAFE.getLong(leftAddress);
long right = UNSAFE.getLong(rightAddress);
@@ -270,10 +281,24 @@ public class CalculateAverage_artsiomkorzun {
length -= 8;
}
- int shift = (8 - length) << 3;
- long left = getLongLittleEndian(leftAddress) << shift;
- long right = getLongLittleEndian(rightAddress) << shift;
- return (left == right);
+ return leftWord == word(rightAddress);
+ }
+
+ private static boolean equal(long leftAddress, long rightAddress, int length) {
+ do {
+ long left = UNSAFE.getLong(leftAddress);
+ long right = UNSAFE.getLong(rightAddress);
+
+ if (left != right) {
+ return false;
+ }
+
+ leftAddress += 8;
+ rightAddress += 8;
+ length -= 8;
+ } while (length > 0);
+
+ return true;
}
}
@@ -320,45 +345,89 @@ public class CalculateAverage_artsiomkorzun {
// as a result a read will be split across pages, where one of them is not mapped
// but for some reason it works on my machine, leaving to investigate
- for (long start = position, hash = 0; position <= limit;) {
- int length; // idea: royvanrijn, explanation: https://richardstartin.github.io/posts/finding-bytes
- {
- long word = getLongLittleEndian(position);
- long match = word ^ COMMA_PATTERN;
- long mask = (match - 0x0101010101010101L) & ~match & 0x8080808080808080L;
-
- if (mask == 0) {
- hash ^= word;
- position += 8;
- continue;
- }
+ while (position <= limit) { // branchy version, credit: thomaswue
+ int length;
+ int hash;
- int bit = Long.numberOfTrailingZeros(mask);
- position += (bit >>> 3) + 1; // +sep
- hash ^= (word << (69 - bit));
- length = (int) (position - start);
- }
+ long ptr = 0;
+ long word = word(position);
+ long separator = separator(word);
- int value; // idea: merykitty
- {
- long word = getLongLittleEndian(position);
- long inverted = ~word;
- int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
- long signed = (inverted << 59) >> 63;
- long mask = ~(signed & 0xFF);
- long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
- long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
- value = (int) ((abs ^ signed) - signed);
- position += (dot >> 3) + 3;
+ if (separator != 0) {
+ length = length(separator);
+ word = mask(word, separator);
+ hash = mix(word);
+ ptr = aggregates.find(word, hash);
+ }
+ else {
+ long word0 = word;
+ word = word(position + 8);
+ separator = separator(word);
+
+ if (separator != 0) {
+ length = length(separator) + 8;
+ word = mask(word, separator);
+ hash = mix(word ^ word0);
+ ptr = aggregates.find(word0, word, hash);
+ }
+ else {
+ length = 16;
+ long h = word ^ word0;
+
+ while (true) {
+ word = word(position + length);
+ separator = separator(word);
+
+ if (separator == 0) {
+ length += 8;
+ h ^= word;
+ continue;
+ }
+
+ length += length(separator);
+ word = mask(word, separator);
+ hash = mix(h ^ word);
+ break;
+ }
+ }
}
- aggregates.add(start, length, mix(hash), value);
+ if (ptr == 0) {
+ ptr = aggregates.put(position, word, length, hash);
+ }
- start = position;
- hash = 0;
+ position = update(ptr, position + length + 1);
}
}
+ private static long update(long ptr, long position) {
+ // idea: merykitty
+ long word = word(position);
+ long inverted = ~word;
+ int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
+ long signed = (inverted << 59) >> 63;
+ long mask = ~(signed & 0xFF);
+ long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
+ long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
+ int value = (int) ((abs ^ signed) - signed);
+
+ Aggregates.update(ptr, value);
+ return position + (dot >> 3) + 3;
+ }
+
+ private static long separator(long word) {
+ long match = word ^ COMMA_PATTERN;
+ return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
+ }
+
+ private static long mask(long word, long separator) {
+ return word & ((separator >>> 7) - 1) & 0x00FFFFFFFFFFFFFFL;
+ }
+
+ private static int length(long separator) {
+ return Long.numberOfTrailingZeros(separator) >>> 3;
+ }
+
private static long next(long position) {
while (UNSAFE.getByte(position++) != '\n') {
// continue