From bd4cff945daf1526c31605623dc56d1d31885ed6 Mon Sep 17 00:00:00 2001 From: Thomas Wuerthinger Date: Fri, 12 Jan 2024 20:51:22 +0100 Subject: Adding Scanner object and also tuning for better branch prediction for about +6%. (#341) --- .../morling/onebrc/CalculateAverage_thomaswue.java | 283 +++++++++++++-------- 1 file changed, 182 insertions(+), 101 deletions(-) (limited to 'src/main/java/dev') diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index 7fd880a..10e92fc 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -32,30 +32,25 @@ import java.util.stream.IntStream; * Simple solution that memory maps the input file, then splits it into one segment per available core and uses * sun.misc.Unsafe to directly access the mapped memory. Uses a long at a time when checking for collision. *

- * Runs in 0.70s on my Intel i9-13900K + * Runs in 0.66s on my Intel i9-13900K * Perf stats: - * 40,622,862,783 cpu_core/cycles/ - * 48,241,929,925 cpu_atom/cycles/ + * 35,935,262,091 cpu_core/cycles/ + * 47,305,591,173 cpu_atom/cycles/ */ public class CalculateAverage_thomaswue { private static final String FILE = "./measurements.txt"; // Holding the current result for a single city. private static class Result { - final long nameAddress; - long lastNameLong; - int remainingShift; - int min; - int max; + long lastNameLong, secondLastNameLong, nameAddress; + int nameLength, remainingShift; + int min, max, count; long sum; - int count; - private Result(long nameAddress, int value) { + private Result(long nameAddress) { this.nameAddress = nameAddress; - this.min = value; - this.max = value; - this.sum = value; - this.count = 1; + this.min = Integer.MAX_VALUE; + this.max = Integer.MIN_VALUE; } public String toString() { @@ -73,6 +68,10 @@ public class CalculateAverage_thomaswue { sum += other.sum; count += other.count; } + + public String calcName() { + return new Scanner(nameAddress, nameAddress + nameLength).getString(nameLength); + } } public static void main(String[] args) throws IOException { @@ -81,122 +80,155 @@ public class CalculateAverage_thomaswue { long[] chunks = getSegments(numberOfChunks); // Parallel processing of segments. - List> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> { - HashMap cities = HashMap.newHashMap(1 << 10); - parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1], cities); - return cities; - }).parallel().toList(); - - // Accumulate results sequentially. - HashMap result = allResults.getFirst(); - for (int i = 1; i < allResults.size(); ++i) { - for (Map.Entry entry : allResults.get(i).entrySet()) { - Result current = result.putIfAbsent(entry.getKey(), entry.getValue()); - if (current != null) { - current.add(entry.getValue()); - } - } - } + List> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1])) + .map(resultArray -> { + List results = new ArrayList<>(); + for (Result r : resultArray) { + if (r != null) { + results.add(r); + } + } + return results; + }).parallel().toList(); // Final output. - System.out.println(new TreeMap<>(result)); + System.out.println(accumulateResults(allResults)); } - private static final Unsafe UNSAFE = initUnsafe(); - - private static Unsafe initUnsafe() { - try { - Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); - theUnsafe.setAccessible(true); - return (Unsafe) theUnsafe.get(Unsafe.class); - } - catch (NoSuchFieldException | IllegalAccessException e) { - throw new RuntimeException(e); + // Accumulate results sequentially for simplicity. + private static TreeMap accumulateResults(List> allResults) { + TreeMap result = new TreeMap<>(); + for (List resultArr : allResults) { + for (Result r : resultArr) { + String name = r.calcName(); + Result current = result.putIfAbsent(name, r); + if (current != null) { + current.add(r); + } + } } + return result; } - private static void parseLoop(long chunkStart, long chunkEnd, HashMap cities) { + // Main parse loop. + private static Result[] parseLoop(long chunkStart, long chunkEnd) { Result[] results = new Result[1 << 18]; - long scanPtr = chunkStart; - while (scanPtr < chunkEnd) { - long nameAddress = scanPtr; + Scanner scanner = new Scanner(chunkStart, chunkEnd); + while (scanner.hasNext()) { + long nameAddress = scanner.pos(); long hash = 0; // Search for ';', one long at a time. - long word = UNSAFE.getLong(scanPtr); + long word = scanner.getLong(); int pos = findDelimiter(word); if (pos != 8) { - scanPtr += pos; - word = word & (-1L >>> ((8 - pos - 1) << 3)); + scanner.add(pos); + word = mask(word, pos); hash ^= word; + + Result existingResult = results[hashToIndex(hash, results)]; + if (existingResult != null && existingResult.lastNameLong == word) { + scanAndRecord(scanner, existingResult); + continue; + } } else { - scanPtr += 8; + scanner.add(8); hash ^= word; - while (true) { - word = UNSAFE.getLong(scanPtr); - pos = findDelimiter(word); - if (pos != 8) { - scanPtr += pos; - word = word & (-1L >>> ((8 - pos - 1) << 3)); - hash ^= word; - break; + long prevWord = word; + word = scanner.getLong(); + pos = findDelimiter(word); + if (pos != 8) { + scanner.add(pos); + word = mask(word, pos); + hash ^= word; + Result existingResult = results[hashToIndex(hash, results)]; + if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { + scanAndRecord(scanner, existingResult); + continue; } - else { - scanPtr += 8; - hash ^= word; + } + else { + scanner.add(8); + hash ^= word; + while (true) { + word = scanner.getLong(); + pos = findDelimiter(word); + if (pos != 8) { + scanner.add(pos); + word = mask(word, pos); + hash ^= word; + break; + } + else { + scanner.add(8); + hash ^= word; + } } } } // Save length of name for later. - int nameLength = (int) (scanPtr - nameAddress); - scanPtr++; + int nameLength = (int) (scanner.pos() - nameAddress); + scanner.add(1); - long numberWord = UNSAFE.getLong(scanPtr); - // The 4th binary digit of the ascii of a digit is 1 while - // that of the '.' is 0. This finds the decimal separator - // The value can be 12, 20, 28 + long numberWord = scanner.getLong(); int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); int number = convertIntoNumber(decimalSepPos, numberWord); - - // Skip past new line. - // scanPtr++; - scanPtr += (decimalSepPos >>> 3) + 3; + scanner.add((decimalSepPos >>> 3) + 3); // Final calculation for index into hash table. - int hashAsInt = (int) (hash ^ (hash >>> 32)); - int finalHash = (hashAsInt ^ (hashAsInt >>> 18)); - int tableIndex = (finalHash & (results.length - 1)); + int tableIndex = hashToIndex(hash, results); outer: while (true) { Result existingResult = results[tableIndex]; if (existingResult == null) { - newEntry(results, cities, nameAddress, number, tableIndex, nameLength); - break; + existingResult = newEntry(results, nameAddress, tableIndex, nameLength, scanner); } - else { - // Check for collision. - int i = 0; - for (; i < nameLength + 1 - 8; i += 8) { - if (UNSAFE.getLong(existingResult.nameAddress + i) != UNSAFE.getLong(nameAddress + i)) { - tableIndex = (tableIndex + 1) & (results.length - 1); - continue outer; - } - } - if (((existingResult.lastNameLong ^ UNSAFE.getLong(nameAddress + i)) << existingResult.remainingShift) == 0) { - existingResult.min = Math.min(existingResult.min, number); - existingResult.max = Math.max(existingResult.max, number); - existingResult.sum += number; - existingResult.count++; - break; - } - else { - // Collision error, try next. + // Check for collision. + int i = 0; + for (; i < nameLength + 1 - 8; i += 8) { + if (scanner.getLongAt(existingResult.nameAddress + i) != scanner.getLongAt(nameAddress + i)) { tableIndex = (tableIndex + 1) & (results.length - 1); + continue outer; } } + if (((existingResult.lastNameLong ^ scanner.getLongAt(nameAddress + i)) << existingResult.remainingShift) == 0) { + record(existingResult, number); + break; + } + else { + // Collision error, try next. + tableIndex = (tableIndex + 1) & (results.length - 1); + } } } + return results; + } + + private static void scanAndRecord(Scanner scanPtr, Result existingResult) { + scanPtr.add(1); + long numberWord = scanPtr.getLong(); + int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + int number = convertIntoNumber(decimalSepPos, numberWord); + scanPtr.add((decimalSepPos >>> 3) + 3); + record(existingResult, number); + } + + private static void record(Result existingResult, int number) { + existingResult.min = Math.min(existingResult.min, number); + existingResult.max = Math.max(existingResult.max, number); + existingResult.sum += number; + existingResult.count++; + } + + private static int hashToIndex(long hash, Result[] results) { + int hashAsInt = (int) (hash ^ (hash >>> 32)); + int finalHash = (hashAsInt ^ (hashAsInt >>> 18)); + return (finalHash & (results.length - 1)); + } + + private static long mask(long word, int pos) { + return word & (-1L >>> ((8 - pos - 1) << 3)); } // Special method to convert a number in the specific format into an int value without branches created by @@ -229,19 +261,18 @@ public class CalculateAverage_thomaswue { return Long.numberOfTrailingZeros(tmp) >>> 3; } - private static void newEntry(Result[] results, HashMap cities, long nameAddress, int number, int hash, int nameLength) { - Result r = new Result(nameAddress, number); + private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner) { + Result r = new Result(nameAddress); results[hash] = r; - byte[] bytes = new byte[nameLength]; int i = 0; for (; i < nameLength + 1 - 8; i += 8) { + r.secondLastNameLong = (scanner.getLongAt(nameAddress + i)); } - r.lastNameLong = UNSAFE.getLong(nameAddress + i); r.remainingShift = (64 - (nameLength + 1 - i) << 3); - UNSAFE.copyMemory(null, nameAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); - String nameAsString = new String(bytes, StandardCharsets.UTF_8); - cities.put(nameAsString, r); + r.lastNameLong = (scanner.getLongAt(nameAddress + i) & (-1L >>> r.remainingShift)); + r.nameLength = nameLength; + return r; } private static long[] getSegments(int numberOfChunks) throws IOException { @@ -252,10 +283,11 @@ public class CalculateAverage_thomaswue { long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); chunks[0] = mappedAddress; long endAddress = mappedAddress + fileSize; + Scanner s = new Scanner(mappedAddress, mappedAddress + fileSize); for (int i = 1; i < numberOfChunks; ++i) { long chunkAddress = mappedAddress + i * segmentSize; // Align to first row start. - while (chunkAddress < endAddress && UNSAFE.getByte(chunkAddress++) != '\n') { + while (chunkAddress < endAddress && (s.getLongAt(chunkAddress++) & 0xFF) != '\n') { // nop } chunks[i] = Math.min(chunkAddress, endAddress); @@ -264,4 +296,53 @@ public class CalculateAverage_thomaswue { return chunks; } } + + private static class Scanner { + + private static final Unsafe UNSAFE = initUnsafe(); + + private static Unsafe initUnsafe() { + try { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + return (Unsafe) theUnsafe.get(Unsafe.class); + } + catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + long pos, end; + + public Scanner(long start, long end) { + this.pos = start; + this.end = end; + } + + boolean hasNext() { + return pos < end; + } + + long pos() { + return pos; + } + + void add(int delta) { + pos += delta; + } + + long getLong() { + return UNSAFE.getLong(pos); + } + + long getLongAt(long pos) { + return UNSAFE.getLong(pos); + } + + public String getString(int nameLength) { + byte[] bytes = new byte[nameLength]; + UNSAFE.copyMemory(null, pos, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength); + return new String(bytes, StandardCharsets.UTF_8); + } + } } -- cgit v1.2.3