1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
|
/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import sun.misc.Unsafe;
/**
* Changelog:
*
* Initial submission: 62000 ms
* Chunked reader: 16000 ms
* Optimized parser: 13000 ms
* Branchless methods: 11000 ms
* Adding memory mapped files: 6500 ms (based on bjhara's submission)
* Skipping string creation: 4700 ms
* Custom hashmap... 4200 ms
* Added SWAR token checks: 3900 ms
* Skipped String creation: 3500 ms (idea from kgonia)
* Improved String skip: 3250 ms
* Segmenting files: 3150 ms (based on spullara's code)
* Not using SWAR for EOL: 2850 ms
* Inlining hash calculation: 2450 ms
* Replacing branchless code: 2200 ms (sometimes we need to kill the things we love)
* Added unsafe memory access: 1900 ms (keeping the long[] small and local)
* Fixed bug, UNSAFE bytes String: 1850 ms
* Separate hash from entries: 1550 ms
* Various tweaks for Linux/cache 1550 ms (should/could make a difference on target machine)
* Improved layout/predictability: 1400 ms
* Delayed String creation again: 1350 ms
* Remove writing to buffer: 1335 ms
* Optimized collecting at the end: 1310 ms
* Adding a lot of comments: priceless
* Changed to flyweight byte[]: 1290 ms (adds even more Unsafe, was initially slower, now faster)
* More LOC now parallel: 1260 ms (moved more to processMemoryArea, recombining in ConcurrentHashMap)
* Storing only the address: 1240 ms (this is now faster, tried before, was slower)
* Unrolling scan-loop: 1200 ms (seems to help, perhaps even more on target machine)
* Adding more readable reader: 1300 ms (scores got worse on target machine anyway)
*
* Using old x86 MacBook and perf: 3500 ms (different machine for testing)
* Decided to rewrite loop for 16 b: 3050 ms
* Small changes, limited heap: 2950 ms
*
* I have some instructions that could be removed, but faster with...
*
* Big thanks to Francesco Nigro, Thomas Wuerthinger, Quan Anh Mai and many others for ideas.
*
* Follow me at: @royvanrijn
*/
public class CalculateAverage_royvanrijn {
private static final String FILE = "./measurements.txt";
// private static final String FILE = "src/test/resources/samples/measurements-1.txt";
private static final Unsafe UNSAFE = initUnsafe();
// Twice the processors, smoothens things out.
private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();
/**
* Flyweight entry in a byte[], max 128 bytes.
* <p>
* long: sum
* int: min
* int: max
* int: count
* byte: length
* byte[]: cityname
*/
// ------------------------------------------------------------------------
private static final int ENTRY_LENGTH = (Unsafe.ARRAY_BYTE_BASE_OFFSET);
private static final int ENTRY_SUM = (ENTRY_LENGTH + Byte.BYTES);
private static final int ENTRY_MIN = (ENTRY_SUM + Long.BYTES);
private static final int ENTRY_MAX = (ENTRY_MIN + Integer.BYTES);
private static final int ENTRY_COUNT = (ENTRY_MAX + Integer.BYTES);
private static final int ENTRY_NAME = (ENTRY_COUNT + Integer.BYTES);
private static final int ENTRY_NAME_8 = ENTRY_NAME + 8;
private static final int ENTRY_NAME_16 = ENTRY_NAME + 16;
private static final int ENTRY_BASESIZE_WHITESPACE = ENTRY_NAME + 7; // with enough empty bytes to fill a long
// ------------------------------------------------------------------------
private static final int PREMADE_MAX_SIZE = 1 << 5; // pre-initialize some entries in memory, keep them close
private static final int PREMADE_ENTRIES = 512; // amount of pre-created entries we should use
private static final int TABLE_SIZE = 1 << 19; // large enough for the contest.
private static final int TABLE_MASK = (TABLE_SIZE - 1);
// Idea of thomaswue, don't wait for slow unmap:
private static void spawnWorker() throws IOException {
ProcessHandle.Info info = ProcessHandle.current().info();
ArrayList<String> workerCommand = new ArrayList<>();
info.command().ifPresent(workerCommand::add);
info.arguments().ifPresent(args -> workerCommand.addAll(Arrays.asList(args)));
workerCommand.add("--worker");
new ProcessBuilder()
.command(workerCommand)
.inheritIO()
.redirectOutput(ProcessBuilder.Redirect.PIPE)
.start()
.getInputStream()
.transferTo(System.out);
}
public static void main(String[] args) throws Exception {
if (args.length == 0 || !("--worker".equals(args[0]))) {
spawnWorker();
return;
}
// Calculate input segments.
final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
final long fileSize = fileChannel.size();
final long segmentSize = (fileSize + PROCESSORS - 1) / PROCESSORS;
final long mapAddress = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
final Thread[] parallelThreads = new Thread[PROCESSORS - 1];
// This is where the entries will land:
final ConcurrentHashMap<String, byte[]> measurements = new ConcurrentHashMap(1 << 10);
// We create separate threads for twice the amount of processors.
long lastAddress = mapAddress;
final long endOfFile = mapAddress + fileSize;
for (int i = 0; i < PROCESSORS - 1; ++i) {
final long fromAddress = lastAddress;
final long toAddress = Math.min(endOfFile, fromAddress + segmentSize);
final Thread thread = new Thread(() -> {
// The actual work is done here:
final byte[][] table = processMemoryArea(fromAddress, toAddress, fromAddress == mapAddress);
for (byte[] entry : table) {
if (entry != null) {
measurements.merge(entryToName(entry), entry, CalculateAverage_royvanrijn::mergeEntry);
}
}
});
thread.start(); // start a.s.a.p.
parallelThreads[i] = thread;
lastAddress = toAddress;
}
// Use the current thread for the part of memory:
final byte[][] table = processMemoryArea(lastAddress, mapAddress + fileSize, false);
for (byte[] entry : table) {
if (entry != null) {
measurements.merge(entryToName(entry), entry, CalculateAverage_royvanrijn::mergeEntry);
}
}
// Wait for all threads to finish:
for (Thread thread : parallelThreads) {
// Can we implement work-stealing? Not sure how...
thread.join();
}
// If we don't reach start of file,
System.out.print("{" +
measurements.entrySet().stream().sorted(Map.Entry.comparingByKey())
.map(entry -> entry.getKey() + '=' + entryValuesToString(entry.getValue()))
.collect(Collectors.joining(", ")));
System.out.println("}");
System.out.close(); // close the stream to stop
}
private static byte[] fillEntry(final byte[] entry, final long fromAddress, final int entryLength, final int temp, final long readBuffer1, final long readBuffer2) {
UNSAFE.putLong(entry, ENTRY_SUM, temp);
UNSAFE.putInt(entry, ENTRY_MIN, temp);
UNSAFE.putInt(entry, ENTRY_MAX, temp);
UNSAFE.putInt(entry, ENTRY_COUNT, 1);
UNSAFE.putByte(entry, ENTRY_LENGTH, (byte) entryLength);
UNSAFE.copyMemory(null, fromAddress, entry, ENTRY_NAME, entryLength - 16);
UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 16, readBuffer1);
UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 8, readBuffer2);
return entry;
}
private static byte[] fillEntry16(final byte[] entry, final int entryLength, final int temp, final long readBuffer1, final long readBuffer2) {
UNSAFE.putLong(entry, ENTRY_SUM, temp);
UNSAFE.putInt(entry, ENTRY_MIN, temp);
UNSAFE.putInt(entry, ENTRY_MAX, temp);
UNSAFE.putInt(entry, ENTRY_COUNT, 1);
UNSAFE.putByte(entry, ENTRY_LENGTH, (byte) entryLength);
UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 16, readBuffer1);
UNSAFE.putLong(entry, ENTRY_NAME + entryLength - 8, readBuffer2);
return entry;
}
public static void updateEntry(final byte[] entry, final int temp) {
int entryMin = UNSAFE.getInt(entry, ENTRY_MIN);
int entryMax = UNSAFE.getInt(entry, ENTRY_MAX);
long entrySum = UNSAFE.getLong(entry, ENTRY_SUM) + temp;
int entryCount = UNSAFE.getInt(entry, ENTRY_COUNT) + 1;
if (temp < entryMin) {
UNSAFE.putInt(entry, ENTRY_MIN, temp);
}
else if (temp > entryMax) {
UNSAFE.putInt(entry, ENTRY_MAX, temp);
}
UNSAFE.putInt(entry, ENTRY_COUNT, entryCount);
UNSAFE.putLong(entry, ENTRY_SUM, entrySum);
}
public static byte[] mergeEntry(final byte[] entry, final byte[] merge) {
long sum = UNSAFE.getLong(merge, ENTRY_SUM);
final int mergeMin = UNSAFE.getInt(merge, ENTRY_MIN);
final int mergeMax = UNSAFE.getInt(merge, ENTRY_MAX);
int count = UNSAFE.getInt(merge, ENTRY_COUNT);
sum += UNSAFE.getLong(entry, ENTRY_SUM);
count += UNSAFE.getInt(entry, ENTRY_COUNT);
int entryMin = UNSAFE.getInt(entry, ENTRY_MIN);
int entryMax = UNSAFE.getInt(entry, ENTRY_MAX);
entryMin = Math.min(entryMin, mergeMin);
entryMax = Math.max(entryMax, mergeMax);
UNSAFE.putInt(entry, ENTRY_MIN, entryMin);
UNSAFE.putInt(entry, ENTRY_MAX, entryMax);
UNSAFE.putLong(entry, ENTRY_SUM, sum);
UNSAFE.putInt(entry, ENTRY_COUNT, count);
return entry;
}
private static String entryToName(final byte[] entry) {
// Get the length from memory:
int length = UNSAFE.getByte(entry, ENTRY_LENGTH);
byte[] name = new byte[length];
UNSAFE.copyMemory(entry, ENTRY_NAME, name, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
// Create a new String with the existing byte[]:
return new String(name, StandardCharsets.UTF_8).trim();
}
private static String entryValuesToString(final byte[] entry) {
return (round(UNSAFE.getInt(entry, ENTRY_MIN))
+ "/" +
round((1.0 * UNSAFE.getLong(entry, ENTRY_SUM)) /
UNSAFE.getInt(entry, ENTRY_COUNT))
+ "/" +
round(UNSAFE.getInt(entry, ENTRY_MAX)));
}
// Print a piece of memory:
// For debug.
private static String printMemory(final Object target, final long address, int length) {
String result = "";
for (int i = 0; i < length; i++) {
result += (char) UNSAFE.getByte(target, address + i);
}
return result;
}
// Print a piece of memory:
// For debug.
private static String printMemory(final long value, int length) {
String result = "";
for (int i = 0; i < length; i++) {
result += (char) ((value >> (i << 3)) & 0xFF);
}
return result;
}
private static double round(final double value) {
return Math.round(value) / 10.0;
}
private static final class Reader {
private long ptr;
private long readBuffer1;
private long readBuffer2;
private long hash;
private long entryStart;
private int entryLength; // in bytes rounded to nearest 16
private final long endAddress;
Reader(final long startAddress, final long endAddress, final boolean isFileStart) {
this.ptr = startAddress;
this.endAddress = endAddress;
// Adjust start to next delimiter:
if (!isFileStart) {
ptr--;
while (ptr < endAddress) {
if (UNSAFE.getByte(ptr++) == '\n') {
break;
}
}
}
}
private void processStart() {
hash = 0;
entryStart = ptr;
entryLength = 0;
}
private boolean hasNext() {
return (ptr < endAddress);
}
private static final long DELIMITER_MASK = 0x3B3B3B3B3B3B3B3BL;
private boolean readNext() {
long lastRead = UNSAFE.getLong(ptr);
entryLength += 16;
// Find delimiter and create mask for long1
long comparisonResult1 = (lastRead ^ DELIMITER_MASK);
long highBitMask1 = (comparisonResult1 - 0x0101010101010101L) & (~comparisonResult1 & 0x8080808080808080L);
boolean noContent1 = highBitMask1 == 0;
long mask1 = noContent1 ? 0 : ~((highBitMask1 >>> 7) - 1);
int position1 = noContent1 ? 0 : 1 + (Long.numberOfTrailingZeros(highBitMask1) >> 3);
readBuffer1 = lastRead & ~mask1;
hash ^= readBuffer1;
int delimiter1 = position1 == 0 ? 0 : position1; // not nnecessary, but faster?
if (delimiter1 != 0) {
hash ^= hash >> 32;
readBuffer2 = 0;
ptr += delimiter1;
return false;
}
lastRead = UNSAFE.getLong(ptr + 8);
// Repeat for long2
long comparisonResult2 = (lastRead ^ DELIMITER_MASK);
long highBitMask2 = (comparisonResult2 - 0x0101010101010101L) & (~comparisonResult2 & 0x8080808080808080L);
boolean noContent2 = highBitMask2 == 0;
long mask2 = noContent2 ? 0 : ~((highBitMask2 >>> 7) - 1);
int position2 = noContent2 ? 0 : 1 + (Long.numberOfTrailingZeros(highBitMask2) >> 3);
// Apply masks
readBuffer2 = lastRead & ~mask2;
hash ^= readBuffer2;
int delimiter2 = position2 == 0 ? 0 : position2 + 8; // not necessary, but faster?
hash ^= hash >> 32;
if (delimiter2 != 0) {
ptr += delimiter2;
return false;
}
ptr += 16;
return true;
}
private int processEndAndGetTemperature() {
finalizeHash();
return readTemperature();
}
private void finalizeHash() {
hash ^= hash >> 17; // extra entropy
}
private static final long DOT_BITS = 0x10101000;
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);
// Awesome idea of merykitty:
private int readTemperature() {
// This is the number part: X.X, -X.X, XX.x or -XX.X
final long numberBytes = UNSAFE.getLong(ptr);
final long invNumberBytes = ~numberBytes;
final int dotPosition = Long.numberOfTrailingZeros(invNumberBytes & DOT_BITS);
// Calculates the sign
final long signed = (invNumberBytes << 59) >> 63;
final int min28 = (dotPosition ^ 0b11100);
final long minusFilter = ~(signed & 0xFF);
// Use the pre-calculated decimal position to adjust the values
final long digits = ((numberBytes & minusFilter) << min28) & 0x0F000F0F00L;
// Update the pointer here, bit awkward, but we have all the data
ptr += (dotPosition >> 3) + 3;
// Multiply by a magic (100 * 0x1000000 + 10 * 0x10000 + 1), to get the result
final long absValue = ((digits * MAGIC_MULTIPLIER) >>> 32) & 0x3FF;
// And perform abs()
return (int) ((absValue + signed) ^ signed); // non-patented method of doing the same trick
}
private boolean matches(final byte[] entry) {
int step = 0;
for (; step < entryLength - 16;) {
if (compare(null, entryStart + step, entry, ENTRY_NAME + step)) {
return false;
}
step += 8;
}
if (compare(readBuffer1, entry, ENTRY_NAME + step)) {
return false;
}
step += 8;
if (compare(readBuffer2, entry, ENTRY_NAME + step)) {
return false;
}
return true;
}
private boolean matches16(final byte[] entry) {
if (compare(readBuffer1, entry, ENTRY_NAME)) {
return false;
}
if (compare(readBuffer2, entry, ENTRY_NAME + 8)) {
return false;
}
return true;
}
}
private static byte[][] processMemoryArea(final long startAddress, final long endAddress, boolean isFileStart) {
final byte[][] table = new byte[TABLE_SIZE][];
final byte[][] preConstructedEntries = new byte[PREMADE_ENTRIES][ENTRY_BASESIZE_WHITESPACE + PREMADE_MAX_SIZE];
final Reader reader = new Reader(startAddress, endAddress, isFileStart);
byte[] entry;
int entryCount = 0;
// Find the correct starting position
while (reader.hasNext()) {
reader.processStart();
if (!reader.readNext()) {
// First 16 bytes:
int temperature = reader.processEndAndGetTemperature();
// Find or insert the entry:
int index = (int) (reader.hash & TABLE_MASK);
while (true) {
entry = table[index];
if (entry == null) {
byte[] entryBytes = (entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++]
: new byte[ENTRY_BASESIZE_WHITESPACE + 16]; // with enough room
table[index] = fillEntry16(entryBytes, 16, temperature, reader.readBuffer1, reader.readBuffer2);
break;
}
else if (reader.matches16(entry)) {
updateEntry(entry, temperature);
break;
}
else {
// Move to the next index
index = (index + 1) & TABLE_MASK;
}
}
continue;
}
while (reader.readNext())
;
int temperature = reader.processEndAndGetTemperature();
// Find or insert the entry:
int index = (int) (reader.hash & TABLE_MASK);
while (true) {
entry = table[index];
if (entry == null) {
int length = reader.entryLength;
byte[] entryBytes = (length < PREMADE_MAX_SIZE && entryCount < PREMADE_ENTRIES) ? preConstructedEntries[entryCount++]
: new byte[ENTRY_BASESIZE_WHITESPACE + length]; // with enough room
table[index] = fillEntry(entryBytes, reader.entryStart, length, temperature, reader.readBuffer1, reader.readBuffer2);
break;
}
else if (reader.matches(entry)) {
updateEntry(entry, temperature);
break;
}
else {
// Move to the next index
index = (index + 1) & TABLE_MASK;
}
}
}
return table;
}
private static boolean compare(final Object object1, final long address1, final Object object2, final long address2) {
return UNSAFE.getLong(object1, address1) != UNSAFE.getLong(object2, address2);
}
private static boolean compare(final long value1, final Object object2, final long address2) {
return value1 != UNSAFE.getLong(object2, address2);
}
/*
* `___` ___ ___ _ ___` ` ___ ` _ ` _ ` _` ___
* / ` \| _ \ __| \| \ \ / /_\ | | | | | | __|
* | () | _ / __|| . |\ V / _ \| |_| |_| | ._|
* \___/|_| |___|_|\_| \_/_/ \_\___|\___/|___|
* ---------------- BETTER SOFTWARE, FASTER --
*
* https://www.openvalue.eu/
*/
private static Unsafe initUnsafe() {
try {
final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
|