aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
authorDr Ian Preston <157221403+ianopolousfast@users.noreply.github.com>2024-01-23 15:37:33 +0000
committerGitHub <noreply@github.com>2024-01-23 16:37:33 +0100
commit8bae1b87810f75ddf307ba0b84400d97e3e6f851 (patch)
treefde952f84b21a10053fdd83247a278560982e5a1 /src/main/java
parentf7febea2f6277263665365a4cbd0b36343159245 (diff)
Use simd for name comparison (#568)
Co-authored-by: Ian Preston <ianopolous@protonmail.com>
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java119
1 files changed, 32 insertions, 87 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
index 8944a47..f1b4e7b 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
@@ -34,7 +34,7 @@ import static java.lang.foreign.ValueLayout.*;
/* A fast implementation with no unsafe.
* Features:
* * memory mapped file using preview Arena FFI
- * * semicolon finding using incubator vector api
+ * * semicolon finding and name comparison using incubator vector api
* * read chunks in parallel
* * minimise allocation
* * no unsafe
@@ -80,12 +80,11 @@ public class CalculateAverage_ianopolousfast {
System.out.println(merged);
}
- public static boolean matchingStationBytes(long start, long end, int offset, MemorySegment buffer, Stat existing) {
- int len = (int) (end - start);
- if (len != existing.name.length)
- return false;
- for (int i = offset; i < len; i++) {
- if (existing.name[i] != buffer.get(JAVA_BYTE, offset + start++))
+ public static boolean matchingStationBytes(long start, long end, MemorySegment buffer, Stat existing) {
+ for (int index = 0; index < end - start; index += BYTE_SPECIES.vectorByteSize()) {
+ ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, start + index, ByteOrder.nativeOrder(), BYTE_SPECIES.indexInRange(start + index, end));
+ ByteVector found = ByteVector.fromArray(BYTE_SPECIES, existing.name, index);
+ if (!found.eq(line).allTrue())
return false;
}
return true;
@@ -98,21 +97,19 @@ public class CalculateAverage_ianopolousfast {
return (finalHash & (len - 1));
}
- public static Stat parseStation(long start, long end, long first8, long second8,
- MemorySegment buffer) {
+ public static Stat createStation(long start, long end, MemorySegment buffer) {
byte[] stationBuffer = new byte[(int) (end - start)];
for (long off = start; off < end; off++)
stationBuffer[(int) (off - start)] = buffer.get(JAVA_BYTE, off);
- return new Stat(stationBuffer, first8, second8);
+ return new Stat(stationBuffer);
}
- public static Stat dedupeStation(long start, long end, long hash, long first8, long second8,
- MemorySegment buffer, List<List<Stat>> stations) {
+ public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, List<List<Stat>> stations) {
int index = hashToIndex(hash, MAX_STATIONS);
List<Stat> matches = stations.get(index);
if (matches == null) {
List<Stat> value = new ArrayList<>();
- Stat res = parseStation(start, end, first8, second8, buffer);
+ Stat res = createStation(start, end, buffer);
value.add(res);
stations.set(index, value);
return res;
@@ -120,54 +117,10 @@ public class CalculateAverage_ianopolousfast {
else {
for (int i = 0; i < matches.size(); i++) {
Stat s = matches.get(i);
- if (first8 == s.first8 && second8 == s.second8 && matchingStationBytes(start, end, 16, buffer, s))
+ if (matchingStationBytes(start, end, buffer, s))
return s;
}
- Stat res = parseStation(start, end, first8, second8, buffer);
- matches.add(res);
- return res;
- }
- }
-
- public static Stat dedupeStation8(long start, long end, long hash, long first8, MemorySegment buffer, List<List<Stat>> stations) {
- int index = hashToIndex(hash, MAX_STATIONS);
- List<Stat> matches = stations.get(index);
- if (matches == null) {
- List<Stat> value = new ArrayList<>();
- Stat station = parseStation(start, end, first8, 0, buffer);
- value.add(station);
- stations.set(index, value);
- return station;
- }
- else {
- for (int i = 0; i < matches.size(); i++) {
- Stat s = matches.get(i);
- if (first8 == s.first8 && s.name.length <= 8)
- return s;
- }
- Stat station = parseStation(start, end, first8, 0, buffer);
- matches.add(station);
- return station;
- }
- }
-
- public static Stat dedupeStation16(long start, long end, long hash, long first8, long second8, MemorySegment buffer, List<List<Stat>> stations) {
- int index = hashToIndex(hash, MAX_STATIONS);
- List<Stat> matches = stations.get(index);
- if (matches == null) {
- List<Stat> value = new ArrayList<>();
- Stat res = parseStation(start, end, first8, second8, buffer);
- value.add(res);
- stations.set(index, value);
- return res;
- }
- else {
- for (int i = 0; i < matches.size(); i++) {
- Stat s = matches.get(i);
- if (first8 == s.first8 && second8 == s.second8 && s.name.length <= 16)
- return s;
- }
- Stat res = parseStation(start, end, first8, second8, buffer);
+ Stat res = createStation(start, end, buffer);
matches.add(res);
return res;
}
@@ -181,32 +134,22 @@ public class CalculateAverage_ianopolousfast {
ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
+ long first8 = buffer.get(LONG_LAYOUT, lineStart);
if (keySize == BYTE_SPECIES.vectorByteSize()) {
while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
keySize++;
}
- long first8 = buffer.get(LONG_LAYOUT, lineStart);
- if (keySize < 8)
- return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
- if (keySize < 16)
- return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
long hash = first8 ^ second8; // todo include other bytes
- return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
+ return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
}
- long first8 = buffer.get(LONG_LAYOUT, lineStart);
if (keySize <= 8) {
first8 = maskHighBytes(first8, keySize & 0x07);
- return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
- }
- long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
- if (keySize < 16) {
- second8 = maskHighBytes(second8, keySize & 0x07);
- return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
}
+ long second8 = keySize <= 8 ? 0 : maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07);
long hash = first8 ^ second8; // todo include later bytes
- return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
+ return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations);
}
public static int getDot(long d) {
@@ -261,13 +204,10 @@ public class CalculateAverage_ianopolousfast {
// in the inner loop (reducing branches)
// We need at least the vector lane size bytes back
if (endByte == buffer.byteSize()) {
- endByte -= 1; // skip final new line
// reverse at least vector lane width
- while (endByte > 0 && buffer.byteSize() - endByte < BYTE_SPECIES.vectorByteSize()) {
+ endByte = Math.max(buffer.byteSize() - BYTE_SPECIES.vectorByteSize(), 0);
+ while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
endByte--;
- while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
- endByte--;
- }
if (endByte > 0)
endByte++;
@@ -278,28 +218,33 @@ public class CalculateAverage_ianopolousfast {
int index = 0;
while (endByte + index < buffer.byteSize()) {
Stat station = parseStation(index, end, stations);
- index = (int) processTemperature(index + station.name.length + 1, end, station);
+ index = (int) processTemperature(index + station.namelen + 1, end, station);
}
}
+ innerloop(startByte, endByte, buffer, stations);
+ return stations;
+ }
+
+ private static void innerloop(long startByte, long endByte, MemorySegment buffer, List<List<Stat>> stations) {
while (startByte < endByte) {
Stat station = parseStation(startByte, buffer, stations);
- startByte = processTemperature(startByte + station.name.length + 1, buffer, station);
+ startByte = processTemperature(startByte + station.namelen + 1, buffer, station);
}
- return stations;
}
public static class Stat {
final byte[] name;
+ final int namelen;
int count = 0;
short min = Short.MAX_VALUE, max = Short.MIN_VALUE;
long total = 0;
- final long first8, second8;
- public Stat(byte[] name, long first8, long second8) {
- this.name = name;
- this.first8 = first8;
- this.second8 = second8;
+ public Stat(byte[] name) {
+ int vecSize = BYTE_SPECIES.vectorByteSize();
+ int arrayLen = (name.length + vecSize - 1) / vecSize * vecSize;
+ this.name = Arrays.copyOfRange(name, 0, arrayLen);
+ this.namelen = name.length;
}
public void add(short value) {
@@ -326,7 +271,7 @@ public class CalculateAverage_ianopolousfast {
}
public String name() {
- return new String(name);
+ return new String(Arrays.copyOfRange(name, 0, namelen));
}
public String toString() {