aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
authorMarkus Ebner <seijikun@users.noreply.github.com>2024-01-05 19:35:15 +0100
committerGitHub <noreply@github.com>2024-01-05 19:35:15 +0100
commit36dac255cf2b36811c5fc9b6fc9ca37e17bc34b6 (patch)
tree117df00ea142eac4230ac048313be871f68e198f /src/main/java/dev/morling/onebrc
parente3f6c3aaf795cab2f0818c96bd46c573a2e14cf1 (diff)
Update seijikun implementation
* Use Integer calculation instead of double, add unit-test * Bring back StationIdent optimization Originally, StationIdent was using byte[] to store names, so the extra String allocation could be avoided. However, that produced incorrect sorting. Sorting is now moved to the result merging step. Here, names are converted to Strings. * Implement readStationName with SIMD 256bit * Rebase and cleanup test code, now that it's in the project * Fix seijikun formatting * Fix test failure in specific jobCnt edge-cases * Also switch to graalvm
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java221
1 files changed, 145 insertions, 76 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java
index bdea518..c5678b1 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_seijikun.java
@@ -15,9 +15,18 @@
*/
package dev.morling.onebrc;
-import java.io.*;
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.VectorOperators;
+
+import java.io.IOException;
+import java.io.PrintStream;
+import java.io.RandomAccessFile;
+import java.lang.foreign.MemorySegment;
+import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
+import java.util.Arrays;
+import java.util.HashMap;
import java.util.TreeMap;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@@ -27,24 +36,36 @@ public class CalculateAverage_seijikun {
private static final String FILE = "./measurements.txt";
private static class MeasurementAggregator {
- private double min = Double.POSITIVE_INFINITY;
- private double max = Double.NEGATIVE_INFINITY;
- private double sum;
- private long count;
+ private int min = Integer.MAX_VALUE;
+ private int max = Integer.MIN_VALUE;
+ // final long startTs = System.currentTimeMillis();
+ private long sum = 0;
+ private long count = 0;
+
+ private double mean = 0;
+
+ public void finish() {
+ double sum = this.sum / 10.0;
+ mean = sum / (double) count;
+ }
public void printInto(PrintStream out) {
- out.printf("%.1f/%.1f/%.1f", min, (sum / (double) count), max);
+ double min = (double) this.min / 10.0;
+ double max = (double) this.max / 10.0;
+ out.printf("%.1f/%.1f/%.1f", min, mean, max);
}
}
- public static class StationIdent implements Comparable<StationIdent> {
- private final int nameLength;
- private final String name;
+ public static class StationIdent {
+ private final byte[] name;
private final int nameHash;
public StationIdent(byte[] name, int nameHash) {
- this.nameLength = name.length;
- this.name = new String(name);
+ this.name = name;
+ // TODO: DEBUG
+ // if(Arrays.asList(this.name).contains(';')) {
+ // throw new RuntimeException();
+ // }
this.nameHash = nameHash;
}
@@ -56,15 +77,10 @@ public class CalculateAverage_seijikun {
@Override
public boolean equals(Object obj) {
var other = (StationIdent) obj;
- if (other.nameLength != nameLength) {
+ if (other.name.length != name.length) {
return false;
}
- return name.equals(other.name);
- }
-
- @Override
- public int compareTo(StationIdent o) {
- return name.compareTo(o.name);
+ return Arrays.equals(name, other.name);
}
}
@@ -77,9 +93,11 @@ public class CalculateAverage_seijikun {
private final long endOffset;
// state
+ private int chunkSize = 0;
private MappedByteBuffer buffer = null;
+ private MemorySegment memorySegment = null;
private int ptr = 0;
- private TreeMap<StationIdent, MeasurementAggregator> workSet;
+ private HashMap<StationIdent, MeasurementAggregator> workSet;
public ChunkReader(RandomAccessFile file, long startOffset, long endOffset) {
this.file = file;
@@ -87,36 +105,67 @@ public class CalculateAverage_seijikun {
this.endOffset = endOffset;
}
+ // private StationIdent readStationName() {
+ // int startPtr = ptr;
+ // int hashCode = 0;
+ // int hashBytePtr = 0;
+ // byte c;
+ // while ((c = buffer.get(ptr++)) != ';') {
+ // hashCode ^= ((int) c) << (hashBytePtr * 8);
+ // hashBytePtr = (hashBytePtr + 1) % 4;
+ // }
+ // byte[] stationNameBfr = new byte[ptr - startPtr - 1];
+ // buffer.get(startPtr, stationNameBfr);
+ // return new StationIdent(stationNameBfr, hashCode);
+ // }
+
private StationIdent readStationName() {
- int startPtr = ptr;
- int hashCode = 0;
- int hashBytePtr = 0;
- byte c;
- while ((c = buffer.get(ptr++)) != ';') {
- hashCode ^= ((int) c) << (hashBytePtr * 8);
- hashBytePtr = (hashBytePtr + 1) % 4;
+ final var VECTOR_SPECIES = ByteVector.SPECIES_256;
+
+ if (chunkSize - ptr < VECTOR_SPECIES.length()) { // fallback
+ int startPtr = ptr;
+ while (buffer.get(ptr++) != ';') {
+ }
+ byte[] stationNameBfr = new byte[ptr - startPtr - 1];
+ buffer.get(startPtr, stationNameBfr);
+ return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length);
+ }
+ else { // SIMD
+ int sepIdx = 0;
+
+ while (true) {
+ ByteVector tmp = ByteVector.fromMemorySegment(VECTOR_SPECIES, memorySegment, ptr + sepIdx, ByteOrder.LITTLE_ENDIAN);
+ final var cmpResult = tmp.compare(VectorOperators.EQ, ';');
+ if (cmpResult.anyTrue()) {
+ sepIdx += cmpResult.firstTrue();
+ break;
+ }
+ else {
+ sepIdx += tmp.length();
+ }
+ }
+
+ int endPtr = ptr + sepIdx;
+ byte[] stationNameBfr = new byte[endPtr - ptr];
+ buffer.get(ptr, stationNameBfr);
+ ptr = endPtr + 1;
+ return new StationIdent(stationNameBfr, Arrays.hashCode(stationNameBfr) ^ stationNameBfr.length);
}
- byte[] stationNameBfr = new byte[ptr - startPtr - 1];
- buffer.get(startPtr, stationNameBfr);
- return new StationIdent(stationNameBfr, hashCode);
}
- private double readTemperature() {
- double ret = 0, div = 1;
+ private int readTemperature() {
+ int ret = 0;
byte c = buffer.get(ptr++);
- boolean neg = (c == '-');
- if (neg)
+ final boolean neg = (c == '-');
+ if (neg) {
c = buffer.get(ptr++);
+ }
do {
- ret = ret * 10 + c - '0';
- } while ((c = buffer.get(ptr++)) >= '0' && c <= '9');
-
- if (c == '.') {
- while ((c = buffer.get(ptr++)) != '\n') {
- ret += (c - '0') / (div *= 10);
+ if (c != '.') {
+ ret = ret * 10 + c - '0';
}
- }
+ } while ((c = buffer.get(ptr++)) != '\n');
if (neg)
return -ret;
@@ -125,14 +174,18 @@ public class CalculateAverage_seijikun {
@Override
public void run() {
- workSet = new TreeMap<>();
- int chunkSize = (int) (endOffset - startOffset);
+ workSet = new HashMap<>();
+ if (endOffset - startOffset > Integer.MAX_VALUE) {
+ throw new RuntimeException("Mapping a block larger than 2GB is not possible with Java! Welcome to 2024 :)");
+ }
+ chunkSize = (int) (endOffset - startOffset);
try {
buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startOffset, chunkSize);
+ memorySegment = MemorySegment.ofBuffer(buffer);
while (ptr < chunkSize) {
var station = readStationName();
- var temp = readTemperature();
+ int temp = readTemperature();
var stationWorkSet = workSet.get(station);
if (stationWorkSet == null) {
stationWorkSet = new MeasurementAggregator();
@@ -144,26 +197,42 @@ public class CalculateAverage_seijikun {
stationWorkSet.count += 1;
}
}
- catch (IOException e) {
+ catch (Throwable e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
}
- public static void main(String[] args) throws IOException, InterruptedException {
- RandomAccessFile file = new RandomAccessFile(FILE, "r");
+ private static void printWorkSet(TreeMap<String, MeasurementAggregator> result, PrintStream out) {
+ out.write('{');
+ final var iterator = result.entrySet().iterator();
+ while (iterator.hasNext()) {
+ var entry = iterator.next();
+ out.print(entry.getKey());
+ out.write('=');
+ entry.getValue().printInto(out);
+ if (iterator.hasNext()) {
+ out.print(", ");
+ }
+ }
+ out.println('}');
+ }
- int jobCnt = Runtime.getRuntime().availableProcessors();
+ private static int createChunks(final RandomAccessFile file, final ChunkReader[] chunks) throws IOException {
+ final long fileEndPtr = file.length();
+ final long chunkSize = Math.max(1, fileEndPtr / chunks.length);
- var chunks = new ChunkReader[jobCnt];
- long chunkSize = file.length() / jobCnt;
+ int jobCnt = 0;
long chunkStartPtr = 0;
- byte[] tmpBuffer = new byte[128];
- for (int i = 0; i < jobCnt; ++i) {
- long chunkEndPtr = chunkStartPtr + chunkSize;
- if (i != (jobCnt - 1)) { // align chunks to newlines
- file.seek(chunkEndPtr - 1);
+ final byte[] tmpBuffer = new byte[128];
+ while (chunkStartPtr < fileEndPtr) {
+ long chunkEndPtr = Math.min(chunkStartPtr + chunkSize, fileEndPtr);
+
+ // Seek into file at the calculated chunk end ptr, then extend it until the next
+ // new-line or EOF
+ if (chunkEndPtr < fileEndPtr) {
+ file.seek(Math.max(0, chunkEndPtr - 1));
file.read(tmpBuffer);
int offset = 0;
while (tmpBuffer[offset] != '\n') {
@@ -171,28 +240,38 @@ public class CalculateAverage_seijikun {
}
chunkEndPtr += offset;
}
- else { // last chunk ends at file end
- chunkEndPtr = file.length();
- }
- chunks[i] = new ChunkReader(file, chunkStartPtr, chunkEndPtr);
+
+ chunks[jobCnt] = new ChunkReader(file, chunkStartPtr, chunkEndPtr);
+ jobCnt += 1;
chunkStartPtr = chunkEndPtr;
}
+ return jobCnt;
+ }
- try (var executor = Executors.newFixedThreadPool(jobCnt)) {
+ public static void main(String[] args) throws IOException, InterruptedException {
+ final RandomAccessFile file = new RandomAccessFile(FILE, "r");
+
+ int jobCnt = Runtime.getRuntime().availableProcessors();
+
+ final var chunks = new ChunkReader[jobCnt];
+ jobCnt = createChunks(file, chunks);
+
+ try (final var executor = Executors.newFixedThreadPool(jobCnt)) {
for (int i = 0; i < jobCnt; ++i) {
executor.submit(chunks[i]);
}
executor.shutdown();
- var ignored = executor.awaitTermination(1, TimeUnit.DAYS);
+ final var ignored = executor.awaitTermination(1, TimeUnit.DAYS);
}
// merge chunks
- var result = chunks[0].workSet;
- for (int i = 1; i < jobCnt; ++i) {
+ final var result = new TreeMap<String, MeasurementAggregator>();
+ for (int i = 0; i < jobCnt; ++i) {
chunks[i].workSet.forEach((ident, otherStationWorkSet) -> {
- var stationWorkSet = result.get(ident);
+ final var identStr = new String(ident.name);
+ final var stationWorkSet = result.get(identStr);
if (stationWorkSet == null) {
- result.put(ident, otherStationWorkSet);
+ result.put(identStr, otherStationWorkSet);
}
else {
stationWorkSet.min = Math.min(stationWorkSet.min, otherStationWorkSet.min);
@@ -202,19 +281,9 @@ public class CalculateAverage_seijikun {
}
});
}
+ result.forEach((ignored, meas) -> meas.finish());
// print in required format
- System.out.write('{');
- var iterator = result.entrySet().iterator();
- while (iterator.hasNext()) {
- var entry = iterator.next();
- System.out.print(entry.getKey().name);
- System.out.write('=');
- entry.getValue().printInto(System.out);
- if (iterator.hasNext()) {
- System.out.print(", ");
- }
- }
- System.out.println('}');
+ printWorkSet(result, System.out);
}
}