aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorCharlie Evans <charlibot@protonmail.com>2024-01-14 13:34:08 +0000
committerGitHub <noreply@github.com>2024-01-14 14:34:08 +0100
commit695760b31b7b5f2a8542cd641eb1c9a389980a42 (patch)
tree038f677d4b833bf4cad063cda0ad1fca5e457d91 /src/main
parente4f0891d2dddff9461945cc83fe36b36c26dba4a (diff)
Charlibot - use memory mapping (#372)
* add memory map approach * cleanup
Diffstat (limited to 'src/main')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_charlibot.java294
1 files changed, 108 insertions, 186 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_charlibot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_charlibot.java
index f5de00e..f71535e 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_charlibot.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_charlibot.java
@@ -15,9 +15,14 @@
*/
package dev.morling.onebrc;
-import java.io.*;
+import sun.misc.Unsafe;
+
+import java.lang.foreign.Arena;
+import java.lang.reflect.Field;
+import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
import java.util.*;
import java.util.concurrent.*;
import java.util.stream.Collectors;
@@ -26,12 +31,23 @@ public class CalculateAverage_charlibot {
private static final String FILE = "./measurements.txt";
- private static final int BUFFER_SIZE = 1024 * 1024 * 10;
+ private static final Unsafe UNSAFE = initUnsafe();
+
+ private static Unsafe initUnsafe() {
+ try {
+ final Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ theUnsafe.setAccessible(true);
+ return (Unsafe) theUnsafe.get(Unsafe.class);
+ }
+ catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
private static final int MAP_CAPACITY = 16384; // Need at least 10,000 so 2^14 = 16384. Might need 2^15 = 32768.
public static void main(String[] args) throws Exception {
- multiThreadedReadingDoItAll();
+ memoryMap();
}
// Copied from Roy van Rijn's code
@@ -75,111 +91,74 @@ public class CalculateAverage_charlibot {
}
}
- static int hashArraySlice(byte[] array, int offset, int length) {
- int hashcode = 0;
- for (int i = offset; i < offset + length; i++) {
- hashcode = 31 * hashcode + array[i];
- }
- // Not sure the below actually helps much?
- // hashcode = hashcode >>> 16; // Do the same trick as-in hashmap since we're using power of 2
- return hashcode;
- }
+ static class MeasurementMap3 {
- static class MeasurementMap {
+ final Measurement[] measurements;
+ final byte[][] cities;
- final int[][] map;
final int capacity = MAP_CAPACITY;
- final int numIntsToStoreCity = 25; // stores up to 100 characters.
- int minPos = numIntsToStoreCity;
- int maxPos = numIntsToStoreCity + 1;
- int sumPos = numIntsToStoreCity + 2;
- int countPos = numIntsToStoreCity + 3;
-
- MeasurementMap() {
- map = new int[capacity][numIntsToStoreCity + 4]; // length of string and then the city encoded cast bytes to int. then min, max, sum, count,
+ MeasurementMap3() {
+ measurements = new Measurement[capacity];
+ cities = new byte[capacity][128]; // 100 bytes for the city. Round up to nearest power of 2.
}
- public void insert(byte[] array, int offset, int length, int value) {
- int hashcode = hashArraySlice(array, offset, length);
+ public void insert(long fromAddress, long toAddress, int hashcode, int value) {
int index = hashcode & (capacity - 1); // same trick as in hashmap. This is the same as (% capacity).
- tryInsert(index, array, offset, length, value);
+ tryInsert(index, fromAddress, toAddress, value);
}
- private void tryInsert(int mapIndex, byte[] array, int offset, int length, int value) {
+ private void tryInsert(int mapIndex, long fromAddress, long toAddress, int value) {
+ byte length = (byte) (toAddress - fromAddress);
outer: while (true) {
- int[] jas = map[mapIndex];
- if (jas[0] == 0) {
- // just insert
- int i = 0;
- int jasIndex = -1;
- while (i < length) {
- byte b = array[i + offset];
- // i & 3 is the same as i % 4
- if ((i & 3) == 0) { // when at i=0,4,8,12 then
- jasIndex++;
+ byte[] cityArray = cities[mapIndex];
+ Measurement jas = measurements[mapIndex];
+ if (jas != null) {
+ if (cityArray[0] == length) {
+ int i = 0;
+ while (i < length) {
+ byte b = UNSAFE.getByte(fromAddress + i);
+ if (b != cityArray[i + 1]) {
+ mapIndex = (mapIndex + 1) & (capacity - 1);
+ continue outer;
+ }
+ i++;
}
- jas[jasIndex] = jas[jasIndex] | ((b & 0xFF) << (8 * (i & 3)));
- i++;
+ jas.min = min(value, jas.min);
+ jas.max = max(value, jas.max);
+ jas.sum += value;
+ jas.count += 1;
+ break;
+ }
+ else {
+ mapIndex = (mapIndex + 1) & (capacity - 1);
}
- jas[minPos] = value;
- jas[maxPos] = value;
- jas[sumPos] = value;
- jas[countPos] = 1;
- break;
}
else {
+ // just insert
int i = 0;
- int jasIndex = -1;
+ cityArray[0] = length;
while (i < length) {
- byte b = array[i + offset];
- if ((i & 3) == 0) { // when at i=0,4,8,12,... then
- jasIndex++;
- }
- byte inJas = (byte) (jas[jasIndex] >>> (8 * (i & 3)));
- if (b != inJas) {
- mapIndex = (mapIndex + 1) & (capacity - 1);
- continue outer;
- }
+ byte b = UNSAFE.getByte(fromAddress + i);
+ cityArray[i + 1] = b;
i++;
}
- jas[minPos] = min(value, jas[minPos]);
- jas[maxPos] = max(value, jas[maxPos]);
- jas[sumPos] += value;
- jas[countPos] += 1;
+ measurements[mapIndex] = new Measurement(value);
break;
+
}
}
}
public HashMap<String, Measurement> toMap() {
HashMap<String, Measurement> hashMap = new HashMap<>();
- for (int[] jas : map) {
- if (jas[0] != 0) {
- int jasIndex = 0;
- byte[] array = new byte[numIntsToStoreCity * 4];
- while (jasIndex < numIntsToStoreCity) {
- int tmp = jas[jasIndex];
- array[jasIndex * 4] = (byte) tmp;
- array[jasIndex * 4 + 1] = (byte) (tmp >>> 8);
- array[jasIndex * 4 + 2] = (byte) (tmp >>> 16);
- array[jasIndex * 4 + 3] = (byte) (tmp >>> 24);
- jasIndex++;
- }
- int length = array.length;
- for (int i = 0; i < array.length; i++) {
- if (array[i] == 0) {
- length = i;
- break;
- }
- }
- String city = new String(array, 0, length, StandardCharsets.UTF_8);
- Measurement m = new Measurement(0);
- m.min = jas[minPos];
- m.max = jas[maxPos];
- m.sum = jas[sumPos];
- m.count = jas[countPos];
- hashMap.put(city, m);
+ for (int mapIndex = 0; mapIndex < cities.length; mapIndex++) {
+ byte[] cityArray = cities[mapIndex];
+ Measurement measurement = measurements[mapIndex];
+ if (measurement != null) {
+ int length = cityArray[0];
+ String city = new String(cityArray, 1, length, StandardCharsets.UTF_8);
+ hashMap.put(city, measurement);
}
}
return hashMap;
@@ -190,124 +169,68 @@ public class CalculateAverage_charlibot {
}
}
- public static void multiThreadedReadingDoItAll() throws Exception {
- File file = Path.of(FILE).toFile();
- long length = file.length();
- int numProcessors = Runtime.getRuntime().availableProcessors();
- long chunkToRead = length / numProcessors;
-
- // make life easier by spending a bit of time up front to find line breaks around the chunks
- final long[] startPositions = new long[numProcessors + 1];
- try (RandomAccessFile raf = new RandomAccessFile(file, "r")) {
- byte[] buffer = new byte[256];
- for (int processIdx = 1; processIdx < numProcessors; processIdx++) {
- long initialSeekPoint = processIdx * chunkToRead;
- raf.seek(initialSeekPoint);
- int bytesRead = raf.read(buffer);
- // if (bytesRead != buffer.length) {
- // throw new Exception("Actual read is not same as requested. " + bytesRead);
- // }
- int i = 0;
- while (buffer[i] != '\n') {
- i++;
+ public static long[] getChunks(int numChunks) throws Exception {
+ long[] chunks = new long[numChunks + 1];
+ try (FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
+ long fileSize = fileChannel.size();
+ long sizeOfChunk = fileSize / numChunks;
+ var address = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
+ chunks[0] = address;
+ for (int processIdx = 1; processIdx < numChunks; processIdx++) {
+ long chunkAddress = processIdx * sizeOfChunk + address;
+ while (UNSAFE.getByte(chunkAddress) != '\n') {
+ chunkAddress++;
}
- initialSeekPoint += (i + 1);
- startPositions[processIdx] = initialSeekPoint;
+ chunkAddress++;
+ chunks[processIdx] = chunkAddress;
}
- startPositions[numProcessors] = length;
+ chunks[numChunks] = address + fileSize;
}
+ return chunks;
+ }
+ public static void memoryMap() throws Exception {
+ int numProcessors = Runtime.getRuntime().availableProcessors();
+ long[] chunks = getChunks(numProcessors);
try (ExecutorService executorService = Executors.newWorkStealingPool(numProcessors)) {
Future[] results = new Future[numProcessors];
for (int processIdx = 0; processIdx < numProcessors; processIdx++) {
- long seekPoint = startPositions[processIdx];
- long bytesToRead = startPositions[processIdx + 1] - startPositions[processIdx];
+ int finalProcessIdx = processIdx;
Future<HashMap<String, Measurement>> future = executorService.submit(() -> {
- MeasurementMap measurements = new MeasurementMap();
- try (FileInputStream fis = new FileInputStream(file)) {
- long actualSkipped = fis.skip(seekPoint);
- if (actualSkipped != seekPoint) {
- throw new Exception("Uho oh");
+ long chunkIdx = chunks[finalProcessIdx];
+ long chunkEnd = chunks[finalProcessIdx + 1];
+ MeasurementMap3 measurements = new MeasurementMap3();
+ while (chunkIdx < chunkEnd) {
+ long cityStart = chunkIdx;
+ byte b;
+ int hashcode = 0;
+ while ((b = UNSAFE.getByte(chunkIdx)) != ';') {
+ hashcode = 31 * hashcode + b;
+ chunkIdx++;
}
- byte[] buffer = new byte[BUFFER_SIZE];
- long totalBytesRead = 0;
- int bytesRead;
- int currentCityLength = 0;
- while ((bytesRead = fis.read(buffer, currentCityLength, buffer.length - currentCityLength)) != -1) {
- totalBytesRead -= currentCityLength; // avoid double counting. There must be a better way !
- if (totalBytesRead >= bytesToRead && currentCityLength == 0) {
- // we have read everything we intend to and there is no city in the buffer to finish processing
- return measurements.toMap();
- }
- int i = 0;
- int cityIndexStart = 0;
- int cityLength;
- int multiplier = 1;
- int value = 0;
- while (i < bytesRead + currentCityLength) {
- if (totalBytesRead >= bytesToRead) {
- // we have read everything we intend to for this chunk
- return measurements.toMap();
- }
- if (buffer[i] == ';') {
- cityLength = i - cityIndexStart;
- i++;
- totalBytesRead++;
- if (i == bytesRead + currentCityLength) {
- System.arraycopy(buffer, cityIndexStart, buffer, 0, cityLength);
- bytesRead = fis.read(buffer, cityLength, buffer.length - cityLength);
- currentCityLength = cityLength;
- cityIndexStart = 0;
- i = cityLength;
- }
- if (buffer[i] == '-') {
- multiplier = -1;
- i++;
- totalBytesRead++;
- if (i == bytesRead + currentCityLength) {
- System.arraycopy(buffer, cityIndexStart, buffer, 0, cityLength);
- bytesRead = fis.read(buffer, cityLength, buffer.length - cityLength);
- currentCityLength = cityLength;
- cityIndexStart = 0;
- i = cityLength;
- }
- }
- while (buffer[i] != '\n') {
- if (buffer[i] != '.') {
- value = (value * 10) + (buffer[i] - '0');
- }
- i++;
- totalBytesRead++;
- if (i == bytesRead + currentCityLength) {
- System.arraycopy(buffer, cityIndexStart, buffer, 0, cityLength);
- bytesRead = fis.read(buffer, cityLength, buffer.length - cityLength);
- currentCityLength = cityLength;
- cityIndexStart = 0;
- i = cityLength;
- }
- }
- value = value * multiplier; // is boolean check faster?
- measurements.insert(buffer, cityIndexStart, cityLength, value);
- if (totalBytesRead >= bytesToRead) {
- return measurements.toMap();
- }
- // buffer[i] == \n so go one more
- cityIndexStart = i + 1;
- value = 0;
- multiplier = 1;
- }
- i++;
- totalBytesRead++;
+ long cityEnd = chunkIdx;
+ chunkIdx++;
+ int multiplier = 1;
+ b = UNSAFE.getByte(chunkIdx);
+ if (b == '-') {
+ multiplier = -1;
+ chunkIdx++;
+ }
+ int value = 0;
+ while ((b = UNSAFE.getByte(chunkIdx)) != '\n') {
+ if (b != '.') {
+ value = (value * 10) + (b - '0');
}
- currentCityLength = buffer.length - cityIndexStart;
- System.arraycopy(buffer, cityIndexStart, buffer, 0, currentCityLength);
+ chunkIdx++;
}
+ value = value * multiplier;
+ measurements.insert(cityStart, cityEnd, hashcode, value);
+ chunkIdx++;
}
return measurements.toMap();
});
results[processIdx] = future;
}
-
final HashMap<String, Measurement> measurements = new HashMap<>();
for (Future f : results) {
HashMap<String, Measurement> m = (HashMap<String, Measurement>) f.get();
@@ -328,5 +251,4 @@ public class CalculateAverage_charlibot {
System.out.println("}");
}
}
-
}