aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAntonio Muñoz <antoniogmc@gmail.com>2024-01-25 23:07:20 +0100
committerGitHub <noreply@github.com>2024-01-25 23:07:20 +0100
commit65d2c1b0c911579c0f04b9f996ed357f57ac10e3 (patch)
tree0379af6e1a775ff0edb8a5c8b04b5113b771e0b1 /src
parent0bd167557183922825a0a9ec3a1d347f4ba24b2f (diff)
tonivade improved solution (#582)
* tonivade improved not using HashMap * use java 21.0.2 * same hash same station * remove unused parameter in sameSation * use length too * refactor parallelization * use parallel GC * refactor * refactor
Diffstat (limited to 'src')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java274
1 files changed, 146 insertions, 128 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java b/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java
index bd28488..9deb3f2 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_tonivade.java
@@ -15,9 +15,6 @@
*/
package dev.morling.onebrc;
-import static java.util.Comparator.comparing;
-import static java.util.stream.Collectors.joining;
-
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
@@ -26,9 +23,8 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
import java.util.Map;
+import java.util.TreeMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.StructuredTaskScope;
import java.util.concurrent.StructuredTaskScope.Subtask;
@@ -37,32 +33,16 @@ public class CalculateAverage_tonivade {
private static final String FILE = "./measurements.txt";
- private static final int EOL = 10;
- private static final int MINUS = 45;
- private static final int SEMICOLON = 59;
+ private static final int MIN_CHUNK_SIZE = 1024;
+ private static final int MAX_NAME_LENGTH = 128;
+ private static final int MAX_TEMP_LENGTH = 8;
public static void main(String[] args) throws IOException, InterruptedException, ExecutionException {
- var result = readFile();
-
- var measurements = getMeasurements(result);
-
- System.out.println(measurements);
- }
-
- static record PartialResult(int end, Map<Name, Station> map) {
-
- void merge(Map<Name, Station> result) {
- map.forEach((name, station) -> result.merge(name, station, Station::merge));
- }
+ System.out.println(readFile());
}
- private static String getMeasurements(Map<Name, Station> result) {
- return result.values().stream().sorted(comparing(Station::getName))
- .map(Station::asString).collect(joining(", ", "{", "}"));
- }
-
- private static Map<Name, Station> readFile() throws IOException, InterruptedException, ExecutionException {
- Map<Name, Station> result = HashMap.newHashMap(10_000);
+ private static Map<String, Station> readFile() throws IOException, InterruptedException, ExecutionException {
+ Map<String, Station> result = new TreeMap<>();
try (var channel = FileChannel.open(Paths.get(FILE), StandardOpenOption.READ)) {
long consumed = 0;
long remaining = channel.size();
@@ -70,8 +50,11 @@ public class CalculateAverage_tonivade {
var buffer = channel.map(
MapMode.READ_ONLY, consumed, Math.min(remaining, Integer.MAX_VALUE));
- if (buffer.remaining() <= 1024) {
- var partialResult = readChunk(buffer, 0, buffer.remaining());
+ int chunks = Runtime.getRuntime().availableProcessors();
+ int chunkSize = buffer.remaining() / chunks;
+ int leftover = buffer.remaining() % chunks;
+ if (chunkSize < MIN_CHUNK_SIZE) {
+ var partialResult = new Chunk(buffer, 0, buffer.remaining()).read();
consumed += partialResult.end();
remaining -= partialResult.end();
@@ -79,17 +62,12 @@ public class CalculateAverage_tonivade {
partialResult.merge(result);
}
else {
- var chunks = Runtime.getRuntime().availableProcessors();
- var chunksSize = buffer.remaining() / chunks;
- var leftover = buffer.remaining() % chunks;
-
try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
var tasks = new ArrayList<Subtask<PartialResult>>(chunks);
for (int i = 0; i < chunks; i++) {
- int start = i * chunksSize;
- int length = chunksSize + (i < chunks ? leftover : 0);
- tasks.add(scope.fork(() -> readChunk(
- buffer, findStart(buffer, start), start + length)));
+ int start = i * chunkSize;
+ int length = chunkSize + (i < chunks ? leftover : 0);
+ tasks.add(scope.fork(new Chunk(buffer, start, length)::read));
}
scope.join();
scope.throwIfFailed();
@@ -106,132 +84,154 @@ public class CalculateAverage_tonivade {
return result;
}
- private static PartialResult readChunk(ByteBuffer buffer, int start, int end) {
- final byte[] name = new byte[128];
- final byte[] temp = new byte[8];
- final Map<Name, Station> map = HashMap.newHashMap(1000);
- int position = start;
- while (position < end) {
- int semicolon = readName(buffer, position, end - position, name);
- if (semicolon < 0) {
- break;
- }
+ static final class Chunk {
- int endOfLine = readTemp(buffer, semicolon + 1, end - semicolon - 1, temp);
- if (endOfLine < 0) {
- break;
- }
+ private static final int EOL = 10;
+ private static final int MINUS = 45;
+ private static final int SEMICOLON = 59;
- map.computeIfAbsent(new Name(name, semicolon - position), Station::new)
- .add(parseTemp(temp, endOfLine - semicolon - 1));
+ final ByteBuffer buffer;
+ final int start;
+ final int end;
- // skip end of line
- position = endOfLine + 1;
+ final byte[] name = new byte[MAX_NAME_LENGTH];
+ final byte[] temp = new byte[MAX_TEMP_LENGTH];
+ final Stations stations = new Stations();
+
+ int hash;
+
+ Chunk(ByteBuffer buffer, int start, int length) {
+ this.buffer = buffer;
+ this.start = findStart(buffer, start);
+ this.end = start + length;
}
- return new PartialResult(position, map);
- }
- private static int findStart(ByteBuffer buffer, int start) {
- if (start > 0 && buffer.get(start - 1) != EOL) {
- for (int i = start - 2; i > 0; i--) {
- byte b = buffer.get(i);
- if (b == EOL) {
- return i + 1;
+ private static int findStart(ByteBuffer buffer, int start) {
+ if (start > 0 && buffer.get(start - 1) != EOL) {
+ for (int i = start - 2; i > 0; i--) {
+ byte b = buffer.get(i);
+ if (b == EOL) {
+ return i + 1;
+ }
}
}
+ return start;
}
- return start;
- }
- private static int readName(ByteBuffer buffer, int offset, int length, byte[] name) {
- return readUntil(buffer, offset, length, name, SEMICOLON);
- }
+ PartialResult read() {
+ int position = start;
+ while (position < end) {
+ int semicolon = readName(position, end - position);
+ if (semicolon < 0) {
+ break;
+ }
- private static int readTemp(ByteBuffer buffer, int offset, int length, byte[] percentage) {
- return readUntil(buffer, offset, length, percentage, EOL);
- }
+ int endOfLine = readTemp(semicolon + 1, end - semicolon - 1);
+ if (endOfLine < 0) {
+ break;
+ }
+
+ stations.find(name, semicolon - position, hash)
+ .add(parseTemp(temp, endOfLine - semicolon - 1));
- private static int readUntil(ByteBuffer buffer, int offset, int length, byte[] array, int target) {
- for (int i = 0; i < length; i++) {
- byte b = buffer.get(i + offset);
- if (b == target) {
- return i + offset;
+ // skip end of line
+ position = endOfLine + 1;
}
- array[i] = b;
+ return new PartialResult(position, stations.buckets);
+ }
+
+ private int readName(int offset, int length) {
+ hash = 1;
+ for (int i = 0; i < length; i++) {
+ byte b = buffer.get(i + offset);
+ if (b == SEMICOLON) {
+ return i + offset;
+ }
+ name[i] = b;
+ hash = 31 * hash + b;
+ }
+ return -1;
+ }
+
+ private int readTemp(int offset, int length) {
+ for (int i = 0; i < length; i++) {
+ byte b = buffer.get(i + offset);
+ if (b == EOL) {
+ return i + offset;
+ }
+ temp[i] = b;
+ }
+ return -1;
}
- return -1;
- }
- // non null double between -99.9 (inclusive) and 99.9 (inclusive), always with one fractional digit
- private static int parseTemp(byte[] value, int length) {
- int period = length - 2;
- if (value[0] == MINUS) {
- int left = parseLeft(value, 1, period - 1);
+ // non null double between -99.9 (inclusive) and 99.9 (inclusive), always with one fractional digit
+ private static int parseTemp(byte[] value, int length) {
+ int period = length - 2;
+ if (value[0] == MINUS) {
+ int left = parseLeft(value, 1, period - 1);
+ int right = toInt(value[period + 1]);
+ return -(left + right);
+ }
+ int left = parseLeft(value, 0, period);
int right = toInt(value[period + 1]);
- return -(left + right);
+ return left + right;
}
- int left = parseLeft(value, 0, period);
- int right = toInt(value[period + 1]);
- return left + right;
- }
- private static int parseLeft(byte[] value, int start, int length) {
- if (length == 1) {
- return toInt(value[start]) * 10;
+ private static int parseLeft(byte[] value, int start, int length) {
+ if (length == 1) {
+ return toInt(value[start]) * 10;
+ }
+ // two chars
+ int a = toInt(value[start]) * 100;
+ int b = toInt(value[start + 1]) * 10;
+ return a + b;
}
- // two chars
- int a = toInt(value[start]) * 100;
- int b = toInt(value[start + 1]) * 10;
- return a + b;
- }
- private static int toInt(byte c) {
- return c - 48;
+ private static int toInt(byte c) {
+ return c - 48;
+ }
}
- static final class Name {
+ static final class Stations {
- private final byte[] value;
+ private static final int NUMBER_OF_BUCKETS = 1000;
+ private static final int BUCKET_SIZE = 50;
- Name(byte[] source, int length) {
- value = new byte[length];
- System.arraycopy(source, 0, value, 0, length);
- }
-
- @Override
- public int hashCode() {
- return Arrays.hashCode(value);
- }
+ final Station[][] buckets = new Station[NUMBER_OF_BUCKETS][BUCKET_SIZE];
- @Override
- public boolean equals(Object obj) {
- if (obj instanceof Name other) {
- return Arrays.equals(value, other.value);
+ Station find(byte[] name, int length, int hash) {
+ var bucket = buckets[Math.abs(hash % NUMBER_OF_BUCKETS)];
+ for (int i = 0; i < BUCKET_SIZE; i++) {
+ if (bucket[i] == null) {
+ bucket[i] = new Station(name, length, hash);
+ return bucket[i];
+ }
+ else if (bucket[i].sameName(length, hash)) {
+ return bucket[i];
+ }
}
- return false;
- }
-
- @Override
- public String toString() {
- return new String(value, StandardCharsets.UTF_8);
+ throw new IllegalStateException("no more space left");
}
}
static final class Station {
- private final Name name;
+ private final byte[] name;
+ private final int hash;
- private int min = Integer.MAX_VALUE;
- private int max = Integer.MIN_VALUE;
+ private int min = 1000;
+ private int max = -1000;
private int sum;
private long count;
- Station(Name name) {
- this.name = name;
+ Station(byte[] source, int length, int hash) {
+ name = new byte[length];
+ System.arraycopy(source, 0, name, 0, length);
+ this.hash = hash;
}
String getName() {
- return name.toString();
+ return new String(name, StandardCharsets.UTF_8);
}
void add(int value) {
@@ -249,8 +249,13 @@ public class CalculateAverage_tonivade {
return this;
}
- String asString() {
- return name + "=" + toDouble(min) + "/" + round(mean()) + "/" + toDouble(max);
+ @Override
+ public String toString() {
+ return toDouble(min) + "/" + round(mean()) + "/" + toDouble(max);
+ }
+
+ boolean sameName(int length, int hash) {
+ return name.length == length && this.hash == hash;
}
private double mean() {
@@ -265,4 +270,17 @@ public class CalculateAverage_tonivade {
return Math.round(value * 10.) / 10.;
}
}
+
+ static record PartialResult(int end, Station[][] stations) {
+
+ void merge(Map<String, Station> result) {
+ for (Station[] bucket : stations) {
+ for (Station station : bucket) {
+ if (station != null) {
+ result.merge(station.getName(), station, Station::merge);
+ }
+ }
+ }
+ }
+ }
}