aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorDr Ian Preston <157221403+ianopolousfast@users.noreply.github.com>2024-01-20 19:09:40 +0000
committerGitHub <noreply@github.com>2024-01-20 20:09:40 +0100
commit062f2bbecf586d85ff44dec42cc63f94e49bc6b8 (patch)
tree24a98c41a371b8440b1bb4e63e9091c5803f2807 /src/main
parent114ba76d20f946ac6421aff73cd69387b0cb15b7 (diff)
Introducing the vector api. 1s faster on 4 core i7 (#506)
Co-authored-by: Ian Preston <ianopolous@protonmail.com>
Diffstat (limited to 'src/main')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java104
1 files changed, 50 insertions, 54 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
index 4bffe78..8944a47 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
@@ -15,6 +15,10 @@
*/
package dev.morling.onebrc;
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.VectorOperators;
+import jdk.incubator.vector.VectorSpecies;
+
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
@@ -30,19 +34,23 @@ import static java.lang.foreign.ValueLayout.*;
/* A fast implementation with no unsafe.
* Features:
* * memory mapped file using preview Arena FFI
+ * * semicolon finding using incubator vector api
* * read chunks in parallel
* * minimise allocation
* * no unsafe
*
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
* average_baseline: 4m48s
- * ianopolous: 16s
+ * ianopolous: 15s
*/
public class CalculateAverage_ianopolousfast {
public static final int MAX_LINE_LENGTH = 107;
public static final int MAX_STATIONS = 1 << 14;
private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
+ private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 32
+ ? ByteVector.SPECIES_256
+ : ByteVector.SPECIES_128;
public static void main(String[] args) throws Exception {
Arena arena = Arena.global();
@@ -165,58 +173,40 @@ public class CalculateAverage_ianopolousfast {
}
}
- public static long hasSemicolon(long d) {
- // from Hacker's Delight page 92
- d = d ^ 0x3b3b3b3b3b3b3b3bL;
- long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
- return ~(y | d | 0x7f7f7f7f7f7f7f7fL);
- }
-
- public static int getSemicolonIndex(long y) {
- // from Hacker's Delight page 92
- return Long.numberOfLeadingZeros(y) >> 3;
- }
-
static long maskHighBytes(long d, int nbytes) {
return d & (-1L << ((8 - nbytes) * 8));
}
public static Stat parseStation(long lineStart, MemorySegment buffer, List<List<Stat>> stations) {
- // find semicolon and update hash as we go, reading a long at a time
- long d = buffer.get(LONG_LAYOUT, lineStart);
- long hasSemi = hasSemicolon(d);
- if (hasSemi != 0) {
- int semiIndex = getSemicolonIndex(hasSemi);
- d = maskHighBytes(d, semiIndex);
- return dedupeStation8(lineStart, lineStart + semiIndex, d, d, buffer, stations);
- }
- long first8 = d;
- long hash = d;
-
- d = buffer.get(LONG_LAYOUT, lineStart + 8);
- hasSemi = hasSemicolon(d);
- if (hasSemi != 0) {
- int semiIndex = getSemicolonIndex(hasSemi);
- if (semiIndex == 0)
- return dedupeStation8(lineStart, lineStart + 8, first8, first8, buffer, stations);
- d = maskHighBytes(d, semiIndex);
- return dedupeStation16(lineStart, lineStart + 8 + semiIndex, first8 ^ d, first8, d, buffer, stations);
+ ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder());
+ int keySize = line.compare(VectorOperators.EQ, ';').firstTrue();
+
+ if (keySize == BYTE_SPECIES.vectorByteSize()) {
+ while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') {
+ keySize++;
+ }
+ long first8 = buffer.get(LONG_LAYOUT, lineStart);
+ if (keySize < 8)
+ return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
+ long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
+ if (keySize < 16)
+ return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
+ long hash = first8 ^ second8; // todo include other bytes
+ return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
}
- int index = 8;
- long second8 = d;
- while (hasSemi == 0) {
- hash = hash ^ d;
- index += 8;
- d = buffer.get(LONG_LAYOUT, lineStart + index);
- hasSemi = hasSemicolon(d);
+ long first8 = buffer.get(LONG_LAYOUT, lineStart);
+ if (keySize <= 8) {
+ first8 = maskHighBytes(first8, keySize & 0x07);
+ return dedupeStation8(lineStart, lineStart + keySize, first8, first8, buffer, stations);
}
- int semiIndex = getSemicolonIndex(hasSemi);
- d = maskHighBytes(d, semiIndex);
- if (semiIndex > 0) {
- hash = hash ^ d;
+ long second8 = buffer.get(LONG_LAYOUT, lineStart + 8);
+ if (keySize < 16) {
+ second8 = maskHighBytes(second8, keySize & 0x07);
+ return dedupeStation16(lineStart, lineStart + keySize, first8 ^ second8, first8, second8, buffer, stations);
}
- return dedupeStation(lineStart, lineStart + index + semiIndex, hash, first8, second8, buffer, stations);
+ long hash = first8 ^ second8; // todo include later bytes
+ return dedupeStation(lineStart, lineStart + keySize, hash, first8, second8, buffer, stations);
}
public static int getDot(long d) {
@@ -266,24 +256,30 @@ public class CalculateAverage_ianopolousfast {
for (int i = 0; i < MAX_STATIONS; i++)
stations.add(null);
- // Handle reading the very last line in the file
- // this allows us to not worry about reading a long beyond the end
+ // Handle reading the very last few lines in the file
+ // this allows us to not worry about reading beyond the end
// in the inner loop (reducing branches)
- // We only need to read one because the min record size is 6 bytes
- // so 2nd last record must be > 8 from end
+ // We need at least the vector lane size bytes back
if (endByte == buffer.byteSize()) {
- endByte -= 2; // skip final new line
- while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
+ endByte -= 1; // skip final new line
+ // reverse at least vector lane width
+ while (endByte > 0 && buffer.byteSize() - endByte < BYTE_SPECIES.vectorByteSize()) {
endByte--;
+ while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
+ endByte--;
+ }
if (endByte > 0)
endByte++;
- // copy into a 8n sized buffer to avoid reading off end
- MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 4);
+ // copy into a larger buffer to avoid reading off end
+ MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize());
for (long i = endByte; i < buffer.byteSize(); i++)
end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i));
- Stat station = parseStation(0, end, stations);
- processTemperature(station.name.length + 1, end, station);
+ int index = 0;
+ while (endByte + index < buffer.byteSize()) {
+ Stat station = parseStation(index, end, stations);
+ index = (int) processTemperature(index + station.name.length + 1, end, station);
+ }
}
while (startByte < endByte) {