diff options
| author | Dimitar Dimitrov <dimitar.dimitrov@gmail.com> | 2024-01-07 03:24:48 +0900 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-01-06 19:24:48 +0100 |
| commit | 14918bb306784daf78988cd252c4bf451130b695 (patch) | |
| tree | b6124a1afe13d711b20ae0f6d4f847595d868df0 /src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java | |
| parent | e8b2d2d7b4caf114b12fa0386e35f741f31df905 (diff) | |
ddimtirov - supporting hash collisions, should have fixed #101
* ddimtirov - supporting hash collisions, should have fixed #101
* Make life easier for Windows user who need to use WSL to run the tests
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java')
| -rw-r--r-- | src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java | 278 |
1 files changed, 177 insertions, 101 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java index 7f5ac50..7602d43 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ddimtirov.java @@ -26,165 +26,245 @@ import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.time.Duration; import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.TreeMap; - -// gunnar morling - 2:10 -// roy van rijn - 1:01 -// 0:37 +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.LongAdder; public class CalculateAverage_ddimtirov { private static final String FILE = "./measurements.txt"; + private static final int MAX_STATIONS = 100_000; + private static final int MAX_STATION_NAME_LENGTH = 100; + private static final int MAX_UTF8_CODEPOINT_SIZE = 3; - private static final int HASH_NO_CLASH_MODULUS = 49999; private static final int OFFSET_MIN = 0; private static final int OFFSET_MAX = 1; private static final int OFFSET_COUNT = 2; + private static final boolean assertions = CalculateAverage_ddimtirov.class.desiredAssertionStatus(); + private static final Map<String, LongAdder> hashCollisionOccurrences = new ConcurrentHashMap<>(); + @SuppressWarnings("RedundantSuppression") - public static void main(String[] args) throws IOException { - var path = Path.of(FILE); - var start = Instant.now(); - var desiredSegmentsCount = Runtime.getRuntime().availableProcessors(); + public static void main(String[] args) throws IOException, InterruptedException { + var path = Path.of(args.length>0 ? args[0] : FILE); + Instant start = null;// Instant.now(); + var desiredSegmentsCount = Runtime.getRuntime().availableProcessors(); var fileSegments = FileSegment.forFile(path, desiredSegmentsCount); - var trackers = fileSegments.stream().parallel().map(fileSegment -> { - try (var fileChannel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) { - var tracker = new Tracker(); - var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, fileSegment.start(), fileSegment.size(), Arena.ofConfined()); - tracker.processSegment(memorySegment); - return tracker; - } - catch (IOException e) { - throw new RuntimeException(e); - } - }).toList(); + var loaders = new ThreadGroup("Loaders"); + var trackers = Collections.synchronizedList(new ArrayList<Tracker>()); + var threads = fileSegments.stream().map(fileSegment -> Thread // manually start thread per segment + .ofPlatform() + .group(loaders) + .name(STR."Segment \{fileSegment}") + .start(() -> { + try (var fileChannel = (FileChannel) Files.newByteChannel(path, StandardOpenOption.READ)) { + var tracker = new Tracker(); + var memorySegment = fileChannel.map(FileChannel.MapMode.READ_ONLY, fileSegment.start(), fileSegment.size(), Arena.ofConfined()); + tracker.processSegment(memorySegment); + trackers.add(tracker); + } + catch (IOException e) { + throw new RuntimeException(e); + } + }) + ).toList(); + + for (Thread thread : threads) thread.join(); + assert trackers.size() == threads.size(); + assert trackers.size() <= desiredSegmentsCount; - var result = summarizeTrackers(trackers); + var result = summarizeTrackers(trackers.toArray(Tracker[]::new)); System.out.println(result); // noinspection ConstantValue - if (start != null) + if (start != null) { System.err.println(Duration.between(start, Instant.now())); - assert Files.readAllLines(Path.of("measurements_result.txt")).getFirst().equals(result); + if (assertions) System.err.printf("hash clashes: %s%n", hashCollisionOccurrences); + } + assert Files.readAllLines(Path.of("measurements.out")).getFirst().equals(result); } - record FileSegment(long start, long size) { + record FileSegment(int index, long start, long size) { + @Override + public String toString() { + return STR."#\{index} [\{start}..\{start + size}] \{size} bytes"; + } + public static List<FileSegment> forFile(Path file, int desiredSegmentsCount) throws IOException { try (var raf = new RandomAccessFile(file.toFile(), "r")) { var segments = new ArrayList<FileSegment>(); var fileSize = raf.length(); - var segmentSize = fileSize / desiredSegmentsCount; - for (int segmentIdx = 0; segmentIdx < desiredSegmentsCount; segmentIdx++) { - var segStart = segmentIdx * segmentSize; - var segEnd = (segmentIdx == desiredSegmentsCount - 1) ? fileSize : segStart + segmentSize; - segStart = findSegmentBoundary(raf, segmentIdx, 0, segStart, segEnd); - segEnd = findSegmentBoundary(raf, segmentIdx, desiredSegmentsCount - 1, segEnd, fileSize); + var segmentSize = Math.max(1024 * 1024, fileSize / desiredSegmentsCount); - var segSize = segEnd - segStart; - - segments.add(new FileSegment(segStart, segSize)); + var i = 1; + var prevEnd = 0L; + while (prevEnd < fileSize-1) { + var start = prevEnd; + var end = findNewLineAfter(raf, prevEnd + segmentSize, fileSize); + segments.add(new FileSegment(i, start, end - start)); + prevEnd = end; } return segments; } } - private static long findSegmentBoundary(RandomAccessFile raf, int i, int skipForSegment, long location, long fileSize) throws IOException { - if (i == skipForSegment) return location; - + private static long findNewLineAfter(RandomAccessFile raf, long location, long fileSize) throws IOException { raf.seek(location); while (location < fileSize) { location++; - if (raf.read() == '\n') break; + int c = raf.read(); + if (c == '\r' || c == '\n') break; } - return location; + return Math.min(location, fileSize - 1); + } + } + + static class Accumulator { + public final String name; + public int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE, count; + public long sum; + + public Accumulator(String name) { + this.name = name; + } + + public void accumulate(int min, int max, int count, long sum) { + if (this.min > min) + this.min = min; + if (this.max < max) + this.max = max; + this.count += count; + this.sum += sum; + } + + @Override + public String toString() { + var mean = Math.round((double) sum / count) / 10.0; + return (min / 10.0) + "/" + mean + "/" + (max / 10.0); } } - private static String summarizeTrackers(List<Tracker> trackers) { - var result = new TreeMap<String, String>(); - for (var i = 0; i < HASH_NO_CLASH_MODULUS; i++) { - String name = null; + private static String summarizeTrackers(Tracker[] trackers) { + var result = new TreeMap<String, Accumulator>(); + + for (var i = 0; i < Tracker.SIZE; i++) { + Accumulator acc = null; - var min = Integer.MAX_VALUE; - var max = Integer.MIN_VALUE; - var sum = 0L; - var count = 0L; for (Tracker tracker : trackers) { - if (tracker.names[i] == null) + var name = tracker.names[i]; + if (name == null) { continue; - if (name == null) - name = tracker.names[i]; - - var minn = tracker.minMaxCount[i * 3]; - var maxx = tracker.minMaxCount[i * 3 + 1]; - if (minn < min) - min = minn; - if (maxx > max) - max = maxx; - count += tracker.minMaxCount[i * 3 + 2]; - sum += tracker.sums[i]; + } + else if (acc == null || !name.equals(acc.name)) { + acc = result.computeIfAbsent(name, Accumulator::new); + } + acc.accumulate( + tracker.minMaxCount[i * 3 + OFFSET_MIN], + tracker.minMaxCount[i * 3 + OFFSET_MAX], + tracker.minMaxCount[i * 3 + OFFSET_COUNT], + tracker.sums[i]); } - if (name == null) - continue; - - var mean = Math.round((double) sum / count) / 10.0; - result.put(name, (min / 10.0) + "/" + mean + "/" + (max / 10.0)); } return result.toString(); } static class Tracker { - private final int[] minMaxCount = new int[HASH_NO_CLASH_MODULUS * 3]; - private final long[] sums = new long[HASH_NO_CLASH_MODULUS]; - private final String[] names = new String[HASH_NO_CLASH_MODULUS]; + public static final int SIZE = MAX_STATIONS * 10; + + private static final int CORRECTION_0_TO_9 = '0' * 10 + '0'; + private static final int CORRECTION_10_TO_99 = '0' * 100 + '0' * 10 + '0'; + + private final byte[] tempNameBytes = new byte[MAX_STATION_NAME_LENGTH * MAX_UTF8_CODEPOINT_SIZE]; + private final int[] minMaxCount = new int[SIZE * 3]; + private final long[] sums = new long[SIZE]; + private final String[] names = new String[SIZE]; + private final byte[][] nameBytes = new byte[SIZE][]; private void processSegment(MemorySegment memory) { - int position = 0; long limit = memory.byteSize(); - while (position < limit) { - int pos = position; + + var pos = 0; + + // skip newlines so the chunk limit check can work correctly + while (pos < limit) { + byte c = memory.get(ValueLayout.JAVA_BYTE, pos); + if (c != '\r' && c != '\n') + break; + pos++; + } + + while (pos < limit) { byte b; int nameLength = 0, nameHash = 0; while ((b = memory.get(ValueLayout.JAVA_BYTE, pos++)) != ';') { + tempNameBytes[nameLength++] = b; nameHash = nameHash * 31 + b; - nameLength++; } - int temperature = 0, sign = 1; - outer: while ((b = memory.get(ValueLayout.JAVA_BYTE, pos++)) != '\n') { - switch (b) { - case '\r': - pos++; - break outer; - case '.': - break; - case '-': - sign = -1; - break; - default: - var digit = b - '0'; - assert digit >= 0 && digit <= 9; - temperature = 10 * temperature + digit; - } + int sign; + if (memory.get(ValueLayout.JAVA_BYTE, pos) == '-') { + sign = -1; + pos++; } + else { + sign = 1; + } + + int temperature; // between [-99.9 and 99.9], mapped to fixed point int (scaled by 10) + if (memory.get(ValueLayout.JAVA_BYTE, pos + 1) == '.') { // between -9.99 and 9.99 + assert memory.get(ValueLayout.JAVA_BYTE, pos + 1) == '.'; + temperature = memory.get(ValueLayout.JAVA_BYTE, pos) * 10 + + memory.get(ValueLayout.JAVA_BYTE, pos + 2) - CORRECTION_0_TO_9; + pos += 3; // #.# - 3 chars + } + else { // between [-99.9 and -9.99] OR [9.99 and 99.9] + assert memory.get(ValueLayout.JAVA_BYTE, pos + 2) == '.'; + temperature = memory.get(ValueLayout.JAVA_BYTE, pos) * 100 + + memory.get(ValueLayout.JAVA_BYTE, pos + 1) * 10 + + memory.get(ValueLayout.JAVA_BYTE, pos + 3) - CORRECTION_10_TO_99; + pos += 4; // ##.# - 4 chars + } + + processLine(nameHash, tempNameBytes, nameLength, temperature * sign); - processLine(nameHash, memory, position, nameLength, temperature * sign); - position = pos; + // skip newlines so the chunk limit check can work correctly + while (pos < limit) { + byte c = memory.get(ValueLayout.JAVA_BYTE, pos); + if (c != '\r' && c != '\n') + break; + pos++; + } } } - public void processLine(int nameHash, MemorySegment buffer, int nameOffset, int nameLength, int temperature) { - var i = Math.abs(nameHash) % HASH_NO_CLASH_MODULUS; + public void processLine(int nameHash, byte[] nameBytesBuffer, int nameLength, int temperature) { + var i = Math.abs(nameHash) % SIZE; - if (names[i] == null) { - names[i] = parseName(buffer, nameOffset, nameLength); + while (true) { + if (names[i] == null) { + byte[] trimmedBytes = Arrays.copyOf(nameBytesBuffer, nameLength); + names[i] = new String(trimmedBytes, StandardCharsets.UTF_8); + nameBytes[i] = trimmedBytes; + minMaxCount[i*3 + OFFSET_MIN] = Integer.MAX_VALUE; + minMaxCount[i*3 + OFFSET_MAX] = Integer.MIN_VALUE; + break; + } + else if (nameBytes[i].length==nameLength && Arrays.equals(nameBytes[i], 0, nameLength, nameBytesBuffer, 0, nameLength)) { + break; + } + if (assertions) { + var key = new String(nameBytesBuffer, 0, nameLength, StandardCharsets.UTF_8); + hashCollisionOccurrences.computeIfAbsent(key, _ -> new LongAdder()).increment(); + } + i = (i + 1) % SIZE; } - else { - assert parseName(buffer, nameOffset, nameLength).equals(names[i]) : parseName(buffer, nameOffset, nameLength) + "!=" + names[i]; + if (assertions) { + var key = new String(nameBytesBuffer, 0, nameLength, StandardCharsets.UTF_8); + if (hashCollisionOccurrences.containsKey(key)) { + hashCollisionOccurrences.computeIfAbsent(STR."\{key}[\{i}]", _ -> new LongAdder()).increment(); + } } sums[i] += temperature; @@ -200,9 +280,5 @@ public class CalculateAverage_ddimtirov { minMaxCount[mmcIndex + OFFSET_COUNT]++; } - private String parseName(MemorySegment memory, int nameOffset, int nameLength) { - byte[] array = memory.asSlice(nameOffset, nameLength).toArray(ValueLayout.JAVA_BYTE); - return new String(array, StandardCharsets.UTF_8); - } } } |
