aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java324
1 files changed, 199 insertions, 125 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
index a8c4e4c..4bffe78 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java
@@ -15,45 +15,53 @@
*/
package dev.morling.onebrc;
-import java.io.*;
-import java.nio.*;
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.nio.ByteOrder;
import java.nio.channels.*;
-import java.util.concurrent.*;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
import java.util.stream.*;
import java.util.*;
+import static java.lang.foreign.ValueLayout.*;
+
/* A fast implementation with no unsafe.
* Features:
- * * memory mapped file
+ * * memory mapped file using preview Arena FFI
* * read chunks in parallel
* * minimise allocation
* * no unsafe
*
* Timings on 4 core i7-7500U CPU @ 2.70GHz:
* average_baseline: 4m48s
- * ianopolous: 19s
+ * ianopolous: 16s
*/
public class CalculateAverage_ianopolousfast {
public static final int MAX_LINE_LENGTH = 107;
- public static final int MAX_STATIONS = 10_000;
+ public static final int MAX_STATIONS = 1 << 14;
+ private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
public static void main(String[] args) throws Exception {
- File input = new File("./measurements.txt");
- long filesize = input.length();
- // keep chunk size between 256 MB and 1G (1 chunk for files < 256MB)
- long chunkSize = Math.min(Math.max((filesize + 31) / 32, 256 * 1024 * 1024), 1024 * 1024 * 1024L);
- int nChunks = (int) ((filesize + chunkSize - 1) / chunkSize);
- ExecutorService pool = Executors.newVirtualThreadPerTaskExecutor();
- List<Future<List<List<Stat>>>> allResults = IntStream.range(0, nChunks)
- .mapToObj(i -> pool.submit(() -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize))))
+ Arena arena = Arena.global();
+ Path input = Path.of("measurements.txt");
+ FileChannel channel = (FileChannel) Files.newByteChannel(input, StandardOpenOption.READ);
+ long filesize = Files.size(input);
+ MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena);
+ int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors();
+ long chunkSize = (filesize + nChunks - 1) / nChunks;
+ List<List<List<Stat>>> allResults = IntStream.range(0, nChunks)
+ .parallel()
+ .mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap))
.toList();
TreeMap<String, Stat> merged = allResults.stream()
.parallel()
.flatMap(f -> {
try {
- return f.get().stream().filter(Objects::nonNull).flatMap(Collection::stream);
+ return f.stream().filter(Objects::nonNull).flatMap(Collection::stream);
}
catch (Exception e) {
e.printStackTrace();
@@ -64,25 +72,39 @@ public class CalculateAverage_ianopolousfast {
System.out.println(merged);
}
- public static boolean matchingStationBytes(int start, int end, ByteBuffer buffer, Stat existing) {
- if (end - start != existing.name.length)
+ public static boolean matchingStationBytes(long start, long end, int offset, MemorySegment buffer, Stat existing) {
+ int len = (int) (end - start);
+ if (len != existing.name.length)
return false;
- for (int i = start; i < end; i++) {
- if (existing.name[i - start] != buffer.get(i))
+ for (int i = offset; i < len; i++) {
+ if (existing.name[i] != buffer.get(JAVA_BYTE, offset + start++))
return false;
}
return true;
}
- public static Stat dedupeStation(int start, int end, long hash, ByteBuffer buffer, List<List<Stat>> stations) {
- int index = Math.floorMod(hash ^ (hash >> 32), MAX_STATIONS);
+ private static int hashToIndex(long hash, int len) {
+ // From Thomas Wuerthinger's entry
+ int hashAsInt = (int) (hash ^ (hash >>> 28));
+ int finalHash = (hashAsInt ^ (hashAsInt >>> 15));
+ return (finalHash & (len - 1));
+ }
+
+ public static Stat parseStation(long start, long end, long first8, long second8,
+ MemorySegment buffer) {
+ byte[] stationBuffer = new byte[(int) (end - start)];
+ for (long off = start; off < end; off++)
+ stationBuffer[(int) (off - start)] = buffer.get(JAVA_BYTE, off);
+ return new Stat(stationBuffer, first8, second8);
+ }
+
+ public static Stat dedupeStation(long start, long end, long hash, long first8, long second8,
+ MemorySegment buffer, List<List<Stat>> stations) {
+ int index = hashToIndex(hash, MAX_STATIONS);
List<Stat> matches = stations.get(index);
if (matches == null) {
List<Stat> value = new ArrayList<>();
- byte[] stationBuffer = new byte[end - start];
- buffer.position(start);
- buffer.get(stationBuffer);
- Stat res = new Stat(stationBuffer);
+ Stat res = parseStation(start, end, first8, second8, buffer);
value.add(res);
stations.set(index, value);
return res;
@@ -90,136 +112,185 @@ public class CalculateAverage_ianopolousfast {
else {
for (int i = 0; i < matches.size(); i++) {
Stat s = matches.get(i);
- if (matchingStationBytes(start, end, buffer, s))
+ if (first8 == s.first8 && second8 == s.second8 && matchingStationBytes(start, end, 16, buffer, s))
return s;
}
- byte[] stationBuffer = new byte[end - start];
- buffer.position(start);
- buffer.get(stationBuffer);
- Stat res = new Stat(stationBuffer);
+ Stat res = parseStation(start, end, first8, second8, buffer);
matches.add(res);
return res;
}
}
- public static int getSemicolon(long d) {
+ public static Stat dedupeStation8(long start, long end, long hash, long first8, MemorySegment buffer, List<List<Stat>> stations) {
+ int index = hashToIndex(hash, MAX_STATIONS);
+ List<Stat> matches = stations.get(index);
+ if (matches == null) {
+ List<Stat> value = new ArrayList<>();
+ Stat station = parseStation(start, end, first8, 0, buffer);
+ value.add(station);
+ stations.set(index, value);
+ return station;
+ }
+ else {
+ for (int i = 0; i < matches.size(); i++) {
+ Stat s = matches.get(i);
+ if (first8 == s.first8 && s.name.length <= 8)
+ return s;
+ }
+ Stat station = parseStation(start, end, first8, 0, buffer);
+ matches.add(station);
+ return station;
+ }
+ }
+
+ public static Stat dedupeStation16(long start, long end, long hash, long first8, long second8, MemorySegment buffer, List<List<Stat>> stations) {
+ int index = hashToIndex(hash, MAX_STATIONS);
+ List<Stat> matches = stations.get(index);
+ if (matches == null) {
+ List<Stat> value = new ArrayList<>();
+ Stat res = parseStation(start, end, first8, second8, buffer);
+ value.add(res);
+ stations.set(index, value);
+ return res;
+ }
+ else {
+ for (int i = 0; i < matches.size(); i++) {
+ Stat s = matches.get(i);
+ if (first8 == s.first8 && second8 == s.second8 && s.name.length <= 16)
+ return s;
+ }
+ Stat res = parseStation(start, end, first8, second8, buffer);
+ matches.add(res);
+ return res;
+ }
+ }
+
+ public static long hasSemicolon(long d) {
// from Hacker's Delight page 92
d = d ^ 0x3b3b3b3b3b3b3b3bL;
long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
- y = ~(y | d | 0x7f7f7f7f7f7f7f7fL);
+ return ~(y | d | 0x7f7f7f7f7f7f7f7fL);
+ }
+
+ public static int getSemicolonIndex(long y) {
+ // from Hacker's Delight page 92
return Long.numberOfLeadingZeros(y) >> 3;
}
- public static long updateHash(long hash, long x) {
- return ((hash << 5) ^ x) * 0x517cc1b727220a95L; // fxHash
+ static long maskHighBytes(long d, int nbytes) {
+ return d & (-1L << ((8 - nbytes) * 8));
}
- public static Stat parseStation(int lineStart, ByteBuffer buffer, List<List<Stat>> stations) {
+ public static Stat parseStation(long lineStart, MemorySegment buffer, List<List<Stat>> stations) {
// find semicolon and update hash as we go, reading a long at a time
- long d = buffer.getLong(lineStart);
+ long d = buffer.get(LONG_LAYOUT, lineStart);
+ long hasSemi = hasSemicolon(d);
+ if (hasSemi != 0) {
+ int semiIndex = getSemicolonIndex(hasSemi);
+ d = maskHighBytes(d, semiIndex);
+ return dedupeStation8(lineStart, lineStart + semiIndex, d, d, buffer, stations);
+ }
+ long first8 = d;
+ long hash = d;
+
+ d = buffer.get(LONG_LAYOUT, lineStart + 8);
+ hasSemi = hasSemicolon(d);
+ if (hasSemi != 0) {
+ int semiIndex = getSemicolonIndex(hasSemi);
+ if (semiIndex == 0)
+ return dedupeStation8(lineStart, lineStart + 8, first8, first8, buffer, stations);
+ d = maskHighBytes(d, semiIndex);
+ return dedupeStation16(lineStart, lineStart + 8 + semiIndex, first8 ^ d, first8, d, buffer, stations);
+ }
- int semiIndex = getSemicolon(d);
- int index = 0;
- long hash = 0;
- while (semiIndex == 8) {
- hash = updateHash(hash, d);
+ int index = 8;
+ long second8 = d;
+ while (hasSemi == 0) {
+ hash = hash ^ d;
index += 8;
- d = buffer.getLong(lineStart + index);
- semiIndex = getSemicolon(d);
+ d = buffer.get(LONG_LAYOUT, lineStart + index);
+ hasSemi = hasSemicolon(d);
}
- // mask extra bytes off last long
- d = d & (-1L << ((8 - semiIndex) * 8));
+ int semiIndex = getSemicolonIndex(hasSemi);
+ d = maskHighBytes(d, semiIndex);
if (semiIndex > 0) {
- hash = updateHash(hash, d);
+ hash = hash ^ d;
}
- return dedupeStation(lineStart, lineStart + index + semiIndex, hash, buffer, stations);
+ return dedupeStation(lineStart, lineStart + index + semiIndex, hash, first8, second8, buffer, stations);
}
- public static int processTemperature(int lineSplit, MappedByteBuffer buffer, Stat station) {
- short temperature;
- boolean negative = false;
- byte b = buffer.get(lineSplit++);
- if (b == '-') {
- negative = true;
- b = buffer.get(lineSplit++);
- }
- temperature = (short) (b - 0x30);
- b = buffer.get(lineSplit++);
- if (b == '.') {
- b = buffer.get(lineSplit++);
- temperature = (short) (temperature * 10 + (b - 0x30));
- }
- else {
- temperature = (short) (temperature * 10 + (b - 0x30));
- lineSplit++;
- b = buffer.get(lineSplit++);
- temperature = (short) (temperature * 10 + (b - 0x30));
- }
- temperature = negative ? (short) -temperature : temperature;
+ public static int getDot(long d) {
+ // from Hacker's Delight page 92
+ d = d ^ 0x2e2e2e2e2e2e2e2eL;
+ long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
+ y = ~(y | d | 0x7f7f7f7f7f7f7f7fL);
+ return Long.numberOfLeadingZeros(y) >> 3;
+ }
+
+ public static short getMinus(long d) {
+ d = d & 0xff00000000000000L;
+ d = d ^ 0x2d2d2d2d2d2d2d2dL;
+ long y = (d & 0x7f7f7f7f7f7f7f7fL) + 0x7f7f7f7f7f7f7f7fL;
+ y = ~(y | d | 0x7f7f7f7f7f7f7f7fL);
+ return (short) ((Long.numberOfLeadingZeros(y) >> 6) - 1);
+ }
+
+ public static long processTemperature(long lineSplit, MemorySegment buffer, Stat station) {
+ long d = buffer.get(LONG_LAYOUT, lineSplit);
+ // negative is either 0 or -1
+ short negative = getMinus(d);
+ d = d << (negative * -8);
+ int dotIndex = getDot(d);
+ d = (d >> 8) | 0x30000000_00000000L; // add a leading 0 digit
+ d = d >> 8 * (5 - dotIndex);
+ short temperature = (short) ((byte) d - '0' +
+ 10 * (((byte) (d >> 16)) - '0') +
+ 100 * (((byte) (d >> 24)) - '0'));
+ temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty
station.add(temperature);
- return lineSplit + 1;
+ return lineSplit - negative + dotIndex + 3;
}
- public static List<List<Stat>> parseStats(long startByte, long endByte) {
- try {
- RandomAccessFile file = new RandomAccessFile("./measurements.txt", "r");
- long maxEnd = Math.min(file.length(), endByte + MAX_LINE_LENGTH);
- long len = maxEnd - startByte;
- if (len > Integer.MAX_VALUE)
- throw new RuntimeException("Segment size must fit into an int");
- int maxDone = (int) (endByte - startByte);
- MappedByteBuffer buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, startByte, len);
- int done = 0;
- // read first partial line
- if (startByte > 0) {
- for (int i = 0; i < MAX_LINE_LENGTH; i++) {
- byte b = buffer.get(i);
- if (b == '\n') {
- done = i + 1;
- break;
- }
+ public static List<List<Stat>> parseStats(long startByte, long endByte, MemorySegment buffer) {
+ // read first partial line
+ if (startByte > 0) {
+ for (int i = 0; i < MAX_LINE_LENGTH; i++) {
+ byte b = buffer.get(JAVA_BYTE, startByte++);
+ if (b == '\n') {
+ break;
}
}
+ }
- List<List<Stat>> stations = new ArrayList<>(MAX_STATIONS);
- for (int i = 0; i < MAX_STATIONS; i++)
- stations.add(null);
-
- // Handle reading the very last line in the file
- // this allows us to not worry about reading a long beyond the end
- // in the inner loop (reducing branches)
- // We only need to read one because the min record size is 6 bytes
- // so 2nd last record must be > 8 from end
- if (endByte == file.length()) {
- int offset = (int) (file.length() - startByte - 1);
- while (buffer.get(offset) != '\n') // final new line
- offset--;
- offset--;
- while (offset > 0 && buffer.get(offset) != '\n') // end of second last line
- offset--;
- maxDone = offset;
- if (offset > 0)
- offset++;
- // copy into a 8n sized buffer to avoid reading off end
- int roundedSize = (int) (file.length() - startByte) - offset;
- roundedSize = (roundedSize + 7) / 8 * 8;
- byte[] end = new byte[roundedSize];
- for (int i = offset; i < (int) (file.length() - startByte); i++)
- end[i - offset] = buffer.get(i);
- Stat station = parseStation(0, ByteBuffer.wrap(end), stations);
- processTemperature(offset + station.name.length + 1, buffer, station);
- }
+ List<List<Stat>> stations = new ArrayList<>(MAX_STATIONS);
+ for (int i = 0; i < MAX_STATIONS; i++)
+ stations.add(null);
- int lineStart = done;
- while (lineStart < maxDone) {
- Stat station = parseStation(lineStart, buffer, stations);
- lineStart = processTemperature(lineStart + station.name.length + 1, buffer, station);
- }
- return stations;
+ // Handle reading the very last line in the file
+ // this allows us to not worry about reading a long beyond the end
+ // in the inner loop (reducing branches)
+ // We only need to read one because the min record size is 6 bytes
+ // so 2nd last record must be > 8 from end
+ if (endByte == buffer.byteSize()) {
+ endByte -= 2; // skip final new line
+ while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n')
+ endByte--;
+
+ if (endByte > 0)
+ endByte++;
+ // copy into a 8n sized buffer to avoid reading off end
+ MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 4);
+ for (long i = endByte; i < buffer.byteSize(); i++)
+ end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i));
+ Stat station = parseStation(0, end, stations);
+ processTemperature(station.name.length + 1, end, station);
}
- catch (IOException e) {
- throw new RuntimeException(e);
+
+ while (startByte < endByte) {
+ Stat station = parseStation(startByte, buffer, stations);
+ startByte = processTemperature(startByte + station.name.length + 1, buffer, station);
}
+ return stations;
}
public static class Stat {
@@ -227,9 +298,12 @@ public class CalculateAverage_ianopolousfast {
int count = 0;
short min = Short.MAX_VALUE, max = Short.MIN_VALUE;
long total = 0;
+ final long first8, second8;
- public Stat(byte[] name) {
+ public Stat(byte[] name, long first8, long second8) {
this.name = name;
+ this.first8 = first8;
+ this.second8 = second8;
}
public void add(short value) {
@@ -263,4 +337,4 @@ public class CalculateAverage_ianopolousfast {
return round((double) min) + "/" + round(((double) total) / count) + "/" + round((double) max);
}
}
-}
+} \ No newline at end of file