aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java397
-rw-r--r--src/main/java/dev/morling/onebrc/CreateMeasurements3.java2
2 files changed, 222 insertions, 177 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java b/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java
index 7b549dc..7123c2c 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_giovannicuccu.java
@@ -15,10 +15,19 @@
*/
package dev.morling.onebrc;
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.VectorOperators;
+import jdk.incubator.vector.VectorSpecies;
+
import static java.util.stream.Collectors.*;
import java.io.IOException;
import java.io.RandomAccessFile;
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
@@ -31,34 +40,42 @@ import java.util.*;
import java.util.concurrent.*;
/*
- Solution without unsafe that borrows the ideas of splullara, thomasvue, royvanrijn
+ Solution without unsafe that borrows the ideas of splullara, thomasvue, royvanrijn and merykitty
*/
public class CalculateAverage_giovannicuccu {
private static final String FILE = "./measurements.txt";
- public static record PartitionBoundary(long start, long end) {
+ private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_256;
+ private static final int BYTE_SPECIES_LANES = BYTE_SPECIES.length();
+ private static final ByteOrder NATIVE_ORDER = ByteOrder.nativeOrder();
+ public static final VectorSpecies<Integer> INT_SPECIES = IntVector.SPECIES_256;
+ public static final int INT_SPECIES_LANES = INT_SPECIES.length();
+
+ public static final int KEY_SIZE = 128;
+
+ public static record PartitionBoundary(Path path, long start, long end) {
}
public static interface PartitionCalculator {
- PartitionBoundary[] computePartitionsBoundaries(Path path);
+ List<PartitionBoundary> computePartitionsBoundaries(Path path);
}
public static class ProcessorPartitionCalculator implements PartitionCalculator {
- public PartitionBoundary[] computePartitionsBoundaries(Path path) {
+ public List<PartitionBoundary> computePartitionsBoundaries(Path path) {
try {
int numberOfSegments = Runtime.getRuntime().availableProcessors();
long fileSize = path.toFile().length();
long segmentSize = fileSize / numberOfSegments;
- PartitionBoundary[] segmentBoundaries = new PartitionBoundary[numberOfSegments];
+ List<PartitionBoundary> segmentBoundaries = new ArrayList<>(numberOfSegments);
try (RandomAccessFile randomAccessFile = new RandomAccessFile(path.toFile(), "r")) {
long segStart = 0;
long segEnd = segmentSize;
for (int i = 0; i < numberOfSegments; i++) {
segEnd = findEndSegment(randomAccessFile, segEnd, fileSize);
- segmentBoundaries[i] = new PartitionBoundary(segStart, segEnd);
+ segmentBoundaries.add(new PartitionBoundary(path, segStart, segEnd));
segStart = segEnd;
segEnd = Math.min(segEnd + segmentSize, fileSize);
}
@@ -81,51 +98,27 @@ public class CalculateAverage_giovannicuccu {
}
}
- public static class MeasurementAggregator {
- private final int hash;
+ private static class MeasurementAggregatorVectorized {
+
private int min;
private int max;
private double sum;
private long count;
- private final byte[] station;
- private final int offset;
- private final String name;
+ private final int len;
+ private final int hash;
- private final long[] data;
- private final int dataOffset;
+ private final int offset;
+ private byte[] data;
- public MeasurementAggregator(byte[] station, int offset, int hash, int initialValue, long[] data, int dataOffset) {
+ public MeasurementAggregatorVectorized(byte[] data, int offset, int len, int hash, int initialValue) {
min = initialValue;
max = initialValue;
sum = initialValue;
count = 1;
- this.station = station;
- this.offset = offset;
+ this.len = len;
this.hash = hash;
- this.data = data;
- this.dataOffset = dataOffset;
- this.name = new String(station, 0, offset, StandardCharsets.UTF_8);
- }
-
- public MeasurementAggregator(byte[] station, int offset, int hash, int initialValue) {
- min = initialValue;
- max = initialValue;
- sum = initialValue;
- count = 1;
- this.station = station;
this.offset = offset;
- this.hash = hash;
- this.data = new long[0];
- this.dataOffset = 0;
- this.name = new String(station, 0, offset, StandardCharsets.UTF_8);
- }
-
- public boolean hasSameStation(byte[] stationIn, int offsetIn) {
- return Arrays.equals(stationIn, 0, offsetIn, station, 0, offset);
- }
-
- public boolean hasSameStation(long[] dataIn, int offsetIn) {
- return Arrays.equals(dataIn, 0, offsetIn, data, 0, dataOffset);
+ this.data = data;
}
public void add(int value) {
@@ -139,8 +132,7 @@ public class CalculateAverage_giovannicuccu {
count++;
}
- public void merge(MeasurementAggregator other) {
- // System.out.println("min=" +min + " other min=" +other.min);
+ public void merge(MeasurementAggregatorVectorized other) {
min = Math.min(min, other.min);
max = Math.max(max, other.max);
sum += other.sum;
@@ -149,7 +141,7 @@ public class CalculateAverage_giovannicuccu {
@Override
public String toString() {
- return round((double) min / 10) + "/" + round((sum / (double) count) / 10) + "/" + round((double) max / 10);
+ return round(min / 10.) + "/" + round(sum / (double) (10 * count)) + "/" + round(max / 10.);
}
private double round(double value) {
@@ -164,116 +156,141 @@ public class CalculateAverage_giovannicuccu {
return hash;
}
- public String getName() {
- return name;
+ public int getLen() {
+ return len;
+ }
+
+ public boolean dataEquals(byte[] data, int offset) {
+ return Arrays.equals(this.data, this.offset, this.offset + len, data, offset, offset + len);
+
}
- public byte[] getStation() {
- return station;
+ public String getName() {
+ return new String(data, offset, len, StandardCharsets.UTF_8);
}
public int getOffset() {
return offset;
}
- public long[] getData() {
+ public byte[] getData() {
return data;
}
-
}
- public static class MeasurementList {
-
+ private static class MeasurementListVectorized {
private static final int SIZE = 1024 * 64;
- private final MeasurementAggregator[] measurements = new MeasurementAggregator[SIZE];
+ private final MeasurementAggregatorVectorized[] measurements = new MeasurementAggregatorVectorized[SIZE];
+ private final byte[] keyData = new byte[SIZE * KEY_SIZE];
- public void add(byte[] station, int offset, int hash, int value) {
+ private final MemorySegment dataSegment = MemorySegment.ofArray(keyData);
+
+ public void addWithByteVector(ByteVector chunk1, int len, int hash, int value, MemorySegment memorySegment, long offset) {
int index = hash & (SIZE - 1);
- if (measurements[index] == null) {
- measurements[index] = new MeasurementAggregator(station.clone(), offset, hash, value);
- }
- else {
- if (measurements[index].hasSameStation(station, offset)) {
- measurements[index].add(value);
- }
- else {
- while (measurements[index] != null && !measurements[index].hasSameStation(station, offset)) {
- index = (index + 1) & (SIZE - 1);
+ int i = 0;
+ while (measurements[index] != null) {
+ if (measurements[index].getLen() == len && measurements[index].getHash() == hash) {
+ var nodeKey = ByteVector.fromArray(BYTE_SPECIES, keyData, index * KEY_SIZE);
+ long eqMask = chunk1.compare(VectorOperators.EQ, nodeKey).toLong();
+ long validMask = -1L >>> (64 - len);
+ if ((eqMask & validMask) == validMask) {
+ measurements[index].add(value);
+ return;
}
- if (measurements[index] == null) {
- measurements[index] = new MeasurementAggregator(station.clone(), offset, hash, value);
+ }
+ index = (index + 1) & (SIZE - 1);
+ }
+ MemorySegment.copy(memorySegment, offset, dataSegment, (long) index * KEY_SIZE, len);
+ measurements[index] = new MeasurementAggregatorVectorized(keyData, index * KEY_SIZE, len, hash, value);
+ }
+
+ public void add(int len, int hash, int value, MemorySegment memorySegment, long offset) {
+ int index = hash & (SIZE - 1);
+ while (measurements[index] != null) {
+ if (measurements[index].getLen() == len && measurements[index].getHash() == hash) {
+ int i = 0;
+ while (i < len && keyData[index * KEY_SIZE + i] == memorySegment.get(ValueLayout.JAVA_BYTE, offset + i)) {
+ i++;
}
- else {
+ if (i == len) {
measurements[index].add(value);
+ return;
}
}
+ index = (index + 1) & (SIZE - 1);
}
+ MemorySegment.copy(memorySegment, offset, dataSegment, (long) index * KEY_SIZE, len);
+ measurements[index] = new MeasurementAggregatorVectorized(keyData, index * KEY_SIZE, len, hash, value);
}
- public void merge(MeasurementAggregator measurementAggregator) {
- int index = (measurementAggregator.getHash() & (SIZE - 1));
- if (measurements[index] == null) {
- measurements[index] = measurementAggregator;
- }
- else {
- while (measurements[index] != null && !measurements[index].hasSameStation(measurementAggregator.getStation(), measurementAggregator.getOffset())) {
- index = (index + 1) & (SIZE - 1);
- }
- if (measurements[index] == null) {
- measurements[index] = measurementAggregator;
- }
- else {
- measurements[index].merge(measurementAggregator);
+ public void merge(MeasurementAggregatorVectorized measurementAggregator) {
+ int index = measurementAggregator.getHash() & (SIZE - 1);
+ while (measurements[index] != null) {
+ if (measurements[index].getLen() == measurementAggregator.getLen() && measurements[index].getHash() == measurementAggregator.getHash()) {
+ if (measurementAggregator.dataEquals(measurements[index].getData(), measurements[index].getOffset())) {
+ measurements[index].merge(measurementAggregator);
+ return;
+ }
}
+ index = (index + 1) & (SIZE - 1);
}
+ measurements[index] = measurementAggregator;
}
- public MeasurementAggregator[] getMeasurements() {
+ public MeasurementAggregatorVectorized[] getMeasurements() {
return measurements;
}
+
}
- public static class MMapReader {
- private final Path path;
- private final PartitionBoundary[] boundaries;
+ private static class MMapReaderMemorySegment {
+ private final Path path;
+ private final List<PartitionBoundary> boundaries;
private final boolean serial;
+ private static final byte SEPARATOR = ';';
+ ByteVector separators = ByteVector.broadcast(BYTE_SPECIES, SEPARATOR);
+ private static final ValueLayout.OfLong JAVA_LONG_LT = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
- public MMapReader(Path path, PartitionCalculator partitionCalculator, boolean serial) {
+ public MMapReaderMemorySegment(Path path, PartitionCalculator partitionCalculator, boolean serial) {
this.path = path;
this.serial = serial;
boundaries = partitionCalculator.computePartitionsBoundaries(path);
}
- public TreeMap<String, MeasurementAggregator> elaborate() {
- try (ExecutorService executor = Executors.newFixedThreadPool(boundaries.length)) {
- List<Future<MeasurementList>> futures = new ArrayList<>();
+ public TreeMap<String, MeasurementAggregatorVectorized> elaborate() throws IOException {
+ try (ExecutorService executor = Executors.newFixedThreadPool(boundaries.size());
+ FileChannel fileChannel = (FileChannel) Files.newByteChannel((path), StandardOpenOption.READ);
+ var arena = Arena.ofShared()) {
+
+ List<Future<MeasurementListVectorized>> futures = new ArrayList<>();
for (PartitionBoundary boundary : boundaries) {
if (serial) {
- FutureTask<MeasurementList> future = new FutureTask<>(() -> computeListForPartition(boundary.start(), boundary.end()));
+ FutureTask<MeasurementListVectorized> future = new FutureTask<>(() -> computeListForPartition(
+ fileChannel, boundary));
future.run();
- // System.out.println("done with partition " + boundary);
futures.add(future);
}
else {
- Future<MeasurementList> future = executor.submit(() -> computeListForPartition(boundary.start(), boundary.end()));
+ Future<MeasurementListVectorized> future = executor.submit(() -> computeListForPartition(
+ fileChannel, boundary));
futures.add(future);
}
}
- TreeMap<String, MeasurementAggregator> ris = reduce(futures);
+ TreeMap<String, MeasurementAggregatorVectorized> ris = reduce(futures);
return ris;
}
}
- private TreeMap<String, MeasurementAggregator> reduce(List<Future<MeasurementList>> futures) {
+ private TreeMap<String, MeasurementAggregatorVectorized> reduce(List<Future<MeasurementListVectorized>> futures) {
try {
- TreeMap<String, MeasurementAggregator> risMap = new TreeMap<>();
- MeasurementList ris = new MeasurementList();
- for (Future<MeasurementList> future : futures) {
- MeasurementList results = future.get();
+ TreeMap<String, MeasurementAggregatorVectorized> risMap = new TreeMap<>();
+ MeasurementListVectorized ris = new MeasurementListVectorized();
+ for (Future<MeasurementListVectorized> future : futures) {
+ MeasurementListVectorized results = future.get();
merge(ris, results);
}
- for (MeasurementAggregator m : ris.getMeasurements()) {
+ for (MeasurementAggregatorVectorized m : ris.getMeasurements()) {
if (m != null) {
risMap.put(m.getName(), m);
}
@@ -286,101 +303,134 @@ public class CalculateAverage_giovannicuccu {
}
}
- private void merge(MeasurementList result, MeasurementList partial) {
- for (MeasurementAggregator m : partial.getMeasurements()) {
+ private void merge(MeasurementListVectorized result, MeasurementListVectorized partial) {
+ for (MeasurementAggregatorVectorized m : partial.getMeasurements()) {
if (m != null) {
result.merge(m);
}
}
}
- private MeasurementList computeListForPartition(long start, long end) {
- MeasurementList list = new MeasurementList();
- try {
- try (FileChannel fileChannel = (FileChannel) Files.newByteChannel((path), StandardOpenOption.READ)) {
- MappedByteBuffer mappedByteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, start, end - start);
- mappedByteBuffer.order(BYTE_ORDER.LITTLE_ENDIAN);
- int limit = mappedByteBuffer.limit();
- int startLine;
- byte[] stationb = new byte[100];
- while ((startLine = mappedByteBuffer.position()) < limit - 110) {
- int currentPosition = startLine;
- byte b = 0;
- int i = 0;
- int hash = 0;
-
- while ((b = mappedByteBuffer.get(currentPosition++)) != ';') {
- stationb[i++] = b;
- hash = 31 * hash + b;
+ private MeasurementListVectorized computeListForPartition(FileChannel fileChannel, PartitionBoundary boundary) {
+ try (var arena = Arena.ofConfined()) {
+ var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, boundary.start(), boundary.end() - boundary.start(), arena);
+ MeasurementListVectorized list = new MeasurementListVectorized();
+ long size = memorySegment.byteSize();
+ long offset = 0;
+ long safe = size - KEY_SIZE;
+ // ByteBuffer byteBuffer = memorySegment.asByteBuffer();
+ // byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
+ ByteVector chunk1 = ByteVector.zero(BYTE_SPECIES);
+ ByteVector chunk2 = ByteVector.zero(BYTE_SPECIES);
+ while (offset < safe) {
+ int len = 0;
+ chunk1 = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, offset, NATIVE_ORDER);
+ int equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue();
+ len += equals;
+ if (equals == BYTE_SPECIES_LANES) {
+ while (memorySegment.get(ValueLayout.JAVA_BYTE, offset + len) != ';') {
+ len++;
}
- if (hash < 0) {
- hash = -hash;
- }
-
- long numberWord = mappedByteBuffer.getLong(currentPosition);
- int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
- int value = convertIntoNumber(decimalSepPos, numberWord);
- mappedByteBuffer.position(currentPosition + (decimalSepPos >>> 3) + 3);
-
- list.add(stationb, i, hash, value);
+ }
+ int hash = hash(memorySegment, offset, len);
+ long prevOffset = offset;
+ offset += len + 1;
+
+ long numberWord = memorySegment.get(JAVA_LONG_LT, offset);
+ int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
+ int value = convertIntoNumber(decimalSepPos, numberWord);
+ offset += (decimalSepPos >>> 3) + 3;
+ // System.out.println("Value=" + value);
+ if (len < BYTE_SPECIES_LANES) {
+ list.addWithByteVector(chunk1, len, hash, value, memorySegment, prevOffset);
}
- while ((startLine = mappedByteBuffer.position()) < limit) {
- int currentPosition = startLine;
- byte b = 0;
- int i = 0;
- int hash = 0;
- while ((b = mappedByteBuffer.get(currentPosition++)) != ';') {
- stationb[i++] = b;
- hash = 31 * hash + b;
- }
- if (hash < 0) {
- hash = -hash;
+ else {
+ list.add(len, hash, value, memorySegment, prevOffset);
+ }
+ }
+
+ while (offset < size) {
+ int len = 0;
+ int equals = BYTE_SPECIES_LANES;
+ if (offset + BYTE_SPECIES_LANES < size) {
+ chunk1 = ByteVector.fromMemorySegment(BYTE_SPECIES, memorySegment, offset, NATIVE_ORDER);
+ equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue();
+ len += equals;
+ if (equals == BYTE_SPECIES_LANES) {
+ while (memorySegment.get(ValueLayout.JAVA_BYTE, offset + len) != ';') {
+ len++;
+ }
}
+ }
+ else {
+ byte[] bytes = new byte[BYTE_SPECIES_LANES];
+ MemorySegment.copy(memorySegment, offset + len, MemorySegment.ofArray(bytes), 0, (size - offset - len));
+ // byteBuffer.get(offset + len, bytes, 0, (int) (size - offset - len));
+ chunk1 = ByteVector.fromArray(BYTE_SPECIES, bytes, 0);
+ equals = chunk1.compare(VectorOperators.EQ, separators).firstTrue();
+ len += equals;
+ }
+ int hash = hash(memorySegment, offset, len);
+ long prevOffset = offset;
+ offset += len + 1;
- int value = 0;
- if (currentPosition <= limit - 8) {
- long numberWord = mappedByteBuffer.getLong(currentPosition);
- int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
- value = convertIntoNumber(decimalSepPos, numberWord);
- mappedByteBuffer.position(currentPosition + (decimalSepPos >>> 3) + 3);
+ int value = 0;
+ if (offset < size - 8) {
+ long numberWord = memorySegment.get(JAVA_LONG_LT, offset);
+ int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000);
+ value = convertIntoNumber(decimalSepPos, numberWord);
+ offset += (decimalSepPos >>> 3) + 3;
+ }
+ else {
+ long currentPosition = offset;
+ int sign = 1;
+ byte b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition++);
+ if (b == '-') {
+ sign = -1;
}
else {
- int sign = 1;
- b = mappedByteBuffer.get(currentPosition++);
- if (b == '-') {
- sign = -1;
- }
- else {
- value = b - '0';
- }
- while ((b = mappedByteBuffer.get(currentPosition++)) != '.') {
- value = value * 10 + (b - '0');
- }
- b = mappedByteBuffer.get(currentPosition);
+ value = b - '0';
+ }
+ while ((b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition++)) != '.') {
value = value * 10 + (b - '0');
- if (sign == -1) {
- value = -value;
- }
- mappedByteBuffer.position(currentPosition + 2);
}
-
- list.add(stationb, i, hash, value);
+ b = memorySegment.get(ValueLayout.JAVA_BYTE, currentPosition);
+ value = value * 10 + (b - '0');
+ if (sign == -1) {
+ value = -value;
+ }
+ offset = currentPosition + 2;
+ }
+ if (len < BYTE_SPECIES_LANES) {
+ list.addWithByteVector(chunk1, len, hash, value, memorySegment, prevOffset);
+ }
+ else {
+ list.add(len, hash, value, memorySegment, prevOffset);
}
}
+ return list;
}
catch (IOException e) {
- System.out.println("Error");
- System.err.println(e);
+ throw new RuntimeException(e);
}
- return list;
}
- private static final ByteOrder BYTE_ORDER = ByteOrder.nativeOrder();
+ private static final int GOLDEN_RATIO = 0x9E3779B9;
+ private static final int HASH_LROTATE = 5;
- private static long getLongLittleEndian(long value) {
- value = Long.reverseBytes(value);
- return value;
+ private static int hash(MemorySegment memorySegment, long start, int len) {
+ int x;
+ int y;
+ if (len >= Integer.BYTES) {
+ x = memorySegment.get(ValueLayout.JAVA_INT_UNALIGNED, start);
+ y = memorySegment.get(ValueLayout.JAVA_INT_UNALIGNED, start + len - Integer.BYTES);
+ }
+ else {
+ x = memorySegment.get(ValueLayout.JAVA_BYTE, start);
+ y = memorySegment.get(ValueLayout.JAVA_BYTE, start + len - Byte.BYTES);
+ }
+ return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO;
}
private static int convertIntoNumber(int decimalSepPos, long numberWord) {
@@ -405,16 +455,11 @@ public class CalculateAverage_giovannicuccu {
return (int) value;
}
- private static long[] masks = new long[]{ 0x0000000000000000, 0xFF00000000000000L, 0xFFFF000000000000L,
- 0xFFFFFF0000000000L, 0xFFFFFFFF00000000L, 0xFFFFFFFFFF000000L, 0xFFFFFFFFFF0000L, 0xFFFFFFFFFFFF00L };
-
}
public static void main(String[] args) throws IOException {
- long start = System.currentTimeMillis();
- MMapReader reader = new MMapReader(Paths.get(FILE), new ProcessorPartitionCalculator(), false);
- Map<String, MeasurementAggregator> measurements = reader.elaborate();
- // System.out.println("ela=" + (System.currentTimeMillis() - start));
+ MMapReaderMemorySegment reader = new MMapReaderMemorySegment(Paths.get(FILE), new ProcessorPartitionCalculator(), false);
+ Map<String, MeasurementAggregatorVectorized> measurements = reader.elaborate();
System.out.println(measurements);
}
diff --git a/src/main/java/dev/morling/onebrc/CreateMeasurements3.java b/src/main/java/dev/morling/onebrc/CreateMeasurements3.java
index 804b83c..9bcc16d 100644
--- a/src/main/java/dev/morling/onebrc/CreateMeasurements3.java
+++ b/src/main/java/dev/morling/onebrc/CreateMeasurements3.java
@@ -55,7 +55,7 @@ public class CreateMeasurements3 {
out.write(station.name);
out.write(';');
out.write(Double.toString(Math.round(temp * 10.0) / 10.0));
- out.newLine();
+ out.write('\n');
if (i % 50_000_000 == 0) {
System.out.printf("Wrote %,d measurements in %,d ms%n", i, System.currentTimeMillis() - start);
}