aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java257
1 files changed, 179 insertions, 78 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
index 34a5552..ec6c9e5 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
@@ -24,11 +24,12 @@ import java.nio.channels.FileChannel.MapMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
+import java.util.Arrays;
import java.util.TreeMap;
import sun.misc.Unsafe;
public class CalculateAverage_abeobk {
- private static final boolean SHOW_COLLISIONS = false;
+ private static final boolean SHOW_ANALYSIS = false;
private static final String FILE = "./measurements.txt";
private static final int BUCKET_SIZE = 1 << 16;
@@ -99,13 +100,13 @@ public class CalculateAverage_abeobk {
boolean contentEquals(long other_addr, long other_tail) {
if (tail != other_tail) // compare tail & length at the same time
return false;
- long my_addr = addr;
- int nl = (int) (tail >> 59);
- for (int i = 0; i < nl; i++, my_addr += 8, other_addr += 8) {
- if (UNSAFE.getLong(my_addr) != UNSAFE.getLong(other_addr))
- return false;
+ // this is faster than comparision if key is short
+ long xsum = 0;
+ int n = ((int) (tail >>> 56)) & 0xF8;
+ for (int i = 0; i < n; i += 8) {
+ xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
}
- return true;
+ return xsum == 0;
}
}
@@ -123,6 +124,7 @@ public class CalculateAverage_abeobk {
return ptrs;
}
+ // idea from royvanrijn
static final long getSemiPosCode(final long word) {
long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
@@ -133,17 +135,164 @@ public class CalculateAverage_abeobk {
// zero collision on test data
static final int xxh32(long hash) {
final int p1 = 0x85EBCA77; // prime
- final int p2 = 0xC2B2AE3D; // prime
+ final int p2 = 0x165667B1; // prime
int low = (int) hash;
- int high = (int) (hash >>> 32);
- low ^= low >> 15;
- low *= p1;
- high ^= high >> 13;
- high *= p2;
- var h = low ^ high;
+ int high = (int) (hash >>> 31);
+ int h = low + high;
+ h ^= h >> 15;
+ h *= p1;
+ h ^= h >> 13;
+ h *= p2;
+ h ^= h >> 11;
return h;
}
+ // great idea from merykitty (Quan Anh Mai)
+ static final int parseNum(long num_word, int dot_pos) {
+ int shift = 28 - dot_pos;
+ long signed = (~num_word << 59) >> 63;
+ long dsmask = ~(signed & 0xFF);
+ long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
+ long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ return (int) ((abs_val ^ signed) - signed);
+ }
+
+ // optimize for contest
+ // save as much slow memory access as possible
+ // about 50% key < 8chars, 25% key bettween 8-10 chars
+ // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
+ static final Node[] parse(int thread_id, long start, long end, int[] cls) {
+ long addr = start;
+ var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
+ // parse loop
+ while (addr < end) {
+ long row_addr = addr;
+ long tail = 0;
+ long hash = 0;
+ int val = 0;
+ int bucket = 0;
+
+ long word = UNSAFE.getLong(addr);
+ long semipos_code = getSemiPosCode(word);
+
+ // about 50% chance key < 8 chars
+ if (semipos_code != 0) {
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos;
+ tail = (word & HASH_MASKS[semi_pos]);
+ bucket = xxh32(tail) & BUCKET_MASK;
+ long keylen = (addr - row_addr);
+ tail |= (keylen << 56);
+ long num_word = UNSAFE.getLong(++addr);
+ int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
+ val = parseNum(num_word, dot_pos);
+ addr += (dot_pos >>> 3) + 3;
+
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = new Node(row_addr, tail, val);
+ break;
+ }
+ if (node.tail == tail) {
+ node.add(val);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls[thread_id]++;
+ }
+ continue;
+ }
+
+ hash ^= word;
+ addr += 8;
+ word = UNSAFE.getLong(addr);
+ semipos_code = getSemiPosCode(word);
+ // frist byte semicolon ~13%
+ if (semipos_code == 0x80) {
+ bucket = xxh32(hash) & BUCKET_MASK;
+ tail = 8L << 56;
+ long num_word = word >>> 8;
+ int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
+ val = parseNum(num_word, dot_pos);
+ addr += (dot_pos >>> 3) + 4;
+
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = new Node(row_addr, tail, val);
+ break;
+ }
+ if (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr)) {
+ node.add(val);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls[thread_id]++;
+ }
+ continue;
+ }
+
+ while (semipos_code == 0) {
+ hash ^= word;
+ addr += 8;
+ word = UNSAFE.getLong(addr);
+ semipos_code = getSemiPosCode(word);
+ }
+
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos;
+ tail = (word & HASH_MASKS[semi_pos]);
+ hash ^= tail;
+ bucket = xxh32(hash) & BUCKET_MASK;
+ long keylen = (addr - row_addr);
+ tail |= (keylen << 56);
+
+ ++addr;
+ long num_word = UNSAFE.getLong(addr);
+ int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
+ val = parseNum(num_word, dot_pos);
+ addr += (dot_pos >>> 3) + 3;
+
+ if (keylen < 16) {
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = new Node(row_addr, tail, val);
+ break;
+ }
+ if (node.tail == tail && (UNSAFE.getLong(node.addr) == UNSAFE.getLong(row_addr))) {
+ node.add(val);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls[thread_id]++;
+ }
+ continue;
+ }
+
+ // longer key
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = new Node(row_addr, tail, val);
+ break;
+ }
+ if (node.contentEquals(row_addr, tail)) {
+ node.add(val);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls[thread_id]++;
+ }
+ }
+ return map;
+ }
+
public static void main(String[] args) throws InterruptedException, IOException {
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
@@ -158,71 +307,14 @@ public class CalculateAverage_abeobk {
var threads = new Thread[cpu_cnt];
var maps = new Node[cpu_cnt][];
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
- int[] cls = new int[cpu_cnt];
+
+ int[] cls = new int[cpu_cnt]; // collision
+ int[] lenhist = new int[64]; // length histogram
for (int i = 0; i < cpu_cnt; i++) {
int thread_id = i;
- long start = ptrs[i];
- long end = ptrs[i + 1];
- maps[i] = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
-
- (threads[i] = new Thread(() -> {
- long addr = start;
- var map = maps[thread_id];
- // parse loop
- while (addr < end) {
- long hash = 0;
- long word = 0;
- long row_addr = addr;
- int semi_pos = 8;
- word = UNSAFE.getLong(addr);
- long semipos_code = getSemiPosCode(word);
-
- while (semipos_code == 0) {
- hash ^= word;
- addr += 8;
- word = UNSAFE.getLong(addr);
- semipos_code = getSemiPosCode(word);
- }
-
- semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- long tail = word & HASH_MASKS[semi_pos];
- hash ^= tail;
- addr += semi_pos;
-
- int hash32 = xxh32(hash);
- long keylen = (addr - row_addr);
- tail = tail | (keylen << 56);
-
- addr++;
-
- // great idea from merykitty (Quan Anh Mai)
- long num_word = UNSAFE.getLong(addr);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 3;
- int shift = 28 - dot_pos;
- long signed = (~num_word << 59) >> 63;
- long dsmask = ~(signed & 0xFF);
- long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
- long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- int val = (int) ((abs_val ^ signed) - signed);
-
- int bucket = (hash32 & BUCKET_MASK);
- while (true) {
- var node = map[bucket];
- if (node == null) {
- map[bucket] = new Node(row_addr, tail, val);
- break;
- }
- if (node.contentEquals(row_addr, tail)) {
- node.add(val);
- break;
- }
- bucket++;
- if (SHOW_COLLISIONS)
- cls[thread_id]++;
- }
- }
+ (threads[thread_id] = new Thread(() -> {
+ maps[thread_id] = parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1], cls);
})).start();
}
@@ -230,7 +322,7 @@ public class CalculateAverage_abeobk {
for (var thread : threads)
thread.join();
- if (SHOW_COLLISIONS) {
+ if (SHOW_ANALYSIS) {
for (int i = 0; i < cpu_cnt; i++) {
System.out.println("thread-" + i + " collision = " + cls[i]);
}
@@ -242,13 +334,22 @@ public class CalculateAverage_abeobk {
for (var node : map) {
if (node == null)
continue;
+ if (SHOW_ANALYSIS) {
+ int kl = (int) (node.tail >>> 56) & (lenhist.length - 1);
+ lenhist[kl] += node.count;
+ }
var stat = ms.putIfAbsent(node.key(), node);
if (stat != null)
stat.merge(node);
}
}
- if (!SHOW_COLLISIONS)
+ if (SHOW_ANALYSIS) {
+ System.out.println("total=" + Arrays.stream(lenhist).sum());
+ System.out.println("length_histogram = "
+ + Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
+ }
+ else
System.out.println(ms);
}
}