aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev
diff options
context:
space:
mode:
authorgonix <d.giedrius+github@gmail.com>2024-01-16 23:49:39 +0200
committerGitHub <noreply@github.com>2024-01-16 22:49:39 +0100
commit7f5f808176c13e080e50fb6649c24a9f0010f8cb (patch)
tree3a29275f59b45bf297d4226f34d3857994badca1 /src/main/java/dev
parent455b85c5af1cba9ddb6e6dc686c091d1e000a432 (diff)
CalculateAverage_gonix initial attempt (#413)
Diffstat (limited to 'src/main/java/dev')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java354
1 files changed, 354 insertions, 0 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java
new file mode 100644
index 0000000..8349d00
--- /dev/null
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java
@@ -0,0 +1,354 @@
+/*
+ * Copyright 2023 The original authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package dev.morling.onebrc;
+
+import java.io.IOException;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class CalculateAverage_gonix {
+
+ private static final String FILE = "./measurements.txt";
+
+ public static void main(String[] args) throws IOException {
+
+ var file = new RandomAccessFile(FILE, "r");
+
+ var res = buildChunks(file).stream().parallel()
+ .flatMap(chunk -> new Aggregator().processChunk(chunk).stream())
+ .collect(Collectors.toMap(
+ Aggregator.Entry::getKey,
+ Aggregator.Entry::getValue,
+ Aggregator.Entry::add,
+ TreeMap::new));
+
+ System.out.println(res);
+ }
+
+ private static List<MappedByteBuffer> buildChunks(RandomAccessFile file) throws IOException {
+ var fileSize = file.length();
+ var chunkSize = Math.min(Integer.MAX_VALUE - 512, fileSize / Runtime.getRuntime().availableProcessors());
+ if (chunkSize <= 0) {
+ chunkSize = fileSize;
+ }
+ var chunks = new ArrayList<MappedByteBuffer>((int) (fileSize / chunkSize) + 1);
+ var start = 0L;
+ while (start < fileSize) {
+ var pos = start + chunkSize;
+ if (pos < fileSize) {
+ file.seek(pos);
+ while (file.read() != '\n') {
+ pos += 1;
+ }
+ pos += 1;
+ }
+ else {
+ pos = fileSize;
+ }
+ var buf = file.getChannel().map(FileChannel.MapMode.READ_ONLY, start, pos - start);
+ buf.order(ByteOrder.nativeOrder());
+ chunks.add(buf);
+ start = pos;
+ }
+ return chunks;
+ }
+}
+
+class Aggregator {
+ private static final int MAX_STATIONS = 10_000;
+ private static final int MAX_STATION_SIZE = (100 * 4) / 8 + 5;
+ private static final int INDEX_SIZE = 1024 * 1024;
+ private static final int INDEX_MASK = INDEX_SIZE - 1;
+ private static final int FLD_MAX = 0;
+ private static final int FLD_MIN = 1;
+ private static final int FLD_SUM = 2;
+ private static final int FLD_COUNT = 3;
+
+ // Poor man's hash map: hash code to offset in `mem`.
+ private final int[] index;
+
+ // Contiguous storage of key (station name) and stats fields of all
+ // unique stations.
+ // The idea here is to improve locality so that stats fields would
+ // possibly be already in the CPU cache after we are done comparing
+ // the key.
+ private final long[] mem;
+ private int memUsed;
+
+ Aggregator() {
+ assert ((INDEX_SIZE & (INDEX_SIZE - 1)) == 0) : "INDEX_SIZE must be power of 2";
+ assert (INDEX_SIZE > MAX_STATIONS) : "INDEX_SIZE must be greater than MAX_STATIONS";
+
+ index = new int[INDEX_SIZE];
+ mem = new long[1 + (MAX_STATIONS * MAX_STATION_SIZE)];
+ memUsed = 1;
+ }
+
+ Aggregator processChunk(MappedByteBuffer buf) {
+ // To avoid checking if it is safe to read a whole long near the
+ // end of a chunk, we copy last couple of lines to a padded buffer
+ // and process that part separately.
+ int limit = buf.limit();
+ int pos = Math.max(limit - 16, -1);
+ while (pos >= 0 && buf.get(pos) != '\n') {
+ pos--;
+ }
+ pos++;
+ if (pos > 0) {
+ processChunkLongs(buf, pos);
+ }
+ int tailLen = limit - pos;
+ var tailBuf = ByteBuffer.allocate(tailLen + 8).order(ByteOrder.nativeOrder());
+ buf.get(pos, tailBuf.array(), 0, tailLen);
+ processChunkLongs(tailBuf, tailLen);
+ return this;
+ }
+
+ Aggregator processChunkLongs(ByteBuffer buf, int limit) {
+ int pos = 0;
+ while (pos < limit) {
+
+ int start = pos;
+ int hash = 0;
+ while (true) {
+ // This is a bit ugly, but it is faster than reading by byte.
+ long tmpLong = buf.getLong(pos);
+ if ((tmpLong & 0xFF) == ';') {
+ break;
+ }
+ if (((tmpLong >>> 8) & 0xFF) == ';') {
+ hash = (33 * hash) ^ (int) (tmpLong & 0xFF);
+ pos += 1;
+ break;
+ }
+ if (((tmpLong >>> 16) & 0xFF) == ';') {
+ hash = (33 * hash) ^ (int) (tmpLong & 0xFFFF);
+ pos += 2;
+ break;
+ }
+ if (((tmpLong >>> 24) & 0xFF) == ';') {
+ hash = (33 * hash) ^ (int) (tmpLong & 0xFFFFFF);
+ pos += 3;
+ break;
+ }
+ if (((tmpLong >>> 32) & 0xFF) == ';') {
+ hash = (33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF);
+ pos += 4;
+ break;
+ }
+ if (((tmpLong >>> 40) & 0xFF) == ';') {
+ hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFF);
+ pos += 5;
+ break;
+ }
+ if (((tmpLong >>> 48) & 0xFF) == ';') {
+ hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFF);
+ pos += 6;
+ break;
+ }
+ if (((tmpLong >>> 56) & 0xFF) == ';') {
+ hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFF);
+ pos += 7;
+ break;
+ }
+ hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFFFF);
+ pos += 8;
+ }
+ hash = (33 * hash) ^ (hash >>> 15);
+ int len = pos - start;
+ assert (buf.get(pos) == ';') : "Expected ';'";
+ pos++;
+
+ int measurement;
+ {
+ long tmpLong = buf.getLong(pos);
+ int sign = 1;
+ if ((tmpLong & 0xFF) == '-') {
+ sign = -1;
+ tmpLong >>>= 8;
+ pos++;
+ }
+ int value;
+ if (((tmpLong >>> 8) & 0xFF) == '.') {
+ value = (int) (((tmpLong & 0xFF) - '0') * 10 + (((tmpLong >>> 16) & 0xFF) - '0'));
+ pos += 4;
+ }
+ else {
+ value = (int) (((tmpLong & 0xFF) - '0') * 100 + (((tmpLong >>> 8) & 0xFF) - '0') * 10 + (((tmpLong >>> 24) & 0xFF) - '0'));
+ pos += 5;
+ }
+ measurement = sign * value;
+ }
+ assert (buf.get(pos - 1) == '\n') : "Expected '\\n'";
+
+ add(buf, start, len, hash, measurement);
+ }
+
+ return this;
+ }
+
+ public Stream<Entry> stream() {
+ return Arrays.stream(index)
+ .filter(offset -> offset != 0)
+ .mapToObj(offset -> new Entry(mem, offset));
+ }
+
+ private void add(ByteBuffer buf, int start, int len, int hash, int measurement) {
+ int idx = hash & INDEX_MASK;
+ while (true) {
+ if (index[idx] != 0) {
+ int offset = index[idx];
+ if (keyEqual(offset, buf, start, len)) {
+ int pos = offset + (len >> 3) + 2;
+ mem[pos + FLD_MIN] = Math.min((int) measurement, (int) mem[pos + FLD_MIN]);
+ mem[pos + FLD_MAX] = Math.max((int) measurement, (int) mem[pos + FLD_MAX]);
+ mem[pos + FLD_SUM] += measurement;
+ mem[pos + FLD_COUNT] += 1;
+ return;
+ }
+ }
+ else {
+ index[idx] = create(buf, start, len, hash, measurement);
+ return;
+ }
+ idx = (idx + 1) & INDEX_MASK;
+ }
+ }
+
+ private int create(ByteBuffer buf, int start, int len, int hash, int measurement) {
+ int offset = memUsed;
+
+ mem[offset] = len;
+
+ int memPos = offset + 1;
+ int memEndEarly = memPos + (len >> 3);
+ int bufPos = start;
+ int bufEnd = start + len;
+ while (memPos < memEndEarly) {
+ mem[memPos] = buf.getLong(bufPos);
+ memPos += 1;
+ bufPos += 8;
+ }
+ if (bufPos < bufEnd) {
+ int shift = (8 - (len & 7)) << 3; // (8 - (len % 8)) * 8
+ long tmpLong = buf.getLong(bufPos) << shift >>> shift;
+ mem[memPos] = tmpLong;
+ }
+ else {
+ // "consume" extra long - makes math a bit simpler to calculate
+ // fields offset for update.
+ mem[memPos] = 0;
+ }
+
+ memPos += 1;
+ mem[memPos + FLD_MIN] = measurement;
+ mem[memPos + FLD_MAX] = measurement;
+ mem[memPos + FLD_SUM] = measurement;
+ mem[memPos + FLD_COUNT] = 1;
+ memUsed = memPos + 4;
+
+ return offset;
+ }
+
+ private boolean keyEqual(int offset, ByteBuffer buf, int start, int len) {
+ if (len != mem[offset]) {
+ return false;
+ }
+ int memPos = offset + 1;
+ int memEndEarly = memPos + (len >> 3);
+ int bufPos = start;
+ int bufEnd = start + len;
+ while (memPos < memEndEarly) {
+ if (mem[memPos] != buf.getLong(bufPos)) {
+ return false;
+ }
+ memPos += 1;
+ bufPos += 8;
+ }
+ if (bufPos < bufEnd) {
+ int shift = (8 - (len & 7)) << 3; // (8 - (len % 8)) * 8
+ long tmpLong = buf.getLong(bufPos) << shift >>> shift;
+ if (mem[memPos] != tmpLong) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public static class Entry {
+ private final long[] mem;
+ private final int offset;
+ private String key;
+
+ Entry(long[] mem, int offset) {
+ this.mem = mem;
+ this.offset = offset;
+ }
+
+ public String getKey() {
+ if (key == null) {
+ int pos = this.offset;
+ int keyLen = (int) mem[pos++];
+ var tmpBuf = ByteBuffer.allocate(keyLen + 8).order(ByteOrder.nativeOrder());
+ for (int i = 0; i < keyLen; i += 8) {
+ tmpBuf.putLong(mem[pos++]);
+ }
+ key = new String(tmpBuf.array(), 0, keyLen, StandardCharsets.UTF_8);
+ }
+ return key;
+ }
+
+ public Entry add(Entry other) {
+ int keyLen = (int) mem[offset];
+ int fldOffset = (keyLen >> 3) + 2;
+ int pos = offset + fldOffset;
+ int otherPos = other.offset + fldOffset;
+ long[] otherMem = other.mem;
+ mem[pos + FLD_MIN] = Math.min((int) mem[pos + FLD_MIN], (int) otherMem[otherPos + FLD_MIN]);
+ mem[pos + FLD_MAX] = Math.max((int) mem[pos + FLD_MAX], (int) otherMem[otherPos + FLD_MAX]);
+ mem[pos + FLD_SUM] += otherMem[otherPos + FLD_SUM];
+ mem[pos + FLD_COUNT] += otherMem[otherPos + FLD_COUNT];
+ return this;
+ }
+
+ public Entry getValue() {
+ return this;
+ }
+
+ @Override
+ public String toString() {
+ int keyLen = (int) mem[offset];
+ int pos = offset + (keyLen >> 3) + 2;
+ return round(mem[pos + FLD_MIN])
+ + "/" + round(((double) mem[pos + FLD_SUM]) / mem[pos + FLD_COUNT])
+ + "/" + round(mem[pos + FLD_MAX]);
+ }
+
+ private static double round(double value) {
+ return Math.round(value) / 10.0;
+ }
+ }
+}