aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc
diff options
context:
space:
mode:
authorThomas Wuerthinger <thomas.wuerthinger@oracle.com>2024-02-01 10:57:05 +0100
committerGitHub <noreply@github.com>2024-02-01 10:57:05 +0100
commit241d42ca6609b6bc32b403b1f4ee4d1fe6e325f8 (patch)
tree7f0cf053499e0cd880913b3d8457acfb4117e124 /src/main/java/dev/morling/onebrc
parent4debc7c5dd1b00f0dbc1822425cda727b250cad8 (diff)
One last improvement for thomaswue (#702)
* Combine <8 and 8-16 cases into one case. * Adopt mask-based approach for the <16 length city fast path (idea of Van Phu Do). * Slightly improved code layout. * Update perf number.
Diffstat (limited to 'src/main/java/dev/morling/onebrc')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java127
1 files changed, 66 insertions, 61 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java
index dc4df0c..8e311fa 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java
@@ -27,11 +27,14 @@ import java.util.concurrent.atomic.AtomicLong;
* split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread.
* Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in
* the end.
- * Runs in 0.39s on an Intel i9-13900K.
+ * Runs in 0.31 on an Intel i9-13900K while the reference implementation takes 120.37s.
* Credit:
* Quan Anh Mai for branchless number parsing code
* Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea
* Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers
+ * Jaromir Hamala for showing that avoiding the branch misprediction between <8 and 8-16 cases is a big win even if
+ * more work is performed
+ * Van Phu DO for demonstrating the lookup tables based on masks instead of bit shifting
*/
public class CalculateAverage_thomaswue {
private static final String FILE = "./measurements.txt";
@@ -141,9 +144,15 @@ public class CalculateAverage_thomaswue {
long delimiterMask1 = findDelimiter(word1);
long delimiterMask2 = findDelimiter(word2);
long delimiterMask3 = findDelimiter(word3);
- Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults);
- Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults);
- Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults);
+ long word1b = scanner1.getLongAt(scanner1.pos() + 8);
+ long word2b = scanner2.getLongAt(scanner2.pos() + 8);
+ long word3b = scanner3.getLongAt(scanner3.pos() + 8);
+ long delimiterMask1b = findDelimiter(word1b);
+ long delimiterMask2b = findDelimiter(word2b);
+ long delimiterMask3b = findDelimiter(word3b);
+ Result existingResult1 = findResult(word1, delimiterMask1, word1b, delimiterMask1b, scanner1, results, collectedResults);
+ Result existingResult2 = findResult(word2, delimiterMask2, word2b, delimiterMask2b, scanner2, results, collectedResults);
+ Result existingResult3 = findResult(word3, delimiterMask3, word3b, delimiterMask3b, scanner3, results, collectedResults);
long number1 = scanNumber(scanner1);
long number2 = scanNumber(scanner2);
long number3 = scanNumber(scanner3);
@@ -155,76 +164,70 @@ public class CalculateAverage_thomaswue {
while (scanner1.hasNext()) {
long word = scanner1.getLong();
long pos = findDelimiter(word);
- record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1));
+ long wordB = scanner1.getLongAt(scanner1.pos() + 8);
+ long posB = findDelimiter(wordB);
+ record(findResult(word, pos, wordB, posB, scanner1, results, collectedResults), scanNumber(scanner1));
}
while (scanner2.hasNext()) {
long word = scanner2.getLong();
long pos = findDelimiter(word);
- record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2));
+ long wordB = scanner2.getLongAt(scanner2.pos() + 8);
+ long posB = findDelimiter(wordB);
+ record(findResult(word, pos, wordB, posB, scanner2, results, collectedResults), scanNumber(scanner2));
}
while (scanner3.hasNext()) {
long word = scanner3.getLong();
long pos = findDelimiter(word);
- record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3));
+ long wordB = scanner3.getLongAt(scanner3.pos() + 8);
+ long posB = findDelimiter(wordB);
+ record(findResult(word, pos, wordB, posB, scanner3, results, collectedResults), scanNumber(scanner3));
}
}
}
- private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) {
+ private static final long[] MASK1 = new long[]{ 0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL, 0xFFFFFFFFFFL, 0xFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL,
+ 0xFFFFFFFFFFFFFFFFL };
+ private static final long[] MASK2 = new long[]{ 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0xFFFFFFFFFFFFFFFFL };
+
+ private static Result findResult(long initialWord, long initialDelimiterMask, long wordB, long delimiterMaskB, Scanner scanner, Result[] results,
+ List<Result> collectedResults) {
Result existingResult;
long word = initialWord;
long delimiterMask = initialDelimiterMask;
long hash;
long nameAddress = scanner.pos();
-
- // Search for ';', one long at a time. There are two common cases that a specially treated:
- // (b) the ';' is found in the first 16 bytes
- if (delimiterMask != 0) {
- // Special case for when the ';' is found in the first 8 bytes.
- int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
- word = (word << (63 - trailingZeros));
- scanner.add(trailingZeros >>> 3);
- hash = word;
+ long word2 = wordB;
+ long delimiterMask2 = delimiterMaskB;
+ if ((delimiterMask | delimiterMask2) != 0) {
+ int letterCount1 = Long.numberOfTrailingZeros(delimiterMask) >>> 3; // value between 1 and 8
+ int letterCount2 = Long.numberOfTrailingZeros(delimiterMask2) >>> 3; // value between 0 and 8
+ long mask = MASK2[letterCount1];
+ word = word & MASK1[letterCount1];
+ word2 = mask & word2 & MASK1[letterCount2];
+ hash = word ^ word2;
existingResult = results[hashToIndex(hash, results)];
- if (existingResult != null && existingResult.lastNameLong == word) {
+ scanner.add(letterCount1 + (letterCount2 & mask));
+ if (existingResult != null && existingResult.firstNameWord == word && existingResult.secondNameWord == word2) {
return existingResult;
}
}
else {
- // Special case for when the ';' is found in bytes 9-16.
- hash = word;
- long prevWord = word;
- scanner.add(8);
- word = scanner.getLong();
- delimiterMask = findDelimiter(word);
- if (delimiterMask != 0) {
- int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
- word = (word << (63 - trailingZeros));
- scanner.add(trailingZeros >>> 3);
- hash ^= word;
- existingResult = results[hashToIndex(hash, results)];
- if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) {
- return existingResult;
+ // Slow-path for when the ';' could not be found in the first 16 bytes.
+ hash = word ^ word2;
+ scanner.add(16);
+ while (true) {
+ word = scanner.getLong();
+ delimiterMask = findDelimiter(word);
+ if (delimiterMask != 0) {
+ int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
+ word = (word << (63 - trailingZeros));
+ scanner.add(trailingZeros >>> 3);
+ hash ^= word;
+ break;
}
- }
- else {
- // Slow-path for when the ';' could not be found in the first 16 bytes.
- scanner.add(8);
- hash ^= word;
- while (true) {
- word = scanner.getLong();
- delimiterMask = findDelimiter(word);
- if (delimiterMask != 0) {
- int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
- word = (word << (63 - trailingZeros));
- scanner.add(trailingZeros >>> 3);
- hash ^= word;
- break;
- }
- else {
- scanner.add(8);
- hash ^= word;
- }
+ else {
+ scanner.add(8);
+ hash ^= word;
}
}
}
@@ -249,8 +252,8 @@ public class CalculateAverage_thomaswue {
}
}
- int remainingShift = (64 - (nameLength + 1 - i) << 3);
- if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) {
+ int remainingShift = (64 - ((nameLength + 1 - i) << 3));
+ if (((scanner.getLongAt(existingResult.nameAddress + i) ^ (scanner.getLongAt(nameAddress + i))) << remainingShift) == 0) {
break;
}
else {
@@ -297,7 +300,7 @@ public class CalculateAverage_thomaswue {
}
private static int hashToIndex(long hash, Result[] results) {
- long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17);
+ long hashAsInt = hash ^ (hash >>> 33) ^ (hash >>> 15);
return (int) (hashAsInt & (results.length - 1));
}
@@ -324,21 +327,23 @@ public class CalculateAverage_thomaswue {
private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) {
Result r = new Result();
results[hash] = r;
- int i = 0;
- for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) {
+ int totalLength = nameLength + 1;
+ r.firstNameWord = scanner.getLongAt(nameAddress);
+ r.secondNameWord = scanner.getLongAt(nameAddress + 8);
+ if (totalLength <= 8) {
+ r.firstNameWord = r.firstNameWord & MASK1[totalLength - 1];
+ r.secondNameWord = 0;
}
- if (nameLength + 1 > 8) {
- r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8);
+ else if (totalLength < 16) {
+ r.secondNameWord = r.secondNameWord & MASK1[totalLength - 9];
}
- int remainingShift = (64 - (nameLength + 1 - i) << 3);
- r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift);
r.nameAddress = nameAddress;
collectedResults.add(r);
return r;
}
private static final class Result {
- long lastNameLong, secondLastNameLong;
+ long firstNameWord, secondNameWord;
short min, max;
int count;
long sum;