aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java308
1 files changed, 180 insertions, 128 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
index c08a9d8..2340bca 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
@@ -98,21 +98,21 @@ public class CalculateAverage_abeobk {
return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8);
}
- Node(long a, long t, int kl, long h, long val) {
+ Node(long a, long t, int kl, long h) {
addr = a;
tail = t;
- sum = min = max = val;
- count = 1;
+ min = 999;
+ max = -999;
keylen = kl;
hash = h;
}
- Node(long a, long w0, long t, int kl, long h, long val) {
+ Node(long a, long w0, long t, int kl, long h) {
addr = a;
word0 = w0;
+ min = 999;
+ max = -999;
tail = t;
- sum = min = max = val;
- count = 1;
keylen = kl;
hash = h;
}
@@ -120,9 +120,8 @@ public class CalculateAverage_abeobk {
final void add(long val) {
sum += val;
count++;
- if (val >= max) {
+ if (val > max) {
max = val;
- return;
}
if (val < min) {
min = val;
@@ -170,25 +169,141 @@ public class CalculateAverage_abeobk {
return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
}
+ static final long getLFCode(final long word) {
+ long xor_semi = word ^ 0x0A0A0A0A0A0A0A0AL; // xor with \n\n\n\n\n\n\n\n
+ return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
+ }
+
+ static final long nextLine(long addr) {
+ long word = UNSAFE.getLong(addr);
+ long lfpos_code = getLFCode(word);
+ while (lfpos_code == 0) {
+ addr += 8;
+ word = UNSAFE.getLong(addr);
+ lfpos_code = getLFCode(word);
+ }
+ return addr + (Long.numberOfTrailingZeros(lfpos_code) >>> 3) + 1;
+ }
+
// speed/collision balance
static final long xxh32(long hash) {
long h = hash * 37;
return (h ^ (h >>> 29));
}
- // great idea from merykitty (Quan Anh Mai)
- static final long parseNum(long num_word, int dot_pos) {
- int shift = 28 - dot_pos;
- long signed = (~num_word << 59) >> 63;
- long dsmask = ~(signed & 0xFF);
- long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
- long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- return ((abs_val ^ signed) - signed);
+ static final class ChunkParser {
+ long addr;
+ long end;
+ Node[] map;
+
+ ChunkParser(Node[] m, long a, long e) {
+ map = m;
+ addr = a;
+ end = e;
+ }
+
+ final boolean ok() {
+ return addr < end;
+ }
+
+ final long word() {
+ return UNSAFE.getLong(addr);
+ }
+
+ final long val() {
+ long num_word = UNSAFE.getLong(addr);
+ int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
+ addr += (dot_pos >>> 3) + 3;
+ // great idea from merykitty (Quan Anh Mai)
+ int shift = 28 - dot_pos;
+ long signed = (~num_word << 59) >> 63;
+ long dsmask = ~(signed & 0xFF);
+ long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
+ long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ return ((abs_val ^ signed) - signed);
+ }
+
+ // optimize for contest
+ // save as much slow memory access as possible
+ // about 50% key < 8chars, 25% key bettween 8-10 chars
+ // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
+ final Node key(long word0, long semipos_code) {
+ long row_addr = addr;
+ // about 50% chance key < 8 chars
+ if (semipos_code != 0) {
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos + 1;
+ long tail = word0 & HASH_MASKS[semi_pos];
+ long hash = xxh32(tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+ while (true) {
+ Node node = map[bucket];
+ if (node == null) {
+ return (map[bucket] = new Node(row_addr, tail, semi_pos, hash));
+ }
+ if (node.tail == tail) {
+ return node;
+ }
+ bucket++;
+ }
+ }
+
+ addr += 8;
+ long word = UNSAFE.getLong(addr);
+ semipos_code = getSemiPosCode(word);
+ // 43% chance
+ if (semipos_code != 0) {
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos + 1;
+ long tail = (word & HASH_MASKS[semi_pos]);
+ long hash = xxh32(word0 ^ tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+ while (true) {
+ Node node = map[bucket];
+ if (node == null) {
+ return (map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash));
+ }
+ if (node.word0 == word0 && node.tail == tail) {
+ return node;
+ }
+ bucket++;
+ }
+ }
+
+ // why not going for more? tested, slower
+ long hash = word0;
+ while (semipos_code == 0) {
+ hash ^= word;
+ addr += 8;
+ word = UNSAFE.getLong(addr);
+ semipos_code = getSemiPosCode(word);
+ }
+
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos;
+ long keylen = addr - row_addr;
+ addr++;
+ long tail = (word & HASH_MASKS[semi_pos]);
+ hash = xxh32(hash ^ tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+
+ while (true) {
+ Node node = map[bucket];
+ if (node == null) {
+ return (map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash));
+ }
+ if (node.contentEquals(row_addr, word0, tail, keylen)) {
+ return node;
+ }
+ bucket++;
+ }
+ }
}
// Thread pool worker
static final class Worker extends Thread {
final int thread_id; // for debug use only
+ int cls = 0;
Worker(int i) {
thread_id = i;
@@ -198,9 +313,8 @@ public class CalculateAverage_abeobk {
@Override
public void run() {
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
- int id;
- int cls = 0;
+ int id;
// process in small chunk to maintain disk locality (artsiomkorzun trick)
while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
long addr = start_addr + id * CHUNK_SZ;
@@ -208,119 +322,57 @@ public class CalculateAverage_abeobk {
// find start of line
if (id > 0) {
- while (UNSAFE.getByte(addr++) != '\n')
- ;
+ addr = nextLine(addr);
}
- // parse loop
- // optimize for contest
- // save as much slow memory access as possible
- // about 50% key < 8chars, 25% key bettween 8-10 chars
- // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
- while (addr < end) {
- long row_addr = addr;
-
- long word0 = UNSAFE.getLong(addr);
- long semipos_code = getSemiPosCode(word0);
-
- // about 50% chance key < 8 chars
- if (semipos_code != 0) {
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos + 1;
- long num_word = UNSAFE.getLong(addr);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 3;
-
- long tail = word0 & HASH_MASKS[semi_pos];
- long hash = xxh32(tail);
- int bucket = (int) (hash & BUCKET_MASK);
- long val = parseNum(num_word, dot_pos);
-
- while (true) {
- var node = map[bucket];
- if (node == null) {
- map[bucket] = new Node(row_addr, tail, semi_pos, hash, val);
- break;
- }
- if (node.tail == tail) {
- node.add(val);
- break;
- }
- bucket++;
- if (SHOW_ANALYSIS)
- cls++;
- }
- continue;
- }
-
- addr += 8;
- long word = UNSAFE.getLong(addr);
- semipos_code = getSemiPosCode(word);
- // 43% chance
- if (semipos_code != 0) {
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos + 1;
- long num_word = UNSAFE.getLong(addr);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 3;
-
- long tail = (word & HASH_MASKS[semi_pos]);
- long hash = xxh32(word0 ^ tail);
- int bucket = (int) (hash & BUCKET_MASK);
- long val = parseNum(num_word, dot_pos);
-
- while (true) {
- var node = map[bucket];
- if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash, val);
- break;
- }
- if (node.word0 == word0 && node.tail == tail) {
- node.add(val);
- break;
- }
- bucket++;
- if (SHOW_ANALYSIS)
- cls++;
- }
- continue;
- }
-
- // why not going for more? tested, slower
- long hash = word0;
- while (semipos_code == 0) {
- hash ^= word;
- addr += 8;
- word = UNSAFE.getLong(addr);
- semipos_code = getSemiPosCode(word);
- }
+ final int num_segs = 3;
+ long seglen = (end - addr) / num_segs;
+
+ long a0 = addr;
+ long a1 = nextLine(addr + 1 * seglen);
+ long a2 = nextLine(addr + 2 * seglen);
+ ChunkParser p0 = new ChunkParser(map, a0, a1);
+ ChunkParser p1 = new ChunkParser(map, a1, a2);
+ ChunkParser p2 = new ChunkParser(map, a2, end);
+
+ while (p0.ok() && p1.ok() && p2.ok()) {
+ long w0 = p0.word();
+ long w1 = p1.word();
+ long w2 = p2.word();
+ long sc0 = getSemiPosCode(w0);
+ long sc1 = getSemiPosCode(w1);
+ long sc2 = getSemiPosCode(w2);
+ Node n0 = p0.key(w0, sc0);
+ Node n1 = p1.key(w1, sc1);
+ Node n2 = p2.key(w2, sc2);
+ long v0 = p0.val();
+ long v1 = p1.val();
+ long v2 = p2.val();
+ n0.add(v0);
+ n1.add(v1);
+ n2.add(v2);
+ }
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos;
- long keylen = addr - row_addr;
- long num_word = UNSAFE.getLong(addr + 1);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 4;
-
- long tail = (word & HASH_MASKS[semi_pos]);
- hash = xxh32(hash ^ tail);
- int bucket = (int) (hash & BUCKET_MASK);
- long val = parseNum(num_word, dot_pos);
-
- while (true) {
- var node = map[bucket];
- if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash, val);
- break;
- }
- if (node.contentEquals(row_addr, word0, tail, keylen)) {
- node.add(val);
- break;
- }
- bucket++;
- if (SHOW_ANALYSIS)
- cls++;
- }
+ while (p0.ok()) {
+ long w = p0.word();
+ long sc = getSemiPosCode(w);
+ Node n = p0.key(w, sc);
+ long v = p0.val();
+ n.add(v);
+ }
+ while (p1.ok()) {
+ long w = p1.word();
+ long sc = getSemiPosCode(w);
+ Node n = p1.key(w, sc);
+ long v = p1.val();
+ n.add(v);
+ }
+ while (p2.ok()) {
+ long w = p2.word();
+ long sc = getSemiPosCode(w);
+ Node n = p2.key(w, sc);
+ long v = p2.val();
+ n.add(v);
}
}