aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java179
1 files changed, 99 insertions, 80 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
index 06cbc17..c08a9d8 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
@@ -26,9 +26,9 @@ import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.List;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import sun.misc.Unsafe;
@@ -39,7 +39,7 @@ public class CalculateAverage_abeobk {
private static final String FILE = "./measurements.txt";
private static final int BUCKET_SIZE = 1 << 16;
- private static final int BUCKET_MASK = BUCKET_SIZE - 1;
+ private static final long BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000;
private static final long CHUNK_SZ = 1 << 22; // 4MB chunk
@@ -56,9 +56,9 @@ public class CalculateAverage_abeobk {
0xffffffffffffffffL, };
private static AtomicInteger chunk_id = new AtomicInteger(0);
+ private static AtomicReference<Node[]> mapref = new AtomicReference<>(null);
private static int chunk_cnt;
private static long start_addr, end_addr;
- private static Stat[][] all_res;
private static final void debug(String s, Object... args) {
System.out.println(String.format(s, args));
@@ -75,57 +75,49 @@ public class CalculateAverage_abeobk {
}
}
- static class Stat {
- Node node;
- String key;
-
- public final String toString() {
- return (node.min / 10.0) + "/"
- + (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
- + (node.max / 10.0);
- }
-
- Stat(Node n) {
- node = n;
- byte[] sbuf = new byte[MAX_STR_LEN];
- long word = UNSAFE.getLong(n.addr);
- long semipos_code = getSemiPosCode(word);
- int keylen = 0;
- while (semipos_code == 0) {
- keylen += 8;
- word = UNSAFE.getLong(n.addr + keylen);
- semipos_code = getSemiPosCode(word);
- }
- keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
- UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
- key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
- }
- }
-
+ // use native type, less conversion
static class Node {
long addr;
+ long hash;
long word0;
long tail;
long sum;
+ long min, max;
+ int keylen;
int count;
- short min, max;
- Node(long a, long t, short val) {
+ public final String toString() {
+ return (min / 10.0) + "/"
+ + (Math.round(((double) sum / count)) / 10.0) + "/"
+ + (max / 10.0);
+ }
+
+ final String key() {
+ byte[] sbuf = new byte[MAX_STR_LEN];
+ UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
+ return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8);
+ }
+
+ Node(long a, long t, int kl, long h, long val) {
addr = a;
tail = t;
sum = min = max = val;
count = 1;
+ keylen = kl;
+ hash = h;
}
- Node(long a, long w0, long t, short val) {
+ Node(long a, long w0, long t, int kl, long h, long val) {
addr = a;
word0 = w0;
tail = t;
sum = min = max = val;
count = 1;
+ keylen = kl;
+ hash = h;
}
- final void add(short val) {
+ final void add(long val) {
sum += val;
count++;
if (val >= max) {
@@ -148,17 +140,28 @@ public class CalculateAverage_abeobk {
}
}
- final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) {
+ final boolean contentEquals(long other_addr, long other_word0, long other_tail, long kl) {
if (word0 != other_word0 || tail != other_tail)
return false;
// this is faster than comparision if key is short
long xsum = 0;
- int n = keylen & 0xF8;
- for (int i = 8; i < n; i += 8) {
+ long n = kl & 0xF8;
+ for (long i = 8; i < n; i += 8) {
xsum |= (UNSAFE.getLong(addr + i) ^ UNSAFE.getLong(other_addr + i));
}
return xsum == 0;
}
+
+ final boolean contentEquals(Node other) {
+ if (tail != other.tail)
+ return false;
+ long n = keylen & 0xF8;
+ for (long i = 0; i < n; i += 8) {
+ if (UNSAFE.getLong(addr + i) != UNSAFE.getLong(other.addr + i))
+ return false;
+ }
+ return true;
+ }
}
// idea from royvanrijn
@@ -168,24 +171,24 @@ public class CalculateAverage_abeobk {
}
// speed/collision balance
- static final int xxh32(long hash) {
+ static final long xxh32(long hash) {
long h = hash * 37;
- return (int) (h ^ (h >>> 29));
+ return (h ^ (h >>> 29));
}
// great idea from merykitty (Quan Anh Mai)
- static final short parseNum(long num_word, int dot_pos) {
+ static final long parseNum(long num_word, int dot_pos) {
int shift = 28 - dot_pos;
long signed = (~num_word << 59) >> 63;
long dsmask = ~(signed & 0xFF);
long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- return (short) ((abs_val ^ signed) - signed);
+ return ((abs_val ^ signed) - signed);
}
// Thread pool worker
static final class Worker extends Thread {
- final int thread_id;
+ final int thread_id; // for debug use only
Worker(int i) {
thread_id = i;
@@ -195,16 +198,15 @@ public class CalculateAverage_abeobk {
@Override
public void run() {
var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
- int cnt = 0;
int id;
int cls = 0;
// process in small chunk to maintain disk locality (artsiomkorzun trick)
- // but keep going instead of merging
while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
long addr = start_addr + id * CHUNK_SZ;
long end = Math.min(addr + CHUNK_SZ, end_addr);
- // adjust start
+
+ // find start of line
if (id > 0) {
while (UNSAFE.getByte(addr++) != '\n')
;
@@ -230,14 +232,14 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 3;
long tail = word0 & HASH_MASKS[semi_pos];
- int bucket = xxh32(tail) & BUCKET_MASK;
- short val = parseNum(num_word, dot_pos);
+ long hash = xxh32(tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+ long val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
- map[bucket] = new Node(row_addr, tail, val);
- cnt++;
+ map[bucket] = new Node(row_addr, tail, semi_pos, hash, val);
break;
}
if (node.tail == tail) {
@@ -263,14 +265,14 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 3;
long tail = (word & HASH_MASKS[semi_pos]);
- int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
- short val = parseNum(num_word, dot_pos);
+ long hash = xxh32(word0 ^ tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+ long val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, val);
- cnt++;
+ map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash, val);
break;
}
if (node.word0 == word0 && node.tail == tail) {
@@ -295,20 +297,20 @@ public class CalculateAverage_abeobk {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
addr += semi_pos;
- int keylen = (int) (addr - row_addr);
+ long keylen = addr - row_addr;
long num_word = UNSAFE.getLong(addr + 1);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 4;
long tail = (word & HASH_MASKS[semi_pos]);
- int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
- short val = parseNum(num_word, dot_pos);
+ hash = xxh32(hash ^ tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+ long val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, val);
- cnt++;
+ map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash, val);
break;
}
if (node.contentEquals(row_addr, word0, tail, keylen)) {
@@ -322,18 +324,36 @@ public class CalculateAverage_abeobk {
}
}
- if (SHOW_ANALYSIS) {
- debug("Thread %d collision = %d", thread_id, cls);
+ // merge is cheaper than string casting (artsiomkorzun)
+ while (!mapref.compareAndSet(null, map)) {
+ var other_map = mapref.getAndSet(null);
+ if (other_map != null) {
+ for (int i = 0; i < other_map.length; i++) {
+ var other = other_map[i];
+ if (other == null)
+ continue;
+ int bucket = (int) (other.hash & BUCKET_MASK);
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = other;
+ break;
+ }
+ if (node.contentEquals(other)) {
+ node.merge(other);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls++;
+ }
+ }
+ }
}
- Stat[] stats = new Stat[cnt];
- int i = 0;
- for (var node : map) {
- if (node != null) {
- stats[i++] = new Stat(node);
- }
+ if (SHOW_ANALYSIS) {
+ debug("Thread %d collision = %d", thread_id, cls);
}
- all_res[thread_id] = stats;
}
}
@@ -366,23 +386,22 @@ public class CalculateAverage_abeobk {
// only use all cpus on large file
int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT;
chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ);
- all_res = new Stat[cpu_cnt][];
- List<Worker> workers = IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList();
- for (var w : workers)
+ // spawn workers
+ for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) {
w.join();
-
- // collect all results
- TreeMap<String, Stat> ms = new TreeMap<>();
- for (var res : all_res) {
- for (var s : res) {
- var stat = ms.putIfAbsent(s.key, s);
- if (stat != null)
- stat.node.merge(s.node);
- }
}
- // print output
+ // collect results
+ TreeMap<String, Node> ms = new TreeMap<>();
+ for (var crr : mapref.get()) {
+ if (crr == null)
+ continue;
+ var prev = ms.putIfAbsent(crr.key(), crr);
+ if (prev != null)
+ prev.merge(crr);
+ }
+ // print result
System.out.println(ms);
System.out.close();
}