/* * Copyright 2023 The original authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package dev.morling.onebrc; import java.io.IOException; import java.lang.foreign.Arena; import java.lang.reflect.Field; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.Arrays; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import sun.misc.Unsafe; /** * Changelog: * * Initial submission: 62000 ms * Chunked reader: 16000 ms * Optimized parser: 13000 ms * Branchless methods: 11000 ms * Adding memory mapped files: 6500 ms (based on bjhara's submission) * Skipping string creation: 4700 ms * Custom hashmap... 4200 ms * Added SWAR token checks: 3900 ms * Skipped String creation: 3500 ms (idea from kgonia) * Improved String skip: 3250 ms * Segmenting files: 3150 ms (based on spullara's code) * Not using SWAR for EOL: 2850 ms * Inlining hash calculation: 2450 ms * Replacing branchless code: 2200 ms (sometimes we need to kill the things we love) * Added unsafe memory access: 1900 ms (keeping the long[] small and local) * Fixed bug, UNSAFE bytes String: 1850 ms * Separate hash from entries: 1550 ms * Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine) * Improved layout/predictability: 1400 ms * Delayed String creation again: 1350 ms * Remove writing to buffer: 1335 ms * Optimized collecting at the end: 1310 ms * Adding a lot of comments: priceless * Changed to flyweight byte[]: 1290 ms (adds even more Unsafe, was initially slower, now faster) * More LOC now parallel: 1260 ms (moved more to processMemoryArea, recombining in ConcurrentHashMap) * Storing only the address: 1240 ms (this is now faster, tried before, was slower) * Unrolling scan-loop: 1200 ms (seems to help, perhaps even more on target machine) * Adding more readable reader: 1300 ms (scores got worse on target machine anyway) * * Using old x86 MacBook and perf: 3500 ms (different machine for testing) * Decided to rewrite loop for 16 b: 3050 ms * Small changes, limited heap: 2950 ms * * I have some instructions that could be removed, but faster with... * * Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai and many others for ideas. * * Follow me at: @royvanrijn */ public class CalculateAverage_royvanrijn { private static final String FILE = "./measurements.txt"; // private static final String FILE = "src/test/resources/samples/measurements-1.txt"; private static final Unsafe UNSAFE = initUnsafe(); // Twice the processors, smoothens things out. private static final int PROCESSORS = Runtime.getRuntime().availableProcessors(); /** * Flyweight entry in a byte[], max 128 bytes. *

* long: sum * int: min * int: max * int: count * byte: length * byte[]: cityname */ // ------------------------------------------------------------------------ private static final int ENTRY_LENGTH = (Unsafe.ARRAY_BYTE_BASE_OFFSET); private static final int ENTRY_SUM = (ENTRY_LENGTH + Byte.BYTES); private static final int ENTRY_MIN = (ENTRY_SUM + Long.BYTES); private static final int ENTRY_MAX = (ENTRY_MIN + Integer.BYTES); private static final int ENTRY_COUNT = (ENTRY_MAX + Integer.BYTES); private static final int ENTRY_NAME = (ENTRY_COUNT + Integer.BYTES); private static final int ENTRY_NAME_8 = ENTRY_NAME + 8; private static final int ENTRY_NAME_16 = ENTRY_NAME + 16; private static final int ENTRY_BASESIZE_WHITESPACE = ENTRY_NAME + 7; // with enough empty bytes to fill a long // ------------------------------------------------------------------------ private static final int PREMADE_MAX_SIZE = 1 << 5; // pre-initialize some entries in memory, keep them close private static final int PREMADE_ENTRIES = 512; // amount of pre-created entries we should use private static final int TABLE_SIZE = 1 << 19; // large enough for the contest. private static final int TABLE_MASK = (TABLE_SIZE - 1); // Idea of thomaswue, don't wait for slow unmap: private static void spawnWorker() throws IOException { ProcessHandle.Info info = ProcessHandle.current().info(); ArrayList workerCommand = new ArrayList<>(); info.command().ifPresent(workerCommand::add); info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args))); workerCommand.add("--worker"); new ProcessBuilder() .command(workerCommand) .inheritIO() .redirectOutput(ProcessBuilder.Redirect.PIPE) .start() .getInputStream() .transferTo(System.out); } public static void main(String[] args) throws Exception { if (args.length == 0 || !("--worker".equals(args[0]))) { spawnWorker(); return; } // Calculate input segments. final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); final long fileSize = fileChannel.size(); final long segmentSize = (fileSize + PROCESSORS - 1) / PROCESSORS; final long mapAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); final Thread[] parallelThreads = new Thread[PROCESSORS - 1]; // This is where the entries will land: final ConcurrentHashMap measurements = new ConcurrentHashMap(1 << 10); // We create separate threads for twice the amount of processors. long lastAddress = mapAddress; final long endOfFile = mapAddress + fileSize; for (int i = 0; i < PROCESSORS - 1; ++i) { final long fromAddress = lastAddress; final long toAddress = Math.min(endOfFile, fromAddress + segmentSize); final Thread thread = new Thread(() -> { // The actual work is done here: final byte[][] table = processMemoryArea(fromAddress, toAddress, fromAddress == mapAddress); for (byte[] entry : table) { if (entry != null) { measurements.merge(entryToName(entry), entry, CalculateAverage_royvanrijn::mergeEntry); } } }); thread.start(); // start a.s.a.p. parallelThreads[i] = thread; lastAddress = toAddress; } // Use the current thread for the part of memory: final byte[][] table = processMemoryArea(lastAddress, mapAddress + fileSize, false); for (byte[] entry : table) { if (entry != null) { measurements.merge(entryToName(entry), entry, CalculateAverage_royvanrijn::mergeEntry); } } // Wait for all threads to finish: for (Thread thread : parallelThreads) { // Can we implement work-stealing? Not sure how... thread.join(); } // If we don't reach start of file, System.out.print("{" + measurements.entrySet().stream().sorted(Map.Entry.comparingByKey()) .map(entry -> entry.getKey() + '=' + entryValuesToString(entry.getValue())) .collect(Collectors.joining(", "))); System.out.println("}"); System.out.close(); // close the stream to stop } private static byte[] fillEntry(final byte[] entry, final long fromAddress, final int entryLength, final int temp, final long readBuffer1, final long readBuffer2) { UNSAFE.putLong(entry, ENTRY_SUM, temp); UNSAFE.putInt(entry, ENTRY_MIN, temp); UNSAFE.putInt(entry, ENTRY_MAX, temp); UNSAFE.putInt(entry, ENTRY_COUNT, 1); UNSAFE.putByte(entry, ENTRY_LENGTH, (byte) entryLength); UNSAFE.copyMemory(null, fromAddress, entry, ENTRY_NAME, entryLength - 16); UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 16, readBuffer1); UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 8, readBuffer2); return entry; } private static byte[] fillEntry16(final byte[] entry, final int entryLength, final int temp, final long readBuffer1, final long readBuffer2) { UNSAFE.putLong(entry, ENTRY_SUM, temp); UNSAFE.putInt(entry, ENTRY_MIN, temp); UNSAFE.putInt(entry, ENTRY_MAX, temp); UNSAFE.putInt(entry, ENTRY_COUNT, 1); UNSAFE.putByte(entry, ENTRY_LENGTH, (byte) entryLength); UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 16, readBuffer1); UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 8, readBuffer2); return entry; } public static void updateEntry(final byte[] entry, final int temp) { int entryMin = UNSAFE.getInt(entry, ENTRY_MIN); int entryMax = UNSAFE.getInt(entry, ENTRY_MAX); long entrySum = UNSAFE.getLong(entry, ENTRY_SUM) + temp; int entryCount = UNSAFE.getInt(entry, ENTRY_COUNT) + 1; if (temp < entryMin) { UNSAFE.putInt(entry, ENTRY_MIN, temp); } else if (temp > entryMax) { UNSAFE.putInt(entry, ENTRY_MAX, temp); } UNSAFE.putInt(entry, ENTRY_COUNT, entryCount); UNSAFE.putLong(entry, ENTRY_SUM, entrySum); } public static byte[] mergeEntry(final byte[] entry, final byte[] merge) { long sum = UNSAFE.getLong(merge, ENTRY_SUM); final int mergeMin = UNSAFE.getInt(merge, ENTRY_MIN); final int mergeMax = UNSAFE.getInt(merge, ENTRY_MAX); int count = UNSAFE.getInt(merge, ENTRY_COUNT); sum += UNSAFE.getLong(entry, ENTRY_SUM); count += UNSAFE.getInt(entry, ENTRY_COUNT); int entryMin = UNSAFE.getInt(entry, ENTRY_MIN); int entryMax = UNSAFE.getInt(entry, ENTRY_MAX); entryMin = Math.min(entryMin, mergeMin); entryMax = Math.max(entryMax, mergeMax); UNSAFE.putInt(entry, ENTRY_MIN, entryMin); UNSAFE.putInt(entry, ENTRY_MAX, entryMax); UNSAFE.putLong(entry, ENTRY_SUM, sum); UNSAFE.putInt(entry, ENTRY_COUNT, count); return entry; } private static String entryToName(final byte[] entry) { // Get the length from memory: int length = UNSAFE.getByte(entry, ENTRY_LENGTH); byte[] name = new byte[length]; UNSAFE.copyMemory(entry, ENTRY_NAME, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, length); // Create a new String with the existing byte[]: return new String(name, StandardCharsets.UTF_8).trim(); } private static String entryValuesToString(final byte[] entry) { return (round(UNSAFE.getInt(entry, ENTRY_MIN)) + "/" + round((1.0 * UNSAFE.getLong(entry, ENTRY_SUM)) / UNSAFE.getInt(entry, ENTRY_COUNT)) + "/" + round(UNSAFE.getInt(entry, ENTRY_MAX))); } // Print a piece of memory: // For debug. private static String printMemory(final Object target, final long address, int length) { String result = ""; for (int i = 0; i < length; i++) { result += (char) UNSAFE.getByte(target, address + i); } return result; } // Print a piece of memory: // For debug. private static String printMemory(final long value, int length) { String result = ""; for (int i = 0; i < length; i++) { result += (char) ((value >> (i << 3)) & 0xFF); } return result; } private static double round(final double value) { return Math.round(value) / 10.0; } private static final class Reader { private long ptr; private long readBuffer1; private long readBuffer2; private long hash; private long entryStart; private int entryLength; // in bytes rounded to nearest 16 private final long endAddress; Reader(final long startAddress, final long endAddress, final boolean isFileStart) { this.ptr = startAddress; this.endAddress = endAddress; // Adjust start to next delimiter: if (!isFileStart) { ptr--; while (ptr < endAddress) { if (UNSAFE.getByte(ptr++) == '\n') { break; } } } } private void processStart() { hash = 0; entryStart = ptr; entryLength = 0; } private boolean hasNext() { return (ptr < endAddress); } private static final long DELIMITER_MASK = 0x3B3B3B3B3B3B3B3BL; private boolean readNext() { long lastRead = UNSAFE.getLong(ptr); entryLength += 16; // Find delimiter and create mask for long1 long comparisonResult1 = (lastRead ^ DELIMITER_MASK); long highBitMask1 = (comparisonResult1 - 0x0101010101010101L) & (~comparisonResult1 & 0x8080808080808080L); boolean noContent1 = highBitMask1 == 0; long mask1 = noContent1 ? 0 : ~((highBitMask1 >>> 7) - 1); int position1 = noContent1 ? 0 : 1 + (Long.numberOfTrailingZeros(highBitMask1) >> 3); readBuffer1 = lastRead & ~mask1; hash ^= readBuffer1; int delimiter1 = position1 == 0 ? 0 : position1; // not nnecessary, but faster? if (delimiter1 != 0) { hash ^= hash >> 32; readBuffer2 = 0; ptr += delimiter1; return false; } lastRead = UNSAFE.getLong(ptr + 8); // Repeat for long2 long comparisonResult2 = (lastRead ^ DELIMITER_MASK); long highBitMask2 = (comparisonResult2 - 0x0101010101010101L) & (~comparisonResult2 & 0x8080808080808080L); boolean noContent2 = highBitMask2 == 0; long mask2 = noContent2 ? 0 : ~((highBitMask2 >>> 7) - 1); int position2 = noContent2 ? 0 : 1 + (Long.numberOfTrailingZeros(highBitMask2) >> 3); // Apply masks readBuffer2 = lastRead & ~mask2; hash ^= readBuffer2; int delimiter2 = position2 == 0 ? 0 : position2 + 8; // not necessary, but faster? hash ^= hash >> 32; if (delimiter2 != 0) { ptr += delimiter2; return false; } ptr += 16; return true; } private int processEndAndGetTemperature() { finalizeHash(); return readTemperature(); } private void finalizeHash() { hash ^= hash >> 17; // extra entropy } private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); // Awesome idea of merykitty: private int readTemperature() { // This is the number part: X.X, -X.X, XX.x or -XX.X final long numberBytes = UNSAFE.getLong(ptr); final long invNumberBytes = ~numberBytes; final int dotPosition = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS); // Calculates the sign final long signed = (invNumberBytes << 59) >> 63; final int min28 = (dotPosition ^ 0b11100); final long minusFilter = ~(signed & 0xFF); // Use the pre-calculated decimal position to adjust the values final long digits = ((numberBytes & minusFilter) << min28) & 0x0F000F0F00L; // Update the pointer here, bit awkward, but we have all the data ptr += (dotPosition >> 3) + 3; // Multiply by a magic (100 * 0x1000000 + 10 * 0x10000 + 1), to get the result final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF; // And perform abs() return (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick } private boolean matches(final byte[] entry) { int step = 0; for (; step < entryLength - 16;) { if (compare(null, entryStart + step, entry, ENTRY_NAME + step)) { return false; } step += 8; } if (compare(readBuffer1, entry, ENTRY_NAME + step)) { return false; } step += 8; if (compare(readBuffer2, entry, ENTRY_NAME + step)) { return false; } return true; } private boolean matches16(final byte[] entry) { if (compare(readBuffer1, entry, ENTRY_NAME)) { return false; } if (compare(readBuffer2, entry, ENTRY_NAME + 8)) { return false; } return true; } } private static byte[][] processMemoryArea(final long startAddress, final long endAddress, boolean isFileStart) { final byte[][] table = new byte[TABLE_SIZE][]; final byte[][] preConstructedEntries = new byte[PREMADE_ENTRIES][ENTRY_BASESIZE_WHITESPACE + PREMADE_MAX_SIZE]; final Reader reader = new Reader(startAddress, endAddress, isFileStart); byte[] entry; int entryCount = 0; // Find the correct starting position while (reader.hasNext()) { reader.processStart(); if (!reader.readNext()) { // First 16 bytes: int temperature = reader.processEndAndGetTemperature(); // Find or insert the entry: int index = (int) (reader.hash & TABLE_MASK); while (true) { entry = table[index]; if (entry == null) { byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++] : new byte[ENTRY_BASESIZE_WHITESPACE + 16]; // with enough room table[index] = fillEntry16(entryBytes, 16, temperature, reader.readBuffer1, reader.readBuffer2); break; } else if (reader.matches16(entry)) { updateEntry(entry, temperature); break; } else { // Move to the next index index = (index + 1) & TABLE_MASK; } } continue; } while (reader.readNext()) ; int temperature = reader.processEndAndGetTemperature(); // Find or insert the entry: int index = (int) (reader.hash & TABLE_MASK); while (true) { entry = table[index]; if (entry == null) { int length = reader.entryLength; byte[] entryBytes = (length < PREMADE_MAX_SIZE && entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++] : new byte[ENTRY_BASESIZE_WHITESPACE + length]; // with enough room table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature, reader.readBuffer1, reader.readBuffer2); break; } else if (reader.matches(entry)) { updateEntry(entry, temperature); break; } else { // Move to the next index index = (index + 1) & TABLE_MASK; } } } return table; } private static boolean compare(final Object object1, final long address1, final Object object2, final long address2) { return UNSAFE.getLong(object1, address1) != UNSAFE.getLong(object2, address2); } private static boolean compare(final long value1, final Object object2, final long address2) { return value1 != UNSAFE.getLong(object2, address2); } /* * `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___ * / ` \| _ \ __| \| \ \ / /_\ | | | | | | __| * | () | _ / __|| . |\ V / _ \| |_| |_| | ._| * \___/|_| |___|_|\_| \_/_/ \_\___|\___/|___| * ---------------- BETTER SOFTWARE, FASTER -- * * https://www.openvalue.eu/ */ private static Unsafe initUnsafe() { try { final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); theUnsafe.setAccessible(true); return (Unsafe) theUnsafe.get(Unsafe.class); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); } } }