aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorNick Palmer <nick@palmr.co.uk>2024-01-03 22:18:40 +0000
committerGunnar Morling <gunnar.morling@googlemail.com>2024-01-04 23:54:04 +0100
commit6aa63e1bd5e2d580324b8ddd58b69d11761b2bf3 (patch)
tree86be29958725a144033b985d716c020e8eae83b2 /src
parentb2cd84c6bc2b6ba0b84ea52f1ef2f2d5d8592c24 (diff)
Attempt nicer threading via streams and spliterators
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java234
1 files changed, 102 insertions, 132 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java b/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java
index bb57d33..c687031 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java
@@ -21,162 +21,118 @@ import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
public class CalculateAverage_palmr {
-
private static final String FILE = "./measurements.txt";
- public static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine
- public static final int LITTLE_CHUNK_SIZE = 128; // Enough bytes to cover a station name and measurement value :fingers-crossed:
- public static final int STATION_NAME_BUFFER_SIZE = 50;
- public static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors());
+ private static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine
+ private static final int STATION_NAME_BUFFER_SIZE = 50;
+ private static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors());
+ private static final char SEPARATOR_CHAR = ';';
+ private static final char NEWLINE_CHAR = '\n';
+ private static final char MINUS_CHAR = '-';
+ private static final char DECIMAL_POINT_CHAR = '.';
public static void main(String[] args) throws IOException {
@SuppressWarnings("resource") // It's faster to leak the file than be well-behaved
- RandomAccessFile file = new RandomAccessFile(FILE, "r");
- FileChannel channel = file.getChannel();
- long fileSize = channel.size();
-
- long threadChunk = fileSize / THREAD_COUNT;
-
- Thread[] threads = new Thread[THREAD_COUNT];
- ByteArrayKeyedMap[] results = new ByteArrayKeyedMap[THREAD_COUNT];
- for (int i = 0; i < THREAD_COUNT; i++) {
- final int j = i;
- long startPoint = j * threadChunk;
- long endPoint = startPoint + threadChunk;
- Thread thread = new Thread(() -> {
- try {
- results[j] = readAndParse(channel, startPoint, endPoint, fileSize);
- }
- catch (Throwable t) {
- System.err.println("It's broken :(");
- // noinspection CallToPrintStackTrace
- t.printStackTrace();
- }
- });
- threads[i] = thread;
- thread.start();
- }
-
- final Map<String, MeasurementAggregator> finalAggregator = new TreeMap<>();
+ final var file = new RandomAccessFile(FILE, "r");
+ final var channel = file.getChannel();
+
+ final TreeMap<String, MeasurementAggregator> results = StreamSupport.stream(ThreadChunk.chunk(file, THREAD_COUNT), true)
+ .map(chunk -> parseChunk(chunk, channel))
+ .flatMap(bakm -> bakm.getAsUnorderedList().stream())
+ .collect(Collectors.toMap(m -> new String(m.stationNameBytes, StandardCharsets.UTF_8), m -> m, MeasurementAggregator::merge, TreeMap::new));
+ System.out.println(results);
+ }
- for (int i = 0; i < THREAD_COUNT; i++) {
- try {
- threads[i].join();
- }
- catch (InterruptedException e) {
- throw new RuntimeException(e);
+ private record ThreadChunk(long startPoint, long endPoint, long size) {
+ public static Spliterator<CalculateAverage_palmr.ThreadChunk> chunk(final RandomAccessFile file, final int chunkCount) throws IOException {
+ final var fileSize = file.length();
+ final var idealChunkSize = fileSize / THREAD_COUNT;
+ final var chunks = new CalculateAverage_palmr.ThreadChunk[chunkCount];
+
+ var startPoint = 0L;
+ for (int i = 0; i < chunkCount; i++) {
+ var endPoint = Math.min(startPoint + idealChunkSize, fileSize);
+ file.seek(endPoint);
+ while (endPoint < fileSize && file.readByte() != NEWLINE_CHAR) {
+ endPoint++;
+ }
+ final var actualSize = endPoint - startPoint;
+ chunks[i] = new CalculateAverage_palmr.ThreadChunk(startPoint, endPoint, actualSize);
+ startPoint += actualSize;
}
- results[i].getAsUnorderedList().forEach(v -> {
- String stationName = new String(v.stationNameBytes, StandardCharsets.UTF_8);
- finalAggregator.compute(stationName, (_, x) -> {
- if (x == null) {
- return v;
- }
- else {
- x.count += v.count;
- x.min = Math.min(x.min, v.min);
- x.max = Math.max(x.max, v.max);
- x.sum += v.sum;
- return x;
- }
- });
- });
+ return Spliterators.spliterator(chunks,
+ Spliterator.ORDERED |
+ Spliterator.DISTINCT |
+ Spliterator.SORTED |
+ Spliterator.NONNULL |
+ Spliterator.IMMUTABLE |
+ Spliterator.CONCURRENT
+ );
}
- System.out.println(finalAggregator);
}
- private static ByteArrayKeyedMap readAndParse(final FileChannel channel,
- final long startPoint,
- final long endPoint,
- final long fileSize) {
- final State state = new State();
+ private static ByteArrayKeyedMap parseChunk(ThreadChunk chunk, FileChannel channel) {
+ final var state = new State();
- boolean skipFirstEntry = startPoint != 0;
-
- long offset = startPoint;
- while (offset < endPoint) {
- parseData(channel, state, offset, Math.min(CHUNK_SIZE, fileSize - offset), false, skipFirstEntry);
- skipFirstEntry = false;
+ var offset = chunk.startPoint;
+ while (offset < chunk.endPoint) {
+ parseData(channel, state, offset, Math.min(CHUNK_SIZE, chunk.endPoint - offset));
offset += CHUNK_SIZE;
}
- if (offset < fileSize) {
- // Make sure we finish reading any partially read entry by going a little in to the next chunk, stopping at the first newline
- parseData(channel, state, offset, Math.min(LITTLE_CHUNK_SIZE, fileSize - offset), true, false);
- }
-
return state.aggregators;
}
private static void parseData(final FileChannel channel,
final State state,
final long offset,
- final long bufferSize,
- final boolean stopAtNewline,
- final boolean skipFirstEntry) {
- ByteBuffer byteBuffer;
+ final long bufferSize) {
+ final ByteBuffer byteBuffer;
try {
byteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, bufferSize);
- }
- catch (IOException e) {
- throw new RuntimeException(e);
- }
-
- boolean isSkippingToFirstCleanEntry = skipFirstEntry;
-
- while (byteBuffer.hasRemaining()) {
- byte currentChar = byteBuffer.get();
-
- if (isSkippingToFirstCleanEntry) {
- if (currentChar == '\n') {
- isSkippingToFirstCleanEntry = false;
- }
- continue;
- }
+ while (byteBuffer.hasRemaining()) {
+ final var currentChar = byteBuffer.get();
- if (currentChar == ';') {
- state.parsingValue = true;
- }
- else if (currentChar == '\n') {
- if (state.stationPointerEnd != 0) {
- double value = state.measurementValue * state.exponent;
-
- MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode);
- aggregator.count++;
- aggregator.min = Math.min(aggregator.min, value);
- aggregator.max = Math.max(aggregator.max, value);
- aggregator.sum += value;
- }
-
- if (stopAtNewline) {
- return;
- }
+ if (currentChar == SEPARATOR_CHAR) {
+ state.parsingValue = true;
+ } else if (currentChar == NEWLINE_CHAR) {
+ if (state.stationPointerEnd != 0) {
+ final var value = state.measurementValue * state.exponent;
- // reset
- state.reset();
- }
- else {
- if (!state.parsingValue) {
- state.stationBuffer[state.stationPointerEnd++] = currentChar;
- state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff);
- }
- else {
- if (currentChar == '-') {
- state.exponent = -0.1;
+ MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode);
+ aggregator.count++;
+ aggregator.min = Math.min(aggregator.min, value);
+ aggregator.max = Math.max(aggregator.max, value);
+ aggregator.sum += value;
}
- else if (currentChar != '.') {
- state.measurementValue = state.measurementValue * 10 + (currentChar - '0');
+
+ // reset
+ state.reset();
+ } else {
+ if (!state.parsingValue) {
+ state.stationBuffer[state.stationPointerEnd++] = currentChar;
+ state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff);
+ } else {
+ if (currentChar == MINUS_CHAR) {
+ state.exponent = -0.1;
+ } else if (currentChar != DECIMAL_POINT_CHAR) {
+ state.measurementValue = state.measurementValue * 10 + (currentChar - '0');
+ }
}
}
}
+ } catch (IOException e) {
+ throw new RuntimeException(e);
}
}
- static final class State {
+ private static final class State {
ByteArrayKeyedMap aggregators = new ByteArrayKeyedMap();
boolean parsingValue = false;
byte[] stationBuffer = new byte[STATION_NAME_BUFFER_SIZE];
@@ -208,37 +164,51 @@ public class CalculateAverage_palmr {
}
public String toString() {
- return round(min) + "/" + round(sum / count) + "/" + round(max);
+ return STR."\{round(min)}/\{round(sum / count)}/\{round(max)}";
}
- private double round(double value) {
+ private double round(final double value) {
return Math.round(value * 10.0) / 10.0;
}
+
+ private MeasurementAggregator merge(final MeasurementAggregator b) {
+ this.count += b.count;
+ this.min = Math.min(this.min, b.min);
+ this.max = Math.max(this.max, b.max);
+ this.sum += b.sum;
+ return this;
+ }
}
+ /**
+ * Very basic hash table implementation, only implementing computeIfAbsent since that's all the code needs.
+ * It's sized to give minimal collisions with the example test set. this may not hold true if the stations list
+ * changes, but it should still perform fairly well.
+ * It uses Open Addressing, meaning it's just one array, rather Separate Chaining which is what the default java HashMap uses.
+ * IT also uses Linear probing for collision resolution, which given the minimal collision count should hold up well.
+ */
private static class ByteArrayKeyedMap {
private final int BUCKET_COUNT = 0xFFF; // 413 unique stations in the data set, & 0xFFF ~= 399 (only 14 collisions (given our hashcode implementation))
private final MeasurementAggregator[] buckets = new MeasurementAggregator[BUCKET_COUNT + 1];
private final List<MeasurementAggregator> compactUnorderedBuckets = new ArrayList<>(413);
public MeasurementAggregator computeIfAbsent(final byte[] key, final int keyLength, final int keyHashCode) {
- int index = keyHashCode & BUCKET_COUNT;
+ var index = keyHashCode & BUCKET_COUNT;
while (true) {
MeasurementAggregator maybe = buckets[index];
- if (maybe == null) {
- final byte[] copiedKey = Arrays.copyOf(key, keyLength);
- MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode);
- buckets[index] = measurementAggregator;
- compactUnorderedBuckets.add(measurementAggregator);
- return measurementAggregator;
- }
- else {
+ if (maybe != null) {
if (Arrays.equals(key, 0, keyLength, maybe.stationNameBytes, 0, maybe.stationNameBytes.length)) {
return maybe;
}
index++;
index &= BUCKET_COUNT;
+ } else {
+ final var copiedKey = Arrays.copyOf(key, keyLength);
+ MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode);
+ buckets[index] = measurementAggregator;
+ compactUnorderedBuckets.add(measurementAggregator);
+ return measurementAggregator;
}
}
}