aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarko Topolnik <marko.topolnik@gmail.com>2024-01-29 20:51:52 +0100
committerGitHub <noreply@github.com>2024-01-29 20:51:52 +0100
commit886f0cdb4df4fdbee7c99fb1fdd5c72d9608d635 (patch)
tree431af15eea528903d64cd7c0d6b73c9312958bb1
parent5ba094c8fded54677e787220e352a1baf74cacec (diff)
mtopolnik submission 3 (#637)
* calculate_average_mtopolnik * short hash (just first 8 bytes of name) * Remove unneeded checks * Remove archiving classes * 2x larger hashtable * Add "set" to setters * Simplify parsing temperature, remove newline search * Reduce the size of the name slot * Store name length and use to detect collision * Reduce memory loads in parseTemperature * Use short for min/max * Extract constant for semicolon * Fix script header * Explicit bash shell in shebang * Inline usage of broadcast semicolon * Try vectorization * Remove vectorization * Go Unsafe * Use SWAR temperature parsing by merykitty * Inline some things * Remove commented-out MemorySegment usage * Inline namesMem.asSlice() invocation * Try out JVM JIT flags * Implement strcmp * Remove unused instance variables * Optimize hashing * Put station name into hashtable * Reorder method * Remove usage of MemorySegment.getUtf8String Replace with UNSAFE.copyMemory() and new String() * Fix hashing bug * Remove outdated comments * Fix informative constants * Use broadcastByte() more * Improve method naming * More hashing * Revert more hashing * Add commented-out code to hash 16 bytes * Slight cleanup * Align hashtable at cacheline boundary * Add Graal Native image * Revert Graal Native image This reverts commit d916a42326d89bd1a841bbbecfae185adb8679d7. * Simplify shell script (no SDK selection) * Move a constant, zero out hashtable on start * Better name comparison * Add prepare_mtopolnik.sh * Cleaner idiom in name comparison * AND instead of MOD for hashtable indexing * Improve word masking code * Fix formatting * Reduce memory loads * Remove endianness checks * Avoid hash == 0 problem * Fix subtle bug * MergeSort of parellel results * Touch up perf * Touch up perf * Remove -Xmx256m * Extract result printing method * Print allocation details on OOME * Single mmap * Use global allocation arena * Add commented-out Xmx64m XXMaxDirectMemorySize=1g * withinSafeZone * Update cursor earlier * Better assert * Fix bug in addrOfSemicolonSafe * Move declaration lower * Simplify code * Add rounding error test case * Fix DANGER_ZONE_LEN * Deoptimize parseTemperatureSimple() * Inline parseTemperatureAndAdvanceCursor() * Skip masking until the last load * Conditionally fetch name words * Cleanup * Use native image * Use supbrocess * Simpler code * Cleanup * Avoid extra condition on hot path
-rwxr-xr-xcalculate_average_mtopolnik.sh10
-rwxr-xr-xprepare_mtopolnik.sh7
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java466
3 files changed, 222 insertions, 261 deletions
diff --git a/calculate_average_mtopolnik.sh b/calculate_average_mtopolnik.sh
index 24b5a1c..acd1024 100755
--- a/calculate_average_mtopolnik.sh
+++ b/calculate_average_mtopolnik.sh
@@ -15,5 +15,11 @@
# limitations under the License.
#
-java --enable-preview \
- --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik
+if [ -f target/CalculateAverage_mtopolnik_image ]; then
+ echo "Using native image 'target/CalculateAverage_mtopolnik_image'" 1>&2
+ target/CalculateAverage_mtopolnik_image
+else
+ JAVA_OPTS="--enable-preview"
+ echo "Native image not found, using JVM mode." 1>&2
+ java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_mtopolnik
+fi
diff --git a/prepare_mtopolnik.sh b/prepare_mtopolnik.sh
index f83a3ff..d84f20d 100755
--- a/prepare_mtopolnik.sh
+++ b/prepare_mtopolnik.sh
@@ -16,4 +16,9 @@
#
source "$HOME/.sdkman/bin/sdkman-init.sh"
-sdk use java 21.0.1-graal 1>&2
+sdk use java 21.0.2-graal 1>&2
+
+if [ ! -f target/CalculateAverage_mtopolnik_image ]; then
+ NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -H:+UnlockExperimentalVMOptions -H:-GenLoopSafepoints -march=native --enable-preview -H:InlineAllBonus=10 -H:-ParseRuntimeOptions --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_mtopolnik"
+ native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_mtopolnik_image dev.morling.onebrc.CalculateAverage_mtopolnik
+fi
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
index 51ea415..61294a4 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_mtopolnik.java
@@ -29,18 +29,15 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
+import static java.lang.ProcessBuilder.Redirect.PIPE;
+import static java.util.Arrays.asList;
+
public class CalculateAverage_mtopolnik {
private static final Unsafe UNSAFE = unsafe();
private static final int MAX_NAME_LEN = 100;
private static final int STATS_TABLE_SIZE = 1 << 16;
private static final int TABLE_INDEX_MASK = STATS_TABLE_SIZE - 1;
private static final String MEASUREMENTS_TXT = "measurements.txt";
- private static final byte SEMICOLON = ';';
- private static final long BROADCAST_SEMICOLON = broadcastByte(SEMICOLON);
-
- // These two are just informative, I let the IDE calculate them for me
- private static final long NATIVE_MEM_PER_THREAD = StatsAccessor.SIZEOF * STATS_TABLE_SIZE;
- private static final long NATIVE_MEM_ON_8_THREADS = 8 * NATIVE_MEM_PER_THREAD;
private static Unsafe unsafe() {
try {
@@ -53,31 +50,23 @@ public class CalculateAverage_mtopolnik {
}
}
- static class StationStats implements Comparable<StationStats> {
- String name;
- long sum;
- int count;
- int min;
- int max;
-
- @Override
- public String toString() {
- return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0);
- }
-
- @Override
- public boolean equals(Object that) {
- return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name);
- }
-
- @Override
- public int compareTo(StationStats that) {
- return name.compareTo(that.name);
- }
- }
-
public static void main(String[] args) throws Exception {
- calculate();
+ if (args.length >= 1 && args[0].equals("--worker")) {
+ calculate();
+ System.out.close();
+ return;
+ }
+ var curProcInfo = ProcessHandle.current().info();
+ var cmdLine = new ArrayList<String>();
+ cmdLine.add(curProcInfo.command().get());
+ cmdLine.addAll(asList(curProcInfo.arguments().get()));
+ cmdLine.add("--worker");
+ var process = new ProcessBuilder()
+ .command(cmdLine)
+ .inheritIO().redirectOutput(PIPE)
+ .start()
+ .getInputStream().transferTo(System.out);
+
}
static void calculate() throws Exception {
@@ -113,7 +102,6 @@ public class CalculateAverage_mtopolnik {
}
private static class ChunkProcessor implements Runnable {
- private static final long NAMEBUF_SIZE = 2 * Long.BYTES;
private static final int CACHELINE_SIZE = 64;
private final long inputBase;
@@ -122,8 +110,6 @@ public class CalculateAverage_mtopolnik {
private final int myIndex;
private StatsAccessor stats;
- private long nameBufBase;
- private long cursor;
ChunkProcessor(long chunkStart, long chunkLimit, StationStats[][] results, int myIndex) {
this.inputBase = chunkStart;
@@ -138,16 +124,12 @@ public class CalculateAverage_mtopolnik {
long totalAllocated = 0;
String threadName = Thread.currentThread().getName();
long statsByteSize = STATS_TABLE_SIZE * StatsAccessor.SIZEOF;
- var diagnosticString = String.format("Thread %s needs %,d bytes, managed to allocate before OOM: ",
- threadName, statsByteSize + NAMEBUF_SIZE);
+ var diagnosticString = String.format("Thread %s needs %,d bytes", threadName, statsByteSize);
try {
stats = new StatsAccessor(confinedArena.allocate(statsByteSize, CACHELINE_SIZE));
- totalAllocated = statsByteSize;
- nameBufBase = confinedArena.allocate(NAMEBUF_SIZE).address();
}
catch (OutOfMemoryError e) {
System.err.print(diagnosticString);
- System.err.println(totalAllocated);
throw e;
}
processChunk();
@@ -155,227 +137,110 @@ public class CalculateAverage_mtopolnik {
}
}
- private static final int MAX_TEMPERATURE_LEN = 5;
- private static final int MAX_ROW_LEN = MAX_NAME_LEN + 1 + MAX_TEMPERATURE_LEN + 1;
- private static final long DANGER_ZONE_LENGTH = ((MAX_ROW_LEN - 1) / 8 * 8 + 8);
-
private void processChunk() {
+ final long inputSize = this.inputSize;
+ final long inputBase = this.inputBase;
+ long cursor = 0;
+ long lastNameWord;
while (cursor < inputSize) {
- boolean withinSafeZone;
- long word1;
- long word2;
- long nameLen;
long nameStartAddress = inputBase + cursor;
- if (cursor + DANGER_ZONE_LENGTH <= inputSize) {
- withinSafeZone = true;
- word1 = UNSAFE.getLong(nameStartAddress);
- word2 = UNSAFE.getLong(nameStartAddress + Long.BYTES);
- nameLen = nameLen(word1, word2, withinSafeZone);
- word1 = maskWord(word1, nameLen);
- word2 = maskWord(word2, nameLen - Long.BYTES);
+ long nameWord0 = UNSAFE.getLong(nameStartAddress);
+ long nameWord1 = 0;
+ long matchBits = semicolonMatchBits(nameWord0);
+ long hash;
+ int nameLen;
+ int temperature;
+ if (matchBits != 0) {
+ nameLen = nameLen(matchBits);
+ nameWord0 = maskWord(nameWord0, matchBits);
+ cursor += nameLen;
+ long tempWord = UNSAFE.getLong(inputBase + cursor);
+ int dotPos = dotPos(tempWord);
+ temperature = parseTemperature(tempWord, dotPos);
+ cursor += (dotPos >> 3) + 3;
+ hash = hash(nameWord0);
+ if (stats.gotoName0(hash, nameWord0)) {
+ stats.observe(temperature);
+ continue;
+ }
+ lastNameWord = nameWord0;
}
- else {
- withinSafeZone = false;
- UNSAFE.putLong(nameBufBase, 0);
- UNSAFE.putLong(nameBufBase + Long.BYTES, 0);
- UNSAFE.copyMemory(nameStartAddress, nameBufBase, Long.min(NAMEBUF_SIZE, inputSize - cursor));
- word1 = UNSAFE.getLong(nameBufBase);
- word2 = UNSAFE.getLong(nameBufBase + Long.BYTES);
- nameLen = nameLen(word1, word2, withinSafeZone);
+ else { // nameLen > 8
+ hash = hash(nameWord0);
+ nameWord1 = UNSAFE.getLong(nameStartAddress + Long.BYTES);
+ matchBits = semicolonMatchBits(nameWord1);
+ if (matchBits != 0) {
+ nameLen = Long.BYTES + nameLen(matchBits);
+ nameWord1 = maskWord(nameWord1, matchBits);
+ cursor += nameLen;
+ long tempWord = UNSAFE.getLong(inputBase + cursor);
+ int dotPos = dotPos(tempWord);
+ temperature = parseTemperature(tempWord, dotPos);
+ cursor += (dotPos >> 3) + 3;
+ if (stats.gotoName1(hash, nameWord0, nameWord1)) {
+ stats.observe(temperature);
+ continue;
+ }
+ lastNameWord = nameWord1;
+ }
+ else { // nameLen > 16
+ nameLen = 2 * Long.BYTES;
+ while (true) {
+ lastNameWord = UNSAFE.getLong(nameStartAddress + nameLen);
+ matchBits = semicolonMatchBits(lastNameWord);
+ if (matchBits != 0) {
+ nameLen += nameLen(matchBits);
+ lastNameWord = maskWord(lastNameWord, matchBits);
+ cursor += nameLen;
+ long tempWord = UNSAFE.getLong(inputBase + cursor);
+ int dotPos = dotPos(tempWord);
+ temperature = parseTemperature(tempWord, dotPos);
+ cursor += (dotPos >> 3) + 3;
+ break;
+ }
+ nameLen += Long.BYTES;
+ }
+ }
}
- long hash = hash(word1);
- assert nameLen > 0 && nameLen <= 100 : nameLen;
- long tempStartAddress = nameStartAddress + nameLen + 1;
- int temperature = withinSafeZone
- ? parseTemperatureSwarAndAdvanceCursor(tempStartAddress)
- : parseTemperatureSimpleAndAdvanceCursor(tempStartAddress);
- updateStats(hash, nameStartAddress, nameLen, word1, word2, temperature, withinSafeZone);
+ stats.gotoAndObserve(hash, nameStartAddress, nameLen, nameWord0, nameWord1, lastNameWord, temperature);
}
}
- private void updateStats(
- long hash, long nameStartAddress, long nameLen, long nameWord1, long nameWord2,
- int temperature, boolean withinSafeZone) {
- int tableIndex = (int) (hash & TABLE_INDEX_MASK);
- while (true) {
- stats.gotoIndex(tableIndex);
- if (stats.hash() == hash && stats.nameLen() == nameLen && nameEquals(
- stats.nameAddress(), nameStartAddress, nameLen, nameWord1, nameWord2, withinSafeZone)) {
- stats.setSum(stats.sum() + temperature);
- stats.setCount(stats.count() + 1);
- stats.setMin((short) Integer.min(stats.min(), temperature));
- stats.setMax((short) Integer.max(stats.max(), temperature));
- return;
- }
- if (stats.nameLen() != 0) {
- tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK;
- continue;
- }
- stats.setHash(hash);
- stats.setNameLen((int) nameLen);
- stats.setSum(temperature);
- stats.setCount(1);
- stats.setMin((short) temperature);
- stats.setMax((short) temperature);
- UNSAFE.copyMemory(nameStartAddress, stats.nameAddress(), nameLen);
- return;
- }
- }
+ private static final long BROADCAST_SEMICOLON = 0x3B3B3B3B3B3B3B3BL;
+ private static final long BROADCAST_0x01 = 0x0101010101010101L;
+ private static final long BROADCAST_0x80 = 0x8080808080808080L;
- // Credit: merykitty
- private int parseTemperatureSwarAndAdvanceCursor(long tempStartAddress) {
- long word = UNSAFE.getLong(tempStartAddress);
- final long negated = ~word;
- final int dotPos = Long.numberOfTrailingZeros(negated & 0x10101000);
- cursor = (tempStartAddress + (dotPos / 8) + 3) - inputBase;
- final long signed = (negated << 59) >> 63;
- final long removeSignMask = ~(signed & 0xFF);
- final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
- final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
- return (int) ((absValue ^ signed) - signed);
+ private static long semicolonMatchBits(long word) {
+ long diff = word ^ BROADCAST_SEMICOLON;
+ return (diff - BROADCAST_0x01) & (~diff & BROADCAST_0x80);
}
- private int parseTemperatureSimpleAndAdvanceCursor(long tempStartAddress) {
- final byte minus = (byte) '-';
- final byte zero = (byte) '0';
- final byte dot = (byte) '.';
-
- byte ch = UNSAFE.getByte(tempStartAddress);
- long address = tempStartAddress;
- int temperature;
- int sign;
- if (ch == minus) {
- sign = -1;
- address++;
- ch = UNSAFE.getByte(address);
- }
- else {
- sign = 1;
- }
- temperature = ch - zero;
- address++;
- ch = UNSAFE.getByte(address);
- if (ch == dot) {
- address++;
- ch = UNSAFE.getByte(address);
- }
- else {
- temperature = 10 * temperature + (ch - zero);
- address += 2;
- ch = UNSAFE.getByte(address);
- }
- temperature = 10 * temperature + (ch - zero);
- // address - inputBase is the length of the temperature field.
- // A newline character follows the temperature, and so we advance
- // the cursor past the newline to the start of the next line.
- cursor = (address + 2) - inputBase;
- return sign * temperature;
- }
-
- private static long hash(long word1) {
- long seed = 0x51_7c_c1_b7_27_22_0a_95L;
- int rotDist = 17;
-
- long hash = word1;
- hash *= seed;
- hash = Long.rotateLeft(hash, rotDist);
- // hash ^= word2;
- // hash *= seed;
- // hash = Long.rotateLeft(hash, rotDist);
- return hash;
- }
-
- private static boolean nameEquals(long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2,
- boolean withinSafeZone) {
- boolean mismatch1 = maskWord(inputWord1, len) != UNSAFE.getLong(statsAddr);
- boolean mismatch2 = maskWord(inputWord2, len - Long.BYTES) != UNSAFE.getLong(statsAddr + Long.BYTES);
- if (len <= 2 * Long.BYTES) {
- return !(mismatch1 | mismatch2);
- }
- if (withinSafeZone) {
- int i = 2 * Long.BYTES;
- for (; i <= len - Long.BYTES; i += Long.BYTES) {
- if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) {
- return false;
- }
- }
- return maskWord(UNSAFE.getLong(inputAddr + i), len - i) == UNSAFE.getLong(statsAddr + i);
- }
- else {
- for (int i = 2 * Long.BYTES; i < len; i++) {
- if (UNSAFE.getByte(inputAddr + i) != UNSAFE.getByte(statsAddr + i)) {
- return false;
- }
- }
- }
- return true;
+ // credit: artsiomkorzun
+ private static long maskWord(long word, long matchBits) {
+ long mask = matchBits ^ (matchBits - 1);
+ return word & mask;
}
- private static long maskWord(long word, long len) {
- long halfShiftDistance = Long.max(0, Long.BYTES - len) << 2;
- long mask = (~0L >>> halfShiftDistance) >>> halfShiftDistance; // avoid Java trap of shiftDist % 64
- return word & mask;
+ // credit: merykitty
+ private static int dotPos(long word) {
+ return Long.numberOfTrailingZeros(~word & 0x10101000);
}
- private static final long BROADCAST_0x01 = broadcastByte(0x01);
- private static final long BROADCAST_0x80 = broadcastByte(0x80);
-
- // Adapted from https://jameshfisher.com/2017/01/24/bitwise-check-for-zero-byte/
- // and https://github.com/ashvardanian/StringZilla/blob/14e7a78edcc16b031c06b375aac1f66d8f19d45a/stringzilla/stringzilla.h#L139-L169
- long nameLen(long word1, long word2, boolean withinSafeZone) {
- {
- long matchBits1 = matchBits(word1);
- long matchBits2 = matchBits(word2);
- if ((matchBits1 | matchBits2) != 0) {
- int trailing1 = Long.numberOfTrailingZeros(matchBits1);
- int match1IsNonZero = trailing1 & 63;
- match1IsNonZero |= match1IsNonZero >>> 3;
- match1IsNonZero |= match1IsNonZero >>> 1;
- match1IsNonZero |= match1IsNonZero >>> 1;
- // Now match1IsNonZero is 1 if it's non-zero, else 0. Use it to
- // raise the lowest bit in trailing2 if trailing1 is nonzero. This forces
- // trailing2 to be zero if trailing1 is non-zero.
- int trailing2 = Long.numberOfTrailingZeros(matchBits2 | match1IsNonZero) & 63;
- // trailing1 | trailing2 works like trailing1 + trailing2 because if trailing2 is non-zero,
- // then trailing1 is 64, and since trailing2 is < 64, there's no bit overlap.
- return (trailing1 | trailing2) >> 3;
- }
- }
- long nameStartAddress = inputBase + cursor;
- long address = nameStartAddress + 2 * Long.BYTES;
- long limit = inputBase + inputSize;
- if (withinSafeZone) {
- for (; address < limit; address += Long.BYTES) {
- var block = maskWord(UNSAFE.getLong(address), limit - address);
- long matchBits = matchBits(block);
- if (matchBits != 0) {
- return address + (Long.numberOfTrailingZeros(matchBits) >> 3) - nameStartAddress;
- }
- }
- throw new RuntimeException("Semicolon not found");
- }
- return addrOfSemicolonSafe(address, limit) - nameStartAddress;
+ // credit: merykitty
+ private static int parseTemperature(long word, int dotPos) {
+ final long signed = (~word << 59) >> 63;
+ final long removeSignMask = ~(signed & 0xFF);
+ final long digits = ((word & removeSignMask) << (28 - dotPos)) & 0x0F000F0F00L;
+ final long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF;
+ return (int) ((absValue ^ signed) - signed);
}
- private static long addrOfSemicolonSafe(long address, long limit) {
- for (; address < limit - Long.BYTES + 1; address += Long.BYTES) {
- var block = UNSAFE.getLong(address);
- long matchBits = matchBits(block);
- if (matchBits != 0) {
- return address + (Long.numberOfTrailingZeros(matchBits) >> 3);
- }
- }
- for (; address < limit; address++) {
- if (UNSAFE.getByte(address) == SEMICOLON) {
- return address;
- }
- }
- throw new RuntimeException("Semicolon not found");
+ private static int nameLen(long separator) {
+ return (Long.numberOfTrailingZeros(separator) >>> 3) + 1;
}
- private static long matchBits(long word) {
- long diff = word ^ BROADCAST_SEMICOLON;
- return (diff - BROADCAST_0x01) & ~diff & BROADCAST_0x80;
+ private static long hash(long word) {
+ return Long.rotateLeft(word * 0x51_7c_c1_b7_27_22_0a_95L, 17);
}
// Copies the results from native memory to Java heap and puts them into the results array.
@@ -403,22 +268,6 @@ public class CalculateAverage_mtopolnik {
Arrays.sort(exported);
results[myIndex] = exported;
}
-
- private final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder());
-
- private String longToString(long word) {
- buf.clear();
- buf.putLong(word);
- return new String(buf.array(), StandardCharsets.UTF_8); // + "|" + Arrays.toString(buf.array());
- }
- }
-
- private static long broadcastByte(int b) {
- long nnnnnnnn = b;
- nnnnnnnn |= nnnnnnnn << 8;
- nnnnnnnn |= nnnnnnnn << 16;
- nnnnnnnn |= nnnnnnnn << 32;
- return nnnnnnnn;
}
static class StatsAccessor {
@@ -446,6 +295,16 @@ public class CalculateAverage_mtopolnik {
slotBase = address + index * SIZEOF;
}
+ private boolean gotoName0(long hash, long nameWord0) {
+ gotoIndex((int) (hash & TABLE_INDEX_MASK));
+ return hash() == hash && nameWord0() == nameWord0;
+ }
+
+ private boolean gotoName1(long hash, long nameWord0, long nameWord1) {
+ gotoIndex((int) (hash & TABLE_INDEX_MASK));
+ return hash() == hash && nameWord0() == nameWord0 && nameWord1() == nameWord1;
+ }
+
long hash() {
return UNSAFE.getLong(slotBase + HASH_OFFSET);
}
@@ -474,9 +333,17 @@ public class CalculateAverage_mtopolnik {
return slotBase + NAME_OFFSET;
}
+ long nameWord0() {
+ return UNSAFE.getLong(nameAddress());
+ }
+
+ long nameWord1() {
+ return UNSAFE.getLong(nameAddress() + Long.BYTES);
+ }
+
String exportNameString() {
- final var bytes = new byte[nameLen()];
- UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, nameLen());
+ final var bytes = new byte[nameLen() - 1];
+ UNSAFE.copyMemory(null, nameAddress(), bytes, ARRAY_BASE_OFFSET, bytes.length);
return new String(bytes, StandardCharsets.UTF_8);
}
@@ -503,6 +370,59 @@ public class CalculateAverage_mtopolnik {
void setMax(short max) {
UNSAFE.putShort(slotBase + MAX_OFFSET, max);
}
+
+ void gotoAndObserve(
+ long hash, long nameStartAddress, int nameLen, long nameWord0, long nameWord1, long lastNameWord,
+ int temperature) {
+ int tableIndex = (int) (hash & TABLE_INDEX_MASK);
+ while (true) {
+ gotoIndex(tableIndex);
+ if (hash() == hash && nameLen() == nameLen && nameEquals(
+ nameAddress(), nameStartAddress, nameLen, nameWord0, nameWord1, lastNameWord)) {
+ observe(temperature);
+ break;
+ }
+ if (nameLen() != 0) {
+ tableIndex = (tableIndex + 1) & TABLE_INDEX_MASK;
+ continue;
+ }
+ initialize(hash, nameLen, nameStartAddress, temperature);
+ break;
+ }
+ }
+
+ void initialize(long hash, long nameLen, long nameStartAddress, int temperature) {
+ setHash(hash);
+ setNameLen((int) nameLen);
+ setSum(temperature);
+ setCount(1);
+ setMin((short) temperature);
+ setMax((short) temperature);
+ UNSAFE.copyMemory(nameStartAddress, nameAddress(), nameLen);
+ }
+
+ void observe(int temperature) {
+ setSum(sum() + temperature);
+ setCount(count() + 1);
+ setMin((short) Integer.min(min(), temperature));
+ setMax((short) Integer.max(max(), temperature));
+ }
+
+ private static boolean nameEquals(
+ long statsAddr, long inputAddr, long len, long inputWord1, long inputWord2, long lastInputWord) {
+ boolean mismatch1 = inputWord1 != UNSAFE.getLong(statsAddr);
+ boolean mismatch2 = inputWord2 != UNSAFE.getLong(statsAddr + Long.BYTES);
+ if (len <= 2 * Long.BYTES) {
+ return !(mismatch1 | mismatch2);
+ }
+ int i = 2 * Long.BYTES;
+ for (; i <= len - Long.BYTES; i += Long.BYTES) {
+ if (UNSAFE.getLong(inputAddr + i) != UNSAFE.getLong(statsAddr + i)) {
+ return false;
+ }
+ }
+ return i == len || lastInputWord == UNSAFE.getLong(statsAddr + i);
+ }
}
private static void mergeSortAndPrint(StationStats[][] results) {
@@ -556,4 +476,34 @@ public class CalculateAverage_mtopolnik {
}
System.out.println('}');
}
+
+ static class StationStats implements Comparable<StationStats> {
+ String name;
+ long sum;
+ int count;
+ int min;
+ int max;
+
+ @Override
+ public String toString() {
+ return String.format("%s=%.1f/%.1f/%.1f", name, min / 10.0, Math.round((double) sum / count) / 10.0, max / 10.0);
+ }
+
+ @Override
+ public boolean equals(Object that) {
+ return that.getClass() == StationStats.class && ((StationStats) that).name.equals(this.name);
+ }
+
+ @Override
+ public int compareTo(StationStats that) {
+ return name.compareTo(that.name);
+ }
+ }
+
+ private static String longToString(long word) {
+ final ByteBuffer buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder());
+ buf.clear();
+ buf.putLong(word);
+ return new String(buf.array(), StandardCharsets.UTF_8);
+ }
}