aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java151
1 files changed, 77 insertions, 74 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java b/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java
index 0a24b0a..ad57b50 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java
@@ -16,6 +16,8 @@
package dev.morling.onebrc;
import java.util.TreeMap;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
@@ -31,18 +33,15 @@ import jdk.incubator.vector.VectorSpecies;
import sun.misc.Unsafe;
public class CalculateAverage_yourwass {
-
static final class Record {
- public String city;
- public long cityAddr;
- public long cityLength;
- public int min;
- public int max;
- public int count;
- public long sum;
+ private long cityAddr;
+ private long cityLength;
+ private int min;
+ private int max;
+ private int count;
+ private long sum;
Record(final long cityAddr, final long cityLength) {
- this.city = null;
this.cityAddr = cityAddr;
this.cityLength = cityLength;
this.min = 1000;
@@ -62,6 +61,8 @@ public class CalculateAverage_yourwass {
}
}
+ private final static Lock _mutex = new ReentrantLock(true);
+ private final static TreeMap<String, Record> aggregateResults = new TreeMap<>();
private static short lookupDecimal[];
private static byte lookupFraction[];
private static byte lookupDotPositive[];
@@ -70,6 +71,8 @@ public class CalculateAverage_yourwass {
private static final VectorSpecies<Byte> SPECIES = ByteVector.SPECIES_PREFERRED;
private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p
private static final String FILE = "measurements.txt";
+ private static long unsafeResults;
+ private static int RECORDSIZE = 36;
private static final Unsafe UNSAFE = getUnsafe();
private static Unsafe getUnsafe() {
@@ -113,11 +116,9 @@ public class CalculateAverage_yourwass {
}
// open file
- final long fileSize, mmapAddr;
- try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
- fileSize = fileChannel.size();
- mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
- }
+ final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
+ final long fileSize = fileChannel.size();
+ final long mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
// VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file.
// If the mmaped MemorySegment is used for Vector creation as is, then there are two problems:
// 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic
@@ -127,36 +128,24 @@ public class CalculateAverage_yourwass {
// XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here.
VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length());
- // start and wait for threads to finish
+ // allocate memory for results
final int nThreads = Runtime.getRuntime().availableProcessors();
+ unsafeResults = UNSAFE.allocateMemory(RECORDSIZE * MAXINDEX * nThreads);
+ UNSAFE.setMemory(unsafeResults, RECORDSIZE * MAXINDEX * nThreads, (byte) 0);
+
+ // start and wait for threads to finish
Thread[] threadList = new Thread[nThreads];
- final Record[][] results = new Record[nThreads][];
final long chunkSize = fileSize / nThreads;
for (int i = 0; i < nThreads; i++) {
final int threadIndex = i;
final long startAddr = mmapAddr + i * chunkSize;
final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize;
- threadList[i] = new Thread(() -> results[threadIndex] = threadMain(threadIndex, startAddr, endAddr, nThreads));
+ threadList[i] = new Thread(() -> threadMain(threadIndex, startAddr, endAddr, nThreads));
threadList[i].start();
}
for (int i = 0; i < nThreads; i++)
threadList[i].join();
- // aggregate results and sort
- // TODO have to compare with concurrent-parallel stream structures:
- // * concurrent hashtable that have to sort afterwards
- // * concurrent skiplist that is sorted but has O(n) insert
- // * ..other?
- final TreeMap<String, Record> aggregateResults = new TreeMap<>();
- for (int thread = 0; thread < nThreads; thread++) {
- for (int index = 0; index < MAXINDEX; index++) {
- Record record = results[thread][index];
- if (record == null)
- continue;
- aggregateResults.compute(record.city, (k, v) -> (v == null) ? record : v.merge(record));
- }
- }
-
// prepare string and print
StringBuilder sb = new StringBuilder();
sb.append("{");
@@ -167,12 +156,13 @@ public class CalculateAverage_yourwass {
float max = record.max;
max /= 10.f;
double avg = Math.round((record.sum * 1.0) / record.count) / 10.;
- sb.append(record.city).append("=").append(min).append("/").append(avg).append("/").append(max).append(", ");
+ sb.append(entry.getKey()).append("=").append(min).append("/").append(avg).append("/").append(max).append(", ");
}
int stringLength = sb.length();
sb.setCharAt(stringLength - 2, '}');
sb.setCharAt(stringLength - 1, '\n');
System.out.print(sb.toString());
+ System.out.close();
}
private static final boolean citiesDiffer(final long a, final long b, final long len) {
@@ -185,7 +175,7 @@ public class CalculateAverage_yourwass {
return false;
}
- private static Record[] threadMain(int id, long startAddr, long endAddr, long nThreads) {
+ private static void threadMain(int id, long startAddr, long endAddr, long nThreads) {
// snap to newlines
if (id != 0)
while (UNSAFE.getByte(startAddr++) != '\n')
@@ -194,23 +184,24 @@ public class CalculateAverage_yourwass {
while (UNSAFE.getByte(endAddr++) != '\n')
;
+ final long threadResults = unsafeResults + id * MAXINDEX * RECORDSIZE;
final Record[] results = new Record[MAXINDEX];
final long VECTORBYTESIZE = SPECIES.length();
final ByteOrder BYTEORDER = ByteOrder.nativeOrder();
final ByteVector delim = ByteVector.broadcast(SPECIES, ';');
- long nextCityAddr = startAddr; // XXX from these three variables,
- long cityAddr = nextCityAddr; // only two are necessary, but if one
- long ptr = 0; // is eliminated, on my pc the benchmark gets worse..
- while (nextCityAddr < endAddr) {
+ long cityAddr = startAddr;
+ long ptr = 0;
+ while (cityAddr < endAddr) {
// parse city
- long mask = ByteVector.fromMemorySegment(SPECIES, VAS, nextCityAddr + ptr, BYTEORDER)
- .compare(VectorOperators.EQ, delim).toLong();
- if (mask == 0) {
+ ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER);
+ long mask = parsed.compare(VectorOperators.EQ, delim).toLong();
+ while (mask == 0) {
ptr += VECTORBYTESIZE;
- continue;
+ mask = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr + ptr, BYTEORDER).compare(VectorOperators.EQ, delim).toLong();
}
final long cityLength = ptr + Long.numberOfTrailingZeros(mask);
final long tempAddr = cityAddr + cityLength + 1;
+ ptr = 0;
// compute hash table index
int index;
@@ -222,67 +213,79 @@ public class CalculateAverage_yourwass {
& 0xFFFF;
else
index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00;
-
// resolve collisions with linear probing
// use vector api here also, but only if city name fits in one vector length, for faster default case
- Record record = results[index];
+ long record = threadResults + index * RECORDSIZE;
+ long recordCityLength = UNSAFE.getLong(record);
if (cityLength <= VECTORBYTESIZE) {
- ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER);
- while (record != null) {
- if (cityLength == record.cityLength) {
- long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, record.cityAddr, BYTEORDER)
+ while (recordCityLength > 0) {
+ if (cityLength == recordCityLength) {
+ long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, UNSAFE.getLong(record + 8), BYTEORDER)
.compare(VectorOperators.EQ, parsed).toLong();
if (Long.numberOfTrailingZeros(~sameMask) >= cityLength)
break;
}
- record = results[++index];
+ index++;
+ record = threadResults + index * RECORDSIZE;
+ recordCityLength = UNSAFE.getLong(record);
}
}
else { // slower normal case for city names with length > VECTORBYTESIZE
- while (record != null && (cityLength != record.cityLength || citiesDiffer(record.cityAddr, cityAddr, cityLength)))
- record = results[++index];
+ while (recordCityLength > 0 && (cityLength != recordCityLength || citiesDiffer(UNSAFE.getLong(record + 8), cityAddr, cityLength))) {
+ index++;
+ record = threadResults + index * RECORDSIZE;
+ recordCityLength = UNSAFE.getLong(record);
+ }
}
- // add record for new keys
- // TODO have to avoid memory allocations on hot path
- if (record == null) {
- results[index] = new Record(cityAddr, cityLength);
- record = results[index];
+ // add record for new key
+ if (recordCityLength == 0) {
+ UNSAFE.putLong(record, cityLength);
+ UNSAFE.putLong(record + 8, cityAddr);
+ UNSAFE.putInt(record + 16, 1000);
+ UNSAFE.putInt(record + 20, -1000);
}
// parse temp with lookup tables
int temp;
if (UNSAFE.getByte(tempAddr) == '-') {
temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)];
- nextCityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)];
+ cityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)];
}
else {
temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)];
- nextCityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)];
+ cityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)];
}
- cityAddr = nextCityAddr;
- ptr = 0;
- // merge record
- if (temp < record.min)
- record.min = temp;
- if (temp > record.max)
- record.max = temp;
- record.sum += temp;
- record.count += 1;
+ // merge
+ if (temp < UNSAFE.getInt(record + 16))
+ UNSAFE.putInt(record + 16, temp);
+ if (temp > UNSAFE.getInt(record + 20))
+ UNSAFE.putInt(record + 20, temp);
+ UNSAFE.putLong(record + 24, UNSAFE.getLong(record + 24) + temp);
+ UNSAFE.putInt(record + 32, UNSAFE.getInt(record + 32) + 1);
}
// create strings from raw data
- // TODO should avoid this copy
+ // and aggregate results onto TreeMap
+ int idx = 0;
byte b[] = new byte[100];
+ _mutex.lock();
for (int i = 0; i < MAXINDEX; i++) {
- Record r = results[i];
- if (r == null)
+ if (UNSAFE.getLong(threadResults + i * RECORDSIZE) == 0)
continue;
- UNSAFE.copyMemory(null, r.cityAddr, b, Unsafe.ARRAY_BYTE_BASE_OFFSET, r.cityLength);
- r.city = new String(b, 0, (int) r.cityLength, StandardCharsets.UTF_8);
+ final long recordAddress = threadResults + i * RECORDSIZE;
+
+ results[idx] = new Record(UNSAFE.getLong(recordAddress + 8), UNSAFE.getLong(recordAddress));
+ results[idx].min = UNSAFE.getInt(recordAddress + 16);
+ results[idx].max = UNSAFE.getInt(recordAddress + 20);
+ results[idx].sum = UNSAFE.getLong(recordAddress + 24);
+ results[idx].count = UNSAFE.getInt(recordAddress + 32);
+ UNSAFE.copyMemory(null, UNSAFE.getLong(recordAddress + 8), b, Unsafe.ARRAY_BYTE_BASE_OFFSET, UNSAFE.getLong(recordAddress));
+ final Record record = results[idx];
+ aggregateResults.compute(new String(b, 0, (int) results[idx].cityLength, StandardCharsets.UTF_8), (k, v) -> (v == null) ? record : v.merge(record));
+ idx++;
}
- return results;
+ _mutex.unlock();
}
-
}