aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_cliffclick.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/dev/morling/onebrc/CalculateAverage_cliffclick.java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_cliffclick.java470
1 files changed, 470 insertions, 0 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_cliffclick.java b/src/main/java/dev/morling/onebrc/CalculateAverage_cliffclick.java
new file mode 100644
index 0000000..54761a3
--- /dev/null
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_cliffclick.java
@@ -0,0 +1,470 @@
+/*
+ * Copyright 2024 Cliff Click
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package dev.morling.onebrc;
+
+import java.io.*;
+import java.lang.reflect.Field;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.*;
+import java.util.Arrays;
+import sun.misc.Unsafe;
+
+abstract class CalculateAverage_cliffclick {
+ // abstract class CNC {
+ public static final int NCPUS = Runtime.getRuntime().availableProcessors();
+ public static final long HASSEMI = 0x3B3B3B3B3B3B3B3BL;
+
+ private static final Unsafe UNSAFE;
+ private static long MMAP_ADDRESS;
+ static {
+ try {
+ Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ theUnsafe.setAccessible(true);
+ UNSAFE = (Unsafe) theUnsafe.get(Unsafe.class);
+
+ Field f;
+ try {
+ f = java.nio.Buffer.class.getDeclaredField("address");
+ }
+ catch (java.lang.NoSuchFieldException e) {
+ throw new RuntimeException(e);
+ }
+ MMAP_ADDRESS = UNSAFE.objectFieldOffset(f);
+
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ if (args.length < 1)
+ args = new String[]{ "measurements.txt" };
+
+ Work w = work(args);
+ String foo = w.toString();
+ byte[] bar = new byte[foo.length()];
+ foo.getBytes(0, foo.length(), bar, 0);
+ System.out.write(bar);
+ System.out.write('\n');
+ }
+
+ // General work flow:
+
+ // Spawn threads. Make empty hash for sums, counts.
+ // Seek to offset. Skip till newline. Parse to end of chunk, plus rest of partial line.
+
+ // SWAR out city name into n8; both hash and uhash name.
+ // Lookup in sums; if miss: insert map uhash to 0 cnt; insert map uhash to String name also;
+ // if hit: bump count; same index in sums array, bump sums with scaled decimal
+ // At end, grab more work until done.
+
+ // At end, across all threads total sums & counts.
+ // Compute averages, lookup names and print.
+ static Work work(String[] args) throws Exception {
+ File f = new File(args[0]);
+
+ // How many threads?
+ int ncpus = (int) Math.min((f.length() >> 14) + 1, NCPUS); // Keep 1<<14 min work
+ long len = (f.length() / ncpus) + 1;
+
+ Work[] WS = new Work[ncpus];
+ Thread[] TS = new Thread[ncpus];
+
+ // Spawn work on threads
+ for (int i = 0; i < ncpus; i++) {
+ long s = i * len;
+ Work w = WS[i] = new Work();
+ Thread T = TS[i] = new Thread() {
+ public void run() {
+ tstart(w, f, s, Math.min(len, f.length() - s));
+ }
+ };
+ T.start();
+ }
+
+ TS[0].join();
+ Work W = WS[0];
+ for (int i = 1; i < ncpus; i++) {
+ TS[i].join();
+ W.reduce(WS[i]);
+ }
+ return W;
+ }
+
+ static void tstart(Work w, File f, long start, long len) {
+ try {
+ // Thread gets a chunk of work
+ FileChannel fc = FileChannel.open(f.toPath(), StandardOpenOption.READ);
+ final int MAX_MAP = 1 << 30;
+
+ for (long s = start; s < start + len; s += MAX_MAP) {
+ int maxlen = (int) Math.min(len + 1, MAX_MAP); // Length capped at MAX_MAP
+ long rem = f.length() - s;
+ int mlen = (int) Math.min(rem, maxlen + 100); // Add a little extra so can finish out a line
+ int clen = (int) Math.min(rem, maxlen);
+ // mmap is capped at MAX_MAP (plus change), or
+ // to the end of the chosen parse length (plus change)
+ // or the end of the file in any case
+ MappedByteBuffer mmap = fc.map(FileChannel.MapMode.READ_ONLY, s, mlen);
+ // Chunk runs to min(MAX_MAP, parse length, eof), plus it runs to the end
+ // of any partial line.
+ do_chunk(w, s > 0, clen, mmap);
+ }
+ }
+ catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ }
+
+ // Has a zero byte in a long, copied straight from Hackers Delight
+ static long has0(long x) {
+ return (x - 0x0101010101010101L) & (~x) & 0x8080808080808080L;
+ }
+
+ // Parse a chunk, from 0 to limit in mmap. Runs past limit to finish any
+ // partial line. If skip1, then skip any leading partial line.
+ static void do_chunk(Work w, boolean skip1, int limit, MappedByteBuffer mmap) {
+ assert mmap.isDirect();
+ int idx = 0;
+ int max = mmap.limit();
+ long base = UNSAFE.getLong(mmap, MMAP_ADDRESS);
+
+ // If start>0, skip until first newline
+ if (skip1)
+ idx = skipFirst(idx, base);
+
+ // The very last entry will want to fetch 8 bytes, some of which may go
+ // past the mmap max - do this entry now, before looping.
+ if (limit == max)
+ limit = skipLast(limit, w, base);
+
+ // Edges of the ~2G region taken care of. Now do the giant middle part.
+
+ // For this chunk of file do...
+ while (idx < limit) {
+ int cityx = idx; // Used if we find a new city name
+
+ // SWAR read and build n8; the long-as-a-string value. Also track start
+ // and end of the string, in case it is new and needs to be inserted into
+ // the n8->city_name map.
+ long n8 = 0;
+ // Read a misaligned long
+ long x = UNSAFE.getLong(base + idx);
+ // Found semi ?
+ long hasM = has0(x ^ HASSEMI);
+ while (hasM == 0) {
+ // Read 2nd word of city
+ n8 ^= x;
+ idx += 8;
+ // Read a misaligned long
+ x = UNSAFE.getLong(base + idx);
+ // Found semi ?
+ hasM = has0(x ^ HASSEMI);
+ }
+ // Found a semicolon this word.
+ // The high bit of the byte in question is set.
+ int shr = Long.numberOfTrailingZeros(hasM) + 1;
+ if (shr > 8) {
+ int shr2 = 72 - shr;
+ n8 ^= (x << shr2) >> shr2;
+ idx += (shr >> 3) - 1;
+ }
+
+ // Skip semicolon
+ idx++;
+
+ // Reading tempature, and add
+ idx = parseData(idx, w, cityx, n8, base);
+ }
+ }
+
+ // The very last entry will want to fetch 8 bytes, some of which may go
+ // past the mmap max - do this entry now, before looping.
+ private static int skipLast(int limit, Work w, long base) {
+ limit--;
+ while (limit > 0 && UNSAFE.getByte(base + limit - 1) != '\n')
+ limit--;
+ long n8 = 0, mask = 0, c;
+ int i = limit;
+ while ((c = UNSAFE.getByte(base + i)) != ';') {
+ mask = (mask >> 8) | (c << 56);
+ i++;
+ if (((limit - i) & 7) == 0) {
+ n8 ^= mask;
+ mask = 0;
+ }
+ }
+ int shr = (limit - i) & 7;
+ n8 ^= (mask >> (shr << 3));
+ parseData(i + 1, w, limit, n8, base);
+ return limit;
+ }
+
+ // Parse temp data, and insert entry into hash table
+ private static int parseData(int idx, Work w, int cityx, long n8, long base) {
+ // Reading tempature:
+ int temp = 0;
+ boolean neg = false;
+ byte b = UNSAFE.getByte(base + idx++);
+ if (b == '-') {
+ neg = true;
+ b = UNSAFE.getByte(base + idx++);
+ }
+ temp = b - '0';
+ b = UNSAFE.getByte(base + idx++);
+ if (b != '.') {
+ temp = temp * 10 + b - '0';
+ idx++;
+ }
+ // Read fraction digit; scaled decimal temp
+ b = UNSAFE.getByte(base + idx++);
+ temp = temp * 10 + b - '0';
+ if (neg)
+ temp = -temp;
+ // Skip newline
+ idx++;
+ // F*KING WINDOWS.
+ // Skip CR
+ // idx++;
+ w.insert(n8, temp, base, cityx);
+ return idx;
+ }
+
+ private static int skipFirst(int idx, long base) {
+ while (UNSAFE.getByte(base + idx++) != '\n')
+ ;
+ // WINDOWS
+ // idx++;
+ return idx;
+ }
+
+ private static class Work {
+ private static final int TAB_SIZE = 0x4000; // 512 for 413 cities
+ // Fixed size hashtable. Longs are packed to hold the data.
+ // cnt uhash
+ // 8 7 6 5 4 3 2 1
+ // min max temp sum
+ // 8 7 6 5 4 3 2 1
+ long[] table = new long[TAB_SIZE * 2]; //
+ String[] cities = new String[TAB_SIZE]; // Same index holds city names
+
+ // Gather for city bits
+ final byte[] city = new byte[256];
+ int reprobes;
+
+ void insert(long n8, int temp, long base, int cityx) {
+ // 3 bytes uniquely id city, left at 4
+ int uhash = (int) uhash_final(n8);
+ // Index in small table
+ int ihash = hash_hash(uhash);
+ long cnt_key = table[(ihash << 1)];
+ long min_max = table[(ihash << 1) + 1];
+ int key = key(cnt_key);
+ while (key != uhash) {
+ if (key == 0) {
+ // Miss in hash table
+ cnt_key = uhash & 0xFFFFFFFFL;
+ min_max = min_max(0x7FFF, 0xF000, 0);
+ // Put city name in cities
+ new_city(ihash, base + cityx);
+ break;
+ }
+ // Reprobe. Seeiong 53M reprobes out of 1000M rows, so a 5.3% reprobe rate
+ ihash = reprobe(ihash, uhash);
+ cnt_key = table[(ihash << 1)];
+ min_max = table[(ihash << 1) + 1];
+ key = key(cnt_key);
+ }
+ // assert cities[ihash].equals(toChar(n8)) : String.format("uhash=0x%08x %s %s, FAILS FOR %d",uhash,cities[ihash],toChar(n8),X0);
+
+ // Break down parts
+ int min = min(min_max);
+ min = Math.min(min, temp);
+ int max = max(min_max);
+ max = Math.max(max, temp);
+ int sum = temp(min_max);
+ sum += temp;
+ min_max = min_max(min, max, sum);
+ // Back into table
+ table[(ihash << 1)] = cnt_key + (1L << 32);
+ table[(ihash << 1) + 1] = min_max;
+ }
+
+ // Hash the n8 value; the 3 bytes uniquely identify the city.
+ static long uhash_final(long n8) {
+ return n8 ^ (n8 >> 29);
+ }
+
+ // New city
+ void new_city(int ihash, long base_cityx) {
+ // Put city name in cities
+ int i = 0;
+ byte c;
+ while ((c = UNSAFE.getByte(base_cityx++)) != ';')
+ city[i++] = c;
+ cities[ihash] = new String(city, 0, 0, i);
+ }
+
+ private static int hash_hash(int uhash) {
+ // Index in small table
+ int ihash = uhash;
+ ihash = ihash ^ (ihash >> 17);
+ ihash = ihash + 29 * uhash;
+ ihash &= (TAB_SIZE - 1);
+ return ihash;
+ }
+
+ private static int reprobe(int ihash, int uhash) {
+ return (ihash + (uhash | 1)) & (TAB_SIZE - 1);
+ }
+
+ // Convert the large unique hash into a smaller table hash
+ int ihash(int uhash) {
+ // Index in small table
+ int ihash = hash_hash(uhash);
+ long cnt_key = table[ihash << 1];
+ int key = key(cnt_key);
+ while (key != uhash) {
+ if (key == 0)
+ return ihash;
+ // Reprobe
+ ihash = reprobe(ihash, uhash);
+ cnt_key = table[(ihash << 1)];
+ key = key(cnt_key);
+ }
+ return ihash;
+ }
+
+ void reduce(Work w) {
+ for (int i = 0; i < w.cities.length; i++) {
+ if (w.cities[i] == null)
+ continue;
+
+ // Break down parts
+ long cnt_key = w.table[(i << 1)];
+ long min_max = w.table[(i << 1) + 1];
+ int cnt = cnt(cnt_key);
+ int key = key(cnt_key);
+ int min = min(min_max);
+ int max = max(min_max);
+ int sum = temp(min_max);
+
+ // Find key in local table
+ int ihash = ihash(key);
+ long cnt_key0 = table[(ihash << 1)];
+ long min_max0 = table[(ihash << 1) + 1];
+ int cnt0 = cnt(cnt_key0);
+ int key0 = key(cnt_key0);
+ int min0 = min(min_max0);
+ int max0 = max(min_max0);
+ int sum0 = temp(min_max0);
+
+ cnt0 += cnt;
+ sum0 += sum;
+ min0 = Math.min(min0, min);
+ max0 = Math.max(max0, max);
+ if (key0 == 0) {
+ key0 = key;
+ min0 = min;
+ cities[ihash] = w.cities[i];
+ }
+ table[(ihash << 1)] = cnt_key(cnt0, key0);
+ table[(ihash << 1) + 1] = min_max(min0, max0, sum0);
+ }
+ }
+
+ static int key(long cnt_key) {
+ return (int) cnt_key;
+ }
+
+ static int cnt(long cnt_key) {
+ return (int) (cnt_key >> 32);
+ }
+
+ static int min(long min_max) {
+ return (int) (min_max >> 48);
+ } // Signed right shift; min often negative
+
+ static int max(long min_max) {
+ return (short) ((min_max >>> 32) & 0xFFFF);
+ }// Unsigned right shift;
+
+ static int temp(long min_max) {
+ return (int) min_max;
+ }; // Low int
+
+ static long cnt_key(int cnt, int key) {
+ return ((long) cnt << 32) | (key & 0xFFFFFFFFL);
+ }
+
+ static long min_max(int min, int max, int sum) {
+ return ((long) min << 48) | ((long) (max & 0xFFFF) << 32) | (((long) sum) & 0xFFFFFFFFL);
+ }
+
+ @Override
+ public String toString() {
+ int ncitys = 0; // totals 413
+ for (int i = 0; i < TAB_SIZE; i++)
+ if (cities[i] != null)
+ ncitys++;
+ // Index of city entries
+ Integer[] is = new Integer[ncitys];
+ for (int i = 0, j = 0; i < TAB_SIZE; i++)
+ if (cities[i] != null)
+ is[j++] = i;
+
+ // Sort indices
+ Arrays.sort(is, (x, y) -> cities[x].compareTo(cities[y]));
+
+ StringBuilder sb = new StringBuilder().append("{");
+ for (int i : is) {
+ String city = cities[i];
+ int cnt = cnt(table[(i << 1)]);
+ long min_max = table[(i << 1) + 1];
+ double min = min(min_max) / 10.0;
+ double max = max(min_max) / 10.0;
+ double temp = temp(min_max) / 10.0;
+ double mean = temp / cnt;
+ sb.append(String.format("%s=%.1f/%.1f/%.1f, ", city, min, mean, max));
+ }
+ if (sb.length() > 2)
+ sb.setLength(sb.length() - 2);
+ return sb.append("}").toString();
+ }
+ }
+
+ // Debugging utilities
+ static String toHex(byte[] bs) {
+ StringBuilder sb = new StringBuilder().append("[");
+ for (byte b : bs)
+ sb.append(String.format("%02X,", b));
+ sb.setLength(sb.length() - 1);
+ return sb.append("]").toString();
+ }
+
+ static String toChar(long x) {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 8; i++) {
+ char c = (char) (x & 0xFF);
+ if (c != 0)
+ sb.append(c);
+ x >>= 8;
+ }
+ return sb.toString();
+ }
+}