aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling
diff options
context:
space:
mode:
authorStefan Sprenger <info@stefan-sprenger.com>2024-01-06 18:13:52 +0100
committerGitHub <noreply@github.com>2024-01-06 18:13:52 +0100
commit749d2d8f784fe63556f8aa284daa282f1410c471 (patch)
treea730feb56b47d78ee4e457cc46d61cd3c8882ab3 /src/main/java/dev/morling
parent1a9b1cb7daa2ff580b3aefac516a7cb8ef2087c7 (diff)
Add solution by flippingbits - Use SIMD for computing aggregates
* feat(flippingbits): First revision * chore(flippingbits): Clean up output * fix(flippingbits): Use ShortVector.SPECIES_MAX
Diffstat (limited to 'src/main/java/dev/morling')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java211
1 files changed, 211 insertions, 0 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java
new file mode 100644
index 0000000..33ca88c
--- /dev/null
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_flippingbits.java
@@ -0,0 +1,211 @@
+/*
+ * Copyright 2023 The original authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package dev.morling.onebrc;
+
+import jdk.incubator.vector.ShortVector;
+import jdk.incubator.vector.VectorOperators;
+
+import java.io.IOException;
+import java.io.RandomAccessFile;
+import java.util.*;
+
+/**
+ * Approach:
+ * - Use memory-mapped file to speed up loading data into memory
+ * - Partition data, compute aggregates for partitions in parallel, and finally combine results from all partitions
+ * - Apply SIMD instructions for computing min/max/sum aggregates
+ * - Use Shorts for storing aggregates of partitions, so we maximize the SIMD parallelism
+ */
+public class CalculateAverage_flippingbits {
+
+ private static final String FILE = "./measurements.txt";
+
+ private static final long CHUNK_SIZE = 100 * 1024 * 1024; // 100 MB
+
+ private static final int SIMD_LANE_LENGTH = ShortVector.SPECIES_MAX.length();
+
+ public static void main(String[] args) throws IOException {
+ try (var file = new RandomAccessFile(FILE, "r")) {
+ // Calculate chunk boundaries
+ long[][] chunkBoundaries = getChunkBoundaries(file);
+ // Process chunks
+ var result = Arrays.asList(chunkBoundaries).stream()
+ .map(chunk -> {
+ try {
+ return processChunk(chunk[0], chunk[1]);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ })
+ .parallel()
+ .reduce((firstMap, secondMap) -> {
+ for (var entry : secondMap.entrySet()) {
+ PartitionAggregate firstAggregate = firstMap.get(entry.getKey());
+ if (firstAggregate == null) {
+ firstMap.put(entry.getKey(), entry.getValue());
+ }
+ else {
+ firstAggregate.mergeWith(entry.getValue());
+ }
+ }
+ return firstMap;
+ })
+ .map(hashMap -> new TreeMap(hashMap)).get();
+
+ System.out.println(result);
+ }
+ }
+
+ private static long[][] getChunkBoundaries(RandomAccessFile file) throws IOException {
+ var fileSize = file.length();
+ // Split file into chunks, so we can work around the size limitation of channels
+ var chunks = (int) (fileSize / CHUNK_SIZE);
+
+ long[][] chunkBoundaries = new long[chunks + 1][2];
+ var endPointer = 0L;
+
+ for (var i = 0; i <= chunks; i++) {
+ // Start of chunk
+ chunkBoundaries[i][0] = Math.min(Math.max(endPointer, i * CHUNK_SIZE), fileSize);
+
+ // Seek end of chunk, limited by the end of the file
+ file.seek(Math.min(chunkBoundaries[i][0] + CHUNK_SIZE - 1, fileSize));
+
+ // Extend chunk until end of line or file
+ while (true) {
+ var character = file.read();
+ if (character == '\n' || character == -1) {
+ break;
+ }
+ }
+
+ // End of chunk
+ endPointer = file.getFilePointer();
+ chunkBoundaries[i][1] = endPointer;
+ }
+
+ return chunkBoundaries;
+ }
+
+ private static Map<String, PartitionAggregate> processChunk(long startOfChunk, long endOfChunk)
+ throws IOException {
+ Map<String, PartitionAggregate> stationAggregates = new HashMap<>(10_000);
+ byte[] byteChunk = new byte[(int) (endOfChunk - startOfChunk)];
+ try (var file = new RandomAccessFile(FILE, "r")) {
+ file.seek(startOfChunk);
+ file.read(byteChunk);
+ var i = 0;
+ while (i < byteChunk.length) {
+ final var startPosStation = i;
+
+ // read station name
+ while (byteChunk[i] != ';') {
+ i++;
+ }
+ var station = new String(Arrays.copyOfRange(byteChunk, startPosStation, i));
+ i++;
+
+ // read measurement
+ final var startPosMeasurement = i;
+ while (byteChunk[i] != '\n') {
+ i++;
+ }
+
+ var measurement = Arrays.copyOfRange(byteChunk, startPosMeasurement, i);
+ var aggregate = stationAggregates.getOrDefault(station, new PartitionAggregate());
+ aggregate.addMeasurementAndComputeAggregate(measurement);
+ stationAggregates.put(station, aggregate);
+ i++;
+ }
+ stationAggregates.values().forEach(PartitionAggregate::aggregateRemainingMeasurements);
+ }
+
+ return stationAggregates;
+ }
+
+ private static class PartitionAggregate {
+ final short[] lane = new short[SIMD_LANE_LENGTH * 2];
+ // Assume that we do not have more than Integer.MAX_VALUE measurements for the same station per partition
+ int count = 0;
+ long sum = 0;
+ short min = Short.MAX_VALUE;
+ short max = Short.MIN_VALUE;
+
+ public void addMeasurementAndComputeAggregate(byte[] measurementBytes) {
+ // Parse measurement and exploit that we know the format of the floating-point values
+ var measurement = measurementBytes[measurementBytes.length - 1] - '0';
+ var digits = 1;
+ for (var i = measurementBytes.length - 3; i > 0; i--) {
+ var num = measurementBytes[i] - '0';
+ measurement = measurement + (num * (int) Math.pow(10, digits));
+ digits++;
+ }
+
+ // Check if measurement is negative
+ if (measurementBytes[0] == '-') {
+ measurement = measurement * -1;
+ }
+ else {
+ var num = measurementBytes[0] - '0';
+ measurement = measurement + (num * (int) Math.pow(10, digits));
+ }
+
+ // Add measurement to buffer, which is later processed by SIMD instructions
+ lane[count % lane.length] = (short) measurement;
+ count++;
+
+ // Once lane is full, use SIMD instructions to calculate aggregates
+ if (count % lane.length == 0) {
+ var firstVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, 0);
+ var secondVector = ShortVector.fromArray(ShortVector.SPECIES_MAX, lane, SIMD_LANE_LENGTH);
+
+ var simdMin = firstVector.min(secondVector).reduceLanes(VectorOperators.MIN);
+ min = (short) Math.min(min, simdMin);
+
+ var simdMax = firstVector.max(secondVector).reduceLanes(VectorOperators.MAX);
+ max = (short) Math.max(max, simdMax);
+
+ sum += firstVector.add(secondVector).reduceLanes(VectorOperators.ADD);
+ }
+ }
+
+ public void aggregateRemainingMeasurements() {
+ for (var i = 0; i < count % lane.length; i++) {
+ var measurement = lane[i];
+ min = (short) Math.min(min, measurement);
+ max = (short) Math.max(max, measurement);
+ sum += measurement;
+ }
+ }
+
+ public void mergeWith(PartitionAggregate otherAggregate) {
+ min = (short) Math.min(min, otherAggregate.min);
+ max = (short) Math.max(max, otherAggregate.max);
+ count = count + otherAggregate.count;
+ sum = sum + otherAggregate.sum;
+ }
+
+ public String toString() {
+ return String.format(
+ Locale.US,
+ "%.1f/%.1f/%.1f",
+ (min / 10.0),
+ (sum / 10.0) / count,
+ (max / 10.0));
+ }
+ }
+}