aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java
diff options
context:
space:
mode:
authorPeter Levart <peter.levart@gmail.com>2024-01-31 18:07:56 +0100
committerGitHub <noreply@github.com>2024-01-31 18:07:56 +0100
commit3cc4fc85d83122eba8944036691d00e195b6aa57 (patch)
treec35426fbc3c08c6aa65c7e615b4053b668b8a1b7 /src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java
parentc5b7b19e57624d2c510acc9efe7d5b2884dec9e8 (diff)
update1: restructuring for better compilation (#661)
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java165
1 files changed, 81 insertions, 84 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java b/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java
index fd42d45..80c9e89 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_plevart.java
@@ -29,6 +29,7 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Comparator;
+import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@@ -43,9 +44,10 @@ public class CalculateAverage_plevart {
private static final int INITIAL_TABLE_CAPACITY = 8192;
public static void main(String[] args) throws IOException {
- var arena = Arena.global();
+ System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0");
try (
- var channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ)) {
+ var channel = (FileChannel) Files.newByteChannel(FILE, StandardOpenOption.READ);
+ var arena = Arena.ofShared()) {
var segment = channel.map(FileChannel.MapMode.READ_ONLY, 0, Files.size(FILE), arena);
int regions = Runtime.getRuntime().availableProcessors();
IntStream
@@ -54,7 +56,6 @@ public class CalculateAverage_plevart {
.mapToObj(r -> calculateRegion(segment, regions, r))
.reduce(StatsTable::reduce)
.ifPresent(System.out::println);
- segment.unload();
}
}
@@ -68,14 +69,12 @@ public class CalculateAverage_plevart {
end = skipPastNl(segment, end);
}
- var stats = new StatsTable(segment, INITIAL_TABLE_CAPACITY);
- calculateAdjustedRegion(segment, start, end, stats);
- return stats;
+ return calculateAdjustedRegion(segment, start, end);
}
private static long skipPastNl(MemorySegment segment, long i) {
int skipped = 0;
- while (skipped++ < MAX_LINE_LEN && getByte(segment, i++) != '\n') {
+ while (skipped++ < MAX_LINE_LEN && segment.get(ValueLayout.JAVA_BYTE, i++) != '\n') {
}
if (skipped > MAX_LINE_LEN) {
throw new IllegalArgumentException(
@@ -84,27 +83,28 @@ public class CalculateAverage_plevart {
return i;
}
- private static void calculateAdjustedRegion(MemorySegment segment, long start, long end, StatsTable stats) {
+ private static StatsTable calculateAdjustedRegion(MemorySegment segment, long start, long end) {
+ var stats = new StatsTable(segment, INITIAL_TABLE_CAPACITY);
+
var species = ByteVector.SPECIES_PREFERRED;
- long speciesByteSize = species.vectorByteSize();
long cityStart = start, numberStart = 0;
int cityLen = 0;
for (long i = start, j = i; i < end; j = i) {
long semiNlSet;
- if (end - i >= speciesByteSize) {
+ if (end - i >= species.vectorByteSize()) {
var vec = ByteVector.fromMemorySegment(species, segment, i, ByteOrder.nativeOrder());
semiNlSet = vec.compare(VectorOperators.EQ, (byte) ';')
.or(vec.compare(VectorOperators.EQ, (byte) '\n'))
.toLong();
- i += speciesByteSize;
+ i += species.vectorByteSize();
}
else { // tail, smaller than speciesByteSize
semiNlSet = 0;
long mask = 1;
while (i < end && mask != 0) {
- int c = getByte(segment, i++);
+ int c = segment.get(ValueLayout.JAVA_BYTE, i++);
if (c == '\n' || c == ';') {
semiNlSet |= mask;
}
@@ -120,63 +120,17 @@ public class CalculateAverage_plevart {
}
else { // nl
int numberLen = (int) (j - numberStart);
- calculateEntry(segment, cityStart, cityLen, numberStart, numberLen, stats);
+ stats.calculateEntry(cityStart, cityLen, numberStart, numberLen);
cityStart = ++j;
numberStart = 0;
}
}
}
- }
- private static void calculateEntry(MemorySegment segment, long cityStart, int cityLen, long numberStart, int numberLen, StatsTable stats) {
- int hash = StatsTable.hash(segment, cityStart, cityLen);
- int number = parseNumber(segment, numberStart, numberLen);
- stats.aggregate(cityStart, cityLen, hash, 1, number, number, number);
- }
-
- private static int parseNumber(MemorySegment segment, long off, int len) {
- int c0 = getByte(segment, off);
- int d0;
- int sign;
- if (c0 == '-') {
- off++;
- len--;
- d0 = getByte(segment, off) - '0';
- sign = -1;
- } else {
- d0 = c0 - '0';
- sign = 1;
- }
- return sign * switch (len) {
- case 1 -> d0 * 10; // 9
- case 2 -> {
- int d1 = getByte(segment, off + 1) - '0';
- yield d0 * 100 + d1 * 10; // 99
- }
- case 3 -> {
- int d2 = getByte(segment, off + 2) - '0';
- yield d0 * 10 + d2; // 9.9
- }
- case 4 -> {
- int d1 = getByte(segment, off + 1) - '0';
- int d3 = getByte(segment, off + 3) - '0';
- yield d0 * 100 + d1 * 10 + d3; // 99.9
- }
- default -> {
- throw new IllegalArgumentException("Invalid number: " + getString(segment, off, len));
- }
- };
- }
-
- private static int getByte(MemorySegment segment, long off) {
- return segment.get(ValueLayout.JAVA_BYTE, off);
- }
-
- private static String getString(MemorySegment segment, long off, int len) {
- return new String(segment.asSlice(off, len).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8);
+ return stats;
}
- final static class StatsTable implements Cloneable {
+ final static class StatsTable {
private static final int LOAD_FACTOR = 16;
// offsets of fields
private static final int _lenHash = 0,
@@ -190,7 +144,7 @@ public class CalculateAverage_plevart {
private long[] table;
StatsTable(MemorySegment segment, int capacity) {
- this.segment = segment;
+ this.segment = Objects.requireNonNull(segment);
int pow2cap = Integer.highestOneBit(capacity);
if (pow2cap < capacity) {
pow2cap <<= 1;
@@ -199,6 +153,13 @@ public class CalculateAverage_plevart {
this.table = new long[idx(pow2cap)];
}
+ private StatsTable(StatsTable st) {
+ this.segment = st.segment;
+ this.pow2cap = st.pow2cap;
+ this.loadedSize = st.loadedSize;
+ this.table = st.table;
+ }
+
private static int idx(int i) {
return i << 3;
}
@@ -237,7 +198,49 @@ public class CalculateAverage_plevart {
}
}
- static int hash(MemorySegment segment, long off, int len) {
+ void calculateEntry(long cityStart, int cityLen, long numberStart, int numberLen) {
+ int hash = hash(cityStart, cityLen);
+ int number = parseNumber(numberStart, numberLen);
+ aggregate(cityStart, cityLen, hash, 1, number, number, number);
+ }
+
+ int parseNumber(long off, int len) {
+ int c0 = segment.get(ValueLayout.JAVA_BYTE, off);
+ int d0;
+ int sign;
+ if (c0 == '-') {
+ off++;
+ len--;
+ d0 = segment.get(ValueLayout.JAVA_BYTE, off) - '0';
+ sign = -1;
+ } else {
+ d0 = c0 - '0';
+ sign = 1;
+ }
+ return sign * switch (len) {
+ case 1 -> d0 * 10; // 9
+ case 2 -> {
+ int d1 = segment.get(ValueLayout.JAVA_BYTE, off + 1) - '0';
+ yield d0 * 100 + d1 * 10; // 99
+ }
+ case 3 -> {
+ int d2 = segment.get(ValueLayout.JAVA_BYTE, off + 2) - '0';
+ yield d0 * 10 + d2; // 9.9
+ }
+ case 4 -> {
+ int d1 = segment.get(ValueLayout.JAVA_BYTE, off + 1) - '0';
+ int d3 = segment.get(ValueLayout.JAVA_BYTE, off + 3) - '0';
+ yield d0 * 100 + d1 * 10 + d3; // 99.9
+ }
+ default ->
+ throw new IllegalArgumentException(
+ "Invalid number: " +
+ new String(segment.asSlice(off, len).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8)
+ );
+ };
+ }
+
+ int hash(long off, int len) {
if (len > Integer.BYTES) {
int head = segment.get(ValueLayout.JAVA_INT_UNALIGNED, off);
int tail = segment.get(ValueLayout.JAVA_INT_UNALIGNED, off + len - Integer.BYTES);
@@ -251,7 +254,11 @@ public class CalculateAverage_plevart {
}
}
- static boolean equals(MemorySegment segment, long off1, long off2, int len) {
+ private static boolean bothLessThan(long a, long b, long threshold) {
+ return (a < threshold) && (b < threshold);
+ }
+
+ boolean equals(long off1, long off2, int len) {
while (len >= Long.BYTES) {
if (segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off1) != segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off2)) {
return false;
@@ -261,16 +268,16 @@ public class CalculateAverage_plevart {
len -= Long.BYTES;
}
// still enough memory to compare two longs, but masked?
- if (Math.max(off1, off2) + Long.BYTES <= segment.byteSize()) {
+ if (bothLessThan(off1, off2, segment.byteSize() - Long.BYTES + 1)) {
long mask = LEN_LONG_MASK[len];
return (segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off1) & mask) == (segment.get(ValueLayout.JAVA_LONG_UNALIGNED, off2) & mask);
}
else {
- return equalsAtBorder(segment, off1, off2, len);
+ return equalsAtBorder(off1, off2, len);
}
}
- private static boolean equalsAtBorder(MemorySegment segment, long off1, long off2, int len) {
+ private boolean equalsAtBorder(long off1, long off2, int len) {
if (len > Integer.BYTES) {
if (segment.get(ValueLayout.JAVA_INT_UNALIGNED, off1) != segment.get(ValueLayout.JAVA_INT_UNALIGNED, off2)) {
return false;
@@ -290,7 +297,7 @@ public class CalculateAverage_plevart {
// key
long off, int len, int hash,
// value
- long count, long sum, long min, long max) {
+ long count, long sum, int min, int max) {
long lenHash = lenHash(len, hash);
int mask = pow2cap - 1;
for (int i = hash & mask, probe = 0; probe < pow2cap; i = (i + 1) & mask, probe++) {
@@ -309,11 +316,11 @@ public class CalculateAverage_plevart {
}
return;
}
- if (lenHash_i == lenHash && equals(segment, table[idx + _off], off, len)) {
+ if (lenHash_i == lenHash && equals(off, table[idx + _off], len)) {
table[idx + _count] += count;
table[idx + _sum] += sum;
- table[idx + _min] = Math.min(min, table[idx + _min]);
- table[idx + _max] = Math.max(max, table[idx + _max]);
+ table[idx + _min] = Math.min(min, (int) table[idx + _min]);
+ table[idx + _max] = Math.max(max, (int) table[idx + _max]);
return;
}
}
@@ -325,7 +332,7 @@ public class CalculateAverage_plevart {
throw new OutOfMemoryError("StatsTable capacity exceeded");
}
else {
- var oldStats = clone();
+ var oldStats = new StatsTable(this);
pow2cap <<= 1;
table = new long[idx(pow2cap)];
loadedSize = 0;
@@ -333,16 +340,6 @@ public class CalculateAverage_plevart {
}
}
- @Override
- protected StatsTable clone() {
- try {
- return (StatsTable) super.clone();
- }
- catch (CloneNotSupportedException e) {
- throw new InternalError(e);
- }
- }
-
StatsTable reduce(StatsTable other) {
other
.idxStream()
@@ -353,8 +350,8 @@ public class CalculateAverage_plevart {
hash(other.table[idx + _lenHash]),
other.table[idx + _count],
other.table[idx + _sum],
- other.table[idx + _min],
- other.table[idx + _max]));
+ (int) other.table[idx + _min],
+ (int) other.table[idx + _max]));
return this;
}