From 22b960ada294a87d5cf60c2dc1b13585287cab73 Mon Sep 17 00:00:00 2001 From: Samuel Yvon Date: Fri, 12 Jan 2024 11:16:13 -0800 Subject: Graal Native for SamuelYvon (#332) * Graal Native * I need a GC :( * Fix slash lolz * Fix god damn output lol I forgot java :D * Custom hash, custom key * More optimisations * I don't need "optimize-build" I don't care about image size! :D --- .../onebrc/CalculateAverage_SamuelYvon.java | 342 +++++++++++++++++++++ .../onebrc/CalculateAverage_samuelyvon.java | 249 --------------- 2 files changed, 342 insertions(+), 249 deletions(-) create mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_SamuelYvon.java delete mode 100644 src/main/java/dev/morling/onebrc/CalculateAverage_samuelyvon.java (limited to 'src/main/java/dev/morling') diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_SamuelYvon.java b/src/main/java/dev/morling/onebrc/CalculateAverage_SamuelYvon.java new file mode 100644 index 0000000..45ffd0d --- /dev/null +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_SamuelYvon.java @@ -0,0 +1,342 @@ +/* + * Copyright 2023 The original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package dev.morling.onebrc; + +//import jdk.incubator.vector.ByteVector; + +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.stream.Collectors; + +/** + * Samuel Yvon's entry. + *

+ * Explanation behind my reasoning: + * - I want to make it as fast as possible without it being an unreadable mess; I want to avoid bit fiddling UTF-8 + * and use the provided facilities + * - I use the fact that we know the number of stations to optimize HashMap creation (75% rule) + * - I stole branch-less compare from royvanrijn + * - I assume valid ASCII encoding for the number part, which allows me to parse it manually + * (should hold for valid UTF-8) + * - I have not done Java in forever. Especially what the heck it's become. I've looked at the other submissions and + * the given sample to get inspiration. I did not even know about this Stream API thing. + *

+ * + *

+ * Future ideas: + * - Probably can Vector-Apirize the number parsing (but it's three to four numbers, is it worth?) + *

+ * + *

+ * Observations: + * - [2024-01-09] The branch-less code from royvarijn does not have a huge impact + *

+ * + *

+ * Changelogs: + * 2024-01-09: Naive multi-threaded, no floats, manual line parsing + *

+ */ +public class CalculateAverage_SamuelYvon { + + private static final String FILE = "./measurements.txt"; + + private static final int MAX_STATIONS = 10000; + + private static final byte SEMICOL = 0x3B; + + private static final byte DOT = '.'; + + private static final byte MINUS = '-'; + + private static final byte ZERO = '0'; + + private static final String SLASH_S = "/"; + + private static final byte NEWLINE = '\n'; + + // The minimum line length in bytes (over-egg.) + private static final int MIN_LINE_LENGTH_BYTES = 200; + + private static final int DJB2_INIT = 5381; + + /** + * Branchless min (unprecise for large numbers, but good enough) + * + * @author royvanrijn + */ + private static int branchlessMax(final int a, final int b) { + final int diff = a - b; + final int dsgn = diff >> 31; + return a - (diff & dsgn); + } + + /** + * Branchless min (unprecise for large numbers, but good enough) + * + * @author royvanrijn + */ + private static int branchlessMin(final int a, final int b) { + final int diff = a - b; + final int dsgn = diff >> 31; + return b + (diff & dsgn); + } + + /** + * A magic key that contains references to the String and where it's located in memory + */ + private static final class StationName { + private final int hash; + + private final byte[] value; + + private StationName(int hash, MappedByteBuffer backing, int pos, int len) { + this.hash = hash; + this.value = new byte[len]; + backing.get(pos, this.value); + } + + @Override + public int hashCode() { + return hash; + } + + @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") + @Override + public boolean equals(Object obj) { + // Should NEVER be true + // if (!(obj instanceof StationName)) { + // return false; + // } + StationName other = (StationName) obj; + + if (this.value.length != other.value.length) { + return false; + } + + // Byte for byte compare. This actually is a bug! I'm assuming the input + // is UTF-8 normalized, which in real life would probably not be the case. + // TODO: SIMD? + return Arrays.equals(this.value, other.value); + } + + @Override + public String toString() { + return new String(this.value, StandardCharsets.UTF_8); + } + } + + private static class StationMeasureAgg { + private int min; + private int max; + private long sum; + private long count; + + private final StationName station; + + private String memoizedName; + + public StationMeasureAgg(StationName name) { + // Actual numbers are between -99.9 and 99.9, but we *10 to avoid float + this.station = name; + min = 1000; + max = -1000; + sum = 0; + count = 0; + } + + @Override + public int hashCode() { + return this.station.hash; + } + + /** + * Get the city name, but also memoized it to avoid building it multiple times + * + * @return the city name + */ + public String city() { + if (null == this.memoizedName) { + this.memoizedName = station.toString(); + } + return this.memoizedName; + } + + public StationMeasureAgg mergeWith(StationMeasureAgg other) { + min = branchlessMin(min, other.min); + max = branchlessMax(max, other.max); + + sum += other.sum; + count += other.count; + + return this; + } + + public void accumulate(int number) { + min = branchlessMin(min, number); + max = branchlessMax(max, number); + + sum += number; + count++; + } + + @Override + public String toString() { + double min = Math.round((double) this.min) / 10.0; + double max = Math.round((double) this.max) / 10.0; + double mean = Math.round((((double) this.sum / this.count))) / 10.0; + return min + SLASH_S + mean + SLASH_S + max; + } + + public StationName station() { + return this.station; + } + } + + private static HashMap parseChunk(MappedByteBuffer chunk) { + HashMap m = HashMap.newHashMap(MAX_STATIONS); + + int i = 0; + while (i < chunk.limit()) { + int j = i; + + int hash = DJB2_INIT; + + // Implement a version of djb2 while we read until the semi, we read anyways + for (; j < chunk.limit(); ++j) { + byte b = chunk.get(j); + + if (b == SEMICOL) { + break; + } + + // How will this behave in java? Why do I get no control OF SIGNEDNESS FFS + // Can I assume int is 32 bits? What is this language! In Java 1.7 it could be 16 bits? + // Apparently not anymore?? + hash = (((hash << 5) + hash) + b); + } + + StationName name = new StationName(hash, chunk, i, j - i); + + // Skip the `;` + j++; + + // Parse the int ourselves, avoids a 'String::replace' to remove the + // digit. + int temp = 0; + boolean neg = chunk.get(j) == MINUS; + for (j = j + (neg ? 1 : 0); j < chunk.limit(); ++j) { + temp *= 10; + byte c = chunk.get(j); + if (c != DOT) { + temp += (char) (c - ZERO); + } + else { + j++; + break; + } + } + + // The decimal point + temp += (char) (chunk.get(j) - ZERO); + + i = j + 1; + + while (chunk.get(i++) != NEWLINE) + ; + + if (neg) + temp = -temp; + + m.computeIfAbsent(name, StationMeasureAgg::new).accumulate(temp); + } + + return m; + } + + private static int approximateChunks() { + // https://stackoverflow.com/a/4759606 + // I don't remember Java :D + return Runtime.getRuntime().availableProcessors(); + } + + private static List getFileChunks() throws IOException { + int approxChunkCount = approximateChunks(); + + List fileChunks = new ArrayList<>(approxChunkCount * 2); + + // Compute chunks offsets and lengths + try (RandomAccessFile file = new RandomAccessFile(FILE, "r")) { + long totalOffset = 0; + long fLength = file.length(); + long approximateLength = Long.max(fLength / approxChunkCount, MIN_LINE_LENGTH_BYTES); + + while (totalOffset < fLength) { + long offset = totalOffset; + int length = (int) approximateLength; + + boolean eof = offset + length >= fLength; + if (eof) { + length = (int) (fLength - offset); + } + + MappedByteBuffer out = file.getChannel().map(FileChannel.MapMode.READ_ONLY, totalOffset, length); + + while (out.get(length - 1) != NEWLINE) { + length--; + } + + out.position(0); + out.limit(length); + + fileChunks.add(out); + + totalOffset += length; + } + } + + return fileChunks; + } + + public static void main(String[] args) throws IOException { + var fileChunks = getFileChunks(); + + // Map per core, giving the non-overlapping memory slices + final Map allMeasures = fileChunks.parallelStream().map(CalculateAverage_SamuelYvon::parseChunk).flatMap(x -> x.values().stream()) + .collect(Collectors.toMap(StationMeasureAgg::station, x -> x, StationMeasureAgg::mergeWith, HashMap::new)); + + // Give a capacity that should never be gone past + StringBuilder sb = new StringBuilder(allMeasures.size() * (100 + 4 + 20)); + sb.append('{'); + + allMeasures.values().stream().sorted(Comparator.comparing(StationMeasureAgg::city)).forEach(x -> { + sb.append(x.city()); + sb.append('='); + sb.append(x); + sb.append(", "); + }); + + sb.delete(sb.length() - 2, sb.length()); // delete trailing comma and space + + sb.append('}'); + + System.out.println(sb); + } +} diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_samuelyvon.java b/src/main/java/dev/morling/onebrc/CalculateAverage_samuelyvon.java deleted file mode 100644 index 10a2179..0000000 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_samuelyvon.java +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Copyright 2023 The original authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package dev.morling.onebrc; - -import java.io.IOException; -import java.io.RandomAccessFile; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.charset.StandardCharsets; -import java.util.*; -import java.util.stream.Collectors; - -/** - * Samuel Yvon's entry. - *

- * Explanation behind my reasoning: - * - I want to make it as fast as possible without it being an unreadable mess; I want to avoid bit fiddling UTF-8 - * and use the provided facilities - * - I use the fact that we know the number of stations to optimize HashMap creation (75% rule) - * - I stole branch-less compare from royvanrijn - * - I assume valid ASCII encoding for the number part, which allows me to parse it manually - * (should hold for valid UTF-8) - * - I have not done Java in forever. Especially what the heck it's become. I've looked at the other submissions and - * the given sample to get inspiration. I did not even know about this Stream API thing. - *

- * - *

- * Future ideas: - * - Probably can Vector-Apirize the number parsing (but it's three to four numbers, is it worth?) - *

- * - *

- * Observations: - * - [2024-01-09] The branch-less code from royvarijn does not have a huge impact - *

- * - *

- * Changelogs: - * 2024-01-09: Naive multi-threaded, no floats, manual line parsing - *

- */ -public class CalculateAverage_samuelyvon { - - private static final String FILE = "./measurements.txt"; - - private static final int MAX_STATIONS = 10000; - - private static final byte SEMICOL = 0x3B; - - private static final byte MINUS = '-'; - - private static final byte ZERO = '0'; - - private static final byte NEWLINE = '\n'; - - // The minimum line length in bytes (over-egg.) - private static final int MIN_LINE_LENGTH_BYTES = 200; - - /** - * Branchless min (unprecise for large numbers, but good enough) - * - * @author royvanrijn - */ - private static int branchlessMax(final int a, final int b) { - final int diff = a - b; - final int dsgn = diff >> 31; - return a - (diff & dsgn); - } - - /** - * Branchless min (unprecise for large numbers, but good enough) - * - * @author royvanrijn - */ - private static int branchlessMin(final int a, final int b) { - final int diff = a - b; - final int dsgn = diff >> 31; - return b + (diff & dsgn); - } - - private static class StationMeasureAgg { - private int min; - private int max; - private long sum; - private long count; - - private final String city; - - public StationMeasureAgg(String city) { - // Actual numbers are between -99.9 and 99.9, but we *10 to avoid float - this.city = city; - min = 1000; - max = -1000; - sum = 0; - count = 0; - } - - public String city() { - return this.city; - } - - public StationMeasureAgg mergeWith(StationMeasureAgg other) { - min = branchlessMin(min, other.min); - max = branchlessMax(max, other.max); - - sum += other.sum; - count += other.count; - - return this; - } - - public void accumulate(int number) { - min = branchlessMin(min, number); - max = branchlessMax(max, number); - - sum += number; - count++; - } - - @Override - public String toString() { - double min = Math.round((double) this.min) / 10.0; - double max = Math.round((double) this.max) / 10.0; - double mean = Math.round((((double) this.sum / this.count))) / 10.0; - return min + "/" + mean + "/" + max; - } - } - - private static HashMap parseChunk(MappedByteBuffer chunk) { - HashMap m = HashMap.newHashMap(MAX_STATIONS); - - int i = 0; - while (i < chunk.limit()) { - int j = i; - for (; j < chunk.limit(); ++j) { - // TODO: Could compute a hash here, store a byte array, and decode UTF-8 at the - // latest moment. - if (chunk.get(j) == SEMICOL) { - break; - } - } - - byte[] backingNameArray = new byte[j - i]; - chunk.get(i, backingNameArray); - String name = new String(backingNameArray, StandardCharsets.UTF_8); - - // Skip the `;` - j++; - - // Parse the int ourselves, avoids a 'String::replace' to remove the - // digit. - int temp = 0; - boolean neg = chunk.get(j) == MINUS; - for (j = j + (neg ? 1 : 0); j < chunk.limit(); ++j) { - temp *= 10; - byte c = chunk.get(j); - if (c != '.') { - temp += (char) (c - ZERO); - } - else { - j++; - break; - } - } - - // The decimal point - temp += (char) (chunk.get(j) - ZERO); - - i = j + 1; - - while (chunk.get(i++) != '\n') - ; - - if (neg) - temp = -temp; - - m.computeIfAbsent(name, StationMeasureAgg::new).accumulate(temp); - } - - return m; - } - - private static int approximateChunks() { - // https://stackoverflow.com/a/4759606 - // I don't remember Java :D - return Runtime.getRuntime().availableProcessors(); - } - - private static List getFileChunks() throws IOException { - int approxChunkCount = approximateChunks(); - - List fileChunks = new ArrayList<>(approxChunkCount * 2); - - // Compute chunks offsets and lengths - try (RandomAccessFile file = new RandomAccessFile(FILE, "r")) { - long totalOffset = 0; - long fLength = file.length(); - long approximateLength = Long.max(fLength / approxChunkCount, MIN_LINE_LENGTH_BYTES); - - while (totalOffset < fLength) { - long offset = totalOffset; - int length = (int) approximateLength; - - boolean eof = offset + length >= fLength; - if (eof) { - length = (int) (fLength - offset); - } - - MappedByteBuffer out = file.getChannel().map(FileChannel.MapMode.READ_ONLY, totalOffset, length); - - while (out.get(length - 1) != NEWLINE) { - length--; - } - - out.position(0); - out.limit(length); - - fileChunks.add(out); - - totalOffset += length; - } - } - - return fileChunks; - } - - public static void main(String[] args) throws IOException { - var fileChunks = getFileChunks(); - - // Map per core, giving the non-overlapping memory slices - final Map sortedMeasures = fileChunks.parallelStream().map(CalculateAverage_samuelyvon::parseChunk) - .flatMap(x -> x.values().stream()).collect(Collectors.toMap(StationMeasureAgg::city, x -> x, StationMeasureAgg::mergeWith, TreeMap::new)); - - System.out.println(sortedMeasures); - } -} -- cgit v1.2.3