aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java327
1 files changed, 172 insertions, 155 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
index ed859f3..06cbc17 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_abeobk.java
@@ -28,18 +28,21 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import sun.misc.Unsafe;
public class CalculateAverage_abeobk {
private static final boolean SHOW_ANALYSIS = false;
+ private static final int CPU_CNT = Runtime.getRuntime().availableProcessors();
private static final String FILE = "./measurements.txt";
private static final int BUCKET_SIZE = 1 << 16;
private static final int BUCKET_MASK = BUCKET_SIZE - 1;
private static final int MAX_STR_LEN = 100;
private static final int MAX_STATIONS = 10000;
+ private static final long CHUNK_SZ = 1 << 22; // 4MB chunk
private static final Unsafe UNSAFE = initUnsafe();
private static final long[] HASH_MASKS = new long[]{
0x0L,
@@ -52,6 +55,11 @@ public class CalculateAverage_abeobk {
0xffffffffffffffL,
0xffffffffffffffffL, };
+ private static AtomicInteger chunk_id = new AtomicInteger(0);
+ private static int chunk_cnt;
+ private static long start_addr, end_addr;
+ private static Stat[][] all_res;
+
private static final void debug(String s, Object... args) {
System.out.println(String.format(s, args));
}
@@ -153,20 +161,6 @@ public class CalculateAverage_abeobk {
}
}
- // split into chunks
- static long[] slice(long start_addr, long end_addr, long chunk_size, int cpu_cnt) {
- long[] ptrs = new long[cpu_cnt + 1];
- ptrs[0] = start_addr;
- for (int i = 1; i < cpu_cnt; i++) {
- long addr = start_addr + i * chunk_size;
- while (addr < end_addr && UNSAFE.getByte(addr++) != '\n')
- ;
- ptrs[i] = Math.min(addr, end_addr);
- }
- ptrs[cpu_cnt] = end_addr;
- return ptrs;
- }
-
// idea from royvanrijn
static final long getSemiPosCode(final long word) {
long xor_semi = word ^ 0x3b3b3b3b3b3b3b3bL; // xor with ;;;;;;;;
@@ -189,123 +183,158 @@ public class CalculateAverage_abeobk {
return (short) ((abs_val ^ signed) - signed);
}
- // optimize for contest
- // save as much slow memory access as possible
- // about 50% key < 8chars, 25% key bettween 8-10 chars
- // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
- static final Node[] parse(int thread_id, long start, long end) {
- int cls = 0;
- long addr = start;
- var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
- // parse loop
- while (addr < end) {
- long row_addr = addr;
-
- long word0 = UNSAFE.getLong(addr);
- long semipos_code = getSemiPosCode(word0);
-
- // about 50% chance key < 8 chars
- if (semipos_code != 0) {
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos + 1;
- long num_word = UNSAFE.getLong(addr);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 3;
-
- long tail = word0 & HASH_MASKS[semi_pos];
- int bucket = xxh32(tail) & BUCKET_MASK;
- short val = parseNum(num_word, dot_pos);
-
- while (true) {
- var node = map[bucket];
- if (node == null) {
- map[bucket] = new Node(row_addr, tail, val);
- break;
+ // Thread pool worker
+ static final class Worker extends Thread {
+ final int thread_id;
+
+ Worker(int i) {
+ thread_id = i;
+ this.start();
+ }
+
+ @Override
+ public void run() {
+ var map = new Node[BUCKET_SIZE + MAX_STATIONS]; // extra space for collisions
+ int cnt = 0;
+ int id;
+ int cls = 0;
+
+ // process in small chunk to maintain disk locality (artsiomkorzun trick)
+ // but keep going instead of merging
+ while ((id = chunk_id.getAndIncrement()) < chunk_cnt) {
+ long addr = start_addr + id * CHUNK_SZ;
+ long end = Math.min(addr + CHUNK_SZ, end_addr);
+ // adjust start
+ if (id > 0) {
+ while (UNSAFE.getByte(addr++) != '\n')
+ ;
+ }
+
+ // parse loop
+ // optimize for contest
+ // save as much slow memory access as possible
+ // about 50% key < 8chars, 25% key bettween 8-10 chars
+ // keylength histogram (%) = [0, 0, 0, 0, 4, 10, 21, 15, 13, 11, 6, 6, 4, 2...
+ while (addr < end) {
+ long row_addr = addr;
+
+ long word0 = UNSAFE.getLong(addr);
+ long semipos_code = getSemiPosCode(word0);
+
+ // about 50% chance key < 8 chars
+ if (semipos_code != 0) {
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos + 1;
+ long num_word = UNSAFE.getLong(addr);
+ int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
+ addr += (dot_pos >>> 3) + 3;
+
+ long tail = word0 & HASH_MASKS[semi_pos];
+ int bucket = xxh32(tail) & BUCKET_MASK;
+ short val = parseNum(num_word, dot_pos);
+
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = new Node(row_addr, tail, val);
+ cnt++;
+ break;
+ }
+ if (node.tail == tail) {
+ node.add(val);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls++;
+ }
+ continue;
}
- if (node.tail == tail) {
- node.add(val);
- break;
+
+ addr += 8;
+ long word = UNSAFE.getLong(addr);
+ semipos_code = getSemiPosCode(word);
+ // 43% chance
+ if (semipos_code != 0) {
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos + 1;
+ long num_word = UNSAFE.getLong(addr);
+ int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
+ addr += (dot_pos >>> 3) + 3;
+
+ long tail = (word & HASH_MASKS[semi_pos]);
+ int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
+ short val = parseNum(num_word, dot_pos);
+
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = new Node(row_addr, word0, tail, val);
+ cnt++;
+ break;
+ }
+ if (node.word0 == word0 && node.tail == tail) {
+ node.add(val);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls++;
+ }
+ continue;
}
- bucket++;
- if (SHOW_ANALYSIS)
- cls++;
- }
- continue;
- }
- addr += 8;
- long word = UNSAFE.getLong(addr);
- semipos_code = getSemiPosCode(word);
- // 43% chance
- if (semipos_code != 0) {
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos + 1;
- long num_word = UNSAFE.getLong(addr);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 3;
-
- long tail = (word & HASH_MASKS[semi_pos]);
- int bucket = xxh32(word0 ^ tail) & BUCKET_MASK;
- short val = parseNum(num_word, dot_pos);
-
- while (true) {
- var node = map[bucket];
- if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, val);
- break;
+ // why not going for more? tested, slower
+ long hash = word0;
+ while (semipos_code == 0) {
+ hash ^= word;
+ addr += 8;
+ word = UNSAFE.getLong(addr);
+ semipos_code = getSemiPosCode(word);
}
- if (node.word0 == word0 && node.tail == tail) {
- node.add(val);
- break;
+
+ int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
+ addr += semi_pos;
+ int keylen = (int) (addr - row_addr);
+ long num_word = UNSAFE.getLong(addr + 1);
+ int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
+ addr += (dot_pos >>> 3) + 4;
+
+ long tail = (word & HASH_MASKS[semi_pos]);
+ int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
+ short val = parseNum(num_word, dot_pos);
+
+ while (true) {
+ var node = map[bucket];
+ if (node == null) {
+ map[bucket] = new Node(row_addr, word0, tail, val);
+ cnt++;
+ break;
+ }
+ if (node.contentEquals(row_addr, word0, tail, keylen)) {
+ node.add(val);
+ break;
+ }
+ bucket++;
+ if (SHOW_ANALYSIS)
+ cls++;
}
- bucket++;
- if (SHOW_ANALYSIS)
- cls++;
}
- continue;
}
- // why not going for more? tested, slower
-
- long hash = word0;
- while (semipos_code == 0) {
- hash ^= word;
- addr += 8;
- word = UNSAFE.getLong(addr);
- semipos_code = getSemiPosCode(word);
+ if (SHOW_ANALYSIS) {
+ debug("Thread %d collision = %d", thread_id, cls);
}
- int semi_pos = Long.numberOfTrailingZeros(semipos_code) >>> 3;
- addr += semi_pos;
- int keylen = (int) (addr - row_addr);
- long num_word = UNSAFE.getLong(addr + 1);
- int dot_pos = Long.numberOfTrailingZeros(~num_word & 0x10101000);
- addr += (dot_pos >>> 3) + 4;
-
- long tail = (word & HASH_MASKS[semi_pos]);
- int bucket = xxh32(hash ^ tail) & BUCKET_MASK;
- short val = parseNum(num_word, dot_pos);
-
- while (true) {
- var node = map[bucket];
- if (node == null) {
- map[bucket] = new Node(row_addr, word0, tail, val);
- break;
- }
- if (node.contentEquals(row_addr, word0, tail, keylen)) {
- node.add(val);
- break;
+ Stat[] stats = new Stat[cnt];
+ int i = 0;
+ for (var node : map) {
+ if (node != null) {
+ stats[i++] = new Stat(node);
}
- bucket++;
- if (SHOW_ANALYSIS)
- cls++;
}
+ all_res[thread_id] = stats;
}
-
- if (SHOW_ANALYSIS) {
- debug("Thread %d collision = %d", thread_id, cls);
- }
- return map;
}
// thomaswue trick
@@ -329,44 +358,32 @@ public class CalculateAverage_abeobk {
return;
}
- try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
- long start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
- long file_size = file.size();
- long end_addr = start_addr + file_size;
-
- // only use all cpus on large file
- int cpu_cnt = file_size < 1e6 ? 1 : Runtime.getRuntime().availableProcessors();
- long chunk_size = Math.ceilDiv(file_size, cpu_cnt);
-
- // processing
- var ptrs = slice(start_addr, end_addr, chunk_size, cpu_cnt);
-
- List<List<Stat>> maps = IntStream.range(0, cpu_cnt)
- .mapToObj(thread_id -> parse(thread_id, ptrs[thread_id], ptrs[thread_id + 1]))
- .map(map -> {
- List<Stat> stats = new ArrayList<>();
- for (var node : map) {
- if (node == null)
- continue;
- stats.add(new Stat(node));
- }
- return stats;
- })
- .parallel()
- .toList();
-
- TreeMap<String, Stat> ms = new TreeMap<>();
- for (var stats : maps) {
- for (var s : stats) {
- var stat = ms.putIfAbsent(s.key, s);
- if (stat != null)
- stat.node.merge(s.node);
- }
+ var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
+ long file_size = file.size();
+ start_addr = file.map(MapMode.READ_ONLY, 0, file.size(), Arena.global()).address();
+ end_addr = start_addr + file_size;
+
+ // only use all cpus on large file
+ int cpu_cnt = file_size < 1e6 ? 1 : CPU_CNT;
+ chunk_cnt = (int) Math.ceilDiv(file_size, CHUNK_SZ);
+ all_res = new Stat[cpu_cnt][];
+
+ List<Worker> workers = IntStream.range(0, cpu_cnt).mapToObj(i -> new Worker(i)).toList();
+ for (var w : workers)
+ w.join();
+
+ // collect all results
+ TreeMap<String, Stat> ms = new TreeMap<>();
+ for (var res : all_res) {
+ for (var s : res) {
+ var stat = ms.putIfAbsent(s.key, s);
+ if (stat != null)
+ stat.node.merge(s.node);
}
-
- // print result
- System.out.println(ms);
- System.out.close();
}
+
+ // print output
+ System.out.println(ms);
+ System.out.close();
}
} \ No newline at end of file