aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java456
1 files changed, 247 insertions, 209 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
index 2340bca..88de5d2 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
@@ -34,7 +34,6 @@ import java.util.stream.IntStream;
import sun.misc.Unsafe;
public class CalculateAverage_abeobk {
- private static final boolean SHOW_ANALYSIS = false;
private static final int CPU_CNT = Runtime.getRuntime().availableProcessors();
private static final String FILE = "./measurements.txt";
@@ -42,7 +41,7 @@ public class CalculateAverage_abeobk {
private static final long BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000;
- private static final long CHUNK_SZ = 1 << 22; // 4MB chunk
+ private static final long CHUNK_SZ = 1 << 22;
private static final Unsafe UNSAFE = initUnsafe();
private static final long[] HASH_MASKS = new long[]{
0x0L,
@@ -60,10 +59,6 @@ public class CalculateAverage_abeobk {
private static int chunk_cnt;
private static long start_addr, end_addr;
- private static final void debug(String s, Object... args) {
- System.out.println(String.format(s, args));
- }
-
private static Unsafe initUnsafe() {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
@@ -75,12 +70,117 @@ public class CalculateAverage_abeobk {
}
}
- // use native type, less conversion
- static class Node {
+ /*
+ * MAIN FUNCTION
+ */
+ public static void main(String[] args) throws InterruptedException, IOException {
+ // thomaswue trick
+ if (args.length == 0 || !("--worker".equals(args[0]))) {
+ spawnWorker();
+ return;
+ }
+
+ var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
+ long file_size = file.size();
+ start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
+ end_addr = start_addr + file_size;
+
+ // only use all cpus on large file
+ int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT;
+ chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ);
+
+ // spawn workers
+ for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) {
+ w.join();
+ }
+
+ // collect results
+ TreeMap<String, Node> ms = new TreeMap<>();
+ for (var crr : mapref.get()) {
+ if (crr == null)
+ continue;
+ var prev = ms.putIfAbsent(crr.key(), crr);
+ if (prev != null)
+ prev.merge(crr);
+ }
+ // print result
+ System.out.println(ms);
+ System.out.close();
+ }
+
+ /*
+ * HELPER FUNCTIONS
+ */
+
+ // Get semicolon pos code
+ static final long getSemiCode(final long w) {
+ long x = w ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
+ return (x - 0x0101010101010101L) & (~x & 0x8080808080808080L);
+ }
+
+ // Get new line pos code
+ static final long getLFCode(final long w) {
+ long x = w ^ 0x0A0A0A0A0A0A0A0AL; // xor with \n\n\n\n\n\n\n\n
+ return (x - 0x0101010101010101L) & (~x & 0x8080808080808080L);
+ }
+
+ // Get decimal point pos code
+ static final int getDotCode(final long w) {
+ return Long.numberOfTrailingZeros(~w & 0x10101000);
+ }
+
+ // Convert semicolon pos code to position
+ static final int getSemiPos(final long spc) {
+ return Long.numberOfTrailingZeros(spc) >>> 3;
+ }
+
+ // Find next line address
+ static final long nextLF(long addr) {
+ long word = UNSAFE.getLong(addr);
+ long lfpos_code = getLFCode(word);
+ while (lfpos_code == 0) {
+ addr += 8;
+ word = UNSAFE.getLong(addr);
+ lfpos_code = getLFCode(word);
+ }
+ return addr + (Long.numberOfTrailingZeros(lfpos_code) >>> 3) + 1;
+ }
+
+ // Parse number
+ // great idea from merykitty (Quan Anh Mai)
+ static final long num(long w, int d) {
+ int shift = 28 - d;
+ long signed = (~w << 59) >> 63;
+ long dsmask = ~(signed & 0xFF);
+ long digits = ((w & dsmask) << shift) & 0x0F000F0F00L;
+ long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ return ((abs_val ^ signed) - signed);
+ }
+
+ // Hash mixer
+ static final long mix(long hash) {
+ long h = hash * 37;
+ return (h ^ (h >>> 29));
+ }
+
+ // Spawn worker (thomaswue trick
+ private static void spawnWorker() throws IOException {
+ ProcessHandle.Info info = ProcessHandle.current().info();
+ ArrayList<String> workerCommand = new ArrayList<>();
+ info.command().ifPresent(workerCommand::add);
+ info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
+ workerCommand.add("--worker");
+ new ProcessBuilder()
+ .command(workerCommand)
+ .start()
+ .getInputStream()
+ .transferTo(System.out);
+ }
+
+ final static class Node {
long addr;
long hash;
long word0;
- long tail;
long sum;
long min, max;
int keylen;
@@ -98,23 +198,36 @@ public class CalculateAverage_abeobk {
return new String(sbuf, 0, (int) keylen, StandardCharsets.UTF_8);
}
- Node(long a, long t, int kl, long h) {
+ Node(long a, long h, int kl, long v) {
+ addr = a;
+ min = max = v;
+ keylen = kl;
+ hash = h;
+ }
+
+ Node(long a, long h, int kl) {
addr = a;
- tail = t;
+ hash = h;
min = 999;
max = -999;
keylen = kl;
+ }
+
+ Node(long a, long w0, long h, int kl, long v) {
+ addr = a;
+ word0 = w0;
hash = h;
+ min = max = v;
+ keylen = kl;
}
- Node(long a, long w0, long t, int kl, long h) {
+ Node(long a, long w0, long h, int kl) {
addr = a;
word0 = w0;
+ hash = h;
min = 999;
max = -999;
- tail = t;
keylen = kl;
- hash = h;
}
final void add(long val) {
@@ -139,8 +252,8 @@ public class CalculateAverage_abeobk {
}
}
- final boolean contentEquals(long other_addr, long other_word0, long other_tail, long kl) {
- if (word0 != other_word0 || tail != other_tail)
+ final boolean contentEquals(long other_addr, long other_word0, long other_hash, long kl) {
+ if (word0 != other_word0 || hash != other_hash)
return false;
// this is faster than comparision if key is short
long xsum = 0;
@@ -152,7 +265,7 @@ public class CalculateAverage_abeobk {
}
final boolean contentEquals(Node other) {
- if (tail != other.tail)
+ if (hash != other.hash)
return false;
long n = keylen & 0xF8;
for (long i = 0; i < n; i += 8) {
@@ -163,150 +276,13 @@ public class CalculateAverage_abeobk {
}
}
- // idea from royvanrijn
- static final long getSemiPosCode(final long word) {
- long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
- return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
- }
-
- static final long getLFCode(final long word) {
- long xor_semi = word ^ 0x0A0A0A0A0A0A0A0AL; // xor with \n\n\n\n\n\n\n\n
- return (xor_semi - 0x0101010101010101L) & (~xor_semi & 0x8080808080808080L);
- }
-
- static final long nextLine(long addr) {
- long word = UNSAFE.getLong(addr);
- long lfpos_code = getLFCode(word);
- while (lfpos_code == 0) {
- addr += 8;
- word = UNSAFE.getLong(addr);
- lfpos_code = getLFCode(word);
- }
- return addr + (Long.numberOfTrailingZeros(lfpos_code) >>> 3) + 1;
- }
-
- // speed/collision balance
- static final long xxh32(long hash) {
- long h = hash * 37;
- return (h ^ (h >>> 29));
- }
-
- static final class ChunkParser {
- long addr;
- long end;
- Node[] map;
-
- ChunkParser(Node[] m, long a, long e) {
- map = m;
- addr = a;
- end = e;
- }
-
- final boolean ok() {
- return addr < end;
- }
-
- final long word() {
- return UNSAFE.getLong(addr);
- }
-
- final long val() {
- long num_word = UNSAFE.getLong(addr);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 3;
- // great idea from merykitty (Quan Anh Mai)
- int shift = 28 - dot_pos;
- long signed = (~num_word << 59) >> 63;
- long dsmask = ~(signed & 0xFF);
- long digits = ((num_word & dsmask) << shift) & 0x0F000F0F00L;
- long abs_val = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- return ((abs_val ^ signed) - signed);
- }
-
- // optimize for contest
- // save as much slow memory access as possible
- // about 50% key < 8chars, 25% key bettween 8-10 chars
- // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
- final Node key(long word0, long semipos_code) {
- long row_addr = addr;
- // about 50% chance key < 8 chars
- if (semipos_code != 0) {
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos + 1;
- long tail = word0 & HASH_MASKS[semi_pos];
- long hash = xxh32(tail);
- int bucket = (int) (hash & BUCKET_MASK);
- while (true) {
- Node node = map[bucket];
- if (node == null) {
- return (map[bucket] = new Node(row_addr, tail, semi_pos, hash));
- }
- if (node.tail == tail) {
- return node;
- }
- bucket++;
- }
- }
-
- addr += 8;
- long word = UNSAFE.getLong(addr);
- semipos_code = getSemiPosCode(word);
- // 43% chance
- if (semipos_code != 0) {
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos + 1;
- long tail = (word & HASH_MASKS[semi_pos]);
- long hash = xxh32(word0 ^ tail);
- int bucket = (int) (hash & BUCKET_MASK);
- while (true) {
- Node node = map[bucket];
- if (node == null) {
- return (map[bucket] = new Node(row_addr, word0, tail, semi_pos + 8, hash));
- }
- if (node.word0 == word0 && node.tail == tail) {
- return node;
- }
- bucket++;
- }
- }
-
- // why not going for more? tested, slower
- long hash = word0;
- while (semipos_code == 0) {
- hash ^= word;
- addr += 8;
- word = UNSAFE.getLong(addr);
- semipos_code = getSemiPosCode(word);
- }
-
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos;
- long keylen = addr - row_addr;
- addr++;
- long tail = (word & HASH_MASKS[semi_pos]);
- hash = xxh32(hash ^ tail);
- int bucket = (int) (hash & BUCKET_MASK);
-
- while (true) {
- Node node = map[bucket];
- if (node == null) {
- return (map[bucket] = new Node(row_addr, word0, tail, (int) keylen, hash));
- }
- if (node.contentEquals(row_addr, word0, tail, keylen)) {
- return node;
- }
- bucket++;
- }
- }
- }
-
// Thread pool worker
static final class Worker extends Thread {
final int thread_id; // for debug use only
- int cls = 0;
Worker(int i) {
thread_id = i;
+ this.setPriority(Thread.MAX_PRIORITY);
this.start();
}
@@ -322,15 +298,15 @@ public class CalculateAverage_abeobk {
// find start of line
if (id > 0) {
- addr = nextLine(addr);
+ addr = nextLF(addr);
}
final int num_segs = 3;
long seglen = (end - addr) / num_segs;
long a0 = addr;
- long a1 = nextLine(addr + 1 * seglen);
- long a2 = nextLine(addr + 2 * seglen);
+ long a1 = nextLF(addr + 1 * seglen);
+ long a2 = nextLF(addr + 2 * seglen);
ChunkParser p0 = new ChunkParser(map, a0, a1);
ChunkParser p1 = new ChunkParser(map, a1, a2);
ChunkParser p2 = new ChunkParser(map, a2, end);
@@ -339,9 +315,9 @@ public class CalculateAverage_abeobk {
long w0 = p0.word();
long w1 = p1.word();
long w2 = p2.word();
- long sc0 = getSemiPosCode(w0);
- long sc1 = getSemiPosCode(w1);
- long sc2 = getSemiPosCode(w2);
+ long sc0 = getSemiCode(w0);
+ long sc1 = getSemiCode(w1);
+ long sc2 = getSemiCode(w2);
Node n0 = p0.key(w0, sc0);
Node n1 = p1.key(w1, sc1);
Node n2 = p2.key(w2, sc2);
@@ -355,21 +331,21 @@ public class CalculateAverage_abeobk {
while (p0.ok()) {
long w = p0.word();
- long sc = getSemiPosCode(w);
+ long sc = getSemiCode(w);
Node n = p0.key(w, sc);
long v = p0.val();
n.add(v);
}
while (p1.ok()) {
long w = p1.word();
- long sc = getSemiPosCode(w);
+ long sc = getSemiCode(w);
Node n = p1.key(w, sc);
long v = p1.val();
n.add(v);
}
while (p2.ok()) {
long w = p2.word();
- long sc = getSemiPosCode(w);
+ long sc = getSemiCode(w);
Node n = p2.key(w, sc);
long v = p2.val();
n.add(v);
@@ -396,65 +372,127 @@ public class CalculateAverage_abeobk {
break;
}
bucket++;
- if (SHOW_ANALYSIS)
- cls++;
}
}
}
}
-
- if (SHOW_ANALYSIS) {
- debug("Thread %d collision = %d", thread_id, cls);
- }
}
}
- // thomaswue trick
- private static void spawnWorker() throws IOException {
- ProcessHandle.Info info = ProcessHandle.current().info();
- ArrayList<String> workerCommand = new ArrayList<>();
- info.command().ifPresent(workerCommand::add);
- info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
- workerCommand.add("--worker");
- new ProcessBuilder()
- .command(workerCommand)
- .start()
- .getInputStream()
- .transferTo(System.out);
- }
+ static final class ChunkParser {
+ long addr;
+ long end;
+ Node[] map;
- public static void main(String[] args) throws InterruptedException, IOException {
- // thomaswue trick
- if (args.length == 0 || !("--worker".equals(args[0]))) {
- spawnWorker();
- return;
+ ChunkParser(Node[] m, long a, long e) {
+ map = m;
+ addr = a;
+ end = e;
}
- var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
- long file_size = file.size();
- start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
- end_addr = start_addr + file_size;
+ final boolean ok() {
+ return addr < end;
+ }
- // only use all cpus on large file
- int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT;
- chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ);
+ final long word() {
+ return UNSAFE.getLong(addr);
+ }
- // spawn workers
- for (var w : IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList()) {
- w.join();
+ final void skip(int n) {
+ addr += n;
}
- // collect results
- TreeMap<String, Node> ms = new TreeMap<>();
- for (var crr : mapref.get()) {
- if (crr == null)
- continue;
- var prev = ms.putIfAbsent(crr.key(), crr);
- if (prev != null)
- prev.merge(crr);
+ final void skip(long n) {
+ addr += n;
+ }
+
+ final long val0() {
+ long w = word();
+ int d = getDotCode(w);
+ return num(w, d);
+ }
+
+ final long val() {
+ long w = word();
+ int d = getDotCode(w);
+ skip((d >>> 3) + 3);
+ return num(w, d);
+ }
+
+ // optimize for contest
+ // save as much slow memory access as possible
+ // about 50% key < 8chars, 25% key bettween 8-10 chars
+ // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
+ final Node key(long word0, long semipos_code) {
+ long row_addr = addr;
+ // about 50% chance key < 8 chars
+ if (semipos_code != 0) {
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ skip(semi_pos + 1);
+ long tail = word0 & HASH_MASKS[semi_pos];
+ long hash = mix(tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+ while (true) {
+ Node node = map[bucket];
+ if (node == null) {
+ return (map[bucket] = new Node(row_addr, hash, semi_pos));
+ }
+ if (node.hash == hash) {
+ return node;
+ }
+ bucket++;
+ }
+ }
+
+ skip(8);
+ long word = UNSAFE.getLong(addr);
+ semipos_code = getSemiCode(word);
+ // 43% chance
+ if (semipos_code != 0) {
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ skip(semi_pos + 1);
+ long tail = word0 ^ (word & HASH_MASKS[semi_pos]);
+ long hash = mix(tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+ while (true) {
+ Node node = map[bucket];
+ if (node == null) {
+ return (map[bucket] = new Node(row_addr, word0, hash, semi_pos + 8));
+ }
+ if (node.word0 == word0 && node.hash == hash) {
+ return node;
+ }
+ bucket++;
+ }
+ }
+
+ // why not going for more? tested, slower
+ long hash = word0;
+ while (semipos_code == 0) {
+ hash ^= word;
+ skip(8);
+ word = UNSAFE.getLong(addr);
+ semipos_code = getSemiCode(word);
+ }
+
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ skip(semi_pos);
+ long keylen = addr - row_addr;
+ skip(1);
+ long tail = hash ^ (word & HASH_MASKS[semi_pos]);
+ hash = mix(tail);
+ int bucket = (int) (hash & BUCKET_MASK);
+
+ while (true) {
+ Node node = map[bucket];
+ if (node == null) {
+ return (map[bucket] = new Node(row_addr, word0, hash, (int) keylen));
+ }
+ if (node.contentEquals(row_addr, word0, hash, keylen)) {
+ return node;
+ }
+ bucket++;
+ }
}
- // print result
- System.out.println(ms);
- System.out.close();
}
} \ No newline at end of file