aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorQuan Anh Mai <49088128+merykitty@users.noreply.github.com>2024-01-11 02:24:19 +0700
committerGitHub <noreply@github.com>2024-01-10 20:24:19 +0100
commit97b1f014ad0c2b0be1ed010886028543cb6a6060 (patch)
treebd4733999c4e932e5cf901a0c4f1c58135634cd2 /src
parent786a52034c659c09faeddbcdaeebd410b2ccc8b1 (diff)
merykitty's second attempt
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java329
1 files changed, 138 insertions, 191 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java b/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java
index e86ecee..1f5acf3 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java
@@ -25,8 +25,6 @@ import java.nio.channels.FileChannel.MapMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
-import java.util.Arrays;
-import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import jdk.incubator.vector.ByteVector;
@@ -35,13 +33,21 @@ import jdk.incubator.vector.VectorSpecies;
public class CalculateAverage_merykitty {
private static final String FILE = "./measurements.txt";
- private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED;
+ private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 32
+ ? ByteVector.SPECIES_256
+ : ByteVector.SPECIES_128;
private static final ValueLayout.OfLong JAVA_LONG_LT = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
private static final long KEY_MAX_SIZE = 100;
- private record ResultRow(double min, double mean, double max) {
+ private static class Aggregator {
+ private int keySize;
+ private long min = Integer.MAX_VALUE;
+ private long max = Integer.MIN_VALUE;
+ private long sum;
+ private long count;
+
public String toString() {
- return round(min) + "/" + round(mean) + "/" + round(max);
+ return round(min / 10.) + "/" + round(sum / (double) (10 * count)) + "/" + round(max / 10.);
}
private double round(double value) {
@@ -49,96 +55,100 @@ public class CalculateAverage_merykitty {
}
}
- private static class Aggregator {
- private long min = Integer.MAX_VALUE;
- private long max = Integer.MIN_VALUE;
- private long sum;
- private long count;
- }
-
// An open-address map that is specialized for this task
private static class PoorManMap {
- static final int R_LOAD_FACTOR = 2;
-
- private static class PoorManMapNode {
- byte[] data;
- long size;
- int hash;
- Aggregator aggr;
-
- PoorManMapNode(MemorySegment data, long offset, long size, int hash) {
- this.hash = hash;
- this.size = size;
- this.data = new byte[BYTE_SPECIES.vectorByteSize() + (int) KEY_MAX_SIZE];
- this.aggr = new Aggregator();
- MemorySegment.copy(data, offset, MemorySegment.ofArray(this.data), BYTE_SPECIES.vectorByteSize(), size);
- }
- }
- MemorySegment data;
- PoorManMapNode[] nodes;
- int size;
+ // 100-byte key + 4-byte hash + 4-byte size +
+ // 2-byte min + 2-byte max + 8-byte sum + 8-byte count
+ private static final int KEY_SIZE = 128;
+
+ // There is an assumption that map size <= 10000;
+ private static final int CAPACITY = 1 << 17;
+ private static final int BUCKET_MASK = CAPACITY - 1;
- PoorManMap(MemorySegment data) {
- this.data = data;
- this.nodes = new PoorManMapNode[1 << 10];
+ byte[] keyData;
+ Aggregator[] nodes;
+
+ PoorManMap() {
+ this.keyData = new byte[CAPACITY * KEY_SIZE];
+ this.nodes = new Aggregator[CAPACITY];
}
- Aggregator indexSimple(long offset, long size, int hash) {
- hash = rehash(hash);
- int bucketMask = nodes.length - 1;
- int bucket = hash & bucketMask;
- for (;; bucket = (bucket + 1) & bucketMask) {
- PoorManMapNode node = nodes[bucket];
+ void observe(Aggregator node, long value) {
+ node.min = Math.min(node.min, value);
+ node.max = Math.max(node.max, value);
+ node.sum += value;
+ node.count++;
+ }
+
+ Aggregator indexSimple(MemorySegment data, long offset, int size) {
+ int x;
+ int y;
+ if (size >= Integer.BYTES) {
+ x = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset);
+ y = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset + size - Integer.BYTES);
+ }
+ else {
+ x = data.get(ValueLayout.JAVA_BYTE, offset);
+ y = data.get(ValueLayout.JAVA_BYTE, offset + size - Byte.BYTES);
+ }
+ int hash = hash(x, y);
+ int bucket = hash & BUCKET_MASK;
+ for (;; bucket = (bucket + 1) & BUCKET_MASK) {
+ var node = this.nodes[bucket];
if (node == null) {
- this.size++;
- if (this.size * R_LOAD_FACTOR > nodes.length) {
- grow();
- bucketMask = nodes.length - 1;
- for (bucket = hash & bucketMask; nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) {
- }
- }
- node = new PoorManMapNode(this.data, offset, size, hash);
- nodes[bucket] = node;
- return node.aggr;
+ return insertInto(bucket, data, offset, size);
}
- else if (keyEqualScalar(node, offset, size, hash)) {
- return node.aggr;
+ else if (keyEqualScalar(bucket, data, offset, size)) {
+ return node;
}
}
}
- void grow() {
- var oldNodes = this.nodes;
- var newNodes = new PoorManMapNode[oldNodes.length * 2];
- int bucketMask = newNodes.length - 1;
- for (var node : oldNodes) {
+ Aggregator insertInto(int bucket, MemorySegment data, long offset, int size) {
+ var node = new Aggregator();
+ node.keySize = size;
+ this.nodes[bucket] = node;
+ MemorySegment.copy(data, offset, MemorySegment.ofArray(this.keyData), (long) bucket * KEY_SIZE, size);
+ return node;
+ }
+
+ void mergeInto(Map<String, Aggregator> target) {
+ for (int i = 0; i < CAPACITY; i++) {
+ var node = this.nodes[i];
if (node == null) {
continue;
}
- int bucket = node.hash & bucketMask;
- for (; newNodes[bucket] != null; bucket = (bucket + 1) & bucketMask) {
- }
- newNodes[bucket] = node;
+
+ String key = new String(this.keyData, i * KEY_SIZE, node.keySize, StandardCharsets.UTF_8);
+ target.compute(key, (k, v) -> {
+ if (v == null) {
+ v = new Aggregator();
+ }
+
+ v.min = Math.min(v.min, node.min);
+ v.max = Math.max(v.max, node.max);
+ v.sum += node.sum;
+ v.count += node.count;
+ return v;
+ });
}
- this.nodes = newNodes;
}
- static int rehash(int x) {
- x = ((x >>> 16) ^ x) * 0x45d9f3b;
- x = ((x >>> 16) ^ x) * 0x45d9f3b;
- x = (x >>> 16) ^ x;
- return x;
+ static int hash(int x, int y) {
+ int seed = 0x9E3779B9;
+ int rotate = 5;
+ return (Integer.rotateLeft(x * seed, rotate) ^ y) * seed; // FxHash
}
- private boolean keyEqualScalar(PoorManMapNode node, long offset, long size, int hash) {
- if (node.hash != hash || node.size != size) {
+ private boolean keyEqualScalar(int bucket, MemorySegment data, long offset, int size) {
+ if (this.nodes[bucket].keySize != size) {
return false;
}
// Be simple
for (int i = 0; i < size; i++) {
- int c1 = node.data[BYTE_SPECIES.vectorByteSize() + i];
+ int c1 = this.keyData[bucket * KEY_SIZE + i];
int c2 = data.get(ValueLayout.JAVA_BYTE, offset + i);
if (c1 != c2) {
return false;
@@ -152,7 +162,7 @@ public class CalculateAverage_merykitty {
// 1 - 2 digits to the left and 1 digits to the right of the separator to a
// fix-precision format. It returns the offset of the next line (presumably followed
// the final digit and a '\n')
- private static long parseDataPoint(Aggregator aggr, MemorySegment data, long offset) {
+ private static long parseDataPoint(PoorManMap aggrMap, Aggregator node, MemorySegment data, long offset) {
long word = data.get(JAVA_LONG_LT, offset);
// The 4th binary digit of the ascii of a digit is 1 while
// that of the '.' is 0. This finds the decimal separator
@@ -176,16 +186,13 @@ public class CalculateAverage_merykitty {
// That was close :)
long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
long value = (absValue ^ signed) - signed;
- aggr.min = Math.min(value, aggr.min);
- aggr.max = Math.max(value, aggr.max);
- aggr.sum += value;
- aggr.count++;
+ aggrMap.observe(node, value);
return offset + (decimalSepPos >>> 3) + 3;
}
// Tail processing version of the above, do not over-fetch and be simple
- private static long parseDataPointTail(Aggregator aggr, MemorySegment data, long offset) {
- int point = 0;
+ private static long parseDataPointSimple(PoorManMap aggrMap, Aggregator node, MemorySegment data, long offset) {
+ int value = 0;
boolean negative = false;
if (data.get(ValueLayout.JAVA_BYTE, offset) == '-') {
negative = true;
@@ -195,110 +202,80 @@ public class CalculateAverage_merykitty {
int c = data.get(ValueLayout.JAVA_BYTE, offset);
if (c == '.') {
c = data.get(ValueLayout.JAVA_BYTE, offset + 1);
- point = point * 10 + (c - '0');
+ value = value * 10 + (c - '0');
offset += 3;
break;
}
- point = point * 10 + (c - '0');
+ value = value * 10 + (c - '0');
}
- point = negative ? -point : point;
- aggr.min = Math.min(point, aggr.min);
- aggr.max = Math.max(point, aggr.max);
- aggr.sum += point;
- aggr.count++;
+ value = negative ? -value : value;
+ aggrMap.observe(node, value);
return offset;
}
- // An iteration of the main parse loop, parse some lines starting from offset.
- // This requires offset to be the start of a line and there is spare space so
+ // An iteration of the main parse loop, parse a line starting from offset.
+ // This requires offset to be the start of the line and there is spare space so
// that we have relative freedom in processing
- // It returns the offset of the next line that it needs to be processed
+ // It returns the offset of the next line that it needs processing
private static long iterate(PoorManMap aggrMap, MemorySegment data, long offset) {
- // This method fetches a segment of the file starting from offset and returns after
- // finishing processing that segment
var line = ByteVector.fromMemorySegment(BYTE_SPECIES, data, offset, ByteOrder.nativeOrder());
// Find the delimiter ';'
- long semicolons = line.compare(VectorOperators.EQ, ';').toLong();
+ int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
- // If we cannot find the delimiter in the current segment, that means the key is
- // longer than the segment, fall back to scalar processing
- if (semicolons == 0) {
- long semicolonPos = BYTE_SPECIES.vectorByteSize();
- for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) {
+ // If we cannot find the delimiter in the vector, that means the key is
+ // longer than the vector, fall back to scalar processing
+ if (keySize == BYTE_SPECIES.vectorByteSize()) {
+ while (data.get(ValueLayout.JAVA_BYTE, offset + keySize) != ';') {
+ keySize++;
}
- int hash = line.reinterpretAsInts().lane(0);
- var aggr = aggrMap.indexSimple(offset, semicolonPos, hash);
- return parseDataPoint(aggr, data, offset + 1 + semicolonPos);
+ var node = aggrMap.indexSimple(data, offset, keySize);
+ return parseDataPoint(aggrMap, node, data, offset + 1 + keySize);
}
- long currOffset = offset;
- while (true) {
- // Process line by line, currOffset is the offset of the current line in
- // the file, localOffset is the offset of the current line with respect
- // to the start of the iteration segment
- int localOffset = (int) (currOffset - offset);
-
- // The key length
- long semicolonPos = Long.numberOfTrailingZeros(semicolons) - localOffset;
- int hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, currOffset);
- if (semicolonPos < Integer.BYTES) {
- hash = (byte) hash;
+ // We inline the searching of the value in the hash map
+ int x;
+ int y;
+ if (keySize >= Integer.BYTES) {
+ x = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset);
+ y = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset + keySize - Integer.BYTES);
+ }
+ else {
+ x = data.get(ValueLayout.JAVA_BYTE, offset);
+ y = data.get(ValueLayout.JAVA_BYTE, offset + keySize - Byte.BYTES);
+ }
+ int hash = PoorManMap.hash(x, y);
+ int bucket = hash & PoorManMap.BUCKET_MASK;
+ Aggregator node;
+ for (;; bucket = (bucket + 1) & PoorManMap.BUCKET_MASK) {
+ node = aggrMap.nodes[bucket];
+ if (node == null) {
+ node = aggrMap.insertInto(bucket, data, offset, keySize);
+ break;
}
-
- // We inline the searching of the value in the hash map
- Aggregator aggr;
- hash = PoorManMap.rehash(hash);
- int bucketMask = aggrMap.nodes.length - 1;
- int bucket = hash & bucketMask;
- for (;; bucket = (bucket + 1) & bucketMask) {
- PoorManMap.PoorManMapNode node = aggrMap.nodes[bucket];
- if (node == null) {
- aggrMap.size++;
- if (aggrMap.size * PoorManMap.R_LOAD_FACTOR > aggrMap.nodes.length) {
- aggrMap.grow();
- bucketMask = aggrMap.nodes.length - 1;
- for (bucket = hash & bucketMask; aggrMap.nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) {
- }
- }
- node = new PoorManMap.PoorManMapNode(data, currOffset, semicolonPos, hash);
- aggrMap.nodes[bucket] = node;
- aggr = node.aggr;
- break;
- }
-
- if (node.hash != hash || node.size != semicolonPos) {
- continue;
- }
-
- // The technique here is to align the key in both vectors so that we can do an
- // element-wise comparison and check if all characters match
- var nodeKey = ByteVector.fromArray(BYTE_SPECIES, node.data, BYTE_SPECIES.length() - localOffset);
- var eqMask = line.compare(VectorOperators.EQ, nodeKey).toLong();
- long validMask = (-1L >>> -semicolonPos) << localOffset;
- if ((eqMask & validMask) == validMask) {
- aggr = node.aggr;
- break;
- }
+ if (node.keySize != keySize) {
+ continue;
}
- long nextOffset = parseDataPoint(aggr, data, currOffset + 1 + semicolonPos);
- semicolons &= (semicolons - 1);
- if (semicolons == 0) {
- return nextOffset;
+ var nodeKey = ByteVector.fromArray(BYTE_SPECIES, aggrMap.keyData, bucket * PoorManMap.KEY_SIZE);
+ long eqMask = line.compare(VectorOperators.EQ, nodeKey).toLong();
+ long validMask = -1L >>> -keySize;
+ if ((eqMask & validMask) == validMask) {
+ break;
}
- currOffset = nextOffset;
}
+
+ return parseDataPoint(aggrMap, node, data, offset + keySize + 1);
}
// Process all lines that start in [offset, limit)
private static PoorManMap processFile(MemorySegment data, long offset, long limit) {
- var aggrMap = new PoorManMap(data);
+ var aggrMap = new PoorManMap();
// Find the start of a new line
if (offset != 0) {
offset--;
- for (; offset < limit;) {
+ while (offset < limit) {
if (data.get(ValueLayout.JAVA_BYTE, offset++) == '\n') {
break;
}
@@ -318,18 +295,12 @@ public class CalculateAverage_merykitty {
// Now we are at the tail, just be simple
while (offset < limit) {
- long semicolonPos = 0;
- for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) {
- }
- int hash;
- if (semicolonPos >= Integer.BYTES) {
- hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset);
+ int keySize = 0;
+ while (data.get(ValueLayout.JAVA_BYTE, offset + keySize) != ';') {
+ keySize++;
}
- else {
- hash = data.get(ValueLayout.JAVA_BYTE, offset);
- }
- var aggr = aggrMap.indexSimple(offset, semicolonPos, hash);
- offset = parseDataPointTail(aggr, data, offset + 1 + semicolonPos);
+ var node = aggrMap.indexSimple(data, offset, keySize);
+ offset = parseDataPointSimple(aggrMap, node, data, offset + 1 + keySize);
}
return aggrMap;
@@ -337,7 +308,7 @@ public class CalculateAverage_merykitty {
public static void main(String[] args) throws InterruptedException, IOException {
int processorCnt = Runtime.getRuntime().availableProcessors();
- var res = HashMap.<String, Aggregator> newHashMap(processorCnt);
+ var res = new TreeMap<String, Aggregator>();
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
var arena = Arena.ofShared()) {
var data = file.map(MapMode.READ_ONLY, 0, file.size(), arena);
@@ -348,9 +319,7 @@ public class CalculateAverage_merykitty {
int index = i;
long offset = i * chunkSize;
long limit = Math.min((i + 1) * chunkSize, data.byteSize());
- var thread = new Thread(() -> {
- resultList[index] = processFile(data, offset, limit);
- });
+ var thread = new Thread(() -> resultList[index] = processFile(data, offset, limit));
threadList[index] = thread;
thread.start();
}
@@ -360,32 +329,10 @@ public class CalculateAverage_merykitty {
// Collect the results
for (var aggrMap : resultList) {
- for (var node : aggrMap.nodes) {
- if (node == null) {
- continue;
- }
- byte[] keyData = Arrays.copyOfRange(node.data, BYTE_SPECIES.vectorByteSize(), BYTE_SPECIES.vectorByteSize() + (int) node.size);
- String key = new String(keyData, StandardCharsets.UTF_8);
- var aggr = node.aggr;
- var resAggr = new Aggregator();
- var existingAggr = res.putIfAbsent(key, resAggr);
- if (existingAggr != null) {
- resAggr = existingAggr;
- }
- resAggr.min = Math.min(resAggr.min, aggr.min);
- resAggr.max = Math.max(resAggr.max, aggr.max);
- resAggr.sum += aggr.sum;
- resAggr.count += aggr.count;
- }
+ aggrMap.mergeInto(res);
}
}
- Map<String, ResultRow> measurements = new TreeMap<>();
- for (var entry : res.entrySet()) {
- String key = entry.getKey();
- var aggr = entry.getValue();
- measurements.put(key, new ResultRow((double) aggr.min / 10, (double) aggr.sum / (aggr.count * 10), (double) aggr.max / 10));
- }
- System.out.println(measurements);
+ System.out.println(res);
}
}