aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorArtsiom Korzun <72259616+artsiomkorzun@users.noreply.github.com>2024-01-11 09:00:24 +0100
committerGitHub <noreply@github.com>2024-01-11 09:00:24 +0100
commit8602a355048a8d34d95611e7f80c00e1ea9b853a (patch)
tree9136283e5e6afdad7d39545fe780674a49656e49 /src/main
parent085168a0b3c73b64409afcf58a1f0a67f746a30a (diff)
improved artsiomkorzun solution (#176)
improved artsiomkorzun solution improved artsiomkorzun solution Co-authored-by: Artsiom Korzun <akorzun@deltixlab.com>
Diffstat (limited to 'src/main')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java444
1 files changed, 249 insertions, 195 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
index 516a6ab..c9b7144 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
@@ -15,28 +15,46 @@
*/
package dev.morling.onebrc;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
+import sun.misc.Unsafe;
+
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.lang.reflect.Field;
+import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
-import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
-import java.util.Arrays;
-import java.util.Comparator;
+import java.util.Map;
+import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
public class CalculateAverage_artsiomkorzun {
private static final Path FILE = Path.of("./measurements.txt");
- private static final long FILE_SIZE = size(FILE);
+ private static final MemorySegment MAPPED_FILE = map(FILE);
private static final int PARALLELISM = Runtime.getRuntime().availableProcessors();
private static final int SEGMENT_SIZE = 16 * 1024 * 1024;
- private static final int SEGMENT_COUNT = (int) ((FILE_SIZE + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
+ private static final int SEGMENT_COUNT = (int) ((MAPPED_FILE.byteSize() + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
private static final int SEGMENT_OVERLAP = 1024;
+ private static final long COMMA_PATTERN = pattern(';');
+ private static final long DOT_BITS = 0x10101000;
+ private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
+
+ private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder();
+ private static final Unsafe UNSAFE;
+
+ static {
+ try {
+ Field unsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ unsafe.setAccessible(true);
+ UNSAFE = (Unsafe) unsafe.get(Unsafe.class);
+ }
+ catch (Throwable e) {
+ throw new RuntimeException(e);
+ }
+ }
public static void main(String[] args) throws Exception {
// for (int i = 0; i < 10; i++) {
@@ -63,196 +81,231 @@ public class CalculateAverage_artsiomkorzun {
aggregators[i].join();
}
- Aggregates aggregates = result.get();
- aggregates.sort();
-
- print(aggregates);
- }
-
- private static void print(Aggregates aggregates) {
- StringBuilder builder = new StringBuilder(aggregates.size() * 15 + 32);
- builder.append("{");
- aggregates.visit(aggregate -> {
- if (builder.length() > 1) {
- builder.append(", ");
- }
-
- builder.append(aggregate);
- });
- builder.append("}");
- System.out.println(builder);
+ Map<String, Aggregate> aggregates = result.get().aggregate();
+ System.out.println(text(aggregates));
}
- private static long size(Path file) {
- try {
- return Files.size(file);
+ private static MemorySegment map(Path file) {
+ try (FileChannel channel = FileChannel.open(file, StandardOpenOption.READ)) {
+ long size = channel.size();
+ return channel.map(FileChannel.MapMode.READ_ONLY, 0, size, Arena.global());
}
- catch (IOException e) {
+ catch (Throwable e) {
throw new RuntimeException(e);
}
}
- private static class Row {
- final byte[] station = new byte[256];
- int length;
- int hash;
- int temperature;
-
- @Override
- public String toString() {
- return new String(station, 0, length) + ":" + temperature;
- }
+ private static long pattern(char c) {
+ long b = c & 0xFFL;
+ return b | (b << 8) | (b << 16) | (b << 24) | (b << 32) | (b << 40) | (b << 48) | (b << 56);
}
- private static class Aggregate implements Comparable<Aggregate> {
- final byte[] station;
- final int hash;
- int min;
- int max;
- long sum;
- int count;
-
- public Aggregate(Row row) {
- this.station = Arrays.copyOf(row.station, row.length);
- this.hash = row.hash;
- this.min = row.temperature;
- this.max = row.temperature;
- this.sum = row.temperature;
- this.count = 1;
- }
+ private static long getLongBigEndian(long address) {
+ long value = UNSAFE.getLong(address);
- public void add(Row row) {
- min = Math.min(min, row.temperature);
- max = Math.max(max, row.temperature);
- sum += row.temperature;
- count++;
+ if (BYTE_ORDER == ByteOrder.LITTLE_ENDIAN) {
+ value = Long.reverseBytes(value);
}
- public void merge(Aggregate right) {
- min = Math.min(min, right.min);
- max = Math.max(max, right.max);
- sum += right.sum;
- count += right.count;
+ return value;
+ }
+
+ private static long getLongLittleEndian(long address) {
+ long value = UNSAFE.getLong(address);
+
+ if (BYTE_ORDER == ByteOrder.BIG_ENDIAN) {
+ value = Long.reverseBytes(value);
}
- @Override
- public int compareTo(Aggregate that) {
- byte[] lhs = this.station;
- byte[] rhs = that.station;
- int limit = Math.min(lhs.length, rhs.length);
+ return value;
+ }
- for (int offset = 0; offset < limit; offset++) {
- int left = lhs[offset];
- int right = rhs[offset];
+ private static String text(Map<String, Aggregate> aggregates) {
+ StringBuilder text = new StringBuilder(aggregates.size() * 32 + 2);
+ text.append('{');
- if (left != right) {
- return (left & 0xFF) - (right & 0xFF);
- }
+ for (Map.Entry<String, Aggregate> entry : aggregates.entrySet()) {
+ if (text.length() > 1) {
+ text.append(", ");
}
- return lhs.length - rhs.length;
+ Aggregate aggregate = entry.getValue();
+ text.append(entry.getKey()).append('=')
+ .append(round(aggregate.min)).append('/')
+ .append(round(1.0 * aggregate.sum / aggregate.cnt)).append('/')
+ .append(round(aggregate.max));
}
- @Override
- public String toString() {
- return new String(station) + "=" + round(min) + "/" + round(1.0 * sum / count) + "/" + round(max);
- }
+ text.append('}');
+ return text.toString();
+ }
- private static double round(double v) {
- return Math.round(v) / 10.0;
- }
+ private static double round(double v) {
+ return Math.round(v) / 10.0;
}
- private static class Aggregates {
+ private static class Row {
+ long address;
+ int length;
+ int hash;
+ int value;
+ }
- private static final int GROW_FACTOR = 4;
- private static final float LOAD_FACTOR = 0.55f;
+ private record Aggregate(int min, int max, long sum, int cnt) {
+ }
- private Aggregate[] aggregates = new Aggregate[1024];
- private int limit = (int) (aggregates.length * LOAD_FACTOR);
- private int size;
+ private static class Aggregates {
- public int size() {
- return size;
- }
+ private static final int SIZE = 16 * 1024;
+ private final long pointer;
- public void visit(Consumer<Aggregate> consumer) {
- if (size > 0) {
- for (Aggregate aggregate : aggregates) {
- if (aggregate != null) {
- consumer.accept(aggregate);
- }
- }
+ public Aggregates() {
+ int size = 32 * SIZE;
+ long address = UNSAFE.allocateMemory(size + 8096);
+ pointer = (address + 4095) & (~4095);
+ UNSAFE.setMemory(pointer, size, (byte) 0);
+
+ long word = pack(Short.MAX_VALUE, Short.MIN_VALUE, 0);
+ for (int i = 0; i < SIZE; i++) {
+ long entry = pointer + 32 * i;
+ UNSAFE.putLong(entry + 24, word);
}
}
public void add(Row row) {
- int index = row.hash & (aggregates.length - 1);
+ long index = index(row.hash);
+ long header = ((long) row.hash << 32) | (row.length);
while (true) {
- Aggregate aggregate = aggregates[index];
-
- if (aggregate == null) {
- aggregates[index] = new Aggregate(row);
- if (++size >= limit) {
- grow();
- }
- break;
- }
-
- if (row.hash == aggregate.hash && Arrays.equals(row.station, 0, row.length, aggregate.station, 0, aggregate.station.length)) {
- aggregate.add(row);
+ long address = pointer + (index << 5);
+ long head = UNSAFE.getLong(address);
+ long ref = UNSAFE.getLong(address + 8);
+ boolean isHit = (head == 0) || (head == header && equal(ref, row.address, row.length));
+
+ if (isHit) {
+ long sum = UNSAFE.getLong(address + 16) + row.value;
+ long word = UNSAFE.getLong(address + 24);
+ int min = Math.min(min(word), row.value);
+ int max = Math.max(max(word), row.value);
+ int cnt = cnt(word) + 1;
+
+ UNSAFE.putLong(address, header);
+ UNSAFE.putLong(address + 8, row.address);
+ UNSAFE.putLong(address + 16, sum);
+ UNSAFE.putLong(address + 24, pack(min, max, cnt));
break;
}
- index = (index + 1) & (aggregates.length - 1);
+ index = (index + 1) & (SIZE - 1);
}
}
- public void merge(Aggregate right) {
- int index = right.hash & (aggregates.length - 1);
+ public void merge(Aggregates rights) {
+ for (int rightIndex = 0; rightIndex < SIZE; rightIndex++) {
+ long rightAddress = rights.pointer + (rightIndex << 5);
+ long header = UNSAFE.getLong(rightAddress);
+ long reference = UNSAFE.getLong(rightAddress + 8);
- while (true) {
- Aggregate aggregate = aggregates[index];
+ if (header == 0) {
+ continue;
+ }
- if (aggregate == null) {
- aggregates[index] = right;
- if (++size >= limit) {
- grow();
+ int hash = (int) (header >>> 32);
+ int length = (int) (header);
+ long index = index(hash);
+
+ while (true) {
+ long address = pointer + (index << 5);
+ long head = UNSAFE.getLong(address);
+ long ref = UNSAFE.getLong(address + 8);
+ boolean isHit = (head == 0) || (head == header && equal(ref, reference, length));
+
+ if (isHit) {
+ long sum = UNSAFE.getLong(address + 16) + UNSAFE.getLong(rightAddress + 16);
+ long left = UNSAFE.getLong(address + 24);
+ long right = UNSAFE.getLong(rightAddress + 24);
+ int min = Math.min(min(left), min(right));
+ int max = Math.max(max(left), max(right));
+ int cnt = cnt(left) + cnt(right);
+
+ UNSAFE.putLong(address, header);
+ UNSAFE.putLong(address + 8, reference);
+ UNSAFE.putLong(address + 16, sum);
+ UNSAFE.putLong(address + 24, pack(min, max, cnt));
+ break;
}
- break;
+
+ index = (index + 1) & (SIZE - 1);
}
+ }
+ }
- if (right.hash == aggregate.hash && Arrays.equals(right.station, aggregate.station)) {
- aggregate.merge(right);
- break;
+ public Map<String, Aggregate> aggregate() {
+ TreeMap<String, Aggregate> set = new TreeMap<>();
+
+ for (int index = 0; index < SIZE; index++) {
+ long address = pointer + (index << 5);
+ long head = UNSAFE.getLong(address);
+ long ref = UNSAFE.getLong(address + 8);
+
+ if (head == 0) {
+ continue;
}
- index = (index + 1) & (aggregates.length - 1);
+ int length = (int) (head);
+ byte[] array = new byte[length];
+ UNSAFE.copyMemory(null, ref, array, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
+ String key = new String(array);
+
+ long sum = UNSAFE.getLong(address + 16);
+ long word = UNSAFE.getLong(address + 24);
+
+ Aggregate aggregate = new Aggregate(min(word), max(word), sum, cnt(word));
+ set.put(key, aggregate);
}
+
+ return set;
}
- public Aggregates sort() {
- Arrays.sort(aggregates, Comparator.nullsLast(Aggregate::compareTo));
- return this;
+ private static long pack(int min, int max, int cnt) {
+ return ((long) min << 48) | (((long) max & 0xFFFF) << 32) | cnt;
}
- private void grow() {
- Aggregate[] oldAggregates = aggregates;
- aggregates = new Aggregate[oldAggregates.length * GROW_FACTOR];
- limit = (int) (aggregates.length * LOAD_FACTOR);
+ private static int cnt(long word) {
+ return (int) word;
+ }
- for (Aggregate aggregate : oldAggregates) {
- if (aggregate != null) {
- int index = aggregate.hash & (aggregates.length - 1);
+ private static int max(long word) {
+ return (short) (word >>> 32);
+ }
- while (aggregates[index] != null) {
- index = (index + 1) & (aggregates.length - 1);
- }
+ private static int min(long word) {
+ return (short) (word >>> 48);
+ }
- aggregates[index] = aggregate;
+ private static long index(int hash) {
+ return (hash ^ (hash >> 16)) & (SIZE - 1);
+ }
+
+ private static boolean equal(long leftAddress, long rightAddress, int length) {
+ int index = 0;
+
+ while (length > 8) {
+ long left = UNSAFE.getLong(leftAddress + index);
+ long right = UNSAFE.getLong(rightAddress + index);
+
+ if (left != right) {
+ return false;
}
+
+ length -= 8;
+ index += 8;
}
+
+ int shift = 64 - (length << 3);
+ long left = getLongBigEndian(leftAddress + index) >>> shift;
+ long right = getLongBigEndian(rightAddress + index) >>> shift;
+ return (left == right);
}
}
@@ -272,87 +325,88 @@ public class CalculateAverage_artsiomkorzun {
Aggregates aggregates = new Aggregates();
Row row = new Row();
- try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) {
- for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) {
- aggregate(channel, segment, aggregates, row);
- }
- }
- catch (Throwable e) {
- throw new RuntimeException(e);
+ for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) {
+ aggregate(aggregates, row, segment);
}
while (!result.compareAndSet(null, aggregates)) {
Aggregates rights = result.getAndSet(null);
if (rights != null) {
- aggregates = merge(aggregates, rights);
+ aggregates.merge(rights);
}
}
}
- private static void aggregate(FileChannel channel, int segment, Aggregates aggregates, Row row) throws Exception {
+ private static void aggregate(Aggregates aggregates, Row row, int segment) {
long position = (long) SEGMENT_SIZE * segment;
- int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position);
- int limit = Math.min(SEGMENT_SIZE, size - 1);
+ int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, MAPPED_FILE.byteSize() - position);
+ long address = MAPPED_FILE.address() + position;
+ long limit = address + Math.min(SEGMENT_SIZE, size - 1);
- MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, position, size);
-
- if (position > 0) {
- next(buffer);
+ if (segment > 0) {
+ address = next(address);
}
- for (int offset = buffer.position(); offset <= limit;) {
- offset = parse(buffer, row, offset);
+ while (address <= limit) {
+ // this parsing can produce seg fault at page boundaries
+ // e.g. file size is 4096 and the last entry is X=0.0, which is less than 8 bytes
+ // as a result a read will be split across pages, where one of them is not mapped
+ // but for some reason it works on my machine, leaving to investigate
+ address = parseKey(address, row);
+ address = parseValue(address, row);
aggregates.add(row);
}
}
- private static Aggregates merge(Aggregates lefts, Aggregates rights) {
- if (rights.size() < lefts.size()) {
- Aggregates temp = lefts;
- lefts = rights;
- rights = temp;
- }
-
- rights.visit(lefts::merge);
- return lefts;
- }
-
- private static void next(ByteBuffer buffer) {
- while (buffer.get() != '\n') {
+ private static long next(long address) {
+ while (UNSAFE.getByte(address++) != '\n') {
// continue
}
+ return address;
}
- private static int parse(ByteBuffer buffer, Row row, int offset) {
- byte[] station = row.station;
+ // idea: royvanrijn
+ // explanation: https://richardstartin.github.io/posts/finding-bytes
+ private static long parseKey(long address, Row row) {
int length = 0;
- int hash = 0;
-
- for (byte b; (b = buffer.get(offset++)) != ';';) {
- station[length++] = b;
- hash = 71 * hash + b;
- }
+ long hash = 0;
+ long word;
- row.length = length;
- row.hash = hash;
+ while (true) {
+ word = getLongLittleEndian(address + length);
+ long match = word ^ COMMA_PATTERN;
+ long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L;
+
+ if (mask == 0) {
+ hash = 71 * hash + word;
+ length += 8;
+ continue;
+ }
- int sign = 1;
+ int bit = Long.numberOfTrailingZeros(mask);
+ length += (bit >>> 3);
+ hash = 71 * hash + (word & (0x00FFFFFFFFFFFFFFL >>> (63 - bit)));
- if (buffer.get(offset) == '-') {
- sign = -1;
- offset++;
- }
+ row.address = address;
+ row.length = length;
+ row.hash = Long.hashCode(hash);
- int value = buffer.get(offset++) - '0';
-
- if (buffer.get(offset) != '.') {
- value = 10 * value + buffer.get(offset++) - '0';
+ return address + length + 1;
}
+ }
- value = 10 * value + buffer.get(offset + 1) - '0';
- row.temperature = value * sign;
- return offset + 3;
+ // idea: merykitty
+ private static long parseValue(long address, Row row) {
+ long word = getLongLittleEndian(address);
+ long inverted = ~word;
+ int dot = Long.numberOfTrailingZeros(inverted & DOT_BITS);
+ long signed = (inverted << 59) >> 63;
+ long mask = ~(signed & 0xFF);
+ long digits = ((word & mask) << (28 - dot)) & 0x0F000F0F00L;
+ long abs = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
+ row.value = (int) ((abs ^ signed) - signed);
+ return address + (dot >> 3) + 3;
}
}
}