aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorzerninv <zerninvasilii@yandex.ru>2024-01-15 19:25:52 +0000
committerGitHub <noreply@github.com>2024-01-15 20:25:52 +0100
commitd18b10708b632e42822af41342271f16eff7073a (patch)
treeeecf9b56d4a7bae92fffe7e0e8e2b1e4a975114b /src
parentdd9a3dde7e692198cdd58570fc2bd1822d2ca237 (diff)
Sixth attempt CalculateAverage_zerninv.java (#407)
* rethink chunking * fix typo
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java206
1 files changed, 118 insertions, 88 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java
index 789db73..2e7ea4c 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_zerninv.java
@@ -25,14 +25,15 @@ import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
-import java.util.*;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.TreeMap;
public class CalculateAverage_zerninv {
private static final String FILE = "./measurements.txt";
- private static final int MIN_FILE_SIZE = 1024 * 1024 * 16;
+ private static final int L3_CACHE_SIZE = 128 * 1024 * 1024;
+ private static final int CORES = Runtime.getRuntime().availableProcessors();
+ private static final int CHUNK_SIZE = (L3_CACHE_SIZE - MeasurementContainer.SIZE * MeasurementContainer.ENTRY_SIZE * CORES) / CORES - 1024 * CORES;
// #.##
private static final int THREE_DIGITS_MASK = 0x2e0000;
@@ -48,47 +49,48 @@ public class CalculateAverage_zerninv {
private static final Unsafe UNSAFE = initUnsafe();
- public static void main(String[] args) throws IOException {
- var results = new HashMap<String, MeasurementAggregation>();
+ public static void main(String[] args) throws IOException, InterruptedException {
try (var channel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
var fileSize = channel.size();
+ var minChunkSize = Math.min(fileSize, CHUNK_SIZE);
+
+ var tasks = new TaskThread[CORES];
+ for (int i = 0; i < tasks.length; i++) {
+ tasks[i] = new TaskThread(new MeasurementContainer(), (int) (fileSize / minChunkSize / CORES + 1));
+ }
+
var memorySegment = channel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global());
- long address = memorySegment.address();
- var cores = Runtime.getRuntime().availableProcessors();
- var minChunkSize = fileSize < MIN_FILE_SIZE ? fileSize : fileSize / cores;
+ var address = memorySegment.address();
var chunks = splitByChunks(address, address + fileSize, minChunkSize);
+ for (int i = 0; i < chunks.size() - 1; i++) {
+ var task = tasks[i % CORES];
+ task.addChunk(chunks.get(i), chunks.get(i + 1));
+ }
- var executor = Executors.newFixedThreadPool(cores);
- List<Future<Map<String, MeasurementAggregation>>> fResults = new ArrayList<>();
- for (int i = 1; i < chunks.size(); i++) {
- final long prev = chunks.get(i - 1);
- final long curr = chunks.get(i);
- fResults.add(executor.submit(() -> calcForChunk(prev, curr)));
+ for (var task : tasks) {
+ task.start();
}
- fResults.forEach(f -> {
- try {
- f.get().forEach((key, value) -> {
- var result = results.get(key);
- if (result != null) {
- result.merge(value);
- }
- else {
- results.put(key, value);
- }
- });
- }
- catch (InterruptedException | ExecutionException e) {
- e.printStackTrace();
- }
- });
- executor.shutdown();
- }
+ var results = new TreeMap<String, TemperatureAggregation>();
+ for (var task : tasks) {
+ task.join();
+ task.measurements()
+ .forEach(measurement -> {
+ var aggr = results.get(measurement.station());
+ if (aggr == null) {
+ results.put(measurement.station(), measurement.aggregation());
+ }
+ else {
+ aggr.merge(measurement.aggregation());
+ }
+ });
+ }
- var bos = new BufferedOutputStream(System.out);
- bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8));
- bos.write('\n');
- bos.flush();
+ var bos = new BufferedOutputStream(System.out);
+ bos.write(new TreeMap<>(results).toString().getBytes(StandardCharsets.UTF_8));
+ bos.write('\n');
+ bos.flush();
+ }
}
private static Unsafe initUnsafe() {
@@ -103,7 +105,7 @@ public class CalculateAverage_zerninv {
}
private static List<Long> splitByChunks(long address, long end, long minChunkSize) {
- List<Long> result = new ArrayList<>();
+ List<Long> result = new ArrayList<>((int) ((end - address) / minChunkSize + 1));
result.add(address);
while (address < end) {
address += Math.min(end - address, minChunkSize);
@@ -114,60 +116,20 @@ public class CalculateAverage_zerninv {
return result;
}
- private static Map<String, MeasurementAggregation> calcForChunk(long offset, long end) {
- var results = new MeasurementContainer();
-
- long cityOffset;
- int hashCode, temperature, word;
- byte cityNameSize, b;
-
- while (offset < end) {
- cityOffset = offset;
- hashCode = 0;
- while ((b = UNSAFE.getByte(offset++)) != DELIMITER) {
- hashCode = hashCode * 31 + b;
- }
- cityNameSize = (byte) (offset - cityOffset - 1);
-
- word = UNSAFE.getInt(offset);
- offset += 4;
-
- if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) {
- word >>>= 8;
- temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK));
- }
- else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) {
- temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111;
- }
- else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) {
- temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11;
- offset--;
- }
- else {
- // #.##-
- word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24);
- temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK));
- }
- offset++;
- results.put(cityOffset, cityNameSize, hashCode, (short) temperature);
- }
- return results.toStringMap();
- }
-
- private static final class MeasurementAggregation {
+ private static final class TemperatureAggregation {
private long sum;
private int count;
private short min;
private short max;
- public MeasurementAggregation(long sum, int count, short min, short max) {
+ public TemperatureAggregation(long sum, int count, short min, short max) {
this.sum = sum;
this.count = count;
this.min = min;
this.max = max;
}
- public void merge(MeasurementAggregation o) {
+ public void merge(TemperatureAggregation o) {
if (o == null) {
return;
}
@@ -183,6 +145,9 @@ public class CalculateAverage_zerninv {
}
}
+ private record Measurement(String station, TemperatureAggregation aggregation) {
+ }
+
private static final class MeasurementContainer {
private static final int SIZE = 1024 * 16;
@@ -235,26 +200,26 @@ public class CalculateAverage_zerninv {
}
}
- public Map<String, MeasurementAggregation> toStringMap() {
- var result = new HashMap<String, MeasurementAggregation>();
+ public List<Measurement> measurements() {
+ var result = new ArrayList<Measurement>(1000);
int count;
for (int i = 0; i < SIZE; i++) {
long ptr = this.address + i * ENTRY_SIZE;
count = UNSAFE.getInt(ptr + COUNT_OFFSET);
if (count != 0) {
- var measurements = new MeasurementAggregation(
+ var measurements = new TemperatureAggregation(
UNSAFE.getLong(ptr + SUM_OFFSET),
count,
UNSAFE.getShort(ptr + MIN_OFFSET),
UNSAFE.getShort(ptr + MAX_OFFSET));
var key = createString(UNSAFE.getLong(ptr + ADDRESS_OFFSET), UNSAFE.getByte(ptr + SIZE_OFFSET));
- result.put(key, measurements);
+ result.add(new Measurement(key, measurements));
}
}
return result;
}
- private boolean isEqual(long address, long address2, byte size) {
+ private static boolean isEqual(long address, long address2, byte size) {
for (int i = 0; i < size; i++) {
if (UNSAFE.getByte(address + i) != UNSAFE.getByte(address2 + i)) {
return false;
@@ -271,4 +236,69 @@ public class CalculateAverage_zerninv {
return new String(arr);
}
}
-} \ No newline at end of file
+
+ private static class TaskThread extends Thread {
+ private final MeasurementContainer container;
+ private final List<Long> begins;
+ private final List<Long> ends;
+
+ private TaskThread(MeasurementContainer container, int chunks) {
+ this.container = container;
+ this.begins = new ArrayList<>(chunks);
+ this.ends = new ArrayList<>(chunks);
+ }
+
+ public void addChunk(long begin, long end) {
+ begins.add(begin);
+ ends.add(end);
+ }
+
+ @Override
+ public void run() {
+ for (int i = 0; i < begins.size(); i++) {
+ calcForChunk(begins.get(i), ends.get(i));
+ }
+ }
+
+ public List<Measurement> measurements() {
+ return container.measurements();
+ }
+
+ private void calcForChunk(long offset, long end) {
+ long cityOffset;
+ int hashCode, temperature, word;
+ byte cityNameSize, b;
+
+ while (offset < end) {
+ cityOffset = offset;
+ hashCode = 0;
+ while ((b = UNSAFE.getByte(offset++)) != DELIMITER) {
+ hashCode = hashCode * 31 + b;
+ }
+ cityNameSize = (byte) (offset - cityOffset - 1);
+
+ word = UNSAFE.getInt(offset);
+ offset += 4;
+
+ if ((word & TWO_NEGATIVE_DIGITS_MASK) == TWO_NEGATIVE_DIGITS_MASK) {
+ word >>>= 8;
+ temperature = ZERO * 11 - ((word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK));
+ }
+ else if ((word & THREE_DIGITS_MASK) == THREE_DIGITS_MASK) {
+ temperature = (word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK) - ZERO * 111;
+ }
+ else if ((word & TWO_DIGITS_MASK) == TWO_DIGITS_MASK) {
+ temperature = (word & BYTE_MASK) * 10 + ((word >>> 16) & BYTE_MASK) - ZERO * 11;
+ offset--;
+ }
+ else {
+ // #.##-
+ word = (word >>> 8) | (UNSAFE.getByte(offset++) << 24);
+ temperature = ZERO * 111 - ((word & BYTE_MASK) * 100 + ((word >>> 8) & BYTE_MASK) * 10 + ((word >>> 24) & BYTE_MASK));
+ }
+ offset++;
+ container.put(cityOffset, cityNameSize, hashCode, (short) temperature);
+ }
+ }
+ }
+}