aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev
diff options
context:
space:
mode:
authorElliot Barlas <elliotbarlas@gmail.com>2024-01-10 07:02:23 -0800
committerGunnar Morling <gunnar.morling@googlemail.com>2024-01-10 20:03:14 +0100
commit44414a33dc442ebb23f3fa15871ead413d35e7d7 (patch)
tree144fdce5fdd942f5317fdebe517f81e98b705910 /src/main/java/dev
parentc89490aaa1d7852e574ca5014397fa3e38ac8dad (diff)
Consume four bytes at a time from buffer using getInt. Store key with unsafe int array rather than byte array. Use custom equals rather than Arrays equals.
Diffstat (limited to 'src/main/java/dev')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java225
1 files changed, 142 insertions, 83 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java
index 63ff69f..b2a89d0 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ebarlas.java
@@ -15,6 +15,8 @@
*/
package dev.morling.onebrc;
+import sun.misc.Unsafe;
+
import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
@@ -23,7 +25,6 @@ import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
-import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
@@ -33,13 +34,22 @@ public class CalculateAverage_ebarlas {
private static final int HASH_FACTOR = 433;
private static final int HASH_TBL_SIZE = 16_383; // range of allowed hash values, inclusive
- public static void main(String[] args) throws IOException, InterruptedException {
- if (args.length != 2) {
- System.out.println("Usage: java CalculateAverage <input-file> <partitions>");
- System.exit(1);
+ private static final Unsafe UNSAFE = makeUnsafe();
+
+ private static Unsafe makeUnsafe() {
+ try {
+ var f = Unsafe.class.getDeclaredField("theUnsafe");
+ f.setAccessible(true);
+ return (Unsafe) f.get(null);
}
- var path = Paths.get(args[0]);
- var numPartitions = Integer.parseInt(args[1]);
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void main(String[] args) throws IOException, InterruptedException {
+ var path = Paths.get("measurements.txt");
+ var numPartitions = Math.max(8, Runtime.getRuntime().availableProcessors());
var channel = FileChannel.open(path, StandardOpenOption.READ);
var partitionSize = channel.size() / numPartitions;
var partitions = new Partition[numPartitions];
@@ -75,13 +85,31 @@ public class CalculateAverage_ebarlas {
var result = new TreeMap<String, String>();
for (var st : stats) {
if (st != null) {
- var key = new String(st.key, StandardCharsets.UTF_8);
+ var key = new String(convert(st.keyAddr, st.keyLen, st.lastBytes), StandardCharsets.UTF_8);
result.put(key, format(st));
}
}
System.out.println(result);
}
+ private static byte[] convert(long keyAddr, int keyLen, int keyLastBytes) {
+ var len = keyLastBytes == 4
+ ? keyLen * 4 // fully packed
+ : (keyLen - 1) * 4 + keyLastBytes; // last int partially packed
+ var bytes = new byte[len];
+ var idx = 0;
+ for (long i = 0; i < keyLen; i++) {
+ var offset = i << 2;
+ var n = UNSAFE.getInt(keyAddr + offset);
+ var bound = i == keyLen - 1 ? keyLastBytes : 4;
+ for (int j = 0; j < bound; j++) {
+ bytes[idx++] = (byte) (n & 0xFF);
+ n >>>= 8;
+ }
+ }
+ return bytes;
+ }
+
private static String format(Stats st) { // adheres to expected output format
return round(st.min / 10.0) + "/" + round((st.sum / 10.0) / st.count) + "/" + round(st.max / 10.0);
}
@@ -96,7 +124,7 @@ public class CalculateAverage_ebarlas {
var current = partitions.get(i).stats;
for (int j = 0; j < current.length; j++) {
if (current[j] != null) {
- var t = findInTable(target, current[j].hash, current[j].key, current[j].key.length);
+ var t = findInTable(target, current[j].hash, current[j].keyAddr, current[j].keyLen, current[j].lastBytes);
t.min = Math.min(t.min, current[j].min);
t.max = Math.max(t.max, current[j].max);
t.sum += current[j].sum;
@@ -112,7 +140,7 @@ public class CalculateAverage_ebarlas {
var pNext = partitions.get(i);
var pPrev = partitions.get(i - 1);
var merged = mergeFooterAndHeader(pPrev.footer, pNext.header);
- if (merged != null) {
+ if (merged != null && merged.length != 0) {
if (merged[merged.length - 1] == '\n') { // fold into prev partition
doProcessBuffer(ByteBuffer.wrap(merged).order(ByteOrder.LITTLE_ENDIAN), true, pPrev.stats);
}
@@ -148,80 +176,70 @@ public class CalculateAverage_ebarlas {
}
private static int reallyDoProcessBuffer(ByteBuffer buffer, Stats[] stats) {
- var keyBuf = new byte[MAX_KEY_SIZE]; // buffer for key
+ long keyBaseAddr = UNSAFE.allocateMemory(MAX_KEY_SIZE);
int keyStart = 0; // start of key in buffer used for footer calc
try { // abort with exception to allow optimistic line processing
while (true) { // one line per iteration
keyStart = buffer.position(); // preserve line start
- int n = buffer.getInt(); // first four bytes of key
- byte b1 = (byte) (n & 0xFF);
- byte b2 = (byte) ((n >> 8) & 0xFF);
- byte b3 = (byte) ((n >> 16) & 0xFF);
- byte b = (byte) ((n >> 24) & 0xFF);
- int keyPos;
- int keyHash = keyBuf[0] = b1;
- if (b2 != ';' && b3 != ';') { // true for keys of length 3 or more
- keyBuf[1] = b2;
- keyBuf[2] = b3;
- keyHash = HASH_FACTOR * (HASH_FACTOR * keyHash + b2) + b3;
- keyPos = 3;
- while (b != ';') {
- keyHash = HASH_FACTOR * keyHash + b;
- keyBuf[keyPos++] = b;
- b = buffer.get();
+ int keyHash = 0; // key hash code
+ long keyAddr = keyBaseAddr; // address for next int
+ int keyArrLen = 0; // number of key 4-byte ints
+ int keyLastBytes; // occupancy in last byte (1, 2, 3, or 4)
+ int val; // temperature value
+ while (true) {
+ int n = buffer.getInt();
+ byte b0 = (byte) (n & 0xFF);
+ byte b1 = (byte) ((n >> 8) & 0xFF);
+ byte b2 = (byte) ((n >> 16) & 0xFF);
+ byte b3 = (byte) ((n >> 24) & 0xFF);
+ if (b0 == ';') { // ...;1.1
+ val = getVal(buffer, b1, b2, b3, buffer.get());
+ keyLastBytes = 4;
+ break;
}
- }
- else { // slow path, rewind and consume byte-by-byte
- buffer.position(keyStart + 1);
- keyPos = 1;
- while ((b = buffer.get()) != ';') {
- keyHash = HASH_FACTOR * keyHash + b;
- keyBuf[keyPos++] = b;
+ else if (b1 == ';') { // ...a;1.1
+ val = getVal(buffer, b2, b3, buffer.get(), buffer.get());
+ UNSAFE.putInt(keyAddr, b0);
+ keyLastBytes = 1;
+ keyArrLen++;
+ keyHash = HASH_FACTOR * keyHash + b0;
+ break;
+ }
+ else if (b2 == ';') { // ...ab;1.1
+ val = getVal(buffer, b3, buffer.get(), buffer.get(), buffer.get());
+ UNSAFE.putInt(keyAddr, n & 0x0000FFFF);
+ keyLastBytes = 2;
+ keyArrLen++;
+ keyHash = HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1;
+ break;
+ }
+ else if (b3 == ';') { // ...abc;1.1
+ UNSAFE.putInt(keyAddr, n & 0x00FFFFFF);
+ keyLastBytes = 3;
+ keyArrLen++;
+ keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2;
+ n = buffer.getInt();
+ b0 = (byte) (n & 0xFF);
+ b1 = (byte) ((n >> 8) & 0xFF);
+ b2 = (byte) ((n >> 16) & 0xFF);
+ b3 = (byte) ((n >> 24) & 0xFF);
+ val = getVal(buffer, b0, b1, b2, b3);
+ break;
+ }
+ else {
+ UNSAFE.putInt(keyAddr, n);
+ keyArrLen++;
+ keyAddr += 4;
+ keyHash = HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * (HASH_FACTOR * keyHash + b0) + b1) + b2) + b3;
}
}
var idx = keyHash & HASH_TBL_SIZE;
var st = stats[idx];
if (st == null) { // nothing in table, eagerly claim spot
- st = stats[idx] = newStats(keyBuf, keyPos, keyHash);
- }
- else if (!Arrays.equals(st.key, 0, st.key.length, keyBuf, 0, keyPos)) {
- st = findInTable(stats, keyHash, keyBuf, keyPos);
- }
- var value = buffer.getInt();
- b = (byte) (value & 0xFF); // digit or dash
- int val;
- if (b == '-') { // dash branch
- val = ((byte) ((value >> 8) & 0xFF)) - '0'; // digit after dash
- b = (byte) ((value >> 16) & 0xFF); // second digit or decimal
- if (b != '.') { // second digit
- val = val * 10 + (b - '0'); // calc second digit
- // skip decimal (at >> 24)
- b = buffer.get(); // digit after decimal
- val = val * 10 + (b - '0'); // calc digit after decimal
- }
- else { // decimal branch
- // skip decimal (at >> 16)
- b = (byte) ((value >> 24) & 0xFF); // digit after decimal
- val = val * 10 + (b - '0'); // calc digit after decimal
- }
- buffer.get(); // newline
- val = -val;
+ st = stats[idx] = newStats(keyBaseAddr, keyArrLen, keyLastBytes, keyHash);
}
- else { // first digit branch
- val = b - '0'; // calc first digit
- b = (byte) ((value >> 8) & 0xFF); // second digit or decimal
- if (b != '.') { // second digit branch
- val = val * 10 + (b - '0'); // calc second digit
- // skip decimal (at >> 16)
- b = (byte) ((value >> 24) & 0xFF); // digit after decimal
- val = val * 10 + (b - '0'); // calc digit after decimal
- buffer.get(); // newline
- }
- else { // decimal branch
- b = (byte) ((value >> 16) & 0xFF); // digit after decimal
- val = val * 10 + (b - '0'); // calc digit after decimal
- // skip newline (at >> 24)
- }
+ else if (!equals(st.keyAddr, st.keyLen, keyBaseAddr, keyArrLen)) {
+ st = findInTable(stats, keyHash, keyBaseAddr, keyArrLen, keyLastBytes);
}
st.min = Math.min(st.min, val);
st.max = Math.max(st.max, val);
@@ -235,23 +253,60 @@ public class CalculateAverage_ebarlas {
return keyStart;
}
- private static Stats findInTable(Stats[] stats, int hash, byte[] key, int len) { // open-addressing scan
+ private static boolean equals(long key1, int len1, long key2, int len2) {
+ if (len1 != len2) {
+ return false;
+ }
+ for (long i = 0; i < len1; i++) {
+ var offset = i << 2;
+ if (UNSAFE.getInt(key1 + offset) != UNSAFE.getInt(key2 + offset)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static int getVal(ByteBuffer buffer, byte b0, byte b1, byte b2, byte b3) {
+ if (b0 == '-') {
+ if (b2 != '.') { // 6 bytes: -dd.dn
+ var b = buffer.get();
+ buffer.get(); // newline
+ return -(((b1 - '0') * 10 + (b2 - '0')) * 10 + (b - '0'));
+ }
+ else { // 5 bytes: -d.dn
+ buffer.get(); // newline
+ return -((b1 - '0') * 10 + (b3 - '0'));
+ }
+ }
+ else {
+ if (b1 != '.') { // 5 bytes: dd.dn
+ buffer.get(); // newline
+ return ((b0 - '0') * 10 + (b1 - '0')) * 10 + (b3 - '0');
+ }
+ else { // 4 bytes: d.dn
+ return (b0 - '0') * 10 + (b2 - '0');
+ }
+ }
+ }
+
+ private static Stats findInTable(Stats[] stats, int hash, long keyAddr, int keyLen, int keyLastBytes) { // open-addressing scan
var idx = hash & HASH_TBL_SIZE;
var st = stats[idx];
- while (st != null && !Arrays.equals(st.key, 0, st.key.length, key, 0, len)) {
+ while (st != null && !equals(st.keyAddr, st.keyLen, keyAddr, keyLen)) {
idx = (idx + 1) % (HASH_TBL_SIZE + 1);
st = stats[idx];
}
if (st != null) {
return st;
}
- return stats[idx] = newStats(key, len, hash);
+ return stats[idx] = newStats(keyAddr, keyLen, keyLastBytes, hash);
}
- private static Stats newStats(byte[] buffer, int len, int hash) {
- var k = new byte[len];
- System.arraycopy(buffer, 0, k, 0, len);
- return new Stats(k, hash);
+ private static Stats newStats(long keyAddr, int keyLen, int keyLastBytes, int hash) {
+ var bytes = keyLen << 2;
+ long k = UNSAFE.allocateMemory(bytes);
+ UNSAFE.copyMemory(keyAddr, k, bytes);
+ return new Stats(k, keyLen, keyLastBytes, hash);
}
private static byte[] readFooter(ByteBuffer buffer, int lineStart) { // read from line start to current pos (end-of-input)
@@ -281,15 +336,19 @@ public class CalculateAverage_ebarlas {
}
private static class Stats { // min, max, and sum values are modeled with integral types that represent tenths of a unit
- final byte[] key;
+ final long keyAddr; // address of 4-byte integer array
+ final int keyLen; // number of 4-byte integers starting at address
+ final int lastBytes; // number of bytes packed into last key int (1, 2, 3 or 4)
final int hash;
int min = Integer.MAX_VALUE;
int max = Integer.MIN_VALUE;
long sum;
long count;
- Stats(byte[] key, int hash) {
- this.key = key;
+ Stats(long keyAddr, int keyLen, int lastBytes, int hash) {
+ this.keyAddr = keyAddr;
+ this.keyLen = keyLen;
+ this.lastBytes = lastBytes;
this.hash = hash;
}
}