aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java266
1 files changed, 93 insertions, 173 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java
index 0764b65..a7baf9b 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_C5H12O5.java
@@ -16,19 +16,14 @@
package dev.morling.onebrc;
import java.io.IOException;
-import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousFileChannel;
import java.nio.channels.CompletionHandler;
import java.nio.charset.StandardCharsets;
-import java.nio.file.Files;
-import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
@@ -45,91 +40,101 @@ import java.util.concurrent.LinkedBlockingQueue;
* @author Xylitol
*/
public class CalculateAverage_C5H12O5 {
- private static final int BUFFER_CAPACITY = 1024 * 1024;
+ private static final int BUFFER_CAPACITY = 1024 * 1024 * 10;
private static final int MAP_CAPACITY = 10000;
- private static final int QUEUE_CAPACITY = 2;
+ private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();
+ private static final BlockingQueue<byte[]> BYTES_QUEUE = new LinkedBlockingQueue<>(PROCESSORS);
+ private static long readPosition;
public static void main(String[] args) throws Exception {
- // Files.list(Paths.get("./src/test/resources/samples"))
- // .filter(file -> file.toString().endsWith(".txt"))
- // .forEach(file -> {
- // try {
- // String actual = calc(file);
- // String expected = Files.readAllLines(Paths.get(file.toString().replace(".txt", ".out"))).get(0);
- // System.out.println(file.getFileName() + ": " + expected.equals(actual));
- // } catch (Exception e) {
- // System.out.println(file.getFileName() + ": " + false);
- // e.printStackTrace();
- // }
- // });
- // long start = System.currentTimeMillis();
- System.out.println(calc(Paths.get("./measurements.txt")));
- // System.out.println("Time: " + (System.currentTimeMillis() - start) + "ms");
+ System.out.println(calc("./measurements.txt"));
}
/**
* Calculate the average.
*/
- public static String calc(Path file) throws IOException, ExecutionException, InterruptedException {
- long[] positions = fragment(file, Runtime.getRuntime().availableProcessors());
- FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[positions.length];
- for (int i = 0; i < positions.length; i++) {
- tasks[i] = new FutureTask<>(new Task(file, (i == 0 ? 0 : positions[i - 1] + 1), positions[i]));
- new Thread(tasks[i]).start();
- }
+ public static String calc(String path) throws IOException, ExecutionException, InterruptedException {
+ readPosition = 0;
Map<String, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
- for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) {
- task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge));
- }
- return new TreeMap<>(result).toString();
- }
-
- /**
- * Fragment the file into chunks.
- */
- private static long[] fragment(Path filePath, int chunkNum) throws IOException {
- long fileSize = Files.size(filePath);
- long chunkSize = fileSize / chunkNum;
- long[] positions = new long[chunkNum];
- try (RandomAccessFile file = new RandomAccessFile(filePath.toFile(), "r")) {
- long position = chunkSize;
- for (int i = 0; i < chunkNum - 1; i++) {
- if (position >= fileSize) {
- break;
+ // read and offer to queue
+ try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(
+ Paths.get(path), Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor())) {
+ ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
+ channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
+ @Override
+ public void completed(Integer bytesRead, ByteBuffer buffer) {
+ try {
+ if (bytesRead > 0) {
+ for (int i = buffer.position() - 1; i >= 0; i--) {
+ if (buffer.get(i) == '\n') {
+ buffer.limit(i + 1);
+ break;
+ }
+ }
+ buffer.flip();
+ byte[] bytes = new byte[buffer.remaining()];
+ buffer.get(bytes);
+ readPosition += buffer.limit();
+ BYTES_QUEUE.put(bytes);
+ buffer.clear();
+ channel.read(buffer, readPosition, buffer, this);
+ }
+ else {
+ for (int i = 0; i < PROCESSORS; i++) {
+ BYTES_QUEUE.put(new byte[0]);
+ }
+ }
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
}
- file.seek(position);
- while (file.read() != '\n') {
- position++;
+
+ @Override
+ public void failed(Throwable exc, ByteBuffer buffer) {
+ // ignore
}
- positions[i] = position;
- position += chunkSize;
+ });
+
+ @SuppressWarnings("unchecked")
+ FutureTask<Map<MeasurementName, MeasurementData>>[] tasks = new FutureTask[PROCESSORS];
+ for (int i = 0; i < PROCESSORS; i++) {
+ tasks[i] = new FutureTask<>(new Task());
+ new Thread(tasks[i]).start();
+ }
+ for (FutureTask<Map<MeasurementName, MeasurementData>> task : tasks) {
+ task.get().forEach((k, v) -> result.merge(k.toString(), v, MeasurementData::merge));
}
}
- positions[chunkNum - 1] = fileSize;
- return Arrays.stream(positions).filter(value -> value != 0).toArray();
+ return new TreeMap<>(result).toString();
}
/**
* The measurement name.
*/
- private record MeasurementName(byte[] bytes) {
+ private record MeasurementName(byte[] bytes, int length) {
@Override
- public boolean equals(Object other) {
- if (!(other instanceof MeasurementName)) {
+ public boolean equals(Object name) {
+ MeasurementName other = (MeasurementName) name;
+ if (other.length != length) {
return false;
}
- return Arrays.equals(bytes, ((MeasurementName) other).bytes);
+ return Arrays.compare(bytes, 0, length, other.bytes, 0, length) == 0;
}
@Override
public int hashCode() {
- return Arrays.hashCode(bytes);
+ int result = 1;
+ for (int i = 0; i < length; i++) {
+ result = 31 * result + bytes[i];
+ }
+ return result;
}
@Override
public String toString() {
- return new String(bytes, StandardCharsets.UTF_8);
+ return new String(bytes, 0, length, StandardCharsets.UTF_8);
}
}
@@ -168,116 +173,53 @@ public class CalculateAverage_C5H12O5 {
}
/**
- * The task to read and calculate.
+ * The task to calculate.
*/
private static class Task implements Callable<Map<MeasurementName, MeasurementData>> {
- private final Path file;
- private long readPosition;
- private long calcPosition;
- private final long limitSize;
- private final BlockingQueue<byte[]> bytesQueue = new LinkedBlockingQueue<>(QUEUE_CAPACITY);
-
- public Task(Path file, long position, long limitSize) {
- this.file = file;
- this.readPosition = position;
- this.calcPosition = position;
- this.limitSize = limitSize;
- }
@Override
- public Map<MeasurementName, MeasurementData> call() throws IOException {
- // read and offer to queue
- AsynchronousFileChannel channel = AsynchronousFileChannel.open(
- file, Set.of(StandardOpenOption.READ), Executors.newVirtualThreadPerTaskExecutor());
- ByteBuffer buffer = ByteBuffer.allocateDirect(BUFFER_CAPACITY);
- channel.read(buffer, readPosition, buffer, new CompletionHandler<>() {
- @Override
- public void completed(Integer bytesRead, ByteBuffer buffer) {
- if (bytesRead > 0 && readPosition < limitSize) {
- try {
- buffer.flip();
- byte[] bytes = new byte[buffer.remaining()];
- buffer.get(bytes);
- readPosition += bytesRead;
- if (readPosition > limitSize) {
- int diff = (int) (readPosition - limitSize);
- byte[] newBytes = new byte[bytes.length - diff];
- System.arraycopy(bytes, 0, newBytes, 0, newBytes.length);
- bytesQueue.put(newBytes);
- }
- else {
- bytesQueue.put(bytes);
- buffer.clear();
- channel.read(buffer, readPosition, buffer, this);
- }
- }
- catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- }
- }
- }
-
- @Override
- public void failed(Throwable exc, ByteBuffer buffer) {
- // ignore
- }
- });
-
+ public Map<MeasurementName, MeasurementData> call() throws InterruptedException {
// poll from queue and calculate
Map<MeasurementName, MeasurementData> result = HashMap.newHashMap(MAP_CAPACITY);
- byte[] readBytes = null;
- byte[] remaining = null;
- while (calcPosition < limitSize) {
- readBytes = bytesQueue.poll();
- if (readBytes != null) {
- List<byte[]> lines = split(readBytes, (byte) '\n');
- for (int i = 0; i < lines.size(); i++) {
- byte[] lineBytes = lines.get(i);
- if (i == 0 && remaining != null) {
- byte[] newBytes = new byte[remaining.length + lineBytes.length];
- System.arraycopy(remaining, 0, newBytes, 0, remaining.length);
- System.arraycopy(lineBytes, 0, newBytes, remaining.length, lineBytes.length);
- lineBytes = newBytes;
+ for (byte[] bytes = BYTES_QUEUE.take(); true; bytes = BYTES_QUEUE.take()) {
+ if (bytes.length == 0) {
+ break;
+ }
+ int start = 0;
+ for (int end = 0; end < bytes.length; end++) {
+ if (bytes[end] == '\n') {
+ byte[] newBytes = new byte[end - start];
+ System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
+ int semicolon = newBytes.length - 4;
+ for (; semicolon >= 0; semicolon--) {
+ if (newBytes[semicolon] == ';') {
+ break;
+ }
+ }
+ MeasurementName station = new MeasurementName(newBytes, semicolon);
+ int value = toInt(newBytes, semicolon + 1);
+ MeasurementData data = result.get(station);
+ if (data != null) {
+ data.merge(value, value, value, 1);
}
- if (i == lines.size() - 1) {
- remaining = lineBytes;
- break;
+ else {
+ result.put(station, new MeasurementData(value));
}
- agg(result, lineBytes);
+ start = end + 1;
}
- calcPosition += readBytes.length;
}
}
- if (remaining != null && remaining.length > 0) {
- agg(result, remaining);
- }
- channel.close();
return result;
}
/**
- * Aggregate the measurement data.
- */
- private static void agg(Map<MeasurementName, MeasurementData> result, byte[] bytes) {
- List<byte[]> parts = split(bytes, (byte) ';');
- MeasurementName station = new MeasurementName(parts.getFirst());
- int value = toInt(parts.getLast());
- MeasurementData data = result.get(station);
- if (data != null) {
- data.merge(value, value, value, 1);
- }
- else {
- result.put(station, new MeasurementData(value));
- }
- }
-
- /**
* Convert the byte array to int.
*/
- private static int toInt(byte[] bytes) {
+ private static int toInt(byte[] bytes, int start) {
boolean negative = false;
int result = 0;
- for (byte b : bytes) {
+ for (int i = start; i < bytes.length; i++) {
+ byte b = bytes[i];
if (b == '-') {
negative = true;
continue;
@@ -288,27 +230,5 @@ public class CalculateAverage_C5H12O5 {
}
return negative ? -result : result;
}
-
- /**
- * Split the byte array by given byte.
- */
- private static List<byte[]> split(byte[] bytes, byte separator) {
- List<byte[]> result = new ArrayList<>();
- int start = 0;
- for (int end = 0; end < bytes.length; end++) {
- if (bytes[end] == separator) {
- byte[] newBytes = new byte[end - start];
- System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
- result.add(newBytes);
- start = end + 1;
- }
- }
- if (start <= bytes.length) {
- byte[] newBytes = new byte[bytes.length - start];
- System.arraycopy(bytes, start, newBytes, 0, newBytes.length);
- result.add(newBytes);
- }
- return result;
- }
}
}