aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJamie Stansfield <80010454+isolgpus@users.noreply.github.com>2024-01-05 22:10:43 +0000
committerGitHub <noreply@github.com>2024-01-05 23:10:43 +0100
commit4614b81eb6edda5aa7ac4336a8466975c0f6be72 (patch)
tree6993049054ce669db91e9993971c01a5d665ea24 /src
parent4fe00bf7a75530bd363b7599b023d63a4eda0ddf (diff)
isolgpus: submission 1
* isolgpus: submission 1 * isolgpus: fix min value bug (breaks if a negative temperature never appears) * isolgpus: remove unused collector * isolgpus: fix split on chunk bug * isolgpus: change name equality algo to a cheaper check. * isolgpus: fix chunking state to cope with last byte of last chunk * isolgpus: hash as we go, instead of at the end * isolgpus: adjust thread count to core count * isolgpus: change cores to 8 statically --------- Co-authored-by: Jamie Stansfield <jalstansfield@gmail.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java293
1 files changed, 293 insertions, 0 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java b/src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java
new file mode 100644
index 0000000..65528c4
--- /dev/null
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_isolgpus.java
@@ -0,0 +1,293 @@
+/*
+ * Copyright 2023 The original authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package dev.morling.onebrc;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.RandomAccessFile;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.nio.BufferUnderflowException;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Paths;
+import java.util.*;
+import java.util.concurrent.*;
+import java.util.stream.Collectors;
+
+public class CalculateAverage_isolgpus {
+
+ public static final int HISTOGRAMS_LENGTH = 1024 * 32;
+ public static final int HISTOGRAMS_MASK = HISTOGRAMS_LENGTH - 1;
+ public static final int THREAD_COUNT = 8;
+ private static final String FILE = "./measurements.txt";
+ public static final byte SEPERATOR = 59;
+ public static final byte OFFSET = 48;
+ public static final byte NEGATIVE = 45;
+ public static final byte DECIMAL_POINT = 46;
+ public static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 100; // bit of wiggle room
+ public static final byte NEW_LINE = 10;
+
+ public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {
+ ExecutorService executorService = Executors.newFixedThreadPool(THREAD_COUNT);
+
+ File file = Paths.get(FILE).toFile();
+ long length = file.length();
+ long chunksCount = Math.max(THREAD_COUNT, (int) Math.ceil(length / (double) MAX_CHUNK_SIZE));
+
+ long estimatedChunkSize = length / chunksCount;
+
+ FileChannel channel = new RandomAccessFile(file, "r").getChannel();
+
+ List<Future<MeasurementCollector[]>> futures = new ArrayList<>();
+ for (int i = 0; i < chunksCount; i++) {
+ int finalI = i;
+ futures.add(executorService.submit(() -> handleChunk(channel, estimatedChunkSize * finalI, estimatedChunkSize, length)));
+ }
+
+ List<MeasurementCollector[]> measurementCollectors = new ArrayList<>();
+ for (Future<MeasurementCollector[]> result : futures) {
+ measurementCollectors.add(result.get());
+ }
+ executorService.shutdown();
+
+ Map<String, MeasurementCollector> measurementCollectorsByCity = mergeMeasurements(measurementCollectors);
+ List<MeasurementResult> results = measurementCollectorsByCity.values().stream().map(MeasurementResult::from).toList();
+
+ System.out.println("{" + results.stream().map(MeasurementResult::toString).collect(Collectors.joining(", ")) + "}");
+
+ }
+
+ private static Map<String, MeasurementCollector> mergeMeasurements(List<MeasurementCollector[]> resultsFromAllChunk) {
+ Map<String, MeasurementCollector> mergedResults = new TreeMap<>(Comparator.naturalOrder());
+
+ for (int i = 0; i < HISTOGRAMS_LENGTH; i++) {
+ for (MeasurementCollector[] resultFromSpecificChunk : resultsFromAllChunk) {
+ MeasurementCollector measurementCollectorFromChunk = resultFromSpecificChunk[i];
+ while (measurementCollectorFromChunk != null) {
+ MeasurementCollector currentMergedResult = mergedResults.get(new String(measurementCollectorFromChunk.name));
+ if (currentMergedResult == null) {
+ currentMergedResult = new MeasurementCollector(measurementCollectorFromChunk.name);
+ mergedResults.put(new String(currentMergedResult.name), currentMergedResult);
+ }
+ currentMergedResult.merge(measurementCollectorFromChunk);
+ measurementCollectorFromChunk = measurementCollectorFromChunk.link;
+ }
+ }
+ }
+
+ return mergedResults;
+ }
+
+ // ----n---
+ private static MeasurementCollector[] handleChunk(FileChannel channel, long estimatedStart, long lengthOfChunk, long maxLengthOfFile) throws IOException {
+ // -1 to see if we're starting on a brand new message
+ // +200 for wiggle room to finish the final message
+
+ long seekStart = Math.max(estimatedStart - 1, 0);
+ long length = Math.min(lengthOfChunk + 200, maxLengthOfFile - seekStart);
+
+ MappedByteBuffer r = channel.map(FileChannel.MapMode.READ_ONLY, seekStart, length);
+
+ byte[] nameBuffer = new byte[100];
+ boolean isNegative;
+ byte[] valueBuffer = new byte[3];
+ MeasurementCollector[] measurementCollectors = new MeasurementCollector[HISTOGRAMS_LENGTH];
+ int valueIndex = 0;
+ int nameBufferIndex = 0;
+ int nameSum = 0;
+ boolean parsingName = true;
+ long i = 0;
+ int hashResult = 0;
+
+ // seek to the start of the next message
+ if (estimatedStart != 0) {
+ while (r.get() != NEW_LINE) {
+ i++;
+ }
+ i++;
+ }
+
+ try {
+
+ while (i <= lengthOfChunk || !parsingName) {
+ byte aChar;
+ if (parsingName) {
+
+ while ((aChar = r.get()) != SEPERATOR) {
+ nameBuffer[nameBufferIndex++] = aChar;
+ nameSum += aChar;
+ hashResult = 31 * hashResult + aChar;
+ }
+ parsingName = false;
+ i += nameBufferIndex + 1;
+ }
+ else {
+ isNegative = (aChar = r.get()) == NEGATIVE;
+ valueIndex = readNumber(isNegative, valueBuffer, valueIndex, aChar, r);
+
+ byte decimalValue = r.get();
+
+ int value = resolveValue(valueIndex, valueBuffer, decimalValue, isNegative);
+ // new line character
+ r.get();
+
+ MeasurementCollector measurementCollector = resolveMeasurementCollector(measurementCollectors, hashResult, nameBuffer, nameBufferIndex, nameSum);
+
+ measurementCollector.feed(value);
+ i += valueIndex + (isNegative ? 4 : 3);
+ valueIndex = 0;
+ nameBufferIndex = 0;
+ nameSum = 0;
+ parsingName = true;
+ hashResult = 0;
+ }
+ }
+
+ }
+ catch (BufferUnderflowException e) {
+ if (i != maxLengthOfFile - seekStart) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+
+ return measurementCollectors;
+ }
+
+ private static MeasurementCollector resolveMeasurementCollector(MeasurementCollector[] measurementCollectors, int hash, byte[] nameBuffer, int nameBufferIndex,
+ int nameSum) {
+ MeasurementCollector measurementCollector = measurementCollectors[hash & HISTOGRAMS_MASK];
+ if (measurementCollector == null) {
+ measurementCollector = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex));
+ measurementCollectors[hash & HISTOGRAMS_MASK] = measurementCollector;
+ }
+ else {
+ // collision unhappy path, try to avoid
+ while (!nameEquals(measurementCollector.name, measurementCollector.nameSum, nameSum, nameBufferIndex)) {
+ if (measurementCollector.link == null) {
+ measurementCollector.link = new MeasurementCollector(Arrays.copyOf(nameBuffer, nameBufferIndex));
+ measurementCollector = measurementCollector.link;
+ break;
+ }
+ else {
+ measurementCollector = measurementCollector.link;
+ }
+ }
+
+ }
+ return measurementCollector;
+ }
+
+ private static boolean nameEquals(byte[] existingName, int existingNameSum, int incomingNameSum, int nameBufferIndex) {
+
+ if (existingName.length != nameBufferIndex) {
+ return false;
+ }
+
+ return incomingNameSum == existingNameSum;
+ }
+
+ private static int resolveValue(int valueIndex, byte[] valueBuffer, byte decimalValue, boolean isNegative) {
+ int value;
+ if (valueIndex == 1) {
+ value = ((valueBuffer[0] - OFFSET) * 10) + (decimalValue - OFFSET);
+ }
+ else // it's 2 digits
+ {
+ value = ((valueBuffer[0] - OFFSET) * 100) + ((valueBuffer[1] - OFFSET) * 10) + (decimalValue - OFFSET);
+ }
+
+ if (isNegative) {
+ value = Math.negateExact(value);
+ }
+ return value;
+ }
+
+ private static int readNumber(boolean isNegative, byte[] valueBuffer, int valueIndex, byte aChar, MappedByteBuffer r) {
+ if (!isNegative) {
+ valueBuffer[valueIndex++] = aChar;
+ }
+
+ // maybe one or two more
+ while ((aChar = r.get()) != DECIMAL_POINT) {
+ valueBuffer[valueIndex++] = aChar;
+ }
+ return valueIndex;
+ }
+
+ private static class MeasurementCollector {
+ private final byte[] name;
+ private final int nameSum;
+ public MeasurementCollector link;
+ private long sum;
+ private int count;
+ private int min = Integer.MAX_VALUE;
+ private int max = Integer.MIN_VALUE;
+
+ public MeasurementCollector(byte[] name) {
+
+ this.name = name;
+ int nameSum = 0;
+ for (int i = 0; i < name.length; i++) {
+ nameSum += name[i];
+ }
+ this.nameSum = nameSum;
+ }
+
+ public void feed(int value) {
+ sum += value;
+ count++;
+ min = Math.min(value, min);
+ max = Math.max(value, max);
+ }
+
+ public void merge(MeasurementCollector measurementCollector) {
+ this.sum += measurementCollector.sum;
+ this.count += measurementCollector.count;
+ this.min = Math.min(measurementCollector.min, this.min);
+ this.max = Math.max(measurementCollector.max, this.max);
+ }
+ }
+
+ private static class MeasurementResult {
+ private final String name;
+ private final double mean;
+ private final BigDecimal max;
+ private final BigDecimal min;
+
+ public MeasurementResult(String name, double mean, BigDecimal max, BigDecimal min) {
+
+ this.name = name;
+ this.mean = mean;
+ this.max = max;
+ this.min = min;
+ }
+
+ @Override
+ public String toString() {
+ // Abha=-24.9/18.0/61.7
+ return name + "=" + min + "/" + mean + "/" + max;
+ }
+
+ public static MeasurementResult from(MeasurementCollector mc) {
+ double mean = Math.round((double) mc.sum / (double) mc.count) / 10d;
+ BigDecimal max = BigDecimal.valueOf(mc.max).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP);
+ BigDecimal min = BigDecimal.valueOf(mc.min).divide(BigDecimal.TEN, 1, RoundingMode.HALF_UP);
+ return new MeasurementResult(new String(mc.name), mean, max, min);
+ }
+ }
+}