aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorStefan Sprenger <stefan@datacater.io>2024-01-14 19:06:01 +0100
committerGitHub <noreply@github.com>2024-01-14 19:06:01 +0100
commit3fbc4a2fa89e199ab1289cd8ac5d496009120c82 (patch)
tree63326446408d8bda81a5212a18a8405e9c6005e0 /src/main
parentfc6fca43152b82acfbe927602c92abc3bcda1dd6 (diff)
Update submission (#385)
* feat(flippingbits): Improve parsing of station names * chore(flippingbits): Remove obsolete import * feat(flippingbits): Use custom hash map * feat(flippingbits): Use UNSAFE * fix(flippingbits): Support very small files * chore(flippingbits): Few cleanups * chore(flippingbits): Align names * fix(flippingbits): Initialize hash with first byte * fix(flippingbits): Fix initialization of hash value
Diffstat (limited to 'src/main')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java290
1 files changed, 201 insertions, 89 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java
index 2510d85..3489877 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java
@@ -18,8 +18,13 @@ package dev.morling.onebrc;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorOperators;
+import sun.misc.Unsafe;
+import java.lang.foreign.Arena;
+import java.lang.reflect.Field;
+
import java.io.IOException;
import java.io.RandomAccessFile;
+import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.*;
@@ -34,14 +39,31 @@ public class CalculateAverage_flippingbits {
private static final String FILE = "./measurements.txt";
- private static final long CHUNK_SIZE = 10 * 1024 * 1024; // 10 MB
+ private static final long MINIMUM_FILE_SIZE_PARTITIONING = 10 * 1024 * 1024; // 10 MB
private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length();
- private static final int MAX_STATION_NAME_LENGTH = 100;
+ private static final int NUM_STATIONS = 10_000;
+
+ private static final int HASH_MAP_OFFSET_CAPACITY = 200_000;
+
+ private static final Unsafe UNSAFE = initUnsafe();
+
+ private static int HASH_PRIME_NUMBER = 31;
+
+ private static Unsafe initUnsafe() {
+ try {
+ Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ theUnsafe.setAccessible(true);
+ return (Unsafe) theUnsafe.get(Unsafe.class);
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
public static void main(String[] args) throws IOException {
- var result = Arrays.asList(getSegments()).stream()
+ var result = Arrays.asList(getSegments()).parallelStream()
.map(segment -> {
try {
return processSegment(segment[0], segment[1]);
@@ -50,126 +72,137 @@ public class CalculateAverage_flippingbits {
throw new RuntimeException(e);
}
})
- .parallel()
- .reduce((firstMap, secondMap) -> {
- for (var entry : secondMap.entrySet()) {
- PartitionAggregate firstAggregate = firstMap.get(entry.getKey());
- if (firstAggregate == null) {
- firstMap.put(entry.getKey(), entry.getValue());
- }
- else {
- firstAggregate.mergeWith(entry.getValue());
- }
- }
- return firstMap;
- })
- .map(TreeMap::new).get();
+ .reduce(FasterHashMap::mergeWith)
+ .get();
+
+ var sortedMap = new TreeMap<String, Station>();
+ for (Station station : result.getEntries()) {
+ sortedMap.put(station.getName(), station);
+ }
- System.out.println(result);
+ System.out.println(sortedMap);
}
private static long[][] getSegments() throws IOException {
try (var file = new RandomAccessFile(FILE, "r")) {
- var fileSize = file.length();
+ var channel = file.getChannel();
+
+ var fileSize = channel.size();
+ var startAddress = channel
+ .map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global())
+ .address();
+
// Split file into segments, so we can work around the size limitation of channels
- var numSegments = (int) (fileSize / CHUNK_SIZE);
+ var numSegments = (fileSize > MINIMUM_FILE_SIZE_PARTITIONING)
+ ? Runtime.getRuntime().availableProcessors()
+ : 1;
+ var segmentSize = fileSize / numSegments;
- var boundaries = new long[numSegments + 1][2];
- var endPointer = 0L;
+ var boundaries = new long[numSegments][2];
+ var endPointer = startAddress;
- for (var i = 0; i < numSegments; i++) {
+ for (var i = 0; i < numSegments - 1; i++) {
// Start of segment
- boundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize);
-
- // Seek end of segment, limited by the end of the file
- file.seek(Math.min(boundaries[i][0] + CHUNK_SIZE - 1, fileSize));
+ boundaries[i][0] = endPointer;
// Extend segment until end of line or file
- while (file.read() != '\n') {
+ endPointer = endPointer + segmentSize;
+ while (UNSAFE.getByte(endPointer) != '\n') {
+ endPointer++;
}
// End of segment
- endPointer = file.getFilePointer();
- boundaries[i][1] = endPointer;
+ boundaries[i][1] = endPointer++;
}
- boundaries[numSegments][0] = Math.max(endPointer, numSegments * CHUNK_SIZE);
- boundaries[numSegments][1] = fileSize;
+ boundaries[numSegments - 1][0] = endPointer;
+ boundaries[numSegments - 1][1] = startAddress + fileSize;
return boundaries;
}
}
- private static Map<String, PartitionAggregate> processSegment(long startOfSegment, long endOfSegment)
- throws IOException {
- Map<String, PartitionAggregate> stationAggregates = new HashMap<>(50_000);
- var byteChunk = new byte[(int) (endOfSegment - startOfSegment)];
- var stationBuffer = new byte[MAX_STATION_NAME_LENGTH];
- try (var file = new RandomAccessFile(FILE, "r")) {
- file.seek(startOfSegment);
- file.read(byteChunk);
- var i = 0;
- while (i < byteChunk.length) {
- // Station name has at least one byte
- stationBuffer[0] = byteChunk[i];
- i++;
- // Read station name
- var j = 1;
- while (byteChunk[i] != ';') {
- stationBuffer[j] = byteChunk[i];
- j++;
- i++;
- }
- var station = new String(stationBuffer, 0, j, StandardCharsets.UTF_8);
+ private static FasterHashMap processSegment(long startOfSegment, long endOfSegment) throws IOException {
+ var fasterHashMap = new FasterHashMap();
+ for (var i = startOfSegment; i < endOfSegment; i += 3) {
+ // Read station name
+ int nameHash = UNSAFE.getByte(i);
+ final var nameStartAddress = i++;
+ var character = UNSAFE.getByte(i);
+ while (character != ';') {
+ nameHash = nameHash * HASH_PRIME_NUMBER + character;
i++;
+ character = UNSAFE.getByte(i);
+ }
+ var nameLength = (int) (i - nameStartAddress);
+ i++;
- // Read measurement
- var isNegative = byteChunk[i] == '-';
- var measurement = 0;
- if (isNegative) {
+ // Read measurement
+ var isNegative = UNSAFE.getByte(i) == '-';
+ var measurement = 0;
+ if (isNegative) {
+ i++;
+ character = UNSAFE.getByte(i);
+ while (character != '.') {
+ measurement = measurement * 10 + character - '0';
i++;
- while (byteChunk[i] != '.') {
- measurement = measurement * 10 + byteChunk[i] - '0';
- i++;
- }
- measurement = (measurement * 10 + byteChunk[i + 1] - '0') * -1;
+ character = UNSAFE.getByte(i);
}
- else {
- while (byteChunk[i] != '.') {
- measurement = measurement * 10 + byteChunk[i] - '0';
- i++;
- }
- measurement = measurement * 10 + byteChunk[i + 1] - '0';
+ measurement = (measurement * 10 + UNSAFE.getByte(i + 1) - '0') * -1;
+ }
+ else {
+ character = UNSAFE.getByte(i);
+ while (character != '.') {
+ measurement = measurement * 10 + character - '0';
+ i++;
+ character = UNSAFE.getByte(i);
}
-
- // Update aggregate
- var aggregate = stationAggregates.computeIfAbsent(station, x -> new PartitionAggregate());
- aggregate.addMeasurementAndComputeAggregate((short) measurement);
- i += 3;
+ measurement = measurement * 10 + UNSAFE.getByte(i + 1) - '0';
}
- stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements);
+
+ fasterHashMap.addEntry(nameHash, nameLength, nameStartAddress, (short) measurement);
+ }
+
+ for (Station station : fasterHashMap.getEntries()) {
+ station.aggregateRemainingMeasurements();
}
- return stationAggregates;
+ return fasterHashMap;
}
- private static class PartitionAggregate {
- final short[] doubleLane = new short[SIMD_LANE_LENGTH * 2];
+ private static class Station {
+ final short[] measurements = new short[SIMD_LANE_LENGTH * 2];
// Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition
- int count = 0;
+ int count = 1;
long sum = 0;
short min = Short.MAX_VALUE;
short max = Short.MIN_VALUE;
+ final long nameAddress;
+ final int nameLength;
+ final int nameHash;
+
+ public Station(int nameHash, int nameLength, long nameAddress, short measurement) {
+ this.nameHash = nameHash;
+ this.nameLength = nameLength;
+ this.nameAddress = nameAddress;
+ measurements[0] = measurement;
+ }
+
+ public String getName() {
+ byte[] name = new byte[nameLength];
+ UNSAFE.copyMemory(null, nameAddress, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength);
+ return new String(name, StandardCharsets.UTF_8);
+ }
public void addMeasurementAndComputeAggregate(short measurement) {
// Add measurement to buffer, which is later processed by SIMD instructions
- doubleLane[count % doubleLane.length] = measurement;
+ measurements[count % measurements.length] = measurement;
count++;
// Once lane is full, use SIMD instructions to calculate aggregates
- if (count % doubleLane.length == 0) {
- var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, 0);
- var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, doubleLane, SIMD_LANE_LENGTH);
+ if (count % measurements.length == 0) {
+ var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, 0);
+ var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, measurements, SIMD_LANE_LENGTH);
var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN);
min = (short) Math.min(min, simdMin);
@@ -182,19 +215,35 @@ public class CalculateAverage_flippingbits {
}
public void aggregateRemainingMeasurements() {
- for (var i = 0; i < count % doubleLane.length; i++) {
- var measurement = doubleLane[i];
+ for (var i = 0; i < count % measurements.length; i++) {
+ var measurement = measurements[i];
min = (short) Math.min(min, measurement);
max = (short) Math.max(max, measurement);
sum += measurement;
}
}
- public void mergeWith(PartitionAggregate otherAggregate) {
- min = (short) Math.min(min, otherAggregate.min);
- max = (short) Math.max(max, otherAggregate.max);
- count = count + otherAggregate.count;
- sum = sum + otherAggregate.sum;
+ public void mergeWith(Station otherStation) {
+ min = (short) Math.min(min, otherStation.min);
+ max = (short) Math.max(max, otherStation.max);
+ count = count + otherStation.count;
+ sum = sum + otherStation.sum;
+ }
+
+ public boolean nameEquals(long otherNameAddress) {
+ var swarLimit = (nameLength / Long.BYTES) * Long.BYTES;
+ var i = 0;
+ for (; i < swarLimit; i += Long.BYTES) {
+ if (UNSAFE.getLong(nameAddress + i) != UNSAFE.getLong(otherNameAddress + i)) {
+ return false;
+ }
+ }
+ for (; i < nameLength; i++) {
+ if (UNSAFE.getByte(nameAddress + i) != UNSAFE.getByte(otherNameAddress + i)) {
+ return false;
+ }
+ }
+ return true;
}
public String toString() {
@@ -206,4 +255,67 @@ public class CalculateAverage_flippingbits {
(max / 10.0));
}
}
+
+ /**
+ * Use two arrays for implementing the hash map:
+ * - The array `entries` holds the map values, in our case instances of the class Station.
+ * - The array `offsets` maps hashes of the keys to indexes in the `entries` array.
+ *
+ * We create `offsets` with a much larger capacity than `entries`, so we minimize collisions.
+ */
+ private static class FasterHashMap {
+ // Using 16-bit integers (shorts) for offsets supports up to 2^15 (=32,767) entries
+ // If you need to store more entries, consider replacing short with int
+ short[] offsets = new short[HASH_MAP_OFFSET_CAPACITY];
+ Station[] entries = new Station[NUM_STATIONS + 1];
+ int slotsInUse = 0;
+
+ private int getOffsetIdx(int nameHash, int nameLength, long nameAddress) {
+ var offsetIdx = nameHash & (offsets.length - 1);
+ var offset = offsets[offsetIdx];
+
+ while (offset != 0 &&
+ (nameLength != entries[offset].nameLength || !entries[offset].nameEquals(nameAddress))) {
+ offsetIdx = (offsetIdx + 1) % offsets.length;
+ offset = offsets[offsetIdx];
+ }
+
+ return offsetIdx;
+ }
+
+ public void addEntry(int nameHash, int nameLength, long nameAddress, short measurement) {
+ var offsetIdx = getOffsetIdx(nameHash, nameLength, nameAddress);
+ var offset = offsets[offsetIdx];
+
+ if (offset == 0) {
+ slotsInUse++;
+ entries[slotsInUse] = new Station(nameHash, nameLength, nameAddress, measurement);
+ offsets[offsetIdx] = (short) slotsInUse;
+ }
+ else {
+ entries[offset].addMeasurementAndComputeAggregate(measurement);
+ }
+ }
+
+ public FasterHashMap mergeWith(FasterHashMap otherMap) {
+ for (Station station : otherMap.getEntries()) {
+ var offsetIdx = getOffsetIdx(station.nameHash, station.nameLength, station.nameAddress);
+ var offset = offsets[offsetIdx];
+
+ if (offset == 0) {
+ slotsInUse++;
+ entries[slotsInUse] = station;
+ offsets[offsetIdx] = (short) slotsInUse;
+ }
+ else {
+ entries[offset].mergeWith(station);
+ }
+ }
+ return this;
+ }
+
+ public List<Station> getEntries() {
+ return Arrays.asList(entries).subList(1, slotsInUse + 1);
+ }
+ }
}