aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java191
1 files changed, 110 insertions, 81 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
index fe487fc..51ea415 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
@@ -155,39 +155,52 @@ public class CalculateAverage_mtopolnik {
}
}
+ private static final int MAX_TEMPERATURE_LEN = 5;
+ private static final int MAX_ROW_LEN = MAX_NAME_LEN + 1 + MAX_TEMPERATURE_LEN + 1;
+ private static final long DANGER_ZONE_LENGTH = ((MAX_ROW_LEN - 1) / 8 * 8 + 8);
+
private void processChunk() {
while (cursor < inputSize) {
+ boolean withinSafeZone;
long word1;
long word2;
- if (cursor + 2 * Long.BYTES <= inputSize) {
- word1 = UNSAFE.getLong(inputBase + cursor);
- word2 = UNSAFE.getLong(inputBase + cursor + Long.BYTES);
+ long nameLen;
+ long nameStartAddress = inputBase + cursor;
+ if (cursor + DANGER_ZONE_LENGTH <= inputSize) {
+ withinSafeZone = true;
+ word1 = UNSAFE.getLong(nameStartAddress);
+ word2 = UNSAFE.getLong(nameStartAddress + Long.BYTES);
+ nameLen = nameLen(word1, word2, withinSafeZone);
+ word1 = maskWord(word1, nameLen);
+ word2 = maskWord(word2, nameLen - Long.BYTES);
}
else {
+ withinSafeZone = false;
UNSAFE.putLong(nameBufBase, 0);
UNSAFE.putLong(nameBufBase + Long.BYTES, 0);
- UNSAFE.copyMemory(inputBase + cursor, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor));
+ UNSAFE.copyMemory(nameStartAddress, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor));
word1 = UNSAFE.getLong(nameBufBase);
word2 = UNSAFE.getLong(nameBufBase + Long.BYTES);
+ nameLen = nameLen(word1, word2, withinSafeZone);
}
- long posOfSemicolon = posOfSemicolon(word1, word2);
- word1 = maskWord(word1, posOfSemicolon - cursor);
- word2 = maskWord(word2, posOfSemicolon - cursor - Long.BYTES);
long hash = hash(word1);
- long namePos = cursor;
- long nameLen = posOfSemicolon - cursor;
- assert nameLen <= 100 : "nameLen > 100";
- int temperature = parseTemperatureAndAdvanceCursor(posOfSemicolon);
- updateStats(hash, namePos, nameLen, word1, word2, temperature);
+ assert nameLen > 0 && nameLen <= 100 : nameLen;
+ long tempStartAddress = nameStartAddress + nameLen + 1;
+ int temperature = withinSafeZone
+ ? parseTemperatureSwarAndAdvanceCursor(tempStartAddress)
+ : parseTemperatureSimpleAndAdvanceCursor(tempStartAddress);
+ updateStats(hash, nameStartAddress, nameLen, word1, word2, temperature, withinSafeZone);
}
}
- private void updateStats(long hash, long namePos, long nameLen, long nameWord1, long nameWord2, int temperature) {
+ private void updateStats(
+ long hash, long nameStartAddress, long nameLen, long nameWord1, long nameWord2,
+ int temperature, boolean withinSafeZone) {
int tableIndex = (int) (hash & TABLE_INDEX_MASK);
while (true) {
stats.gotoIndex(tableIndex);
- if (stats.hash() == hash && stats.nameLen() == nameLen
- && nameEquals(stats.nameAddress(), inputBase + namePos, nameLen, nameWord1, nameWord2)) {
+ if (stats.hash() == hash && stats.nameLen() == nameLen && nameEquals(
+ stats.nameAddress(), nameStartAddress, nameLen, nameWord1, nameWord2, withinSafeZone)) {
stats.setSum(stats.sum() + temperature);
stats.setCount(stats.count() + 1);
stats.setMin((short) Integer.min(stats.min(), temperature));
@@ -204,72 +217,58 @@ public class CalculateAverage_mtopolnik {
stats.setCount(1);
stats.setMin((short) temperature);
stats.setMax((short) temperature);
- UNSAFE.copyMemory(inputBase + namePos, stats.nameAddress(), nameLen);
+ UNSAFE.copyMemory(nameStartAddress, stats.nameAddress(), nameLen);
return;
}
}
- private int parseTemperatureAndAdvanceCursor(long semicolonPos) {
- long startOffset = semicolonPos + 1;
- if (startOffset <= inputSize - Long.BYTES) {
- return parseTemperatureSwarAndAdvanceCursor(startOffset);
- }
- return parseTemperatureSimpleAndAdvanceCursor(startOffset);
- }
-
// Credit: merykitty
- private int parseTemperatureSwarAndAdvanceCursor(long startOffset) {
- long word = UNSAFE.getLong(inputBase + startOffset);
+ private int parseTemperatureSwarAndAdvanceCursor(long tempStartAddress) {
+ long word = UNSAFE.getLong(tempStartAddress);
final long negated = ~word;
final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000);
+ cursor = (tempStartAddress + (dotPos / 8) + 3) - inputBase;
final long signed = (negated << 59) >> 63;
final long removeSignMask = ~(signed & 0xFF);
final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- final int temperature = (int) ((absValue ^ signed) - signed);
- cursor = startOffset + (dotPos / 8) + 3;
- return temperature;
+ return (int) ((absValue ^ signed) - signed);
}
- private int parseTemperatureSimpleAndAdvanceCursor(long startOffset) {
+ private int parseTemperatureSimpleAndAdvanceCursor(long tempStartAddress) {
final byte minus = (byte) '-';
final byte zero = (byte) '0';
final byte dot = (byte) '.';
- // Temperature plus the following newline is at least 4 chars, so this is always safe:
- int fourCh = UNSAFE.getInt(inputBase + startOffset);
- final int mask = 0xFF;
- byte ch = (byte) (fourCh & mask);
- int shift = 0;
+ byte ch = UNSAFE.getByte(tempStartAddress);
+ long address = tempStartAddress;
int temperature;
int sign;
if (ch == minus) {
sign = -1;
- shift += 8;
- ch = (byte) ((fourCh & (mask << shift)) >>> shift);
+ address++;
+ ch = UNSAFE.getByte(address);
}
else {
sign = 1;
}
temperature = ch - zero;
- shift += 8;
- ch = (byte) ((fourCh & (mask << shift)) >>> shift);
+ address++;
+ ch = UNSAFE.getByte(address);
if (ch == dot) {
- shift += 8;
- ch = (byte) ((fourCh & (mask << shift)) >>> shift);
+ address++;
+ ch = UNSAFE.getByte(address);
}
else {
temperature = 10 * temperature + (ch - zero);
- shift += 16;
- // The last character may be past the four loaded bytes, load it from memory.
- // Checking that with another `if` is self-defeating for performance.
- ch = UNSAFE.getByte(inputBase + startOffset + (shift / 8));
+ address += 2;
+ ch = UNSAFE.getByte(address);
}
temperature = 10 * temperature + (ch - zero);
- // `shift` holds the number of bits in the temperature field.
+ // address - inputBase is the length of the temperature field.
// A newline character follows the temperature, and so we advance
// the cursor past the newline to the start of the next line.
- cursor = startOffset + (shift / 8) + 2;
+ cursor = (address + 2) - inputBase;
return sign * temperature;
}
@@ -286,15 +285,27 @@ public class CalculateAverage_mtopolnik {
return hash;
}
- private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2) {
+ private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2,
+ boolean withinSafeZone) {
boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr);
boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES);
- if (mismatch1 | mismatch2) {
- return false;
+ if (len <= 2 * Long.BYTES) {
+ return !(mismatch1 | mismatch2);
}
- for (int i = 2 * Long.BYTES; i < len; i++) {
- if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) {
- return false;
+ if (withinSafeZone) {
+ int i = 2 * Long.BYTES;
+ for (; i <= len - Long.BYTES; i += Long.BYTES) {
+ if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) {
+ return false;
+ }
+ }
+ return maskWord(UNSAFE.getLong(inputAddr + i), len - i) == UNSAFE.getLong(statsAddr + i);
+ }
+ else {
+ for (int i = 2 * Long.BYTES; i < len; i++) {
+ if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) {
+ return false;
+ }
}
}
return true;
@@ -311,44 +322,62 @@ public class CalculateAverage_mtopolnik {
// Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/
// and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169
- long posOfSemicolon(long word1, long word2) {
- long diff = word1 ^ BROADCAST_SEMICOLON;
- long matchBits1 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
- diff = word2 ^ BROADCAST_SEMICOLON;
- long matchBits2 = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
- if ((matchBits1 | matchBits2) != 0) {
- int trailing1 = Long.numberOfTrailingZeros(matchBits1);
- int match1IsNonZero = trailing1 & 63;
- match1IsNonZero |= match1IsNonZero >>> 3;
- match1IsNonZero |= match1IsNonZero >>> 1;
- match1IsNonZero |= match1IsNonZero >>> 1;
- // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to
- // raise the lowest bit in traling2 if trailing1 is nonzero. This forces
- // trailing2 to be zero if trailing1 is non-zero.
- int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63;
- return cursor + ((trailing1 | trailing2) >> 3);
+ long nameLen(long word1, long word2, boolean withinSafeZone) {
+ {
+ long matchBits1 = matchBits(word1);
+ long matchBits2 = matchBits(word2);
+ if ((matchBits1 | matchBits2) != 0) {
+ int trailing1 = Long.numberOfTrailingZeros(matchBits1);
+ int match1IsNonZero = trailing1 & 63;
+ match1IsNonZero |= match1IsNonZero >>> 3;
+ match1IsNonZero |= match1IsNonZero >>> 1;
+ match1IsNonZero |= match1IsNonZero >>> 1;
+ // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to
+ // raise the lowest bit in trailing2 if trailing1 is nonzero. This forces
+ // trailing2 to be zero if trailing1 is non-zero.
+ int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63;
+ // trailing1 | trailing2 works like trailing1 + trailing2 because if trailing2 is non-zero,
+ // then trailing1 is 64, and since trailing2 is < 64, there's no bit overlap.
+ return (trailing1 | trailing2) >> 3;
+ }
}
- long offset = cursor + 2 * Long.BYTES;
- for (; offset <= inputSize - Long.BYTES; offset += Long.BYTES) {
- var block = UNSAFE.getLong(inputBase + offset);
- diff = block ^ BROADCAST_SEMICOLON;
- long matchBits = (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
- if (matchBits != 0) {
- return offset + Long.numberOfTrailingZeros(matchBits) / 8;
+ long nameStartAddress = inputBase + cursor;
+ long address = nameStartAddress + 2 * Long.BYTES;
+ long limit = inputBase + inputSize;
+ if (withinSafeZone) {
+ for (; address < limit; address += Long.BYTES) {
+ var block = maskWord(UNSAFE.getLong(address), limit - address);
+ long matchBits = matchBits(block);
+ if (matchBits != 0) {
+ return address + (Long.numberOfTrailingZeros(matchBits) >> 3) - nameStartAddress;
+ }
}
+ throw new RuntimeException("Semicolon not found");
}
- return posOfSemicolonSimple(offset);
+ return addrOfSemicolonSafe(address, limit) - nameStartAddress;
}
- private long posOfSemicolonSimple(long offset) {
- for (; offset < inputSize; offset++) {
- if (UNSAFE.getByte(inputBase + offset) == SEMICOLON) {
- return offset;
+ private static long addrOfSemicolonSafe(long address, long limit) {
+ for (; address < limit - Long.BYTES + 1; address += Long.BYTES) {
+ var block = UNSAFE.getLong(address);
+ long matchBits = matchBits(block);
+ if (matchBits != 0) {
+ return address + (Long.numberOfTrailingZeros(matchBits) >> 3);
+ }
+ }
+ for (; address < limit; address++) {
+ if (UNSAFE.getByte(address) == SEMICOLON) {
+ return address;
}
}
throw new RuntimeException("Semicolon not found");
}
+ private static long matchBits(long word) {
+ long diff = word ^ BROADCAST_SEMICOLON;
+ return (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
+ }
+
// Copies the results from native memory to Java heap and puts them into the results array.
private void exportResults() {
var exportedStats = new ArrayList<StationStats>(10_000);