aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
authorVan Phu DO <abeobk@gmail.com>2024-01-26 06:57:04 +0900
committerGitHub <noreply@github.com>2024-01-25 22:57:04 +0100
commit271bdfb0329df636988455e450bb48a45f5b917f (patch)
treebbf9d06fdafa53421ba64008fcf69acb0256dbd0 /src/main/java
parentce9455a584413b575a2eb23633eb92bb415a0618 (diff)
Simplify Node class with less field, improve hash mix speed (#584)
* Simplify Node class with less field, improve hash mix speed * remove some ops, a bit faster * more inline, little bit faster but not sure
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java126
1 files changed, 60 insertions, 66 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
index 293a88c..ed859f3 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
@@ -39,6 +39,7 @@ public class CalculateAverage_abeobk {
private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100;
+ private static final int MAX_STATIONS = 10000;
private static final Unsafe UNSAFE = initUnsafe();
private static final long[] HASH_MASKS = new long[]{
0x0L,
@@ -66,6 +67,33 @@ public class CalculateAverage_abeobk {
}
}
+ static class Stat {
+ Node node;
+ String key;
+
+ public final String toString() {
+ return (node.min / 10.0) + "/"
+ + (Math.round(((double) node.sum / node.count)) / 10.0) + "/"
+ + (node.max / 10.0);
+ }
+
+ Stat(Node n) {
+ node = n;
+ byte[] sbuf = new byte[MAX_STR_LEN];
+ long word = UNSAFE.getLong(n.addr);
+ long semipos_code = getSemiPosCode(word);
+ int keylen = 0;
+ while (semipos_code == 0) {
+ keylen += 8;
+ word = UNSAFE.getLong(n.addr + keylen);
+ semipos_code = getSemiPosCode(word);
+ }
+ keylen += Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ UNSAFE.copyMemory(null, n.addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
+ key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
+ }
+ }
+
static class Node {
long addr;
long word0;
@@ -73,37 +101,23 @@ public class CalculateAverage_abeobk {
long sum;
int count;
short min, max;
- int keylen;
- String key;
- void calcKey() {
- byte[] sbuf = new byte[MAX_STR_LEN];
- UNSAFE.copyMemory(null, addr, sbuf, Unsafe.ARRAY_BYTE_BASE_OFFSET, keylen);
- key = new String(sbuf, 0, keylen, StandardCharsets.UTF_8);
- }
-
- public String toString() {
- return String.format("%.1f/%.1f/%.1f", min * 0.1, sum * 0.1 / count, max * 0.1);
- }
-
- Node(long a, long t, short val, int kl) {
+ Node(long a, long t, short val) {
addr = a;
tail = t;
- keylen = kl;
sum = min = max = val;
count = 1;
}
- Node(long a, long w0, long t, short val, int kl) {
+ Node(long a, long w0, long t, short val) {
addr = a;
word0 = w0;
tail = t;
- keylen = kl;
sum = min = max = val;
count = 1;
}
- void add(short val) {
+ final void add(short val) {
sum += val;
count++;
if (val >= max) {
@@ -115,7 +129,7 @@ public class CalculateAverage_abeobk {
}
}
- void merge(Node other) {
+ final void merge(Node other) {
sum += other.sum;
count += other.count;
if (other.max > max) {
@@ -126,8 +140,8 @@ public class CalculateAverage_abeobk {
}
}
- boolean contentEquals(long other_addr, long other_word0, long other_tail) {
- if (tail != other_tail || word0 != other_word0)
+ final boolean contentEquals(long other_addr, long other_word0, long other_tail, int keylen) {
+ if (word0 != other_word0 || tail != other_tail)
return false;
// this is faster than comparision if key is short
long xsum = 0;
@@ -161,11 +175,8 @@ public class CalculateAverage_abeobk {
// speed/collision balance
static final int xxh32(long hash) {
- final int p1 = 0x85EBCA77; // prime
- int low = (int) hash;
- int high = (int) (hash >>> 33);
- int h = (low * p1) ^ high;
- return h ^ (h >>> 17);
+ long h = hash * 37;
+ return (int) (h ^ (h >>> 29));
}
// great idea from merykitty (Quan Anh Mai)
@@ -185,11 +196,10 @@ public class CalculateAverage_abeobk {
static final Node[] parse(int thread_id, long start, long end) {
int cls = 0;
long addr = start;
- var map = new Node[BUCKET_SIZE + 10000]; // extra space for collisions
+ var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
// parse loop
while (addr < end) {
long row_addr = addr;
- long hash = 0;
long word0 = UNSAFE.getLong(addr);
long semipos_code = getSemiPosCode(word0);
@@ -202,14 +212,14 @@ public class CalculateAverage_abeobk {
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
addr += (dot_pos >>> 3) + 3;
- long tail = (word0 & HASH_MASKS[semi_pos]);
+ long tail = word0 & HASH_MASKS[semi_pos];
int bucket = xxh32(tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
- map[bucket] = new Node(row_addr, tail, val, semi_pos);
+ map[bucket] = new Node(row_addr, tail, val);
break;
}
if (node.tail == tail) {
@@ -223,28 +233,25 @@ public class CalculateAverage_abeobk {
continue;
}
- hash ^= word0;
addr += 8;
long word = UNSAFE.getLong(addr);
semipos_code = getSemiPosCode(word);
// 43% chance
if (semipos_code != 0) {
int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos;
- int keylen = (int) (addr - row_addr);
- long num_word = UNSAFE.getLong(addr + 1);
+ addr += semi_pos + 1;
+ long num_word = UNSAFE.getLong(addr);
int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 4;
+ addr += (dot_pos >>> 3) + 3;
long tail = (word & HASH_MASKS[semi_pos]);
- hash ^= tail;
- int bucket = xxh32(hash) & BUCKET_MASK;
+ int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, val, keylen);
+ map[bucket] = new Node(row_addr, word0, tail, val);
break;
}
if (node.word0 == word0 && node.tail == tail) {
@@ -258,6 +265,9 @@ public class CalculateAverage_abeobk {
continue;
}
+ // why not going for more? tested, slower
+
+ long hash = word0;
while (semipos_code == 0) {
hash ^= word;
addr += 8;
@@ -273,17 +283,16 @@ public class CalculateAverage_abeobk {
addr += (dot_pos >>> 3) + 4;
long tail = (word & HASH_MASKS[semi_pos]);
- hash ^= tail;
- int bucket = xxh32(hash) & BUCKET_MASK;
+ int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
short val = parseNum(num_word, dot_pos);
while (true) {
var node = map[bucket];
if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, val, keylen);
+ map[bucket] = new Node(row_addr, word0, tail, val);
break;
}
- if (node.contentEquals(row_addr, word0, tail)) {
+ if (node.contentEquals(row_addr, word0, tail, keylen)) {
node.add(val);
break;
}
@@ -292,6 +301,7 @@ public class CalculateAverage_abeobk {
cls++;
}
}
+
if (SHOW_ANALYSIS) {
debug("Thread %d collision = %d", thread_id, cls);
}
@@ -307,8 +317,6 @@ public class CalculateAverage_abeobk {
workerCommand.add("--worker");
new ProcessBuilder()
.command(workerCommand)
- .inheritIO()
- .redirectOutput(ProcessBuilder.Redirect.PIPE)
.start()
.getInputStream()
.transferTo(System.out);
@@ -333,43 +341,29 @@ public class CalculateAverage_abeobk {
// processing
var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
- TreeMap<String, Node> ms = new TreeMap<>();
- int[] lenhist = new int[64]; // length histogram
-
- List<List<Node>> maps = IntStream.range(0, cpu_cnt)
+ List<List<Stat>> maps = IntStream.range(0, cpu_cnt)
.mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
.map(map -> {
- List<Node> nodes = new ArrayList<>();
+ List<Stat> stats = new ArrayList<>();
for (var node : map) {
if (node == null)
continue;
- node.calcKey();
- nodes.add(node);
+ stats.add(new Stat(node));
}
- return nodes;
+ return stats;
})
.parallel()
.toList();
- for (var nodes : maps) {
- for (var node : nodes) {
- if (SHOW_ANALYSIS) {
- int kl = node.keylen & (lenhist.length - 1);
- lenhist[kl] += node.count;
- }
- var stat = ms.putIfAbsent(node.key, node);
+ TreeMap<String, Stat> ms = new TreeMap<>();
+ for (var stats : maps) {
+ for (var s : stats) {
+ var stat = ms.putIfAbsent(s.key, s);
if (stat != null)
- stat.merge(node);
+ stat.node.merge(s.node);
}
}
- if (SHOW_ANALYSIS) {
- debug("Total = " + Arrays.stream(lenhist).sum());
- debug("Length_histogram = "
- + Arrays.toString(Arrays.stream(lenhist).map(x -> (int) (x * 1.0e-7)).toArray()));
- return;
- }
-
// print result
System.out.println(ms);
System.out.close();