aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java
diff options
context:
space:
mode:
authorAndrew Sun <as-com@users.noreply.github.com>2024-01-12 14:42:22 -0500
committerGitHub <noreply@github.com>2024-01-12 20:42:22 +0100
commitdac38bc97fb1411d1ae7a1a354fe9c7ae0c659d2 (patch)
tree3b39c568f2e7bcb67aa7f421d2106cc968212f26 /src/main/java/dev/morling/onebrc/CalculateAverage_asun.java
parent90cd353fbee10c58f11f177cfe9e7f6083f1c846 (diff)
Optimizations to Andrew Sun's entry (#310)
Squashed commit of the following: commit 44d3736de87834b41118d45831e59fc2b052117c Merge: fcf795f 3127962 Author: Andrew Sun <as-com@users.noreply.github.com> Date: Thu Jan 11 20:01:13 2024 -0500 Merge branch 'gunnarmorling:main' into as-com commit fcf795fbabacbd91891d11d21450ee4b1c479dc5 Author: Andrew Sun <me@andrewsun.com> Date: Wed Jan 10 21:14:01 2024 -0500 Optimizations to Andrew Sun's entry commit 4203924711bab5252ff3cbb50a90f4ce4e8e67c2 Merge: 9aed05a 085168a Author: Andrew Sun <me@andrewsun.com> Date: Wed Jan 10 19:40:19 2024 -0500 Merge remote-tracking branch 'upstream/main' into as-com commit 9aed05a04bd27fe7323e66c347b1011c77da322c Merge: 3f8df58 c2d120f Author: Andrew Sun <me@andrewsun.com> Date: Sun Jan 7 16:45:27 2024 -0500 Merge remote-tracking branch 'origin/as-com' into as-com # Conflicts: # calculate_average_asun.sh # src/main/java/dev/morling/onebrc/CalculateAverage_asun.java commit c2d120f0cb7f18c720a81a7f898102b310f9ecb9 Author: Andrew Sun <me@andrewsun.com> Date: Sat Jan 6 00:45:47 2024 -0500 Add entry by Andrew Sun commit 3f8df5803bcc8f3e29ed8bfff3077eb0e8cdab15 Author: Andrew Sun <me@andrewsun.com> Date: Sat Jan 6 00:45:47 2024 -0500 Add entry by Andrew Sun
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_asun.java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_asun.java229
1 files changed, 164 insertions, 65 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java
index 88a90ea..0f5b0da 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_asun.java
@@ -16,12 +16,16 @@
package dev.morling.onebrc;
import jdk.incubator.vector.*;
+import sun.misc.Unsafe;
import java.io.File;
import java.io.IOException;
-import java.io.RandomAccessFile;
+import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.VarHandle;
+import java.lang.reflect.Field;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
@@ -31,8 +35,9 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.TreeMap;
+import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutionException;
-import java.util.stream.Collectors;
+import java.util.concurrent.atomic.AtomicLongArray;
// based on spullara's submission
@@ -53,26 +58,72 @@ public class CalculateAverage_asun {
ASC = ByteVector.fromArray(BYTE_SPECIES, bytes, 0);
}
- public static void main(String[] args) throws IOException, ExecutionException, InterruptedException {
- long start = System.currentTimeMillis();
- var filename = args.length == 0 ? FILE : args[0];
- var file = new File(filename);
+ private static final Unsafe UNSAFE;
- List<FileSegment> fileSegments = getFileSegments(file);
- // System.out.println(System.currentTimeMillis() - start);
- var resultsMap = fileSegments.stream().map(segment -> {
+ static {
+ try {
+ Field f = Unsafe.class.getDeclaredField("theUnsafe");
+ f.setAccessible(true);
+ UNSAFE = (Unsafe) f.get(null);
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static AtomicLongArray segmentQueue;
+ @SuppressWarnings("FieldMayBeFinal")
+ // @jdk.internal.vm.annotation.Contended
+ private static volatile int head = 0;
+ @SuppressWarnings("FieldMayBeFinal")
+ // @jdk.internal.vm.annotation.Contended
+ private static volatile int tail = 0;
+ @SuppressWarnings("FieldMayBeFinal")
+ // @jdk.internal.vm.annotation.Contended
+ private static volatile boolean doneQueueing = false;
+
+ private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
+ private static final VarHandle headHandle;
+ private static final VarHandle tailHandle;
+ private static final VarHandle doneHandle;
+
+ static {
+ try {
+ headHandle = LOOKUP.findStaticVarHandle(CalculateAverage_asun.class, "head", int.class);
+ tailHandle = LOOKUP.findStaticVarHandle(CalculateAverage_asun.class, "tail", int.class);
+ doneHandle = LOOKUP.findStaticVarHandle(CalculateAverage_asun.class, "doneQueueing", boolean.class);
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static final ArrayBlockingQueue<ByteArrayToResultMap> workerOutput = new ArrayBlockingQueue<>(Runtime.getRuntime().availableProcessors());
+
+ private static class Worker implements Runnable {
+ private long segmentStart;
+ private long segmentEnd;
+
+ private final MemorySegment ms;
+
+ private Worker(MemorySegment ms) {
+ this.ms = ms;
+ }
+
+ @Override
+ public void run() {
var resultMap = new ByteArrayToResultMap();
- long segmentEnd = segment.end();
- try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(filename), StandardOpenOption.READ)) {
- var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start());
- var ms = MemorySegment.ofBuffer(bb);
+ var ms = this.ms.asSlice(0);
+ var msAddr = ms.address();
+ var actualLimit = ms.byteSize();
+ var buffer = new byte[100 + VECTOR_SIZE];
- // Up to 100 characters for a city name
- var buffer = new byte[100 + VECTOR_SIZE];
+ while (pollSegment()) {
long startLine;
- long pos = 0;
- long limit = ms.byteSize();
- long vectorLimit = limit - VECTOR_SIZE;
+ long pos = segmentStart;
+ long limit = segmentEnd;
+ long vectorLimit = Math.min(limit, actualLimit - VECTOR_SIZE);
+ long longLimit = Math.min(limit, actualLimit - 8);
// int[] lastHashMult = new int[]{ 7, 31, 63, 15, 255, 127, 3, 511 };
// IntVector lastMul = IntVector.fromArray(INT_SPECIES, lastHashMult, 0);
@@ -117,12 +168,15 @@ public class CalculateAverage_asun {
int nameLen = (int) (currentPosition - startLine);
currentPosition++;
- if (currentPosition >= limit - 8) {
+ if (currentPosition >= longLimit) {
break;
}
- long g = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, currentPosition);
- int negative = (g & 0xff) == '-' ? -1 : 1;
+ long g = UNSAFE.getLong(msAddr + currentPosition);
+ // long g = ms.get(ValueLayout.JAVA_LONG_UNALIGNED, currentPosition);
+ boolean minus = (g & 0xff) == '-';
+ long minusL = (minus ? 1L : 0L) - 1;
+ int negative = minus ? -1 : 1;
// 00101101 MINUS
// 00101110 PERIOD
@@ -136,17 +190,13 @@ public class CalculateAverage_asun {
int tzc = Long.numberOfTrailingZeros(lf);
long bytesToLF = tzc / 8;
- int shift = 64 - tzc & 0b111000;
-
- long reversedDigits = Long.reverseBytes(g) >> shift;
- long digitBits = reversedDigits & (0x1010101010101010L >> shift);
- long digitsExt = (digitBits >> 1 | digitBits >> 2 | digitBits >> 3 | digitBits >> 4);
+ int shift = 72 - tzc & 0b111000;
- long digitsOnly = Long.compress(reversedDigits, digitsExt);
+ long reversedDigits = Long.reverseBytes(g & (0xFFFFFFFFFFFFFF00L | minusL)) >> shift;
- long temp = (digitsOnly & 0xf)
- + 10 * ((digitsOnly >> 4) & 0xf)
- + 100 * ((digitsOnly >> 8) & 0xf);
+ long temp = (reversedDigits & 0xf)
+ + 10 * ((reversedDigits >> 16) & 0xf)
+ + 100 * ((reversedDigits >> 24) & 0xf);
temp *= negative;
@@ -194,42 +244,94 @@ public class CalculateAverage_asun {
resultMap.putOrMerge(buffer, 0, offset, temp, hash);
pos = currentPosition;
}
- return resultMap;
}
- catch (IOException e) {
- throw new RuntimeException(e);
- }
- }).parallel().flatMap(partition -> partition.getAll().stream())
- .collect(Collectors.toMap(e -> new String(e.key()), Entry::value, CalculateAverage_asun::merge, TreeMap::new));
- System.out.println(resultsMap);
+ workerOutput.add(resultMap);
+ }
- // System.out.println(System.currentTimeMillis() - start);
+ private boolean pollSegment() {
+ int head;
+ int tail;
+
+ do {
+ head = (int) headHandle.getAcquire();
+ tail = (int) tailHandle.getAcquire();
+
+ while (head >= tail) {
+ if ((boolean) doneHandle.getAcquire()) {
+ return false;
+ }
+
+ head = (int) headHandle.getAcquire();
+ tail = (int) tailHandle.getAcquire();
+ }
+ } while (!headHandle.compareAndSet(head, head + 1));
+
+ segmentStart = segmentQueue.getPlain(head * 2);
+ segmentEnd = segmentQueue.getPlain(head * 2 + 1);
+
+ return true;
+ }
- Runtime.getRuntime().halt(0);
}
- private static List<FileSegment> getFileSegments(File file) throws IOException {
- int numberOfSegments = Runtime.getRuntime().availableProcessors() * 8;
+ public static void main(String[] args) throws IOException, ExecutionException, InterruptedException {
+ // long start = System.currentTimeMillis();
+ var filename = args.length == 0 ? FILE : args[0];
+ var file = new File(filename);
+
+ @SuppressWarnings("resource")
+ var fileChannel = (FileChannel) Files.newByteChannel(Path.of(filename), StandardOpenOption.READ);
+ var ms = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size(), Arena.global());
+
long fileSize = file.length();
- long segmentSize = fileSize / numberOfSegments;
- List<FileSegment> segments = new ArrayList<>(numberOfSegments);
- // Pointless to split small files
- if (segmentSize < 1_000_000) {
- segments.add(new FileSegment(0, fileSize));
- return segments;
+ long segmentSize = 10_000_000;
+ int numberOfSegments = (int) (file.length() / segmentSize + 1) * 2;
+ segmentQueue = new AtomicLongArray(numberOfSegments);
+ int tail = 0;
+
+ int processors = Runtime.getRuntime().availableProcessors();
+
+ Thread.ofPlatform().daemon().start(() -> {
+ for (int i = 0; i < processors - 1; i++) {
+ Thread.ofPlatform().daemon().start(new Worker(ms));
+ }
+
+ new Worker(ms).run();
+ });
+
+ long segStart = 0;
+ while (segStart < fileSize) {
+ long segEnd = findSegment(ms, Math.min(segStart + segmentSize, fileSize), fileSize);
+ segmentQueue.setRelease(tail * 2, segStart);
+ segmentQueue.setRelease(tail * 2 + 1, segEnd);
+ tailHandle.setRelease(++tail);
+
+ segStart = segEnd;
}
- try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) {
- for (int i = 0; i < numberOfSegments; i++) {
- long segStart = i * segmentSize;
- long segEnd = (i == numberOfSegments - 1) ? fileSize : segStart + segmentSize;
- segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd);
- segEnd = findSegment(i, numberOfSegments - 1, randomAccessFile, segEnd, fileSize);
-
- segments.add(new FileSegment(segStart, segEnd));
+
+ doneHandle.setRelease(true);
+
+ // System.out.println(System.currentTimeMillis() - start);
+
+ var resultsMap = new TreeMap<String, Result>();
+ for (int i = 0; i < processors; i++) {
+ var result = workerOutput.take();
+
+ // System.out.println(i + " " + (System.currentTimeMillis() - start));
+
+ for (Entry e : result.getAll()) {
+ resultsMap.merge(new String(e.key()), e.value(), CalculateAverage_asun::merge);
}
+
+ // System.out.println(i + " " + (System.currentTimeMillis() - start));
}
- return segments;
+
+ System.out.println(resultsMap);
+
+ // System.out.println(System.currentTimeMillis() - start);
+
+ Runtime.getRuntime().halt(0);
}
private static Result merge(Result v, Result value) {
@@ -244,14 +346,14 @@ public class CalculateAverage_asun {
return v;
}
- private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException {
- if (i != skipSegment) {
- raf.seek(location);
- while (location < fileSize) {
+ private static long findSegment(MemorySegment ms, long location, long fileSize) {
+ while (location < fileSize) {
+ if (ms.get(ValueLayout.JAVA_BYTE, location) == '\n') {
location++;
- if (raf.read() == '\n')
- break;
+ break;
}
+
+ location++;
}
return location;
}
@@ -283,9 +385,6 @@ public class CalculateAverage_asun {
record Entry(byte[] key, Result value) {
}
- record FileSegment(long start, long end) {
- }
-
static class ByteArrayToResultMap {
public static final int MAPSIZE = 1024 * 128;
Result[] slots = new Result[MAPSIZE];