aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java382
1 files changed, 182 insertions, 200 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java b/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java
index db60403..178a6e1 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_bufistov.java
@@ -15,11 +15,17 @@
*/
package dev.morling.onebrc;
+import sun.misc.Unsafe;
+
import static java.lang.Math.toIntExact;
+import java.lang.foreign.Arena;
+import java.lang.reflect.Field;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
+import java.nio.file.Paths;
+import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
@@ -32,66 +38,6 @@ import java.io.FileInputStream;
import java.io.IOException;
import java.util.concurrent.Future;
-class ResultRow {
- byte[] station;
-
- String stationString;
- long min, max, count, suma;
-
- ResultRow() {
- }
-
- ResultRow(byte[] station, long value) {
- this.station = new byte[station.length];
- System.arraycopy(station, 0, this.station, 0, station.length);
- this.min = value;
- this.max = value;
- this.count = 1;
- this.suma = value;
- }
-
- ResultRow(long value) {
- this.min = value;
- this.max = value;
- this.count = 1;
- this.suma = value;
- }
-
- void setStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition) {
- this.station = new byte[endPosition - startPosition];
- byteBuffer.slice(startPosition, station.length).get(this.station, 0, station.length);
- }
-
- public String toString() {
- stationString = new String(station, StandardCharsets.UTF_8);
- return stationString + "=" + round(min / 10.0) + "/" + round(suma / 10.0 / count) + "/" + round(max / 10.0);
- }
-
- private double round(double value) {
- return Math.round(value * 10.0) / 10.0;
- }
-
- ResultRow update(long newValue) {
- this.count += 1;
- this.suma += newValue;
- if (newValue < this.min) {
- this.min = newValue;
- }
- else if (newValue > this.max) {
- this.max = newValue;
- }
- return this;
- }
-
- ResultRow merge(ResultRow another) {
- this.count += another.count;
- this.suma += another.suma;
- this.min = Math.min(this.min, another.min);
- this.max = Math.max(this.max, another.max);
- return this;
- }
-}
-
class ByteArrayWrapper {
private final byte[] data;
@@ -110,100 +56,176 @@ class ByteArrayWrapper {
}
}
-class OpenHash {
- ResultRow[] data;
- int dataSizeMask;
+public class CalculateAverage_bufistov {
- // ResultRow metrics = new ResultRow();
+ static class ResultRow {
+ byte[] station;
- public OpenHash(int capacityPow2) {
- assert capacityPow2 <= 20;
- int dataSize = 1 << capacityPow2;
- dataSizeMask = dataSize - 1;
- data = new ResultRow[dataSize];
- }
+ String stationString;
+ long min, max, count, suma;
- int hashByteArray(byte[] array) {
- int result = 0;
- long mask = 0;
- for (int i = 0; i < array.length; ++i, mask = ((mask + 1) & 3)) {
- result += array[i] << mask;
+ ResultRow() {
}
- return result & dataSizeMask;
- }
- void merge(byte[] station, long value, int hashValue) {
- while (data[hashValue] != null && !Arrays.equals(station, data[hashValue].station)) {
- hashValue += 1;
- hashValue &= dataSizeMask;
+ ResultRow(byte[] station, long value) {
+ this.station = new byte[station.length];
+ System.arraycopy(station, 0, this.station, 0, station.length);
+ this.min = value;
+ this.max = value;
+ this.count = 1;
+ this.suma = value;
}
- if (data[hashValue] == null) {
- data[hashValue] = new ResultRow(station, value);
+
+ ResultRow(long value) {
+ this.min = value;
+ this.max = value;
+ this.count = 1;
+ this.suma = value;
}
- else {
- data[hashValue].update(value);
+
+ void setStation(long startPosition, long endPosition) {
+ this.station = new byte[(int) (endPosition - startPosition)];
+ for (int i = 0; i < this.station.length; ++i) {
+ this.station[i] = UNSAFE.getByte(startPosition + i);
+ }
}
- // metrics.update(delta);
- }
- void merge(byte[] station, long value) {
- merge(station, value, hashByteArray(station));
- }
+ public String toString() {
+ stationString = new String(station, StandardCharsets.UTF_8);
+ return stationString + "=" + round(min / 10.0) + "/" + round(suma / 10.0 / count) + "/" + round(max / 10.0);
+ }
- void merge(MappedByteBuffer byteBuffer, final int startPosition, final int endPosition, int hashValue, final long value) {
- while (data[hashValue] != null && !equalsToStation(byteBuffer, startPosition, endPosition, data[hashValue].station)) {
- hashValue += 1;
- hashValue &= dataSizeMask;
+ private double round(double value) {
+ return Math.round(value * 10.0) / 10.0;
}
- if (data[hashValue] == null) {
- data[hashValue] = new ResultRow(value);
- data[hashValue].setStation(byteBuffer, startPosition, endPosition);
+
+ void update(long newValue) {
+ this.count += 1;
+ this.suma += newValue;
+ if (newValue < this.min) {
+ this.min = newValue;
+ }
+ else if (newValue > this.max) {
+ this.max = newValue;
+ }
}
- else {
- data[hashValue].update(value);
+
+ ResultRow merge(ResultRow another) {
+ this.count += another.count;
+ this.suma += another.suma;
+ this.min = Math.min(this.min, another.min);
+ this.max = Math.max(this.max, another.max);
+ return this;
}
}
- boolean equalsToStation(MappedByteBuffer byteBuffer, int startPosition, int endPosition, byte[] station) {
- if (endPosition - startPosition != station.length) {
- return false;
+ static class OpenHash {
+ ResultRow[] data;
+ int dataSizeMask;
+
+ // ResultRow metrics = new ResultRow();
+
+ public OpenHash(int capacityPow2) {
+ assert capacityPow2 <= 20;
+ int dataSize = 1 << capacityPow2;
+ dataSizeMask = dataSize - 1;
+ data = new ResultRow[dataSize];
}
- for (int i = 0; i < station.length; ++i, ++startPosition) {
- if (byteBuffer.get(startPosition) != station[i])
+
+ int hashByteArray(byte[] array) {
+ int result = 0;
+ long mask = 0;
+ for (int i = 0; i < array.length; ++i, mask = ((mask + 1) & 3)) {
+ result += array[i] << mask;
+ }
+ return result & dataSizeMask;
+ }
+
+ void merge(byte[] station, long value, int hashValue) {
+ while (data[hashValue] != null && !Arrays.equals(station, data[hashValue].station)) {
+ hashValue += 1;
+ hashValue &= dataSizeMask;
+ }
+ if (data[hashValue] == null) {
+ data[hashValue] = new ResultRow(station, value);
+ }
+ else {
+ data[hashValue].update(value);
+ }
+ // metrics.update(delta);
+ }
+
+ void merge(byte[] station, long value) {
+ merge(station, value, hashByteArray(station));
+ }
+
+ void merge(final long startPosition, long endPosition, int hashValue, long value) {
+ while (data[hashValue] != null && !equalsToStation(startPosition, endPosition, data[hashValue].station)) {
+ hashValue += 1;
+ hashValue &= dataSizeMask;
+ }
+ if (data[hashValue] == null) {
+ data[hashValue] = new ResultRow(value);
+ data[hashValue].setStation(startPosition, endPosition);
+ }
+ else {
+ data[hashValue].update(value);
+ }
+ }
+
+ boolean equalsToStation(long startPosition, long endPosition, byte[] station) {
+ if (endPosition - startPosition != station.length) {
return false;
+ }
+ for (int i = 0; i < station.length; ++i, ++startPosition) {
+ if (UNSAFE.getByte(startPosition) != station[i])
+ return false;
+ }
+ return true;
}
- return true;
- }
- HashMap<ByteArrayWrapper, ResultRow> toJavaHashMap() {
- HashMap<ByteArrayWrapper, ResultRow> result = new HashMap<>(20000);
- for (int i = 0; i < data.length; ++i) {
- if (data[i] != null) {
- var key = new ByteArrayWrapper(data[i].station);
- result.put(key, data[i]);
+ HashMap<ByteArrayWrapper, ResultRow> toJavaHashMap() {
+ HashMap<ByteArrayWrapper, ResultRow> result = new HashMap<>(20000);
+ for (int i = 0; i < data.length; ++i) {
+ if (data[i] != null) {
+ var key = new ByteArrayWrapper(data[i].station);
+ result.put(key, data[i]);
+ }
}
+ return result;
}
- return result;
}
-}
-public class CalculateAverage_bufistov {
+ static final Unsafe UNSAFE;
+
+ static {
+ try {
+ Field unsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ unsafe.setAccessible(true);
+ UNSAFE = (Unsafe) unsafe.get(Unsafe.class);
+ }
+ catch (Throwable e) {
+ throw new RuntimeException(e);
+ }
+ }
static final long LINE_SEPARATOR = '\n';
public static class FileRead implements Callable<HashMap<ByteArrayWrapper, ResultRow>> {
private final FileChannel fileChannel;
+
private long currentLocation;
- private int bytesToRead;
+ private long bytesToRead;
+
+ private static final int hashCapacityPow2 = 18;
- private final int hashCapacityPow2 = 18;
- private final int hashCapacityMask = (1 << hashCapacityPow2) - 1;
+ static final int hashCapacityMask = (1 << hashCapacityPow2) - 1;
- public FileRead(long startLocation, int bytesToRead, FileChannel fileChannel) {
+ public FileRead(FileChannel fileChannel, long startLocation, long bytesToRead, boolean firstSegment) {
+ this.fileChannel = fileChannel;
this.currentLocation = startLocation;
this.bytesToRead = bytesToRead;
- this.fileChannel = fileChannel;
}
@Override
@@ -211,21 +233,13 @@ public class CalculateAverage_bufistov {
try {
OpenHash openHash = new OpenHash(hashCapacityPow2);
log("Reading the channel: " + currentLocation + ":" + bytesToRead);
- byte[] suffix = new byte[128];
if (currentLocation > 0) {
- toLineBegin(suffix);
- }
- while (bytesToRead > 0) {
- int bufferSize = Math.min(1 << 24, bytesToRead);
- MappedByteBuffer byteBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bufferSize);
- bytesToRead -= bufferSize;
- currentLocation += bufferSize;
- int suffixBytes = 0;
- if (currentLocation < fileChannel.size()) {
- suffixBytes = toLineBegin(suffix);
- }
- processChunk(byteBuffer, bufferSize, suffix, suffixBytes, openHash);
+ toLineBeginPrefix();
}
+ toLineBeginSuffix();
+ var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, currentLocation, bytesToRead, Arena.global());
+ currentLocation = memorySegment.address();
+ processChunk(openHash);
log("Done Reading the channel: " + currentLocation + ":" + bytesToRead);
return openHash.toJavaHashMap();
}
@@ -240,39 +254,40 @@ public class CalculateAverage_bufistov {
return byteBuffer.get();
}
- int toLineBegin(byte[] suffix) throws IOException {
- int bytesConsumed = 0;
- if (getByte(currentLocation - 1) != LINE_SEPARATOR) {
- while (getByte(currentLocation) != LINE_SEPARATOR) { // Small bug here if last chunk is less than a line and has no '\n' at the end. Valid input should have '\n' for all rows.
- suffix[bytesConsumed++] = getByte(currentLocation);
- ++currentLocation;
- --bytesToRead;
- }
+ void toLineBeginPrefix() throws IOException {
+ while (getByte(currentLocation - 1) != LINE_SEPARATOR) {
++currentLocation;
--bytesToRead;
}
- return bytesConsumed;
}
- void processChunk(MappedByteBuffer byteBuffer, int bufferSize, byte[] suffix, int suffixBytes, OpenHash result) {
- int nameBegin = 0;
- int nameEnd = -1;
- int numberBegin = -1;
+ void toLineBeginSuffix() throws IOException {
+ while (getByte(currentLocation + bytesToRead - 1) != LINE_SEPARATOR) {
+ ++bytesToRead;
+ }
+ }
+
+ void processChunk(OpenHash result) {
+ long nameBegin = currentLocation;
+ long nameEnd = -1;
+ long numberBegin = -1;
int currentHash = 0;
int currentMask = 0;
int nameHash = 0;
- for (int currentPosition = 0; currentPosition < bufferSize; ++currentPosition) {
- byte nextByte = byteBuffer.get(currentPosition);
+ long end = currentLocation + bytesToRead;
+ byte nextByte;
+ for (; currentLocation < end; ++currentLocation) {
+ nextByte = UNSAFE.getByte(currentLocation);
if (nextByte == ';') {
- nameEnd = currentPosition;
- numberBegin = currentPosition + 1;
+ nameEnd = currentLocation;
+ numberBegin = currentLocation + 1;
nameHash = currentHash & hashCapacityMask;
}
else if (nextByte == LINE_SEPARATOR) {
- long value = getValue(byteBuffer, numberBegin, currentPosition);
- // log("Station name: '" + getStationName(byteBuffer, nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash);
- result.merge(byteBuffer, nameBegin, nameEnd, nameHash, value);
- nameBegin = currentPosition + 1;
+ long value = getValue(numberBegin, currentLocation);
+ // log("Station name: '" + getStationName(nameBegin, nameEnd) + "' value: " + value + " hash: " + nameHash);
+ result.merge(nameBegin, nameEnd, nameHash, value);
+ nameBegin = currentLocation + 1;
currentHash = 0;
currentMask = 0;
}
@@ -281,38 +296,14 @@ public class CalculateAverage_bufistov {
currentMask = (currentMask + 1) & 3;
}
}
- if (nameBegin < bufferSize) {
- byte[] lastLine = new byte[bufferSize - nameBegin + suffixBytes];
- byte[] prefix = new byte[bufferSize - nameBegin];
- byteBuffer.slice(nameBegin, prefix.length).get(prefix, 0, prefix.length);
- System.arraycopy(prefix, 0, lastLine, 0, prefix.length);
- System.arraycopy(suffix, 0, lastLine, prefix.length, suffixBytes);
- processLastLine(lastLine, result);
- }
}
- void processLastLine(byte[] lastLine, OpenHash result) {
- int numberBegin = -1;
- byte[] stationName = null;
- for (int i = 0; i < lastLine.length; ++i) {
- if (lastLine[i] == ';') {
- stationName = new byte[i];
- System.arraycopy(lastLine, 0, stationName, 0, stationName.length);
- numberBegin = i + 1;
- break;
- }
- }
- long value = getValue(lastLine, numberBegin);
- // log("Station name: '" + new String(stationName, StandardCharsets.UTF_8) + "' value: " + value);
- result.merge(stationName, value);
- }
-
- long getValue(MappedByteBuffer byteBuffer, int startLocation, int endLocation) {
- byte nextByte = byteBuffer.get(startLocation);
+ long getValue(long startLocation, long endLocation) {
+ byte nextByte = UNSAFE.getByte(startLocation);
boolean negate = nextByte == '-';
long result = negate ? 0 : nextByte - '0';
- for (int i = startLocation + 1; i < endLocation; ++i) {
- nextByte = byteBuffer.get(i);
+ for (long i = startLocation + 1; i < endLocation; ++i) {
+ nextByte = UNSAFE.getByte(i);
if (nextByte != '.') {
result *= 10;
result += nextByte - '0';
@@ -321,23 +312,11 @@ public class CalculateAverage_bufistov {
return negate ? -result : result;
}
- long getValue(byte[] lastLine, int startLocation) {
- byte nextByte = lastLine[startLocation];
- boolean negate = nextByte == '-';
- long result = negate ? 0 : nextByte - '0';
- for (int i = startLocation + 1; i < lastLine.length; ++i) {
- nextByte = lastLine[i];
- if (nextByte != '.') {
- result *= 10;
- result += nextByte - '0';
- }
+ String getStationName(long from, long to) {
+ byte[] bytes = new byte[(int) (to - from)];
+ for (int i = 0; i < bytes.length; ++i) {
+ bytes[i] = UNSAFE.getByte(from + i);
}
- return negate ? -result : result;
- }
-
- String getStationName(MappedByteBuffer byteBuffer, int from, int to) {
- byte[] bytes = new byte[to - from];
- byteBuffer.slice(from, to - from).get(0, bytes);
return new String(bytes, StandardCharsets.UTF_8);
}
}
@@ -349,7 +328,7 @@ public class CalculateAverage_bufistov {
}
log("InputFile: " + fileName);
FileInputStream fileInputStream = new FileInputStream(fileName);
- int numThreads = 32;
+ int numThreads = 2 * Runtime.getRuntime().availableProcessors();
if (args.length > 1) {
numThreads = Integer.parseInt(args[1]);
}
@@ -363,9 +342,12 @@ public class CalculateAverage_bufistov {
long startLocation = 0;
ArrayList<Future<HashMap<ByteArrayWrapper, ResultRow>>> results = new ArrayList<>(numThreads);
+ var fileChannel = FileChannel.open(Paths.get(fileName));
+ boolean firstSegment = true;
while (remaining_size > 0) {
long actualSize = Math.min(chunk_size, remaining_size);
- results.add(executor.submit(new FileRead(startLocation, toIntExact(actualSize), channel)));
+ results.add(executor.submit(new FileRead(fileChannel, startLocation, toIntExact(actualSize), firstSegment)));
+ firstSegment = false;
remaining_size -= actualSize;
startLocation += actualSize;
}