aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java231
1 files changed, 131 insertions, 100 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java
index 1325255..194dbcc 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jparera.java
@@ -1,3 +1,5 @@
+//COMPILE_OPTIONS -source 21 --enable-preview --add-modules jdk.incubator.vector
+//RUNTIME_OPTIONS --enable-preview --add-modules jdk.incubator.vector
/*
* Copyright 2023 The original authors
*
@@ -19,6 +21,8 @@ import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.VarHandle;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
@@ -26,7 +30,6 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.List;
import java.util.TreeMap;
import java.util.function.Function;
@@ -34,25 +37,41 @@ import java.util.stream.Collectors;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.VectorSpecies;
+import jdk.incubator.vector.VectorOperators;
public class CalculateAverage_jparera {
private static final String FILE = "./measurements.txt";
- private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED;
+ private static final VarHandle BYTE_HANDLE = MethodHandles
+ .memorySegmentViewVarHandle(ValueLayout.JAVA_BYTE);
+
+ private static final VarHandle INT_HANDLE = MethodHandles
+ .memorySegmentViewVarHandle(ValueLayout.JAVA_INT_UNALIGNED);
- private static final int BYTE_SPECIES_SIZE = BYTE_SPECIES.vectorByteSize();
+ private static final VarHandle LONG_LE_HANDLE = MethodHandles
+ .memorySegmentViewVarHandle(ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN));
+
+ private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED;
private static final int BYTE_SPECIES_LANES = BYTE_SPECIES.length();
- private static final ValueLayout.OfLong LONG_U_LE = ValueLayout.JAVA_LONG_UNALIGNED
- .withOrder(ByteOrder.LITTLE_ENDIAN);
+ private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder();
+
+ private static final byte LF = '\n';
- public static void main(String[] args) throws IOException {
+ private static final byte SEPARATOR = ';';
+
+ private static final byte DECIMAL_SEPARATOR = '.';
+
+ private static final byte NEG = '-';
+
+ public static void main(String[] args) throws IOException, InterruptedException {
try (var fc = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
try (var arena = Arena.ofShared()) {
var fs = fc.map(MapMode.READ_ONLY, 0, fc.size(), arena);
- var map = chunks(fs)
- .parallelStream()
+ var cpus = Runtime.getRuntime().availableProcessors();
+ var output = chunks(fs, cpus).stream()
+ .parallel()
.map(Chunk::parse)
.flatMap(List::stream)
.collect(Collectors.toMap(
@@ -60,20 +79,19 @@ public class CalculateAverage_jparera {
Function.identity(),
Entry::merge,
TreeMap::new));
- System.out.println(map);
+ System.out.println(output);
}
}
}
- private static Collection<Chunk> chunks(MemorySegment ms) {
- var cpus = Runtime.getRuntime().availableProcessors();
- long expectedChunkSize = Math.ceilDiv(ms.byteSize(), cpus);
- var chunks = new ArrayList<Chunk>();
+ private static List<Chunk> chunks(MemorySegment ms, int splits) {
long fileSize = ms.byteSize();
+ long expectedChunkSize = Math.ceilDiv(fileSize, splits);
+ var chunks = new ArrayList<Chunk>();
long offset = 0;
while (offset < fileSize) {
var end = Math.min(offset + expectedChunkSize, fileSize);
- while (end < fileSize && ms.get(ValueLayout.JAVA_BYTE, end++) != '\n') {
+ while (end < fileSize && (byte) BYTE_HANDLE.get(ms, end++) != LF) {
}
long len = end - offset;
chunks.add(new Chunk(ms.asSlice(offset, len)));
@@ -83,25 +101,27 @@ public class CalculateAverage_jparera {
}
private static final class Chunk {
- private static final byte SEPARATOR = ';';
+ private static final int KEY_LOG2_BYTES = 7;
- private static final byte DECIMAL_SEPARATOR = '.';
+ private static final int KEY_BYTES = 1 << KEY_LOG2_BYTES;
- private static final byte LF = '\n';
+ private static final int ENTRIES_LOG2_CAPACITY = 16;
- private static final byte MINUS = '-';
+ private static final int ENTRIES_CAPACITY = 1 << ENTRIES_LOG2_CAPACITY;
- private static final int KEY_LOG2_BYTES = 7;
+ private static final int ENTRIES_MASK = ENTRIES_CAPACITY - 1;
- private static final int KEY_BYTES = 1 << KEY_LOG2_BYTES;
+ private final MemorySegment segment;
- private static final int MAP_CAPACITY = 1 << 16;
+ private final long size;
- private static final int BUCKET_MASK = MAP_CAPACITY - 1;
+ private final Entry[] entries = new Entry[ENTRIES_CAPACITY];
- private final MemorySegment segment;
+ private final byte[] keys = new byte[ENTRIES_CAPACITY * KEY_BYTES];
+
+ private final MemorySegment kms = MemorySegment.ofArray(this.keys);
- private final Entry[] entries = new Entry[MAP_CAPACITY];
+ private static final int KEYS_MASK = (ENTRIES_CAPACITY * KEY_BYTES) - 1;
private long offset;
@@ -111,26 +131,23 @@ public class CalculateAverage_jparera {
Chunk(MemorySegment segment) {
this.segment = segment;
+ this.size = segment.byteSize();
}
public List<Entry> parse() {
- long size = this.segment.byteSize();
long safe = size - KEY_BYTES;
while (offset < safe) {
- var e = vectorizedEntry();
- int value = vectorizedValue();
- e.add(value);
+ vectorizedEntry().add(vectorizedValue());
}
next();
while (hasCurrent()) {
- var e = entry();
- int value = value();
- e.add(value);
+ entry().add(value());
}
var output = new ArrayList<Entry>(entries.length);
- for (int i = 0; i < entries.length; i++) {
+ for (int i = 0, o = 0; i < entries.length; i++, o += KEY_BYTES) {
var e = entries[i];
if (e != null) {
+ e.setkey(keys, o);
output.add(e);
}
}
@@ -138,29 +155,48 @@ public class CalculateAverage_jparera {
}
private Entry vectorizedEntry() {
- var start = this.offset;
- var first = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start, ByteOrder.nativeOrder());
- int equals = first.eq(SEPARATOR).firstTrue();
- int len = equals;
- for (int i = BYTE_SPECIES_SIZE; equals == BYTE_SPECIES_LANES; i += BYTE_SPECIES_SIZE) {
- var next = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start + i, ByteOrder.nativeOrder());
- equals = next.eq(SEPARATOR).firstTrue();
+ var separators = ByteVector.broadcast(BYTE_SPECIES, SEPARATOR);
+ int len = 0;
+ for (int i = 0;; i += BYTE_SPECIES_LANES) {
+ var block = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, offset + i, NATIVE_ORDER);
+ int equals = block.compare(VectorOperators.EQ, separators).firstTrue();
len += equals;
+ if (equals != BYTE_SPECIES_LANES) {
+ break;
+ }
}
+ var start = this.offset;
this.offset = start + len + 1;
- int index = hash(this.segment, start, len);
+ int hash = hash(segment, start, len);
+ int index = (hash - (hash >>> -ENTRIES_LOG2_CAPACITY)) & ENTRIES_MASK;
+ int keyOffset = index << KEY_LOG2_BYTES;
int count = 0;
- while (count < BUCKET_MASK) {
- index = index & BUCKET_MASK;
+ while (count < ENTRIES_MASK) {
+ index = index & ENTRIES_MASK;
+ keyOffset = keyOffset & KEYS_MASK;
var e = this.entries[index];
if (e == null) {
- return this.entries[index] = new Entry(len, this.segment.asSlice(start, KEY_BYTES));
+ MemorySegment.copy(this.segment, start, kms, keyOffset, len);
+ return this.entries[index] = new Entry(len, hash);
}
- else if (e.keyLength() == len && vectorizedEquals(e, first, start, len)) {
- return e;
+ else if (e.hash == hash && e.keyLength == len) {
+ int total = 0;
+ for (int i = 0; i < KEY_BYTES; i += BYTE_SPECIES_LANES) {
+ var ekey = ByteVector.fromArray(BYTE_SPECIES, keys, keyOffset + i);
+ var okey = ByteVector.fromMemorySegment(BYTE_SPECIES, this.segment, start + i, NATIVE_ORDER);
+ int equals = ekey.compare(VectorOperators.NE, okey).firstTrue();
+ total += equals;
+ if (equals != BYTE_SPECIES_LANES) {
+ break;
+ }
+ }
+ if (total >= len) {
+ return e;
+ }
}
- index++;
count++;
+ index++;
+ keyOffset += KEY_BYTES;
}
throw new IllegalStateException("Map is full!");
}
@@ -173,19 +209,33 @@ public class CalculateAverage_jparera {
next();
}
expect(SEPARATOR);
- int index = hash(segment, start, len);
+ int hash = hash(segment, start, len);
+ int index = (hash - (hash >>> -ENTRIES_LOG2_CAPACITY)) & ENTRIES_MASK;
+ int keyOffset = index << KEY_LOG2_BYTES;
int count = 0;
- while (count < BUCKET_MASK) {
- index = index & BUCKET_MASK;
+ while (count < ENTRIES_MASK) {
+ index = index & ENTRIES_MASK;
+ keyOffset = keyOffset & KEYS_MASK;
var e = this.entries[index];
if (e == null) {
- return this.entries[index] = new Entry(len, this.segment.asSlice(start, len));
+ MemorySegment.copy(this.segment, start, kms, keyOffset, len);
+ return this.entries[index] = new Entry(len, hash);
}
- else if (e.keyLength() == len && equals(e, start, len)) {
- return e;
+ else if (e.hash == hash && e.keyLength == len) {
+ int total = 0;
+ for (int i = 0; i < len; i++) {
+ if (((byte) BYTE_HANDLE.get(this.segment, start + i)) != this.keys[keyOffset + i]) {
+ break;
+ }
+ total++;
+ }
+ if (total >= len) {
+ return e;
+ }
}
- index++;
count++;
+ index++;
+ keyOffset += KEY_BYTES;
}
throw new IllegalStateException("Map is full!");
}
@@ -193,9 +243,9 @@ public class CalculateAverage_jparera {
private static final long MULTIPLY_ADD_DIGITS = 100 * (1L << 24) + 10 * (1L << 16) + 1;
private int vectorizedValue() {
- long dw = this.segment.get(LONG_U_LE, this.offset);
- boolean negative = ((dw & 0xFF) ^ MINUS) == 0;
+ long dw = (long) LONG_LE_HANDLE.get(this.segment, this.offset);
int zeros = Long.numberOfTrailingZeros(~dw & 0x10101000L);
+ boolean negative = ((dw & 0xFF) ^ NEG) == 0;
dw = ((negative ? (dw & ~0xFF) : dw) << (28 - zeros)) & 0x0F000F0F00L;
int value = (int) (((dw * MULTIPLY_ADD_DIGITS) >>> 32) & 0x3FF);
this.offset += (zeros >>> 3) + 3;
@@ -205,7 +255,7 @@ public class CalculateAverage_jparera {
private int value() {
int value = 0;
var negative = false;
- if (consume(MINUS)) {
+ if (consume(NEG)) {
negative = true;
}
while (hasCurrent()) {
@@ -224,41 +274,18 @@ public class CalculateAverage_jparera {
return negative ? -value : value;
}
- private boolean vectorizedEquals(Entry entry, ByteVector okey, long offset, int len) {
- var ekey = ByteVector.fromMemorySegment(BYTE_SPECIES, entry.segment(), 0, ByteOrder.nativeOrder());
- int equals = ekey.eq(okey).not().firstTrue();
- if (equals != BYTE_SPECIES_LANES) {
- return equals >= len;
- }
- long eo = BYTE_SPECIES_SIZE;
- int total = BYTE_SPECIES_LANES;
- while (equals == BYTE_SPECIES_LANES & eo < KEY_BYTES) {
- offset += BYTE_SPECIES_SIZE;
- ekey = ByteVector.fromMemorySegment(BYTE_SPECIES, entry.segment(), eo, ByteOrder.nativeOrder());
- okey = ByteVector.fromMemorySegment(BYTE_SPECIES, segment, offset, ByteOrder.nativeOrder());
- equals = ekey.eq(okey).not().firstTrue();
- total += equals;
- eo += BYTE_SPECIES_SIZE;
- }
- return total >= len;
- }
-
- private boolean equals(Entry entry, long offset, int len) {
- return MemorySegment.mismatch(this.segment, offset, offset + len, entry.segment(), 0, len) == -1;
- }
-
private static final int GOLDEN_RATIO = 0x9E3779B9;
private static final int HASH_LROTATE = 5;
private static int hash(MemorySegment ms, long start, int len) {
int x, y;
if (len >= Integer.BYTES) {
- x = ms.get(ValueLayout.JAVA_INT_UNALIGNED, start);
- y = ms.get(ValueLayout.JAVA_INT_UNALIGNED, start + len - Integer.BYTES);
+ x = (int) INT_HANDLE.get(ms, start);
+ y = (int) INT_HANDLE.get(ms, start + len - Integer.BYTES);
}
else {
- x = ms.get(ValueLayout.JAVA_BYTE, start);
- y = ms.get(ValueLayout.JAVA_BYTE, start + len - Byte.BYTES);
+ x = (byte) BYTE_HANDLE.get(ms, start) & 0xFF;
+ y = (byte) BYTE_HANDLE.get(ms, start + len - Byte.BYTES) & 0xFF;
}
return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO;
}
@@ -282,8 +309,8 @@ public class CalculateAverage_jparera {
}
private void next() {
- if (offset < segment.byteSize()) {
- this.current = segment.get(ValueLayout.JAVA_BYTE, offset++);
+ if (offset < size) {
+ this.current = (byte) BYTE_HANDLE.get(segment, offset++);
}
else {
this.hasCurrent = false;
@@ -292,9 +319,9 @@ public class CalculateAverage_jparera {
}
private static final class Entry {
- private final int keyLength;
+ final int keyLength;
- private final MemorySegment segment;
+ final int hash;
private int min = Integer.MAX_VALUE;
@@ -304,21 +331,19 @@ public class CalculateAverage_jparera {
private int count;
- Entry(int keyLength, MemorySegment segment) {
- this.keyLength = keyLength;
- this.segment = segment;
- }
+ private String key;
- int keyLength() {
- return keyLength;
+ Entry(int keyLength, int hash) {
+ this.keyLength = keyLength;
+ this.hash = hash;
}
- MemorySegment segment() {
- return segment;
+ public String key() {
+ return key;
}
- public String key() {
- return new String(segment.asSlice(0, keyLength).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8);
+ void setkey(byte[] keys, int offset) {
+ this.key = new String(keys, offset, keyLength, StandardCharsets.UTF_8);
}
public void add(int value) {
@@ -339,13 +364,19 @@ public class CalculateAverage_jparera {
@Override
public String toString() {
var average = Math.round(((sum / 10.0) / count) * 10.0);
- return decimal(min) + "/" + decimal(average) + "/" + decimal(max);
+ return decimal(min) + '/' + decimal(average) + '/' + decimal(max);
}
private static String decimal(long value) {
- boolean negative = value < 0;
+ var builder = new StringBuilder();
+ if (value < 0) {
+ builder.append((char) NEG);
+ }
value = Math.abs(value);
- return (negative ? "-" : "") + (value / 10) + "." + (value % 10);
+ builder.append(value / 10);
+ builder.append((char) DECIMAL_SEPARATOR);
+ builder.append(value % 10);
+ return builder.toString();
}
}
}