aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java329
1 files changed, 186 insertions, 143 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java
index d825e77..0e91253 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java
@@ -24,15 +24,12 @@ import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
-import java.util.Arrays;
import java.util.Collection;
-import java.util.Objects;
+import java.util.Map;
import java.util.TreeMap;
-import java.util.stream.Stream;
import static java.nio.channels.FileChannel.MapMode.READ_ONLY;
import static java.nio.charset.StandardCharsets.UTF_8;
-import static java.util.stream.Collectors.toMap;
public class CalculateAverage_armandino {
@@ -42,19 +39,59 @@ public class CalculateAverage_armandino {
private static final int INITIAL_MAP_CAPACITY = 8192;
private static final byte SEMICOLON = 59;
private static final byte NL = 10;
- private static final byte DOT = 46;
- private static final byte MINUS = 45;
- private static final byte ZERO_DIGIT = 48;
private static final int PRIME = 1117;
+
+ private static final int KEY_OFFSET = 0, // 100b
+ HASH_OFFSET = 100, // int
+ KEY_LENGTH_OFFSET = 104, // short
+ MIN_OFFSET = 106, // short
+ MAX_OFFSET = 108, // short
+ COUNT_OFFSET = 110, // int
+ SUM_OFFSET = 114; // long
+
+ private static final long ENTRY_SIZE = 100 // key: offset=0
+ + 4 // keyHash: offset=100
+ + 2 // keyLength: offset=104
+ + 2 // min: 108; offset=106
+ + 2 // max: 110; offset=108
+ + 4 // count: 114; offset=110
+ + 8; // sum: 122; offset=118
+
private static final Unsafe UNSAFE = getUnsafe();
public static void main(String[] args) throws Exception {
var channel = FileChannel.open(FILE, StandardOpenOption.READ);
- var results = Arrays.stream(split(channel)).parallel()
- .map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end))
- .flatMap(SimpleMap::stream)
- .collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new));
+ Chunk[] chunks = split(channel);
+ ChunkProcessor[] processors = new ChunkProcessor[chunks.length];
+
+ for (int i = 0; i < processors.length; i++) {
+ processors[i] = new ChunkProcessor(chunks[i].start, chunks[i].end);
+ processors[i].start();
+ }
+
+ Map<String, Stats> results = new TreeMap<>();
+
+ for (int i = 0; i < processors.length; i++) {
+ processors[i].join();
+ final long end = processors[i].map.mapEnd;
+
+ for (long addr = processors[i].map.mapStart; addr < end; addr += ENTRY_SIZE) {
+ final short keyLength = UNSAFE.getShort(addr + KEY_LENGTH_OFFSET);
+
+ if (keyLength == 0)
+ continue;
+
+ final byte[] keyBytes = new byte[keyLength];
+ UNSAFE.copyMemory(null, addr, keyBytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, keyLength);
+ final short min = UNSAFE.getShort(addr + MIN_OFFSET);
+ final short max = UNSAFE.getShort(addr + MAX_OFFSET);
+ final int count = UNSAFE.getInt(addr + COUNT_OFFSET);
+ final long sum = UNSAFE.getLong(addr + SUM_OFFSET);
+ final Stats s = new Stats(new String(keyBytes, 0, keyLength, UTF_8), min, max, count, sum);
+ results.merge(s.key, s, CalculateAverage_armandino::mergeStats);
+ }
+ }
print(results.values());
}
@@ -67,87 +104,69 @@ public class CalculateAverage_armandino {
return x;
}
- private static class ChunkProcessor {
- private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY);
+ private static class ChunkProcessor extends Thread {
+ private final UnsafeMap map = new UnsafeMap(INITIAL_MAP_CAPACITY);
+
+ final long chunkStart;
+ final long chunkEnd;
- private SimpleMap process(final long chunkStart, final long chunkEnd) {
+ private ChunkProcessor(long chunkStart, long chunkEnd) {
+ this.chunkStart = chunkStart;
+ this.chunkEnd = chunkEnd;
+ }
+
+ @Override
+ public void run() {
long i = chunkStart;
while (i < chunkEnd) {
final long keyAddress = i;
int keyHash = 0;
- int measurement = 0;
byte b;
while ((b = UNSAFE.getByte(i++)) != SEMICOLON) {
keyHash = PRIME * keyHash + b;
}
- final int keyLength = (int) (i - keyAddress - 1);
+ final short keyLength = (short) (i - keyAddress - 1);
+ final long numberWord = UNSAFE.getLong(i);
+ final int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
+ final short measurement = parseNumber(decimalSepPos, numberWord);
+ final int addOffset = (decimalSepPos >>> 3) + 3;
+ i += addOffset;
- if ((b = UNSAFE.getByte(i++)) == MINUS) {
- while ((b = UNSAFE.getByte(i++)) != DOT) {
- measurement = measurement * 10 + b - ZERO_DIGIT;
- }
-
- b = UNSAFE.getByte(i);
- measurement = measurement * 10 + b - ZERO_DIGIT;
- measurement = -measurement;
- i += 2;
- }
- else {
- measurement = b - ZERO_DIGIT; // D1
- b = UNSAFE.getByte(i); // dot or D2
-
- if (b == DOT) {
- measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F
- i += 3;
- }
- else {
- measurement = measurement * 10 + b - ZERO_DIGIT; // D2
- measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F
- i += 4; // skip NL
- }
- }
-
- final Stats stats = map.putStats(keyHash, keyAddress, keyLength);
- stats.min = Math.min(stats.min, measurement);
- stats.max = Math.max(stats.max, measurement);
- stats.sum += measurement;
- stats.count++;
+ map.addEntry(keyHash, keyAddress, keyLength, measurement);
}
+ }
- return map;
+ // credit: merykitty
+ private static short parseNumber(int decimalSepPos, long numberWord) {
+ int shift = 28 - decimalSepPos;
+ // signed is -1 if negative, 0 otherwise
+ long signed = (~numberWord << 59) >> 63;
+ long designMask = ~(signed & 0xFF);
+ // Align the number to a specific position and transform the ascii to digit value
+ long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L;
+ // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit)
+ // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
+ // 0x000000UU00TTHH00 + 0x00UU00TTHH000000 * 10 + 0xUU00TTHH00000000 * 100
+ long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ return (short) ((absValue ^ signed) - signed);
}
}
- private static class Stats implements Comparable<Stats> {
- private String key;
- private final long keyAddress;
- private final int keyLength;
- private final int keyHash;
- private int min = Integer.MAX_VALUE;
- private int max = Integer.MIN_VALUE;
+ private static class Stats {
+ private final String key;
+ private int min;
+ private int max;
private int count;
private long sum;
- private Stats(long keyAddress, int keyLength, int keyHash) {
- this.keyAddress = keyAddress;
- this.keyLength = keyLength;
- this.keyHash = keyHash;
- }
-
- String getKey() {
- if (key == null) {
- var keyBytes = new byte[keyLength];
- UNSAFE.copyMemory(null, keyAddress, keyBytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, keyLength);
- key = new String(keyBytes, 0, keyLength, UTF_8);
- }
- return key;
- }
-
- @Override
- public int compareTo(final Stats o) {
- return getKey().compareTo(o.getKey());
+ Stats(final String key, final int min, final int max, final int count, final long sum) {
+ this.min = min;
+ this.max = max;
+ this.count = count;
+ this.sum = sum;
+ this.key = key;
}
void print(final PrintStream out) {
@@ -219,90 +238,114 @@ public class CalculateAverage_armandino {
}
}
- private static class SimpleMap {
- private Stats[] table;
+ private static class UnsafeMap {
- SimpleMap(int initialCapacity) {
- table = new Stats[initialCapacity];
- }
+ long mapStart;
+ long mapEnd;
+ int capacity; // num entries
- Stream<Stats> stream() {
- return Arrays.stream(table).filter(Objects::nonNull);
+ UnsafeMap(int numEntries) {
+ capacity = numEntries;
+ final long size = ENTRY_SIZE * numEntries;
+ mapStart = UNSAFE.allocateMemory(size);
+ mapEnd = mapStart + size;
+ UNSAFE.setMemory(mapStart, size, (byte) 0);
}
- Stats putStats(final int keyHash, final long keyAddress, final int keyLength) {
- final int pos = (table.length - 1) & keyHash;
-
- Stats stats = table[pos];
- if (stats == null)
- return createAt(table, keyAddress, keyLength, keyHash, pos);
- if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength))
- return stats;
-
- int i = pos;
- while (++i < table.length) {
- stats = table[i];
- if (stats == null)
- return createAt(table, keyAddress, keyLength, keyHash, i);
- if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
- return stats;
+ void addEntry(final int keyHash, final long keyAddress, final short keyLength, final short measurement) {
+ final int pos = (capacity - 1) & keyHash;
+
+ long addr = mapStart + pos * ENTRY_SIZE;
+ int hash = UNSAFE.getInt(addr + HASH_OFFSET);
+
+ if (hash == 0) { // new entry
+ initEntry(addr, keyAddress, keyLength, measurement, keyHash);
+ return;
+ }
+ if (hash == keyHash && keysEqual(addr, keyAddress, keyLength)) {
+ updateEntry(addr, measurement);
+ return;
+ }
+
+ // this can be improved to avoid clustering at the start.
+ // should only affect the 10k test
+ addr = mapStart;
+
+ while (addr < mapEnd) {
+ addr += ENTRY_SIZE;
+ hash = UNSAFE.getInt(addr + HASH_OFFSET);
+
+ if (hash == 0) {
+ initEntry(addr, keyAddress, keyLength, measurement, keyHash);
+ return;
+ }
+ if (hash == keyHash && keysEqual(addr, keyAddress, keyLength)) {
+ updateEntry(addr, measurement);
+ return;
+ }
}
- i = pos;
- while (i-- > 0) {
- stats = table[i];
- if (stats == null)
- return createAt(table, keyAddress, keyLength, keyHash, i);
- if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength))
- return stats;
+ resize(keyHash, keyAddress, keyLength, measurement);
+ }
+
+ private void resize(final int keyHash, final long keyAddress, final short keyLength, final short measurement) {
+ UnsafeMap newMap = new UnsafeMap(capacity * 2);
+
+ for (long addr = mapStart; addr < mapEnd; addr += ENTRY_SIZE) {
+ final short oKeyLength = UNSAFE.getShort(addr + KEY_LENGTH_OFFSET);
+ final int oKeyHsh = UNSAFE.getInt(addr + HASH_OFFSET);
+ final short oMin = UNSAFE.getShort(addr + MIN_OFFSET);
+ final short oMax = UNSAFE.getShort(addr + MAX_OFFSET);
+ final int oCount = UNSAFE.getInt(addr + COUNT_OFFSET);
+ final long oSum = UNSAFE.getLong(addr + SUM_OFFSET);
+
+ final int newPos = (newMap.capacity - 1) & oKeyHsh;
+ long newAddr = newMap.mapStart + newPos * ENTRY_SIZE;
+
+ UNSAFE.putShort(newAddr + KEY_LENGTH_OFFSET, oKeyLength);
+ UNSAFE.putInt(newAddr + HASH_OFFSET, oKeyHsh);
+ UNSAFE.putShort(newAddr + MIN_OFFSET, oMin);
+ UNSAFE.putShort(newAddr + MAX_OFFSET, oMax);
+ UNSAFE.putInt(newAddr + COUNT_OFFSET, oCount);
+ UNSAFE.putLong(newAddr + SUM_OFFSET, oSum);
}
- resize();
- return putStats(keyHash, keyAddress, keyLength);
+
+ newMap.addEntry(keyHash, keyAddress, keyLength, measurement);
+
+ this.mapStart = newMap.mapStart;
+ this.mapEnd = newMap.mapEnd;
+ this.capacity = newMap.capacity;
}
- private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) {
- Stats stats = new Stats(keyAddress, keyLength, key);
- table[i] = stats;
- return stats;
+ private static void initEntry(final long entry, final long keyAddress, final short keyLength, final short measurement, final int keyHash) {
+ UNSAFE.copyMemory(keyAddress, entry, keyLength);
+ UNSAFE.putInt(entry + HASH_OFFSET, keyHash);
+ UNSAFE.putShort(entry + KEY_LENGTH_OFFSET, keyLength);
+ UNSAFE.putShort(entry + MIN_OFFSET, Short.MAX_VALUE);
+ UNSAFE.putShort(entry + MAX_OFFSET, Short.MIN_VALUE);
+
+ updateEntry(entry, measurement);
}
- private static boolean keysEqual(Stats stats, long keyAddress, final int keyLength) {
- // credit: abeobk
- long xsum = 0;
- int n = keyLength & 0xF8;
- for (int i = 0; i < n; i += 8) {
- xsum |= (UNSAFE.getLong(stats.keyAddress + i) ^ UNSAFE.getLong(keyAddress + i));
- }
- return xsum == 0;
+ private static void updateEntry(final long entry, final short measurement) {
+ UNSAFE.putShort(entry + MIN_OFFSET,
+ (short) Math.min(UNSAFE.getShort(entry + MIN_OFFSET), measurement));
+ UNSAFE.putShort(entry + MAX_OFFSET,
+ (short) Math.max(UNSAFE.getShort(entry + MAX_OFFSET), measurement));
+ UNSAFE.putInt(entry + COUNT_OFFSET,
+ UNSAFE.getInt(entry + COUNT_OFFSET) + 1);
+ UNSAFE.putLong(entry + SUM_OFFSET,
+ UNSAFE.getLong(entry + SUM_OFFSET) + measurement);
}
+ }
- private void resize() {
- var copy = new SimpleMap(table.length * 2);
- for (Stats s : table) {
- if (s != null) {
- final int pos = (copy.table.length - 1) & s.keyHash;
- int i = pos;
- if (copy.table[i] == null) {
- copy.table[i] = s;
- continue;
- }
- while (i < copy.table.length && copy.table[i] != null) {
- i++;
- }
- if (i == copy.table.length) {
- i = pos;
- while (i >= 0 && copy.table[i] != null) {
- i--;
- }
- }
- if (i < 0) {
- // if we reach here it's a bug!
- throw new IllegalStateException("table is full");
- }
- copy.table[i] = s;
- }
- }
- table = copy.table;
+ private static boolean keysEqual(long key1Address, long key2Address, final int keyLength) {
+ // credit: abeobk
+ long xsum = 0;
+ int n = keyLength & 0xF8;
+ for (int i = 0; i < n; i += 8) {
+ xsum |= (UNSAFE.getLong(key1Address + i) ^ UNSAFE.getLong(key2Address + i));
}
+ return xsum == 0;
}
}