aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
authorAnita SV <anitasvasu@gmail.com>2024-01-31 00:41:33 -0800
committerGitHub <noreply@github.com>2024-01-31 09:41:33 +0100
commitaf2b5517c894347d42e8382b4b7559bdd9a7d337 (patch)
treee8270199360038b7a64f917c1333d16d5774f083 /src/main/java/dev/morling/onebrc
parent974ddbae606c412b4d35ada360879d96e93d00b1 (diff)
anitasv 3.8s vs 3m 19s : Improved using custom hashmap. (#672)
* Some optimizations while staying safe * bug fix not caught on tests
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java163
1 files changed, 118 insertions, 45 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java b/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java
index c15250d..7d3d6af 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_anitasv.java
@@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.*;
-import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class CalculateAverage_anitasv {
@@ -44,14 +43,14 @@ public class CalculateAverage_anitasv {
.asByteBuffer();
while (buf.hasRemaining()) {
if (buf.get() == ch) {
- return position + buf.position() - 1;
+ return position + (buf.position() - 1);
}
}
return -1;
}
- byte[] getRange(long start, long end) {
- return mmapMemory.asSlice(start, end - start).toArray(ValueLayout.JAVA_BYTE);
+ MemorySegment getRange(long start, long end) {
+ return mmapMemory.asSlice(start, end - start);
}
int parseDouble(long start, long end) {
@@ -86,22 +85,122 @@ public class CalculateAverage_anitasv {
return buf2.hashCode();
}
- public boolean matches(byte[] existingStation, long start, long end) {
- ByteBuffer buf1 = ByteBuffer.wrap(existingStation);
- ByteBuffer buf2 = mmapMemory.asSlice(start, end - start).asByteBuffer();
- return buf1.equals(buf2);
+ public long truncate(long index) {
+ return Math.min(index, mmapMemory.byteSize());
+ }
+
+ public long getLong(long position) {
+ return mmapMemory.get(ValueLayout.JAVA_LONG_UNALIGNED, position);
}
}
- private record ResultRow(byte[] station, IntSummaryStatistics statistics) {
+ private record ResultRow(IntSummaryStatistics statistics, int keyLength, int next) {
+ }
+
+ private static class FastHashMap {
+ private final byte[] keys;
+ private final ResultRow[] values;
+
+ private final int capacityMinusOne;
- public String toString() {
- return STR."\{new String(station, StandardCharsets.UTF_8)} : \{statToString(statistics)}";
+ private final MemorySegment keySegment;
+
+ private int next = -1;
+
+ private FastHashMap(int capacity) {
+ this.capacityMinusOne = capacity - 1;
+ this.keys = new byte[capacity << 7];
+ this.keySegment = MemorySegment.ofArray(keys);
+ this.values = new ResultRow[capacity];
}
+
+ IntSummaryStatistics find(int hash, Shard shard, long stationStart, long stationEnd) {
+ int initialIndex = hash & capacityMinusOne;
+ int lookupLength = (int) (stationEnd - stationStart);
+ int lookupAligned = ((lookupLength + 7) & (-8));
+ int i = initialIndex;
+
+ lookupAligned = (int) (shard.truncate(stationStart + lookupAligned) - stationStart) - 7;
+
+ do {
+ int keyIndex = i << 7;
+
+ if (keys[keyIndex] != 0 && keys[keyIndex + lookupLength] == 0) {
+
+ int mismatch = -1, j;
+ for (j = 0; j < lookupAligned; j += 8) {
+ long entryLong = keySegment.get(ValueLayout.JAVA_LONG_UNALIGNED, keyIndex + j);
+ long lookupLong = shard.getLong(stationStart + j);
+ if (entryLong != lookupLong) {
+ int diff = Long.numberOfTrailingZeros(entryLong ^ lookupLong);
+ mismatch = j + (diff >> 3);
+ break;
+ }
+ }
+ if (mismatch == -1) {
+ for (; j < lookupLength; j++) {
+ byte entryByte = keys[keyIndex + j];
+ byte lookupByte = shard.getByte(stationStart + j);
+ if (entryByte != lookupByte) {
+ mismatch = j;
+ break;
+ }
+ }
+ }
+ if (mismatch == -1 || mismatch >= lookupLength) {
+ return this.values[i].statistics;
+ }
+ }
+ if (keys[keyIndex] == 0) {
+ MemorySegment fullLookup = shard.getRange(stationStart, stationEnd);
+
+ keySegment.asSlice(keyIndex, lookupLength)
+ .copyFrom(fullLookup);
+
+ keys[keyIndex + lookupLength] = 0;
+ IntSummaryStatistics stats = new IntSummaryStatistics();
+ ResultRow resultRow = new ResultRow(stats, lookupLength, this.next);
+ this.next = i;
+ this.values[i] = resultRow;
+ return stats;
+ }
+
+ if (i == capacityMinusOne) {
+ i = 0;
+ }
+ else {
+ i++;
+ }
+ } while (i != initialIndex);
+ throw new IllegalStateException("Hash size too small");
+ }
+
+ Iterable<Map.Entry<String, IntSummaryStatistics>> values() {
+ return () -> new Iterator<>() {
+
+ int scan = FastHashMap.this.next;
+
+ @Override
+ public boolean hasNext() {
+ return scan != -1;
+ }
+
+ @Override
+ public Map.Entry<String, IntSummaryStatistics> next() {
+ ResultRow resultRow = values[scan];
+ IntSummaryStatistics stats = resultRow.statistics;
+ String key = new String(keys, scan << 7, resultRow.keyLength,
+ StandardCharsets.UTF_8);
+ scan = resultRow.next;
+ return new AbstractMap.SimpleEntry<>(key, stats);
+ }
+ };
+ }
+
}
- private static Map<String, IntSummaryStatistics> process(Shard shard) {
- HashMap<Integer, List<ResultRow>> result = new HashMap<>(1 << 14);
+ private static Iterable<Map.Entry<String, IntSummaryStatistics>> process(Shard shard) {
+ FastHashMap result = new FastHashMap(1 << 14);
boolean skip = shard.chunkStart != 0;
for (long position = shard.chunkStart; position < shard.chunkEnd; position++) {
@@ -116,45 +215,19 @@ public class CalculateAverage_anitasv {
long temperatureEnd = shard.indexOf(stationEnd + 1, (byte) '\n');
int temperature = shard.parseDouble(stationEnd + 1, temperatureEnd);
- List<ResultRow> collisions = result.get(hash);
- if (collisions == null) {
- collisions = new ArrayList<>();
- result.put(hash, collisions);
- }
-
- boolean found = false;
- for (ResultRow existing : collisions) {
- byte[] existingStation = existing.station();
- if (shard.matches(existingStation, position, stationEnd)) {
- existing.statistics.accept(temperature);
- found = true;
- break;
- }
- }
- if (!found) {
- IntSummaryStatistics stats = new IntSummaryStatistics();
- stats.accept(temperature);
- ResultRow rr = new ResultRow(shard.getRange(position, stationEnd), stats);
- collisions.add(rr);
- }
+ IntSummaryStatistics stats = result.find(hash, shard, position, stationEnd);
+ stats.accept(temperature);
position = temperatureEnd;
}
}
- return result.values()
- .stream()
- .flatMap(Collection::stream)
- .map(rr -> new AbstractMap.SimpleImmutableEntry<>(
- new String(rr.station, StandardCharsets.UTF_8),
- rr.statistics))
- .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+ return result.values();
}
- private static Map<String, IntSummaryStatistics> combineResults(List<Map<String, IntSummaryStatistics>> list) {
-
+ private static Map<String, IntSummaryStatistics> combineResults(List<Iterable<Map.Entry<String, IntSummaryStatistics>>> list) {
Map<String, IntSummaryStatistics> output = HashMap.newHashMap(1024);
- for (Map<String, IntSummaryStatistics> map : list) {
- for (Map.Entry<String, IntSummaryStatistics> entry : map.entrySet()) {
+ for (Iterable<Map.Entry<String, IntSummaryStatistics>> map : list) {
+ for (Map.Entry<String, IntSummaryStatistics> entry : map) {
output.compute(entry.getKey(), (ignore, val) -> {
if (val == null) {
return entry.getValue();