diff options
| author | gonix <d.giedrius+github@gmail.com> | 2024-01-16 23:49:39 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-01-16 22:49:39 +0100 |
| commit | 7f5f808176c13e080e50fb6649c24a9f0010f8cb (patch) | |
| tree | 3a29275f59b45bf297d4226f34d3857994badca1 | |
| parent | 455b85c5af1cba9ddb6e6dc686c091d1e000a432 (diff) | |
CalculateAverage_gonix initial attempt (#413)
| -rwxr-xr-x | calculate_average_gonix.sh | 20 | ||||
| -rw-r--r-- | src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java | 354 |
2 files changed, 374 insertions, 0 deletions
diff --git a/calculate_average_gonix.sh b/calculate_average_gonix.sh new file mode 100755 index 0000000..a6f9165 --- /dev/null +++ b/calculate_average_gonix.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# +# Copyright 2023 The original authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +JAVA_OPTS="--enable-preview" +java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_gonix diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java new file mode 100644 index 0000000..8349d00 --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java @@ -0,0 +1,354 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.TreeMap; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class CalculateAverage_gonix { + + private static final String FILE = "./measurements.txt"; + + public static void main(String[] args) throws IOException { + + var file = new RandomAccessFile(FILE, "r"); + + var res = buildChunks(file).stream().parallel() + .flatMap(chunk -> new Aggregator().processChunk(chunk).stream()) + .collect(Collectors.toMap( + Aggregator.Entry::getKey, + Aggregator.Entry::getValue, + Aggregator.Entry::add, + TreeMap::new)); + + System.out.println(res); + } + + private static List<MappedByteBuffer> buildChunks(RandomAccessFile file) throws IOException { + var fileSize = file.length(); + var chunkSize = Math.min(Integer.MAX_VALUE - 512, fileSize / Runtime.getRuntime().availableProcessors()); + if (chunkSize <= 0) { + chunkSize = fileSize; + } + var chunks = new ArrayList<MappedByteBuffer>((int) (fileSize / chunkSize) + 1); + var start = 0L; + while (start < fileSize) { + var pos = start + chunkSize; + if (pos < fileSize) { + file.seek(pos); + while (file.read() != '\n') { + pos += 1; + } + pos += 1; + } + else { + pos = fileSize; + } + var buf = file.getChannel().map(FileChannel.MapMode.READ_ONLY, start, pos - start); + buf.order(ByteOrder.nativeOrder()); + chunks.add(buf); + start = pos; + } + return chunks; + } +} + +class Aggregator { + private static final int MAX_STATIONS = 10_000; + private static final int MAX_STATION_SIZE = (100 * 4) / 8 + 5; + private static final int INDEX_SIZE = 1024 * 1024; + private static final int INDEX_MASK = INDEX_SIZE - 1; + private static final int FLD_MAX = 0; + private static final int FLD_MIN = 1; + private static final int FLD_SUM = 2; + private static final int FLD_COUNT = 3; + + // Poor man's hash map: hash code to offset in `mem`. + private final int[] index; + + // Contiguous storage of key (station name) and stats fields of all + // unique stations. + // The idea here is to improve locality so that stats fields would + // possibly be already in the CPU cache after we are done comparing + // the key. + private final long[] mem; + private int memUsed; + + Aggregator() { + assert ((INDEX_SIZE & (INDEX_SIZE - 1)) == 0) : "INDEX_SIZE must be power of 2"; + assert (INDEX_SIZE > MAX_STATIONS) : "INDEX_SIZE must be greater than MAX_STATIONS"; + + index = new int[INDEX_SIZE]; + mem = new long[1 + (MAX_STATIONS * MAX_STATION_SIZE)]; + memUsed = 1; + } + + Aggregator processChunk(MappedByteBuffer buf) { + // To avoid checking if it is safe to read a whole long near the + // end of a chunk, we copy last couple of lines to a padded buffer + // and process that part separately. + int limit = buf.limit(); + int pos = Math.max(limit - 16, -1); + while (pos >= 0 && buf.get(pos) != '\n') { + pos--; + } + pos++; + if (pos > 0) { + processChunkLongs(buf, pos); + } + int tailLen = limit - pos; + var tailBuf = ByteBuffer.allocate(tailLen + 8).order(ByteOrder.nativeOrder()); + buf.get(pos, tailBuf.array(), 0, tailLen); + processChunkLongs(tailBuf, tailLen); + return this; + } + + Aggregator processChunkLongs(ByteBuffer buf, int limit) { + int pos = 0; + while (pos < limit) { + + int start = pos; + int hash = 0; + while (true) { + // This is a bit ugly, but it is faster than reading by byte. + long tmpLong = buf.getLong(pos); + if ((tmpLong & 0xFF) == ';') { + break; + } + if (((tmpLong >>> 8) & 0xFF) == ';') { + hash = (33 * hash) ^ (int) (tmpLong & 0xFF); + pos += 1; + break; + } + if (((tmpLong >>> 16) & 0xFF) == ';') { + hash = (33 * hash) ^ (int) (tmpLong & 0xFFFF); + pos += 2; + break; + } + if (((tmpLong >>> 24) & 0xFF) == ';') { + hash = (33 * hash) ^ (int) (tmpLong & 0xFFFFFF); + pos += 3; + break; + } + if (((tmpLong >>> 32) & 0xFF) == ';') { + hash = (33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF); + pos += 4; + break; + } + if (((tmpLong >>> 40) & 0xFF) == ';') { + hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFF); + pos += 5; + break; + } + if (((tmpLong >>> 48) & 0xFF) == ';') { + hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFF); + pos += 6; + break; + } + if (((tmpLong >>> 56) & 0xFF) == ';') { + hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFF); + pos += 7; + break; + } + hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFFFF); + pos += 8; + } + hash = (33 * hash) ^ (hash >>> 15); + int len = pos - start; + assert (buf.get(pos) == ';') : "Expected ';'"; + pos++; + + int measurement; + { + long tmpLong = buf.getLong(pos); + int sign = 1; + if ((tmpLong & 0xFF) == '-') { + sign = -1; + tmpLong >>>= 8; + pos++; + } + int value; + if (((tmpLong >>> 8) & 0xFF) == '.') { + value = (int) (((tmpLong & 0xFF) - '0') * 10 + (((tmpLong >>> 16) & 0xFF) - '0')); + pos += 4; + } + else { + value = (int) (((tmpLong & 0xFF) - '0') * 100 + (((tmpLong >>> 8) & 0xFF) - '0') * 10 + (((tmpLong >>> 24) & 0xFF) - '0')); + pos += 5; + } + measurement = sign * value; + } + assert (buf.get(pos - 1) == '\n') : "Expected '\\n'"; + + add(buf, start, len, hash, measurement); + } + + return this; + } + + public Stream<Entry> stream() { + return Arrays.stream(index) + .filter(offset -> offset != 0) + .mapToObj(offset -> new Entry(mem, offset)); + } + + private void add(ByteBuffer buf, int start, int len, int hash, int measurement) { + int idx = hash & INDEX_MASK; + while (true) { + if (index[idx] != 0) { + int offset = index[idx]; + if (keyEqual(offset, buf, start, len)) { + int pos = offset + (len >> 3) + 2; + mem[pos + FLD_MIN] = Math.min((int) measurement, (int) mem[pos + FLD_MIN]); + mem[pos + FLD_MAX] = Math.max((int) measurement, (int) mem[pos + FLD_MAX]); + mem[pos + FLD_SUM] += measurement; + mem[pos + FLD_COUNT] += 1; + return; + } + } + else { + index[idx] = create(buf, start, len, hash, measurement); + return; + } + idx = (idx + 1) & INDEX_MASK; + } + } + + private int create(ByteBuffer buf, int start, int len, int hash, int measurement) { + int offset = memUsed; + + mem[offset] = len; + + int memPos = offset + 1; + int memEndEarly = memPos + (len >> 3); + int bufPos = start; + int bufEnd = start + len; + while (memPos < memEndEarly) { + mem[memPos] = buf.getLong(bufPos); + memPos += 1; + bufPos += 8; + } + if (bufPos < bufEnd) { + int shift = (8 - (len & 7)) << 3; // (8 - (len % 8)) * 8 + long tmpLong = buf.getLong(bufPos) << shift >>> shift; + mem[memPos] = tmpLong; + } + else { + // "consume" extra long - makes math a bit simpler to calculate + // fields offset for update. + mem[memPos] = 0; + } + + memPos += 1; + mem[memPos + FLD_MIN] = measurement; + mem[memPos + FLD_MAX] = measurement; + mem[memPos + FLD_SUM] = measurement; + mem[memPos + FLD_COUNT] = 1; + memUsed = memPos + 4; + + return offset; + } + + private boolean keyEqual(int offset, ByteBuffer buf, int start, int len) { + if (len != mem[offset]) { + return false; + } + int memPos = offset + 1; + int memEndEarly = memPos + (len >> 3); + int bufPos = start; + int bufEnd = start + len; + while (memPos < memEndEarly) { + if (mem[memPos] != buf.getLong(bufPos)) { + return false; + } + memPos += 1; + bufPos += 8; + } + if (bufPos < bufEnd) { + int shift = (8 - (len & 7)) << 3; // (8 - (len % 8)) * 8 + long tmpLong = buf.getLong(bufPos) << shift >>> shift; + if (mem[memPos] != tmpLong) { + return false; + } + } + return true; + } + + public static class Entry { + private final long[] mem; + private final int offset; + private String key; + + Entry(long[] mem, int offset) { + this.mem = mem; + this.offset = offset; + } + + public String getKey() { + if (key == null) { + int pos = this.offset; + int keyLen = (int) mem[pos++]; + var tmpBuf = ByteBuffer.allocate(keyLen + 8).order(ByteOrder.nativeOrder()); + for (int i = 0; i < keyLen; i += 8) { + tmpBuf.putLong(mem[pos++]); + } + key = new String(tmpBuf.array(), 0, keyLen, StandardCharsets.UTF_8); + } + return key; + } + + public Entry add(Entry other) { + int keyLen = (int) mem[offset]; + int fldOffset = (keyLen >> 3) + 2; + int pos = offset + fldOffset; + int otherPos = other.offset + fldOffset; + long[] otherMem = other.mem; + mem[pos + FLD_MIN] = Math.min((int) mem[pos + FLD_MIN], (int) otherMem[otherPos + FLD_MIN]); + mem[pos + FLD_MAX] = Math.max((int) mem[pos + FLD_MAX], (int) otherMem[otherPos + FLD_MAX]); + mem[pos + FLD_SUM] += otherMem[otherPos + FLD_SUM]; + mem[pos + FLD_COUNT] += otherMem[otherPos + FLD_COUNT]; + return this; + } + + public Entry getValue() { + return this; + } + + @Override + public String toString() { + int keyLen = (int) mem[offset]; + int pos = offset + (keyLen >> 3) + 2; + return round(mem[pos + FLD_MIN]) + + "/" + round(((double) mem[pos + FLD_SUM]) / mem[pos + FLD_COUNT]) + + "/" + round(mem[pos + FLD_MAX]); + } + + private static double round(double value) { + return Math.round(value) / 10.0; + } + } +} |
