aboutsummaryrefslogtreecommitdiff
path: root/src/main/java
diff options
context:
space:
mode:
authorVemana <vemana.github@gmail.com>2024-01-20 02:17:55 +0530
committerGitHub <noreply@github.com>2024-01-19 21:47:55 +0100
commit6e3893c6a60ba8c514601e41f61b1d8240e2b8b5 (patch)
tree569f8775c38133d6be0e69d060e5f5100f987f83 /src/main/java
parent144a6af1645d8ae9b302463f3ad472a5b8a50d62 (diff)
Reduce variance by (1) Using common chunks at the end (2) Busy looping (#486)
on automatic closing of ByteBuffers.. previously, a straggler could hold up closing the ByteBuffers. Also - Improve Tracing code - Parametrize additional options to aid in tuning Our previous PR was surprising; parallelizing munmap() call did not yield anywhere near the performance gain I expected. Local machine had 10% gain while testing machine only showed 2% gain. I am still not clear why it happened and the two best theories I have are 1) Variance due to stragglers (that this change addresses) 2) munmap() is either too fast or too slow relative to the other instructions compared to our local machine. I don't know which. We'll have to use adaptive tuning, but that's in a different change.
Diffstat (limited to 'src/main/java')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java257
1 files changed, 169 insertions, 88 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java b/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java
index 8f690e3..3e64ac9 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_vemana.java
@@ -171,7 +171,7 @@ public class CalculateAverage_vemana {
int chunkSizeBits = 20;
// For the last commonChunkFraction fraction of total work, use smaller chunk sizes
- double commonChunkFraction = 0;
+ double commonChunkFraction = 0.03;
// Use commonChunkSizeBits for the small-chunk size
int commonChunkSizeBits = 18;
@@ -181,11 +181,17 @@ public class CalculateAverage_vemana {
int minReservedBytesAtFileTail = 9;
+ int nThreads = -1;
+
String inputFile = "measurements.txt";
+ double munmapFraction = 0.03;
+
+ boolean fakeAdvance = false;
+
for (String arg : args) {
- String key = arg.substring(0, arg.indexOf('='));
- String value = arg.substring(key.length() + 1);
+ String key = arg.substring(0, arg.indexOf('=')).trim();
+ String value = arg.substring(key.length() + 1).trim();
switch (key) {
case "chunkSizeBits":
chunkSizeBits = Integer.parseInt(value);
@@ -202,6 +208,15 @@ public class CalculateAverage_vemana {
case "inputfile":
inputFile = value;
break;
+ case "munmapFraction":
+ munmapFraction = Double.parseDouble(value);
+ break;
+ case "fakeAdvance":
+ fakeAdvance = Boolean.parseBoolean(value);
+ break;
+ case "nThreads":
+ nThreads = Integer.parseInt(value);
+ break;
default:
throw new IllegalArgumentException("Unknown argument: " + arg);
}
@@ -218,14 +233,17 @@ public class CalculateAverage_vemana {
System.out.println(
new Runner(
Path.of(inputFile),
+ nThreads,
chunkSizeBits,
commonChunkFraction,
commonChunkSizeBits,
hashtableSizeBits,
- minReservedBytesAtFileTail)
+ minReservedBytesAtFileTail,
+ munmapFraction,
+ fakeAdvance)
.getSummaryStatistics());
- Tracing.recordEvent("After printing result");
+ Tracing.recordEvent("Final result printed");
}
public record AggregateResult(Map<String, Stat> tempStats) {
@@ -286,8 +304,8 @@ public class CalculateAverage_vemana {
bufferEnd = bufferStart = -1;
}
- public void close(int shardIdx) {
- Tracing.recordWorkStart("cleaner", shardIdx);
+ public void close(String closerId, int shardIdx) {
+ Tracing.recordWorkStart(closerId, shardIdx);
if (byteBuffer != null) {
unclosedBuffers.add(byteBuffer);
}
@@ -297,7 +315,7 @@ public class CalculateAverage_vemana {
unclosedBuffers.clear();
bufferEnd = bufferStart = -1;
byteBuffer = null;
- Tracing.recordWorkEnd("cleaner", shardIdx);
+ Tracing.recordWorkEnd(closerId, shardIdx);
}
public void setRange(long rangeStart, long rangeEnd) {
@@ -383,7 +401,7 @@ public class CalculateAverage_vemana {
public interface LazyShardQueue {
- void close(int shardIdx);
+ void close(String closerId, int shardIdx);
Optional<ByteRange> fileTailEndWork(int idx);
@@ -415,37 +433,48 @@ public class CalculateAverage_vemana {
private final double commonChunkFraction;
private final int commonChunkSizeBits;
+ private final boolean fakeAdvance;
private final int hashtableSizeBits;
private final Path inputFile;
private final int minReservedBytesAtFileTail;
+ private final double munmapFraction;
+ private final int nThreads;
private final int shardSizeBits;
public Runner(
Path inputFile,
+ int nThreads,
int chunkSizeBits,
double commonChunkFraction,
int commonChunkSizeBits,
int hashtableSizeBits,
- int minReservedBytesAtFileTail) {
+ int minReservedBytesAtFileTail,
+ double munmapFraction,
+ boolean fakeAdvance) {
this.inputFile = inputFile;
+ this.nThreads = nThreads;
this.shardSizeBits = chunkSizeBits;
this.commonChunkFraction = commonChunkFraction;
this.commonChunkSizeBits = commonChunkSizeBits;
this.hashtableSizeBits = hashtableSizeBits;
this.minReservedBytesAtFileTail = minReservedBytesAtFileTail;
+ this.munmapFraction = munmapFraction;
+ this.fakeAdvance = fakeAdvance;
}
AggregateResult getSummaryStatistics() throws Exception {
- int nThreads = Runtime.getRuntime().availableProcessors();
+ int nThreads = this.nThreads < 0 ? Runtime.getRuntime().availableProcessors() : this.nThreads;
+
LazyShardQueue shardQueue = new SerialLazyShardQueue(
1L << shardSizeBits,
inputFile,
nThreads,
commonChunkFraction,
commonChunkSizeBits,
- minReservedBytesAtFileTail);
+ minReservedBytesAtFileTail,
+ munmapFraction,
+ fakeAdvance);
- List<Future<AggregateResult>> results = new ArrayList<>();
ExecutorService executorService = Executors.newFixedThreadPool(
nThreads,
runnable -> {
@@ -454,42 +483,56 @@ public class CalculateAverage_vemana {
return thread;
});
+ List<Future<AggregateResult>> results = new ArrayList<>();
for (int i = 0; i < nThreads; i++) {
final int shardIdx = i;
final Callable<AggregateResult> callable = () -> {
- Tracing.recordWorkStart("shard", shardIdx);
+ Tracing.recordWorkStart("Shard", shardIdx);
AggregateResult result = new ShardProcessor(shardQueue, hashtableSizeBits, shardIdx).processShard();
- Tracing.recordWorkEnd("shard", shardIdx);
+ Tracing.recordWorkEnd("Shard", shardIdx);
return result;
};
results.add(executorService.submit(callable));
}
Tracing.recordEvent("Basic push time");
- AggregateResult result = executorService.submit(() -> merge(results)).get();
+ // This particular sequence of Futures is so that both merge and munmap() can work as shards
+ // finish their computation without blocking on the entire set of shards to complete. In
+ // particular, munmap() doesn't need to wait on merge.
+ // First, submit a task to merge the results and then submit a task to cleanup bytebuffers
+ // from completed shards.
+ Future<AggregateResult> resultFutures = executorService.submit(() -> merge(results));
+ // Note that munmap() is serial and not parallel and hence we use just one thread.
+ executorService.submit(() -> closeByteBuffers(results, shardQueue));
+ AggregateResult result = resultFutures.get();
Tracing.recordEvent("Merge results received");
- // Note that munmap() is serial and not parallel
- executorService.submit(
- () -> {
- for (int i = 0; i < nThreads; i++) {
- shardQueue.close(i);
- }
- });
-
- Tracing.recordEvent("Waiting for executor shutdown");
-
+ Tracing.recordEvent("About to shutdown executor and wait");
executorService.shutdown();
executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
-
Tracing.recordEvent("Executor terminated");
- Tracing.analyzeWorkThreads("cleaner", nThreads);
- Tracing.recordEvent("After cleaner finish printed");
+ Tracing.analyzeWorkThreads(nThreads);
return result;
}
+ private void closeByteBuffers(
+ List<Future<AggregateResult>> results, LazyShardQueue shardQueue) {
+ int n = results.size();
+ boolean[] isDone = new boolean[n];
+ int remaining = results.size();
+ while (remaining > 0) {
+ for (int i = 0; i < n; i++) {
+ if (!isDone[i] && results.get(i).isDone()) {
+ remaining--;
+ isDone[i] = true;
+ shardQueue.close("Ending Cleaner", i);
+ }
+ }
+ }
+ }
+
private AggregateResult merge(List<Future<AggregateResult>> results)
throws ExecutionException, InterruptedException {
Tracing.recordEvent("Merge start time");
@@ -516,7 +559,6 @@ public class CalculateAverage_vemana {
}
}
Tracing.recordEvent("Merge end time");
- Tracing.analyzeWorkThreads("shard", results.size());
return new AggregateResult(output);
}
}
@@ -532,6 +574,7 @@ public class CalculateAverage_vemana {
private final long commonChunkSize;
private final AtomicLong commonPool;
private final long effectiveFileSize;
+ private final boolean fakeAdvance;
private final long fileSize;
private final long[] perThreadData;
private final RandomAccessFile raf;
@@ -543,8 +586,11 @@ public class CalculateAverage_vemana {
int shards,
double commonChunkFraction,
int commonChunkSizeBits,
- int fileTailReservedBytes)
+ int fileTailReservedBytes,
+ double munmapFraction,
+ boolean fakeAdvance)
throws IOException {
+ this.fakeAdvance = fakeAdvance;
Checks.checkArg(commonChunkFraction < 0.9 && commonChunkFraction >= 0);
Checks.checkArg(fileTailReservedBytes >= 0);
this.raf = new RandomAccessFile(filePath.toFile(), "r");
@@ -580,8 +626,8 @@ public class CalculateAverage_vemana {
// its work, where R = relative speed of unmap compared to the computation.
// For our problem, R ~ 75 because unmap unmaps 30GB/sec (but, it is serial) while
// cores go through data at the rate of 400MB/sec.
- perThreadData[pos + 3] = (long) (currentChunks * (0.03 * (shards - i)));
- perThreadData[pos + 4] = 1;
+ perThreadData[pos + 3] = (long) (currentChunks * (munmapFraction * (shards - i)));
+ perThreadData[pos + 4] = 1; // true iff munmap() hasn't been triggered yet
currentStart += currentChunks * chunkSize;
remainingChunks -= currentChunks;
}
@@ -596,8 +642,8 @@ public class CalculateAverage_vemana {
}
@Override
- public void close(int shardIdx) {
- byteRanges[shardIdx << 4].close(shardIdx);
+ public void close(String closerId, int shardIdx) {
+ byteRanges[shardIdx << 4].close(closerId, shardIdx);
}
@Override
@@ -616,14 +662,18 @@ public class CalculateAverage_vemana {
public ByteRange take(int shardIdx) {
// Try for thread local range
final int pos = shardIdx << 4;
- long rangeStart = perThreadData[pos];
- final long chunkEnd = perThreadData[pos + 1];
+ final long rangeStart;
final long rangeEnd;
- if (rangeStart < chunkEnd) {
+ if (perThreadData[pos + 2] >= 1) {
+ rangeStart = perThreadData[pos];
rangeEnd = rangeStart + chunkSize;
- perThreadData[pos] = rangeEnd;
+ // Don't do this in the if-check; it causes negative values that trigger intermediate
+ // cleanup
perThreadData[pos + 2]--;
+ if (!fakeAdvance) {
+ perThreadData[pos] = rangeEnd;
+ }
}
else {
rangeStart = commonPool.getAndAdd(commonChunkSize);
@@ -634,8 +684,8 @@ public class CalculateAverage_vemana {
rangeEnd = rangeStart + commonChunkSize;
}
- if (perThreadData[pos + 2] <= perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
- if (attemptClose(shardIdx)) {
+ if (perThreadData[pos + 2] < perThreadData[pos + 3] && perThreadData[pos + 4] > 0) {
+ if (attemptIntermediateClose(shardIdx)) {
perThreadData[pos + 4]--;
}
}
@@ -645,9 +695,9 @@ public class CalculateAverage_vemana {
return chunk;
}
- private boolean attemptClose(int shardIdx) {
+ private boolean attemptIntermediateClose(int shardIdx) {
if (seqLock.acquire()) {
- byteRanges[shardIdx << 4].close(shardIdx);
+ close("Intermediate Cleaner", shardIdx);
seqLock.release();
return true;
}
@@ -964,12 +1014,22 @@ public class CalculateAverage_vemana {
static class Tracing {
- private static final long[] cleanerTimes = new long[1 << 6 << 1];
- private static final long[] threadTimes = new long[1 << 6 << 1];
+ private static final Map<String, ThreadTimingsArray> knownWorkThreadEvents;
private static long startTime;
- static void analyzeWorkThreads(String id, int nThreads) {
- printTimingsAnalysis(id + " Stats", nThreads, timingsArray(id));
+ static {
+ // Maintain the ordering to be chronological in execution
+ // Map.of(..) screws up ordering
+ knownWorkThreadEvents = new LinkedHashMap<>();
+ for (String id : List.of("Shard", "Intermediate Cleaner", "Ending Cleaner")) {
+ knownWorkThreadEvents.put(id, new ThreadTimingsArray(id, 1 << 6 << 1));
+ }
+ }
+
+ static void analyzeWorkThreads(int nThreads) {
+ for (ThreadTimingsArray array : knownWorkThreadEvents.values()) {
+ errPrint(array.analyze(nThreads));
+ }
}
static void recordAppStart() {
@@ -981,11 +1041,11 @@ public class CalculateAverage_vemana {
}
static void recordWorkEnd(String id, int threadId) {
- timingsArray(id)[2 * threadId + 1] = System.nanoTime();
+ knownWorkThreadEvents.get(id).recordEnd(threadId);
}
static void recordWorkStart(String id, int threadId) {
- timingsArray(id)[2 * threadId] = System.nanoTime();
+ knownWorkThreadEvents.get(id).recordStart(threadId);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -998,57 +1058,78 @@ public class CalculateAverage_vemana {
errPrint(STR."\{message} = \{(nanoTime - startTime) / 1_000_000}ms");
}
- private static void printTimingsAnalysis(String header, int nThreads, long[] timestamps) {
- long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
- long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
- long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
+ public static class ThreadTimingsArray {
- long[] durationsMs = new long[nThreads];
- long[] completionsMs = new long[nThreads];
- long[] beginMs = new long[nThreads];
- for (int i = 0; i < nThreads; i++) {
- long durationNs = timestamps[2 * i + 1] - timestamps[2 * i];
- durationsMs[i] = durationNs / 1_000_000;
- completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000;
- beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000;
+ private static String toString(long[] array) {
+ return Arrays.stream(array)
+ .map(x -> x < 0 ? -1 : x)
+ .mapToObj(x -> String.format("%6d", x))
+ .collect(Collectors.joining(", ", "[ ", " ]"));
+ }
- minDuration = Math.min(minDuration, durationNs);
- maxDuration = Math.max(maxDuration, durationNs);
+ private final String id;
+ private final long[] timestamps;
+ private boolean hasData = false;
- minBegin = Math.min(minBegin, timestamps[2 * i]);
- maxBegin = Math.max(maxBegin, timestamps[2 * i]);
+ public ThreadTimingsArray(String id, int maxSize) {
+ this.timestamps = new long[maxSize];
+ this.id = id;
+ }
- maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1]);
- minCompletion = Math.min(minCompletion, timestamps[2 * i + 1]);
- }
- errPrint(
- STR."""
+ public String analyze(int nThreads) {
+ if (!hasData) {
+ return "%s has no thread timings data".formatted(id);
+ }
+ Checks.checkArg(nThreads <= timestamps.length);
+ long minDuration = Long.MAX_VALUE, maxDuration = Long.MIN_VALUE;
+ long minBegin = Long.MAX_VALUE, maxCompletion = Long.MIN_VALUE;
+ long maxBegin = Long.MIN_VALUE, minCompletion = Long.MAX_VALUE;
+
+ long[] durationsMs = new long[nThreads];
+ long[] completionsMs = new long[nThreads];
+ long[] beginMs = new long[nThreads];
+ for (int i = 0; i < nThreads; i++) {
+ long durationNs = timestamps[2 * i + 1] - timestamps[2 * i];
+ durationsMs[i] = durationNs / 1_000_000;
+ completionsMs[i] = (timestamps[2 * i + 1] - startTime) / 1_000_000;
+ beginMs[i] = (timestamps[2 * i] - startTime) / 1_000_000;
+
+ minDuration = Math.min(minDuration, durationNs);
+ maxDuration = Math.max(maxDuration, durationNs);
+
+ minBegin = Math.min(minBegin, timestamps[2 * i] - startTime);
+ maxBegin = Math.max(maxBegin, timestamps[2 * i] - startTime);
+
+ maxCompletion = Math.max(maxCompletion, timestamps[2 * i + 1] - startTime);
+ minCompletion = Math.min(minCompletion, timestamps[2 * i + 1] - startTime);
+ }
+ return STR."""
-------------------------------------------------------------------------------------------
- \{header}
+ \{id} Stats
-------------------------------------------------------------------------------------------
Max duration = \{maxDuration / 1_000_000} ms
Min duration = \{minDuration / 1_000_000} ms
- Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms
+ Timespan[max(end)-min(start)] = \{(maxCompletion - minBegin) / 1_000_000} ms [\{maxCompletion / 1_000_000} - \{minBegin / 1_000_000} ]
Completion Timespan[max(end)-min(end)] = \{(maxCompletion - minCompletion) / 1_000_000} ms
Begin Timespan[max(begin)-min(begin)] = \{(maxBegin - minBegin) / 1_000_000} ms
- Durations = \{toString(durationsMs)} in ms
- Begin Timestamps = \{toString(beginMs)} in ms
- Completion Timestamps = \{toString(completionsMs)} in ms
- """);
- }
+ Average Duration = \{Arrays.stream(durationsMs)
+ .average()
+ .getAsDouble()} ms
+ Durations = \{toString(durationsMs)} ms
+ Begin Timestamps = \{toString(beginMs)} ms
+ Completion Timestamps = \{toString(completionsMs)} ms
+ """;
+ }
- private static long[] timingsArray(String id) {
- return switch (id) {
- case "cleaner" -> cleanerTimes;
- case "shard" -> threadTimes;
- default -> throw new RuntimeException("");
- };
- }
+ public void recordEnd(int idx) {
+ timestamps[2 * idx + 1] = System.nanoTime();
+ hasData = true;
+ }
- private static String toString(long[] array) {
- return Arrays.stream(array)
- .mapToObj(x -> String.format("%6d", x))
- .collect(Collectors.joining(", ", "[ ", " ]"));
+ public void recordStart(int idx) {
+ timestamps[2 * idx] = System.nanoTime();
+ hasData = true;
+ }
}
}
}