aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java476
1 files changed, 272 insertions, 204 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java
index eb3d191..e33fe7e 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_yavuztas.java
@@ -15,77 +15,80 @@
*/
package dev.morling.onebrc;
-import java.io.Closeable;
+import sun.misc.Unsafe;
+
import java.io.IOException;
+import java.lang.foreign.Arena;
+import java.lang.reflect.Field;
import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
-import java.util.HashMap;
-import java.util.Map;
import java.util.TreeMap;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
-import java.util.function.BiConsumer;
+import java.util.function.Consumer;
public class CalculateAverage_yavuztas {
private static final Path FILE = Path.of("./measurements.txt");
- static class Measurement {
-
- // Only accessed by a single thread, so it is safe to share
- private static final StringBuilder STRING_BUILDER = new StringBuilder(14);
-
- private int min; // calculations over int is faster than double, we convert to double in the end only once
- private int max;
- private long sum;
- private long count = 1;
+ private static final Unsafe UNSAFE = unsafe();
- public Measurement(int initial) {
- this.min = initial;
- this.max = initial;
- this.sum = initial;
+ // Tried all there: MappedByteBuffer, MemorySegment and Unsafe
+ // Accessing the memory using Unsafe is still the fastest in my experience
+ private static Unsafe unsafe() {
+ try {
+ final Field f = Unsafe.class.getDeclaredField("theUnsafe");
+ f.setAccessible(true);
+ return (Unsafe) f.get(null);
}
-
- public String toString() {
- STRING_BUILDER.setLength(0); // clear the builder to reuse
- STRING_BUILDER.append(this.min / 10.0); // convert to double while generating the string output
- STRING_BUILDER.append("/");
- STRING_BUILDER.append(round((this.sum / 10.0) / this.count));
- STRING_BUILDER.append("/");
- STRING_BUILDER.append(this.max / 10.0);
- return STRING_BUILDER.toString();
- }
-
- private double round(double value) {
- return Math.round(value * 10.0) / 10.0;
+ catch (Exception e) {
+ throw new RuntimeException(e);
}
}
- static class KeyBuffer {
+ // Only one object, both for measurements and keys, less object creation in hotpots is always faster
+ static class Record {
- ByteBuffer buffer;
+ // keep memory starting address for each segment
+ // since we use Unsafe, this is enough to align and fetch the data
+ long segment;
+ int start;
int length;
int hash;
- public KeyBuffer(ByteBuffer buffer, int length, int hash) {
- this.buffer = buffer;
+ private int min = 1000; // calculations over int is faster than double, we convert to double in the end only once
+ private int max = -1000;
+ private long sum;
+ private long count;
+
+ public Record(long segment, int start, int length, int hash) {
+ this.segment = segment;
+ this.start = start;
this.length = length;
this.hash = hash;
}
@Override
public boolean equals(Object o) {
- final KeyBuffer keyBuffer = (KeyBuffer) o;
- if (this.length != keyBuffer.length || this.hash != keyBuffer.hash)
+ final Record record = (Record) o;
+ return equals(record.segment, record.start, record.length, record.hash);
+ }
+
+ /**
+ * Stateless equals, no Record object needed
+ */
+ public boolean equals(long segment, int start, int length, int hash) {
+ if (this.length != length || this.hash != hash)
return false;
- return this.buffer.equals(keyBuffer.buffer);
+ int i = 0; // bytes mismatch check
+ while (i < this.length
+ && UNSAFE.getByte(this.segment + this.start + i) == UNSAFE.getByte(segment + start + i)) {
+ i++;
+ }
+ return i == this.length;
}
@Override
@@ -96,219 +99,284 @@ public class CalculateAverage_yavuztas {
@Override
public String toString() {
final byte[] bytes = new byte[this.length];
- this.buffer.get(bytes);
- return new String(bytes, 0, this.length, StandardCharsets.UTF_8);
- }
- }
+ int i = 0;
+ while (i < this.length) {
+ bytes[i] = UNSAFE.getByte(this.segment + this.start + i++);
+ }
- static class FixedRegionDataAccessor {
+ return new String(bytes, StandardCharsets.UTF_8);
+ }
- long startPos;
- long size;
- ByteBuffer buffer;
- int position; // relative
+ public Record collect(int temp) {
+ this.min = Math.min(this.min, temp);
+ this.max = Math.max(this.max, temp);
+ this.sum += temp;
+ this.count++;
+ return this;
+ }
- public FixedRegionDataAccessor(long startPos, long size, ByteBuffer buffer) {
- this.startPos = startPos;
- this.size = size;
- this.buffer = buffer;
+ public void merge(Record other) {
+ this.min = Math.min(this.min, other.min);
+ this.max = Math.max(this.max, other.max);
+ this.sum += other.sum;
+ this.count += other.count;
}
- void traverse(BiConsumer<KeyBuffer, Integer> consumer) {
- int keyHash;
- int length;
- while (this.buffer.hasRemaining()) {
+ public String measurements() {
+ // here is only executed once for each unique key, so StringBuilder creation doesn't harm
+ final StringBuilder sb = new StringBuilder(14);
+ sb.append(this.min / 10.0);
+ sb.append("/");
+ sb.append(round((this.sum / 10.0) / this.count));
+ sb.append("/");
+ sb.append(this.max / 10.0);
+ return sb.toString();
+ }
+ }
- this.position = this.buffer.position(); // save line start pos
+ // Inspired by @spullara - customized hashmap on purpose
+ // The main difference is we hold only one array instead of two
+ static class RecordMap {
- byte b;
- keyHash = 0;
- length = 0;
- while ((b = this.buffer.get()) != ';') { // read until semicolon
- keyHash = 31 * keyHash + b; // calculate key hash ahead, eleminates one more loop later
- length++;
- }
+ static final int SIZE = 1 << 15; // 32k - bigger bucket size less collisions
+ static final int BITMASK = SIZE - 1;
+ Record[] keys = new Record[SIZE];
- final ByteBuffer station = this.buffer.slice(this.position, length);
- final KeyBuffer key = new KeyBuffer(station, length, keyHash);
+ static int hashBucket(int hash) {
+ hash = hash ^ (hash >>> 16); // naive bit spreading but surprisingly decreases collision :)
+ return hash & BITMASK; // fast modulo, to find bucket
+ }
- this.buffer.mark(); // semicolon pos
- skip(3); // skip more since minimum temperature length is 3
- length = 4; // +1 for semicolon
+ void putAndCollect(long segment, int start, int length, int hash, int temp) {
+ int bucket = hashBucket(hash);
+ Record existing = this.keys[bucket];
+ if (existing == null) {
+ this.keys[bucket] = new Record(segment, start, length, hash)
+ .collect(temp);
+ return;
+ }
- while (this.buffer.get() != '\n') {
- length++; // read until linebreak
- // TODO how to read temperature here
+ if (!existing.equals(segment, start, length, hash)) {
+ // collision, linear probing to find a slot
+ while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(segment, start, length, hash)) {
+ // can be stuck here if all the buckets are full :(
+ // However, since the data set is max 10K (unique) this shouldn't happen
+ // So, I'm happily leave here branchless :)
}
-
- this.buffer.reset(); // set to after semicolon
- consumer.accept(key, readTemperature(length));
+ if (existing == null) {
+ this.keys[bucket & BITMASK] = new Record(segment, start, length, hash)
+ .collect(temp);
+ return;
+ }
+ existing.collect(temp);
+ }
+ else {
+ existing.collect(temp);
}
}
- Map<KeyBuffer, Measurement> accumulate(Map<KeyBuffer, Measurement> initial) {
-
- traverse((station, temperature) -> {
- initial.compute(station, (k, m) -> {
- if (m == null) {
- return new Measurement(temperature);
- }
- // aggregate
- m.min = Math.min(m.min, temperature);
- m.max = Math.max(m.max, temperature);
- m.sum += temperature;
- m.count++;
- return m;
- });
- });
-
- return initial;
- }
-
- // caching Math.pow calculation improves a lot!
- // interestingly, instance field access is much faster than static field access
- final int[] powerOfTenCache = new int[]{ 1, 10, 100 };
-
- int readTemperature(int length) {
- int temp = 0;
- final byte b1 = this.buffer.get(); // get first byte
-
- int digits = length - 4; // digit position
- final boolean negative = b1 == '-';
- if (!negative) {
- temp += this.powerOfTenCache[digits + 1] * (b1 - 48); // add first digit ahead
+ void putOrMerge(Record key) {
+ int bucket = hashBucket(key.hash);
+ Record existing = this.keys[bucket];
+ if (existing == null) {
+ this.keys[bucket] = key;
+ return;
}
- byte b;
- while ((b = this.buffer.get()) != '.') { // read until dot
- temp += this.powerOfTenCache[digits--] * (b - 48);
+ if (!existing.equals(key)) {
+ // collision, linear probing to find a slot
+ while ((existing = this.keys[++bucket & BITMASK]) != null && !existing.equals(key)) {
+ // can be stuck here if all the buckets are full :(
+ // However, since the data set is max 10K (unique keys) this shouldn't happen
+ // So, I'm happily leave here branchless :)
+ }
+ if (existing == null) {
+ this.keys[bucket & BITMASK] = key;
+ return;
+ }
+ existing.merge(key);
+ }
+ else {
+ existing.merge(key);
}
- b = this.buffer.get(); // read after dot, only one digit no loop
- temp += this.powerOfTenCache[digits] * (b - 48);
- this.buffer.get(); // skip line break
-
- return (negative) ? -temp : temp;
}
- ByteBuffer getKeyRef(int length) {
- final ByteBuffer slice = this.buffer.slice().limit(length - 1);
- skip(length);
- return slice;
+ void forEach(Consumer<Record> consumer) {
+ int pos = 0;
+ Record key;
+ while (pos < this.keys.length) {
+ if ((key = this.keys[pos++]) == null) {
+ continue;
+ }
+ consumer.accept(key);
+ }
}
- void skip(int length) {
- final int pos = this.buffer.position();
- this.buffer.position(pos + length);
+ void merge(RecordMap other) {
+ other.forEach(this::putOrMerge);
}
}
- static class FastDataReader implements Closeable {
+ // One actor for one thread, no synchronization
+ static class RegionActor {
- private final FixedRegionDataAccessor[] accessors;
- private final ExecutorService mergerThread;
- private final ExecutorService accessorPool;
+ final FileChannel channel;
+ final long startPos;
+ final int size;
+ final RecordMap map = new RecordMap();
+ long segmentAddress;
+ int position;
+ Thread runner; // each actor has its own thread
- public FastDataReader(Path path) throws IOException {
- var concurrency = Runtime.getRuntime().availableProcessors();
- final long fileSize = Files.size(path);
- long regionSize = fileSize / concurrency;
+ public RegionActor(FileChannel channel, long startPos, int size) {
+ this.channel = channel;
+ this.startPos = startPos;
+ this.size = size;
+ }
- // handling extreme cases
- while (regionSize > Integer.MAX_VALUE) {
- concurrency *= 2;
- regionSize = fileSize / concurrency;
- }
- if (regionSize <= 256) { // small file, no need concurrency
- concurrency = 1;
- regionSize = fileSize;
- }
+ void accumulate() {
+ this.runner = new Thread(() -> {
+ try {
+ // get the segment memory address, this is the only thing we need for Unsafe
+ this.segmentAddress = this.channel.map(FileChannel.MapMode.READ_ONLY, this.startPos, this.size, Arena.global()).address();
+ }
+ catch (IOException e) {
+ // no-op - skip intentionally, no handling for the purpose of this challenge
+ }
- long startPosition = 0;
- this.accessors = new FixedRegionDataAccessor[concurrency];
- for (int i = 0; i < concurrency - 1; i++) {
- // map regions
- try (final FileChannel channel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) {
- final long maxSize = startPosition + regionSize > fileSize ? fileSize - startPosition : regionSize;
- final MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, startPosition, maxSize);
- this.accessors[i] = new FixedRegionDataAccessor(startPosition, maxSize, buffer);
- // adjust positions back and forth until we find a linebreak!
- final int closestPos = findClosestLineEnd((int) maxSize - 1, buffer);
- buffer.limit(closestPos + 1);
- startPosition += closestPos + 1;
+ int start;
+ int keyHash;
+ int length;
+ while (this.position < this.size) {
+ byte b;
+ start = this.position; // save line start position
+ keyHash = UNSAFE.getByte(this.segmentAddress + this.position++); // first byte is guaranteed not to be ';'
+ length = 1; // min key length
+ while ((b = UNSAFE.getByte(this.segmentAddress + this.position++)) != ';') { // read until semicolon
+ keyHash = calculateHash(keyHash, b); // calculate key hash ahead, eleminates one more loop later
+ length++;
+ }
+
+ final int temp = readTemperature();
+ this.map.putAndCollect(this.segmentAddress, start, length, keyHash, temp);
+
+ this.position++; // skip linebreak
}
- }
- // map the last region
- try (final FileChannel channel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) {
- final long maxSize = fileSize - startPosition; // last region will take the rest
- final MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, startPosition, maxSize);
- this.accessors[concurrency - 1] = new FixedRegionDataAccessor(startPosition, maxSize, buffer);
- }
- // create executors
- this.mergerThread = Executors.newSingleThreadExecutor();
- this.accessorPool = Executors.newFixedThreadPool(concurrency);
+ });
+ this.runner.start();
}
- void readAndCollect(Map<KeyBuffer, Measurement> output) {
- for (final FixedRegionDataAccessor accessor : this.accessors) {
- this.accessorPool.submit(() -> {
- final Map<KeyBuffer, Measurement> partial = accessor.accumulate(new HashMap<>(1 << 10, 1)); // aka 1k
- this.mergerThread.submit(() -> mergeMaps(output, partial));
- });
- }
+ static int calculateHash(int hash, int b) {
+ return 31 * hash + b;
}
- @Override
- public void close() {
- try {
- this.accessorPool.shutdown();
- this.accessorPool.awaitTermination(60, TimeUnit.SECONDS);
- this.mergerThread.shutdown();
- this.mergerThread.awaitTermination(60, TimeUnit.SECONDS);
+ // 1. Inspired by @yemreinci - Reading temparature value without Double.parse
+ // 2. Inspired by @obourgain - Fetching first 4 bytes ahead, then masking
+ int readTemperature() {
+ int temp = 0;
+ // read 4 bytes ahead
+ final int first4 = UNSAFE.getInt(this.segmentAddress + this.position);
+ this.position += 3;
+
+ final byte b1 = (byte) first4; // first byte
+ final byte b2 = (byte) ((first4 >> 8) & 0xFF); // second byte
+ final byte b3 = (byte) ((first4 >> 16) & 0xFF); // third byte
+ if (b1 == '-') {
+ if (b3 == '.') {
+ temp -= 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte
+ this.position++;
+ }
+ else {
+ this.position++; // skip dot
+ temp -= 100 * (b2 - '0') + 10 * (b3 - '0') + UNSAFE.getByte(this.segmentAddress + this.position++) - '0'; // fifth byte
+ }
}
- catch (Exception e) {
- this.accessorPool.shutdownNow();
- this.mergerThread.shutdownNow();
+ else {
+ if (b2 == '.') {
+ temp = 10 * (b1 - '0') + b3 - '0';
+ }
+ else {
+ temp = 100 * (b1 - '0') + 10 * (b2 - '0') + (byte) ((first4 >> 24) & 0xFF) - '0'; // fourth byte
+ this.position++;
+ }
}
+
+ return temp;
}
/**
- * Scans the given buffer to the left
+ * blocks until the map is fully collected
*/
- private static int findClosestLineEnd(int regionSize, ByteBuffer buffer) {
- int position = regionSize;
- int left = regionSize;
- while (buffer.get(position) != '\n') {
- position = --left;
- }
- return position;
+ RecordMap get() throws InterruptedException {
+ this.runner.join();
+ return this.map;
}
+ }
- private static Map<KeyBuffer, Measurement> mergeMaps(Map<KeyBuffer, Measurement> map1, Map<KeyBuffer, Measurement> map2) {
- map2.forEach((s, measurement) -> {
- map1.merge(s, measurement, (m1, m2) -> {
- m1.min = Math.min(m1.min, m2.min);
- m1.max = Math.max(m1.max, m2.max);
- m1.sum += m2.sum;
- m1.count += m2.count;
- return m1;
- });
- });
+ private static double round(double value) {
+ return Math.round(value * 10.0) / 10.0;
+ }
- return map1;
+ /**
+ * Scans the given buffer to the left
+ */
+ private static long findClosestLineEnd(long start, int size, FileChannel channel) throws IOException {
+ final long position = start + size;
+ final long left = Math.max(position - 101, 0);
+ final ByteBuffer buffer = ByteBuffer.allocate(101); // enough size to find at least one '\n'
+ if (channel.read(buffer.clear(), left) != -1) {
+ int bufferPos = buffer.position() - 1;
+ while (buffer.get(bufferPos) != '\n') {
+ bufferPos--;
+ size--;
+ }
}
-
+ return size;
}
public static void main(String[] args) throws IOException, InterruptedException {
- final Map<KeyBuffer, Measurement> output = new HashMap<>(1 << 10, 1); // aka 1k
- try (final FastDataReader reader = new FastDataReader(FILE)) {
- reader.readAndCollect(output);
+
+ var concurrency = Runtime.getRuntime().availableProcessors();
+ final long fileSize = Files.size(FILE);
+ long regionSize = fileSize / concurrency;
+
+ // handling extreme cases
+ while (regionSize > Integer.MAX_VALUE) {
+ concurrency *= 2;
+ regionSize /= 2;
+ }
+ if (fileSize <= 1 << 20) { // small file (1mb), no need concurrency
+ concurrency = 1;
+ regionSize = fileSize;
+ }
+
+ long startPos = 0;
+ final FileChannel channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ);
+ final RegionActor[] actors = new RegionActor[concurrency];
+ for (int i = 0; i < concurrency; i++) {
+ // calculate boundaries
+ long maxSize = (startPos + regionSize > fileSize) ? fileSize - startPos : regionSize;
+ // shift position to back until we find a linebreak
+ maxSize = findClosestLineEnd(startPos, (int) maxSize, channel);
+
+ final RegionActor region = (actors[i] = new RegionActor(channel, startPos, (int) maxSize));
+ region.accumulate();
+
+ startPos += maxSize;
}
- final TreeMap<String, Measurement> sorted = new TreeMap<>();
- output.forEach((s, measurement) -> sorted.put(s.toString(), measurement));
+ final RecordMap output = new RecordMap(); // output to merge all regions
+ for (RegionActor actor : actors) {
+ final RecordMap partial = actor.get(); // blocks until get the result
+ output.merge(partial);
+ }
+
+ // sort and print the result
+ final TreeMap<String, String> sorted = new TreeMap<>();
+ output.forEach(key -> sorted.put(key.toString(), key.measurements()));
System.out.println(sorted);
+
}
}