aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java377
1 files changed, 224 insertions, 153 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java
index 21abbb1..dce3a33 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java
@@ -15,188 +15,143 @@
*/
package dev.morling.onebrc;
+import sun.misc.Unsafe;
+
import java.io.IOException;
import java.io.PrintStream;
-import java.nio.ByteBuffer;
+import java.lang.foreign.Arena;
+import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Objects;
+import java.util.TreeMap;
+import java.util.stream.Stream;
import static java.nio.channels.FileChannel.MapMode.READ_ONLY;
import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.stream.Collectors.toMap;
public class CalculateAverage_armandino {
- private static final String FILE = "./measurements.txt";
+ private static final Path FILE = Path.of("./measurements.txt");
- private static final int MAX_KEY_LENGTH = 100;
+ private static final int NUM_CHUNKS = Math.max(8, Runtime.getRuntime().availableProcessors());
+ private static final int INITIAL_MAP_CAPACITY = 8192;
private static final byte SEMICOLON = 59;
private static final byte NL = 10;
private static final byte DOT = 46;
private static final byte MINUS = 45;
+ private static final byte ZERO_DIGIT = 48;
+ private static final Unsafe UNSAFE = getUnsafe();
public static void main(String[] args) throws Exception {
- Aggregator aggregator = new Aggregator();
- aggregator.process();
- aggregator.printStats();
- }
-
- private static class Aggregator {
-
- private final Map<Integer, Stats> map = new ConcurrentHashMap<>(2048);
-
- private record Chunk(long start, long end) {
- }
+ var channel = FileChannel.open(FILE, StandardOpenOption.READ);
- void process() throws Exception {
- var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
- final Chunk[] chunks = split(channel);
- final Thread[] threads = new Thread[chunks.length];
+ var results = Arrays.stream(split(channel)).parallel()
+ .map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end))
+ .flatMap(SimpleMap::stream)
+ .collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new));
- for (int i = 0; i < chunks.length; i++) {
- final Chunk chunk = chunks[i];
-
- threads[i] = Thread.ofVirtual().start(() -> {
- try {
- var bb = channel.map(READ_ONLY, chunk.start, chunk.end - chunk.start);
- process(bb);
- }
- catch (IOException e) {
- throw new RuntimeException(e);
- }
- });
- }
-
- for (Thread t : threads) {
- t.join();
- }
- }
+ print(results.values());
+ }
- private static Chunk[] split(final FileChannel channel) throws IOException {
- final long fileSize = channel.size();
- if (fileSize < 10000) {
- return new Chunk[]{ new Chunk(0, fileSize) };
- }
+ private static Stats mergeStats(final Stats x, final Stats y) {
+ x.min = Math.min(x.min, y.min);
+ x.max = Math.max(x.max, y.max);
+ x.count += y.count;
+ x.sum += y.sum;
+ return x;
+ }
- final int numChunks = 8;
- final long chunkSize = fileSize / numChunks;
- final var chunks = new Chunk[numChunks];
+ private static class ChunkProcessor {
+ private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY);
- for (int i = 0; i < numChunks; i++) {
- long start = 0;
- long end = chunkSize;
+ private SimpleMap process(final long chunkStart, final long chunkEnd) {
+ long i = chunkStart;
+ while (i < chunkEnd) {
+ final long keyAddress = i;
+ int keyHash = 0;
+ int measurement = 0;
+ byte b;
- if (i > 0) {
- start = chunks[i - 1].end + 1;
- end = Math.min(start + chunkSize, fileSize);
+ while ((b = UNSAFE.getByte(i++)) != SEMICOLON) {
+ keyHash = 31 * keyHash + b;
}
- end = end == fileSize ? end : seekNextNewline(channel, end);
- chunks[i] = new Chunk(start, end);
- }
- return chunks;
- }
+ final int keyLength = (int) (i - keyAddress - 1);
- private static long seekNextNewline(final FileChannel channel, final long end) throws IOException {
- var bb = ByteBuffer.allocate(MAX_KEY_LENGTH);
- channel.position(end).read(bb);
-
- for (int i = 0; i < bb.limit(); i++) {
- if (bb.get(i) == NL) {
- return end + i;
- }
- }
-
- throw new IllegalStateException("Couldn't find next newline");
- }
-
- private void process(final ByteBuffer bb) {
- final var sample = new Sample();
- var isKey = true;
-
- for (long i = 0, sz = bb.limit(); i < sz; i++) {
-
- final byte b = bb.get();
+ if ((b = UNSAFE.getByte(i++)) == MINUS) {
+ while ((b = UNSAFE.getByte(i++)) != DOT) {
+ measurement = measurement * 10 + b - ZERO_DIGIT;
+ }
- if (b == SEMICOLON) {
- isKey = false;
- }
- else if (b == NL) {
- isKey = true;
- addSample(sample);
- sample.reset();
- }
- else if (isKey) {
- sample.pushKey(b);
- }
- else if (b == DOT) {
- // skip
- }
- else if (b == MINUS) {
- sample.sign = -1;
+ b = UNSAFE.getByte(i);
+ measurement = measurement * 10 + b - ZERO_DIGIT;
+ measurement = -measurement;
+ i += 2;
}
else {
- sample.pushMeasurement(b);
- }
- }
- }
+ measurement = b - ZERO_DIGIT; // D1
+ b = UNSAFE.getByte(i); // dot or D2
- private void addSample(final Sample sample) {
- final Stats stats = map.computeIfAbsent(sample.keyHash,
- k -> new Stats(new String(sample.keyBytes, 0, sample.keyLength, UTF_8)));
-
- final var val = sample.getMeasurement();
-
- if (val < stats.min)
- stats.min = val;
-
- if (val > stats.max)
- stats.max = val;
-
- stats.sum += val;
- stats.count++;
- }
-
- void printStats() {
- var sorted = new ArrayList<>(map.values());
- Collections.sort(sorted);
-
- int size = sorted.size();
-
- System.out.print('{');
-
- for (Stats stats : sorted) {
- stats.print(System.out);
- if (--size > 0) {
- System.out.print(", ");
+ if (b == DOT) {
+ measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F
+ i += 3;
+ }
+ else {
+ measurement = measurement * 10 + b - ZERO_DIGIT; // D2
+ measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F
+ i += 4; // skip NL
+ }
}
+
+ final Stats stats = map.putStats(keyHash, keyAddress, keyLength);
+ stats.min = Math.min(stats.min, measurement);
+ stats.max = Math.max(stats.max, measurement);
+ stats.sum += measurement;
+ stats.count++;
}
- System.out.println('}');
+ return map;
}
}
private static class Stats implements Comparable<Stats> {
- private final String city;
+ private String key;
+ private final byte[] keyBytes;
+ private final int keyLength;
+ private final int keyHash;
private int min = Integer.MAX_VALUE;
private int max = Integer.MIN_VALUE;
- private long sum;
private int count;
+ private long sum;
+
+ private Stats(long keyAddress, int keyLength, int keyHash) {
+ this.keyLength = keyLength;
+ this.keyBytes = new byte[keyLength];
+ this.keyHash = keyHash;
+
+ for (int i = 0; i < keyLength; i++) {
+ keyBytes[i] = UNSAFE.getByte(keyAddress++);
+ }
+ }
- private Stats(String city) {
- this.city = city;
+ String getKey() {
+ if (key == null) {
+ key = new String(keyBytes, 0, keyLength, UTF_8);
+ }
+ return key;
}
@Override
public int compareTo(final Stats o) {
- return city.compareTo(o.city);
+ return getKey().compareTo(o.getKey());
}
void print(final PrintStream out) {
- out.print(city);
+ out.print(key);
out.print('=');
out.print(round(min / 10f));
out.print('/');
@@ -210,32 +165,148 @@ public class CalculateAverage_armandino {
}
}
- private static class Sample {
- private final byte[] keyBytes = new byte[MAX_KEY_LENGTH];
- private int keyLength;
- private int keyHash;
- private int measurement;
- private int sign = 1;
+ private static void print(final Collection<Stats> sorted) {
+ int size = sorted.size();
+ System.out.print('{');
+ for (Stats stats : sorted) {
+ stats.print(System.out);
+ if (--size > 0) {
+ System.out.print(", ");
+ }
+ }
+ System.out.println('}');
+ }
- void pushKey(byte b) {
- keyBytes[keyLength++] = b;
- keyHash = 31 * keyHash + b;
+ private static Chunk[] split(final FileChannel channel) throws IOException {
+ final long fileSize = channel.size();
+ long start = channel.map(READ_ONLY, 0, fileSize, Arena.global()).address();
+ final long endAddress = start + fileSize;
+ if (fileSize < 10000) {
+ return new Chunk[]{ new Chunk(start, endAddress) };
}
- void pushMeasurement(byte b) {
- final int i = b - '0';
- measurement = measurement * 10 + i;
+ final long chunkSize = fileSize / NUM_CHUNKS;
+ final var chunks = new Chunk[NUM_CHUNKS];
+ long end = start + chunkSize;
+
+ for (int i = 0; i < NUM_CHUNKS; i++) {
+ if (i > 0) {
+ start = chunks[i - 1].end;
+ end = Math.min(start + chunkSize, endAddress);
+ }
+ if (end < endAddress) {
+ while (UNSAFE.getByte(end) != NL) {
+ end++;
+ }
+ end++;
+ }
+ chunks[i] = new Chunk(start, end);
}
+ return chunks;
+ }
+
+ private record Chunk(long start, long end) {
+ }
+
+ private static Unsafe getUnsafe() {
+ try {
+ Field unsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ unsafe.setAccessible(true);
+ return (Unsafe) unsafe.get(null);
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static class SimpleMap {
+ private Stats[] table;
- int getMeasurement() {
- return sign * measurement;
+ SimpleMap(int initialCapacity) {
+ table = new Stats[initialCapacity];
+ }
+
+ Stream<Stats> stream() {
+ return Arrays.stream(table).filter(Objects::nonNull);
+ }
+
+ private void resize() {
+ var copy = new SimpleMap(table.length * 2);
+ for (Stats s : table) {
+ if (s != null) {
+ final int pos = (copy.table.length - 1) & s.keyHash;
+ int i = pos;
+
+ if (copy.table[i] == null) {
+ copy.table[i] = s;
+ continue;
+ }
+
+ while (i < copy.table.length && copy.table[i] != null) {
+ i++;
+ }
+ if (i == copy.table.length) {
+ i = pos;
+ while (i >= 0 && copy.table[i] != null) {
+ i--;
+ }
+ }
+ if (i < 0) {
+ // shouldn't happen because put() is called after increasing size
+ throw new IllegalStateException("table is full");
+ }
+ copy.table[i] = s;
+ }
+ }
+ table = copy.table;
+ }
+
+ Stats putStats(final int keyHash, final long keyAddress, final int keyLength) {
+ final int pos = (table.length - 1) & keyHash;
+
+ Stats stats = table[pos];
+ if (stats == null)
+ return createAt(table, keyAddress, keyLength, keyHash, pos);
+ if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength))
+ return stats;
+
+ int i = pos;
+ while (++i < table.length) {
+ stats = table[i];
+ if (stats == null)
+ return createAt(table, keyAddress, keyLength, keyHash, i);
+ if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
+ return stats;
+ }
+
+ i = pos;
+ while (i-- > 0) {
+ stats = table[i];
+ if (stats == null)
+ return createAt(table, keyAddress, keyLength, keyHash, i);
+ if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
+ return stats;
+ }
+ resize();
+ return putStats(keyHash, keyAddress, keyLength);
+ }
+
+ private boolean keysEqual(Stats stats, long keyAddress, final int keyLength) {
+ if (stats.keyLength != keyLength) {
+ return false;
+ }
+ for (int i = 0; i < keyLength; i++) {
+ if (stats.keyBytes[i] != UNSAFE.getByte(keyAddress++)) {
+ return false;
+ }
+ }
+ return true;
}
- void reset() {
- keyHash = 0;
- keyLength = 0;
- measurement = 0;
- sign = 1;
+ private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) {
+ Stats stats = new Stats(keyAddress, keyLength, key);
+ table[i] = stats;
+ return stats;
}
}
}