aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java125
1 files changed, 53 insertions, 72 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java
index 5c43824..a7df56e 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_roman_r_m.java
@@ -33,37 +33,35 @@ public class CalculateAverage_roman_r_m {
private static Unsafe UNSAFE;
- // based on http://0x80.pl/notesen/2023-03-06-swar-find-any.html
- static long hasZeroByte(long l) {
- return ((l - 0x0101010101010101L) & ~(l) & 0x8080808080808080L);
- }
-
- static long firstSetByteIndex(long l) {
- return ((((l - 1) & 0x101010101010101L) * 0x101010101010101L) >> 56) - 1;
- }
-
- static long broadcast(byte b) {
+ private static long broadcast(byte b) {
return 0x101010101010101L * b;
}
- static long SEMICOLON_MASK = broadcast((byte) ';');
- static long LINE_END_MASK = broadcast((byte) '\n');
+ private static final long SEMICOLON_MASK = broadcast((byte) ';');
+ private static final long LINE_END_MASK = broadcast((byte) '\n');
+ private static final long DOT_MASK = broadcast((byte) '.');
+
+ // from netty
- static long find(long l, long mask) {
- long xor = l ^ mask;
- long match = hasZeroByte(xor);
- return match != 0 ? firstSetByteIndex(match) : -1;
+ /**
+ * Applies a compiled pattern to given word.
+ * Returns a word where each byte that matches the pattern has the highest bit set.
+ */
+ private static long applyPattern(final long word, final long pattern) {
+ long input = word ^ pattern;
+ long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL;
+ return ~(tmp | input | 0x7F7F7F7F7F7F7F7FL);
}
static long nextNewline(long from, MemorySegment ms) {
long start = from;
long i;
long next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, start);
- while ((i = find(next, LINE_END_MASK)) < 0) {
+ while ((i = applyPattern(next, LINE_END_MASK)) == 0) {
start += 8;
next = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, start);
}
- return start + i;
+ return start + Long.numberOfTrailingZeros(i) / 8;
}
static class Worker {
@@ -84,55 +82,53 @@ public class CalculateAverage_roman_r_m {
private void parseName(ByteString station) {
long start = offset;
- long pos = -1;
-
- while (end - offset > 8) {
- long next = UNSAFE.getLong(offset);
- pos = find(next, SEMICOLON_MASK);
- if (pos >= 0) {
- offset += pos;
- break;
- }
- else {
- offset += 8;
- }
- }
- if (pos < 0) {
- while (UNSAFE.getByte(offset++) != ';') {
- }
- offset--;
+ long pattern;
+ long next = UNSAFE.getLong(offset);
+ while ((pattern = applyPattern(next, SEMICOLON_MASK)) == 0) {
+ offset += 8;
+ next = UNSAFE.getLong(offset);
}
+ int bytes = Long.numberOfTrailingZeros(pattern) / 8;
+ offset += bytes;
int len = (int) (offset - start);
station.offset = start;
station.len = len;
station.hash = 0;
+ station.tail = next & ((1L << (8 * bytes)) - 1);
offset++;
}
- long parseNumberFast() {
+ int parseNumberFast() {
long encodedVal = UNSAFE.getLong(offset);
- var len = find(encodedVal, LINE_END_MASK);
- offset += len + 1;
+ int neg = 1 - Integer.bitCount((int) (encodedVal & 0x10));
+ encodedVal >>>= 8 * neg;
+
+ var len = applyPattern(encodedVal, DOT_MASK);
+ len = Long.numberOfTrailingZeros(len) / 8;
encodedVal ^= broadcast((byte) 0x30);
- long c0 = len == 4 ? 100 : 10;
- long c1 = 10 * (len - 3);
- long c2 = 4 - len;
- long c3 = len - 3;
- long a = (encodedVal & 0xFF) * c0;
- long b = ((encodedVal & 0xFF00) >>> 8) * c1;
- long c = ((encodedVal & 0xFF0000L) >>> 16) * c2;
- long d = ((encodedVal & 0xFF000000L) >>> 24) * c3;
+ int intPart = (int) (encodedVal & ((1 << (8 * len)) - 1));
+ intPart <<= 8 * (2 - len);
+ intPart *= (100 * 256 + 10);
+ intPart = (intPart & 0x3FF80) >>> 8;
+
+ int frac = (int) ((encodedVal >>> (8 * (len + 1))) & 0xFF);
- return a + b + c + d;
+ offset += neg + len + 3; // 1 for . + 1 for fractional part + 1 for new line char
+ int sign = 1 - 2 * neg;
+ int val = intPart + frac;
+ return sign * val;
}
- long parseNumberSlow() {
- long val = UNSAFE.getByte(offset++) - '0';
+ int parseNumberSlow() {
+ int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10);
+ offset += neg;
+
+ int val = UNSAFE.getByte(offset++) - '0';
byte b;
while ((b = UNSAFE.getByte(offset++)) != '.') {
val = val * 10 + (b - '0');
@@ -140,22 +136,17 @@ public class CalculateAverage_roman_r_m {
b = UNSAFE.getByte(offset);
val = val * 10 + (b - '0');
offset += 2;
+ val *= 1 - 2 * neg;
return val;
}
- long parseNumber() {
- long val;
- int neg = 1 - Integer.bitCount(UNSAFE.getByte(offset) & 0x10);
- offset += neg;
-
- if (end - offset > 8) {
- val = parseNumberFast();
+ int parseNumber() {
+ if (end - offset >= 8) {
+ return parseNumberFast();
}
else {
- val = parseNumberSlow();
+ return parseNumberSlow();
}
- val *= 1 - 2 * neg;
- return val;
}
public TreeMap<String, ResultRow> run() {
@@ -218,6 +209,7 @@ public class CalculateAverage_roman_r_m {
private long offset;
private int len = 0;
private int hash = 0;
+ private long tail = 0L;
ByteString(MemorySegment ms) {
this.ms = ms;
@@ -235,6 +227,7 @@ public class CalculateAverage_roman_r_m {
copy.offset = this.offset;
copy.len = this.len;
copy.hash = this.hash;
+ copy.tail = this.tail;
return copy;
}
@@ -259,19 +252,7 @@ public class CalculateAverage_roman_r_m {
return false;
}
}
- if (len >= 8) {
- long l1 = UNSAFE.getLong(offset + len - 8);
- long l2 = UNSAFE.getLong(that.offset + len - 8);
- return l1 == l2;
- }
- for (; i < len; i++) {
- byte i1 = UNSAFE.getByte(offset + i);
- byte i2 = UNSAFE.getByte(that.offset + i);
- if (i1 != i2) {
- return false;
- }
- }
- return true;
+ return this.tail == that.tail;
}
@Override