aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java275
1 files changed, 137 insertions, 138 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
index 5efd4a7..516a6ab 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java
@@ -24,68 +24,49 @@ import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Arrays;
import java.util.Comparator;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
-import java.util.stream.IntStream;
public class CalculateAverage_artsiomkorzun {
private static final Path FILE = Path.of("./measurements.txt");
private static final long FILE_SIZE = size(FILE);
+ private static final int PARALLELISM = Runtime.getRuntime().availableProcessors();
private static final int SEGMENT_SIZE = 16 * 1024 * 1024;
private static final int SEGMENT_COUNT = (int) ((FILE_SIZE + SEGMENT_SIZE - 1) / SEGMENT_SIZE);
private static final int SEGMENT_OVERLAP = 1024;
public static void main(String[] args) throws Exception {
- /*
- * for (int i = 0; i < 10; i++) {
- * long start = System.currentTimeMillis();
- * execute();
- * long end = System.currentTimeMillis();
- * System.err.println("Time: " + (end - start));
- * }
- */
+ // for (int i = 0; i < 10; i++) {
+ // long start = System.currentTimeMillis();
+ // execute();
+ // long end = System.currentTimeMillis();
+ // System.err.println("Time: " + (end - start));
+ // }
execute();
}
- private static void execute() {
- Aggregates aggregates = IntStream.range(0, SEGMENT_COUNT)
- .parallel()
- .mapToObj(CalculateAverage_artsiomkorzun::aggregate)
- .reduce(new Aggregates(), CalculateAverage_artsiomkorzun::merge)
- .sort();
+ private static void execute() throws Exception {
+ AtomicInteger counter = new AtomicInteger();
+ AtomicReference<Aggregates> result = new AtomicReference<>();
+ Aggregator[] aggregators = new Aggregator[PARALLELISM];
- print(aggregates);
- }
-
- private static Aggregates aggregate(int segment) {
- long position = (long) SEGMENT_SIZE * segment;
- int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position);
- int limit = Math.min(SEGMENT_SIZE, size - 1);
-
- MappedByteBuffer buffer = map(position, size); // leaking until gc
-
- if (position > 0) {
- next(buffer);
+ for (int i = 0; i < aggregators.length; i++) {
+ aggregators[i] = new Aggregator(counter, result);
+ aggregators[i].start();
}
- Aggregates aggregates = new Aggregates();
- Row row = new Row();
-
- while (buffer.position() <= limit) {
- parse(buffer, row);
- aggregates.add(row);
+ for (int i = 0; i < aggregators.length; i++) {
+ aggregators[i].join();
}
- return aggregates;
- }
+ Aggregates aggregates = result.get();
+ aggregates.sort();
- private static Aggregates merge(Aggregates lefts, Aggregates rights) {
- Aggregates to = (lefts.size() < rights.size()) ? rights : lefts;
- Aggregates from = (lefts.size() < rights.size()) ? lefts : rights;
- from.visit(to::merge);
- return to;
+ print(aggregates);
}
private static void print(Aggregates aggregates) {
@@ -111,62 +92,11 @@ public class CalculateAverage_artsiomkorzun {
}
}
- private static MappedByteBuffer map(long position, int size) {
- try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) {
- return channel.map(FileChannel.MapMode.READ_ONLY, position, size); // leaking until gc
- }
- catch (Throwable e) {
- throw new RuntimeException(e);
- }
- }
-
- private static void next(ByteBuffer buffer) {
- while (buffer.get() != '\n') {
- // continue
- }
- }
-
- private static void parse(ByteBuffer buffer, Row row) {
- int index = 0;
- byte b;
-
- while ((b = buffer.get()) != ';') {
- row.station[index++] = b;
- }
-
- row.length = index;
-
- double value = 0;
- double multiplier = 1;
-
- b = buffer.get();
- if (b == '-') {
- multiplier = -1;
- }
- else {
- assert b >= '0' && b <= '9';
- value = b - '0';
- }
-
- while ((b = buffer.get()) != '.') {
- assert b >= '0' && b <= '9';
- value = 10 * value + (b - '0');
- }
-
- b = buffer.get();
- assert b >= '0' && b <= '9';
- value = 10 * value + (b - '0');
-
- b = buffer.get();
- assert b == '\n';
-
- row.temperature = value * multiplier;
- }
-
private static class Row {
final byte[] station = new byte[256];
int length;
- double temperature;
+ int hash;
+ int temperature;
@Override
public String toString() {
@@ -176,23 +106,25 @@ public class CalculateAverage_artsiomkorzun {
private static class Aggregate implements Comparable<Aggregate> {
final byte[] station;
- double min;
- double max;
- double sum;
- double count;
-
- public Aggregate(byte[] station, int length, double temperature) {
- this.station = Arrays.copyOf(station, length);
- this.min = temperature;
- this.max = temperature;
- this.sum = temperature;
+ final int hash;
+ int min;
+ int max;
+ long sum;
+ int count;
+
+ public Aggregate(Row row) {
+ this.station = Arrays.copyOf(row.station, row.length);
+ this.hash = row.hash;
+ this.min = row.temperature;
+ this.max = row.temperature;
+ this.sum = row.temperature;
this.count = 1;
}
- public void add(double temperature) {
- min = Math.min(min, temperature);
- max = Math.max(max, temperature);
- sum += temperature;
+ public void add(Row row) {
+ min = Math.min(min, row.temperature);
+ max = Math.max(max, row.temperature);
+ sum += row.temperature;
count++;
}
@@ -223,7 +155,7 @@ public class CalculateAverage_artsiomkorzun {
@Override
public String toString() {
- return new String(station) + "=" + round(min) + "/" + round(sum / count) + "/" + round(max);
+ return new String(station) + "=" + round(min) + "/" + round(1.0 * sum / count) + "/" + round(max);
}
private static double round(double v) {
@@ -255,26 +187,21 @@ public class CalculateAverage_artsiomkorzun {
}
public void add(Row row) {
- byte[] station = row.station;
- int length = row.length;
- double temperature = row.temperature;
-
- int hash = hash(station, length);
- int index = hash & (aggregates.length - 1);
+ int index = row.hash & (aggregates.length - 1);
while (true) {
Aggregate aggregate = aggregates[index];
if (aggregate == null) {
- aggregates[index] = new Aggregate(station, length, temperature);
+ aggregates[index] = new Aggregate(row);
if (++size >= limit) {
grow();
}
break;
}
- if (equal(station, length, aggregate.station, aggregate.station.length)) {
- aggregate.add(temperature);
+ if (row.hash == aggregate.hash && Arrays.equals(row.station, 0, row.length, aggregate.station, 0, aggregate.station.length)) {
+ aggregate.add(row);
break;
}
@@ -283,10 +210,7 @@ public class CalculateAverage_artsiomkorzun {
}
public void merge(Aggregate right) {
- byte[] station = right.station;
-
- int hash = hash(station, station.length);
- int index = hash & (aggregates.length - 1);
+ int index = right.hash & (aggregates.length - 1);
while (true) {
Aggregate aggregate = aggregates[index];
@@ -299,7 +223,7 @@ public class CalculateAverage_artsiomkorzun {
break;
}
- if (equal(station, station.length, aggregate.station, aggregate.station.length)) {
+ if (right.hash == aggregate.hash && Arrays.equals(right.station, aggregate.station)) {
aggregate.merge(right);
break;
}
@@ -309,7 +233,7 @@ public class CalculateAverage_artsiomkorzun {
}
public Aggregates sort() {
- Arrays.parallelSort(aggregates, Comparator.nullsLast(Aggregate::compareTo));
+ Arrays.sort(aggregates, Comparator.nullsLast(Aggregate::compareTo));
return this;
}
@@ -320,8 +244,7 @@ public class CalculateAverage_artsiomkorzun {
for (Aggregate aggregate : oldAggregates) {
if (aggregate != null) {
- int hash = hash(aggregate.station, aggregate.station.length);
- int index = hash & (aggregates.length - 1);
+ int index = aggregate.hash & (aggregates.length - 1);
while (aggregates[index] != null) {
index = (index + 1) & (aggregates.length - 1);
@@ -331,29 +254,105 @@ public class CalculateAverage_artsiomkorzun {
}
}
}
+ }
- private static int hash(byte[] array, int length) {
- int hash = 0;
+ private static class Aggregator extends Thread {
- for (int i = 0; i < length; i++) {
- hash = 71 * hash + array[i];
- }
+ private final AtomicInteger counter;
+ private final AtomicReference<Aggregates> result;
- return hash;
+ public Aggregator(AtomicInteger counter, AtomicReference<Aggregates> result) {
+ super("aggregator");
+ this.counter = counter;
+ this.result = result;
}
- private static boolean equal(byte[] left, int leftLength, byte[] right, int rightLength) {
- if (leftLength != rightLength) {
- return false;
+ @Override
+ public void run() {
+ Aggregates aggregates = new Aggregates();
+ Row row = new Row();
+
+ try (FileChannel channel = FileChannel.open(FILE, StandardOpenOption.READ)) {
+ for (int segment; (segment = counter.getAndIncrement()) < SEGMENT_COUNT;) {
+ aggregate(channel, segment, aggregates, row);
+ }
}
+ catch (Throwable e) {
+ throw new RuntimeException(e);
+ }
+
+ while (!result.compareAndSet(null, aggregates)) {
+ Aggregates rights = result.getAndSet(null);
- for (int i = 0; i < leftLength; i++) {
- if (left[i] != right[i]) {
- return false;
+ if (rights != null) {
+ aggregates = merge(aggregates, rights);
}
}
+ }
+
+ private static void aggregate(FileChannel channel, int segment, Aggregates aggregates, Row row) throws Exception {
+ long position = (long) SEGMENT_SIZE * segment;
+ int size = (int) Math.min(SEGMENT_SIZE + SEGMENT_OVERLAP, FILE_SIZE - position);
+ int limit = Math.min(SEGMENT_SIZE, size - 1);
+
+ MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, position, size);
+
+ if (position > 0) {
+ next(buffer);
+ }
+
+ for (int offset = buffer.position(); offset <= limit;) {
+ offset = parse(buffer, row, offset);
+ aggregates.add(row);
+ }
+ }
+
+ private static Aggregates merge(Aggregates lefts, Aggregates rights) {
+ if (rights.size() < lefts.size()) {
+ Aggregates temp = lefts;
+ lefts = rights;
+ rights = temp;
+ }
+
+ rights.visit(lefts::merge);
+ return lefts;
+ }
+
+ private static void next(ByteBuffer buffer) {
+ while (buffer.get() != '\n') {
+ // continue
+ }
+ }
+
+ private static int parse(ByteBuffer buffer, Row row, int offset) {
+ byte[] station = row.station;
+ int length = 0;
+ int hash = 0;
+
+ for (byte b; (b = buffer.get(offset++)) != ';';) {
+ station[length++] = b;
+ hash = 71 * hash + b;
+ }
+
+ row.length = length;
+ row.hash = hash;
+
+ int sign = 1;
+
+ if (buffer.get(offset) == '-') {
+ sign = -1;
+ offset++;
+ }
+
+ int value = buffer.get(offset++) - '0';
+
+ if (buffer.get(offset) != '.') {
+ value = 10 * value + buffer.get(offset++) - '0';
+ }
- return true;
+ value = 10 * value + buffer.get(offset + 1) - '0';
+ row.temperature = value * sign;
+ return offset + 3;
}
}
}