aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorkumarsaurav123 <kumar.saurav@eko.co.in>2024-01-20 02:05:25 +0530
committerGitHub <noreply@github.com>2024-01-19 21:35:25 +0100
commitf6bcaae4b99bca976e5facefb20649ea085a458d (patch)
tree974644c9bcad9c6751e095c37bf3ac44986f5cc1 /src/main
parent836f0805ad40956f704d80caa068c460fc30da5b (diff)
kumarsaurav123 # Attempt 3 (#470)
* Use Memory Segment * Reduce Number of threads
Diffstat (limited to 'src/main')
-rw-r--r--src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java277
1 files changed, 137 insertions, 140 deletions
diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java b/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java
index 5b59d05..f991f9f 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_kumarsaurav123.java
@@ -15,18 +15,20 @@
*/
package dev.morling.onebrc;
+import java.io.IOException;
import java.io.RandomAccessFile;
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
-import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collector;
-import java.util.stream.IntStream;
import static java.util.stream.Collectors.groupingBy;
@@ -40,7 +42,10 @@ public class CalculateAverage_kumarsaurav123 {
}
}
- private static record ResultRow(String station,double min, double mean, double max,double sum,double count) {
+ private static record Pair(long start, int size) {
+ }
+
+ private static record ResultRow(String station, double min, double mean, double max, double sum, double count) {
public String toString() {
return round(min) + "/" + round(mean) + "/" + round(max);
}
@@ -61,18 +66,13 @@ public class CalculateAverage_kumarsaurav123 {
private String station;
}
- public static void main(String[] args) {
- HashMap<Byte, Integer> map = new HashMap<>();
- map.put((byte) 48, 0);
- map.put((byte) 49, 1);
- map.put((byte) 50, 2);
- map.put((byte) 51, 3);
- map.put((byte) 52, 4);
- map.put((byte) 53, 5);
- map.put((byte) 54, 6);
- map.put((byte) 55, 7);
- map.put((byte) 56, 8);
- map.put((byte) 57, 9);
+ public static void main(String[] args) throws IOException {
+ long start = System.currentTimeMillis();
+ System.out.println(run(FILE));
+ // System.out.println(System.currentTimeMillis() - start);
+ }
+
+ public static String run(String filePath) throws IOException {
Collector<ResultRow, MeasurementAggregator, ResultRow> collector2 = Collector.of(
MeasurementAggregator::new,
(a, m) -> {
@@ -91,7 +91,7 @@ public class CalculateAverage_kumarsaurav123 {
return res;
},
agg -> {
- return new ResultRow(agg.station, agg.min, agg.sum / agg.count, agg.max, agg.sum, agg.count);
+ return new ResultRow(agg.station, agg.min, (Math.round(agg.sum * 10.0) / 10.0) / agg.count, agg.max, agg.sum, agg.count);
});
Collector<Measurement, MeasurementAggregator, ResultRow> collector = Collector.of(
MeasurementAggregator::new,
@@ -114,143 +114,140 @@ public class CalculateAverage_kumarsaurav123 {
agg -> {
return new ResultRow(agg.station, agg.min, agg.sum / agg.count, agg.max, agg.sum, agg.count);
});
-
- long start = System.currentTimeMillis();
- long len = Paths.get(FILE).toFile().length();
- Map<Integer, List<byte[]>> leftOutsMap = new ConcurrentSkipListMap<>();
- int chunkSize = 1_0000_00;
- long proc = Math.max(1, (len / chunkSize));
- ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2 * 2 * 2);
+ ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
List<ResultRow> measurements = Collections.synchronizedList(new ArrayList<ResultRow>());
- IntStream.range(0, (int) proc)
- .mapToObj(i -> {
- return new Runnable() {
- @Override
- public void run() {
- try {
- RandomAccessFile file = new RandomAccessFile(FILE, "r");
- byte[] allBytes2 = new byte[chunkSize];
- file.seek((long) i * (long) chunkSize);
- int l = file.read(allBytes2);
- byte[] eol = "\n".getBytes(StandardCharsets.UTF_8);
- byte[] sep = ";".getBytes(StandardCharsets.UTF_8);
-
- List<Measurement> mst = new ArrayList<>();
- int st = 0;
- int cnt = 0;
- ArrayList<byte[]> local = new ArrayList<>();
-
- for (int i = 0; i < l; i++) {
- if (allBytes2[i] == eol[0]) {
- if (i != 0) {
- byte[] s2 = new byte[i - st];
- System.arraycopy(allBytes2, st, s2, 0, s2.length);
- if (cnt != 0) {
- for (int j = 0; j < s2.length; j++) {
- if (s2[j] == sep[0]) {
- byte[] city = new byte[j];
- byte[] value = new byte[s2.length - j - 1];
- System.arraycopy(s2, 0, city, 0, city.length);
- System.arraycopy(s2, city.length + 1, value, 0, value.length);
- double d = 0.0;
- int s = -1;
- for (int k = value.length - 1; k >= 0; k--) {
- if (value[k] == 45) {
- d = d * -1;
- }
- else if (value[k] == 46) {
- }
- else {
- d = d + map.get(value[k]).intValue() * Math.pow(10, s);
- s++;
- }
- }
- mst.add(new Measurement(new String(city), d));
-
- }
- }
-
- }
- else {
- local.add(s2);
- }
+ int chunkSize = 1_0000_00;
+ Map<Integer, List<byte[]>> leftOutsMap = new ConcurrentSkipListMap<>();
+ RandomAccessFile file = new RandomAccessFile(filePath, "r");
+ long filelength = file.length();
+ AtomicInteger kk = new AtomicInteger();
+ MemorySegment memorySegment = file.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, filelength, Arena.global());
+ int nChunks = 1000;
+
+ int pChunkSize = Math.min(Integer.MAX_VALUE, (int) (memorySegment.byteSize() / (1000 * 20)));
+ if (pChunkSize < 100) {
+ pChunkSize = (int) memorySegment.byteSize();
+ nChunks = 1;
+ }
+ ArrayList<Pair> chunks = createStartAndEnd(pChunkSize, nChunks, memorySegment);
+ chunks.stream()
+ .map(p -> {
- }
- cnt++;
- st = i + 1;
- }
- }
- if (st < l) {
- byte[] s2 = new byte[allBytes2.length - st];
- System.arraycopy(allBytes2, st, s2, 0, s2.length);
- local.add(s2);
- }
- leftOutsMap.put(i, local);
- allBytes2 = null;
- measurements.addAll(mst.stream()
- .collect(groupingBy(Measurement::station, collector))
- .values());
- // System.out.println(measurements.size());
- }
- catch (Exception e) {
- // throw new RuntimeException(e);
- System.out.println("");
- }
- }
- };
+ return createRunnable(memorySegment, p, collector, measurements, kk.getAndIncrement());
})
- .forEach(executor::submit);
- executor.shutdown();
-
+ .forEach(executorService::submit);
+ executorService.shutdown();
try {
- executor.awaitTermination(10, TimeUnit.MINUTES);
+ executorService.awaitTermination(10, TimeUnit.MINUTES);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
- Collection<Measurement> lMeasure = new ArrayList<>();
- List<byte[]> leftOuts = leftOutsMap.values()
- .stream()
- .flatMap(List::stream)
- .toList();
- int size = 0;
- for (int i = 0; i < leftOuts.size(); i++) {
- size = size + leftOuts.get(i).length;
- }
- byte[] allBytes = new byte[size];
- int pos = 0;
- for (int i = 0; i < leftOuts.size(); i++) {
- System.arraycopy(leftOuts.get(i), 0, allBytes, pos, leftOuts.get(i).length);
- pos = pos + leftOuts.get(i).length;
- }
- List<String> l = Arrays.asList(new String(allBytes).split(";"));
- List<Measurement> measurements1 = new ArrayList<>();
- String city = l.get(0);
- for (int i = 0; i < l.size() - 1; i++) {
- int sIndex = l.get(i + 1).indexOf('.') + 2;
-
- String tempp = l.get(i + 1).substring(0, sIndex);
- measurements1.add(new Measurement(city, Double.parseDouble(tempp)));
- city = l.get(i + 1).substring(sIndex);
- }
- measurements.addAll(measurements1.stream()
- .collect(groupingBy(Measurement::station, collector))
- .values());
Map<String, ResultRow> measurements2 = new TreeMap<>(measurements
.stream()
.parallel()
.collect(groupingBy(ResultRow::station, collector2)));
+ return measurements2.toString();
+ }
- // Read from bytes 1000 to 2000
- // Something like this
+ private static ArrayList<Pair> createStartAndEnd(int chunksize, int nChunks, MemorySegment memorySegment) {
+ ArrayList<Pair> startSizePairs = new ArrayList<>();
+ byte eol = "\n".getBytes(StandardCharsets.UTF_8)[0];
+ long start = 0;
+ long end = -1;
+ if (nChunks == 1) {
+ startSizePairs.add(new Pair(0, chunksize));
+ return startSizePairs;
+ }
+ else {
+ while (start < memorySegment.byteSize()) {
+ start = end + 1;
+ end = Math.min(memorySegment.byteSize() - 1, start + chunksize - 1);
+ while (memorySegment.get(ValueLayout.JAVA_BYTE, end) != eol) {
+ end--;
+
+ }
+ startSizePairs.add(new Pair(start, (int) (end - start + 1)));
+ }
+ }
+ return startSizePairs;
+ }
- //
- // Map<String, ResultRow> measurements = new TreeMap<>(Files.lines(Paths.get(FILE))
- // .map(l -> new Measurement(l.split(";")))
- // .collect(groupingBy(m -> m.station(), collector)));
+ public static Runnable createRunnable(MemorySegment memorySegment, Pair p, Collector<Measurement, MeasurementAggregator, ResultRow> collector,
+ List<ResultRow> measurements, int kk) {
+ return new Runnable() {
+ @Override
+ public void run() {
+ try {
+ long start = System.currentTimeMillis();
+
+ byte[] allBytes2 = new byte[p.size];
+ MemorySegment lMemory = memorySegment.asSlice(p.start, p.size);
+ lMemory.asByteBuffer().get(allBytes2);
+ HashMap<Byte, Integer> map = new HashMap<>();
+ // Runtime runtime = Runtime.getRuntime();
+ // long memoryMax = runtime.maxMemory();
+ // long memoryUsed = runtime.totalMemory() - runtime.freeMemory();
+ // double memoryUsedPercent = (memoryUsed * 100.0) / memoryMax;
+ // System.out.println("memoryUsedPercent: " + memoryUsedPercent);
+ map.put((byte) 48, 0);
+ map.put((byte) 49, 1);
+ map.put((byte) 50, 2);
+ map.put((byte) 51, 3);
+ map.put((byte) 52, 4);
+ map.put((byte) 53, 5);
+ map.put((byte) 54, 6);
+ map.put((byte) 55, 7);
+ map.put((byte) 56, 8);
+ map.put((byte) 57, 9);
+ byte[] eol = "\n".getBytes(StandardCharsets.UTF_8);
+ byte[] sep = ";".getBytes(StandardCharsets.UTF_8);
+
+ List<Measurement> mst = new ArrayList<>();
+ int st = 0;
+
+ for (int i = 0; i < allBytes2.length; i++) {
+ if (allBytes2[i] == eol[0]) {
+ byte[] s2 = new byte[i - st];
+ System.arraycopy(allBytes2, st, s2, 0, s2.length);
+ for (int j = 0; j < s2.length; j++) {
+ if (s2[j] == sep[0]) {
+ byte[] city = new byte[j];
+ byte[] value = new byte[s2.length - j - 1];
+ System.arraycopy(s2, 0, city, 0, city.length);
+ System.arraycopy(s2, city.length + 1, value, 0, value.length);
+ double d = 0.0;
+ int s = -1;
+ for (int k = value.length - 1; k >= 0; k--) {
+ if (value[k] == 45) {
+ d = d * -1;
+ }
+ else if (value[k] == 46) {
+ }
+ else {
+ d = d + map.get(value[k]).intValue() * Math.pow(10, s);
+ s++;
+ }
+ }
+ mst.add(new Measurement(new String(city), d));
- System.out.println(measurements2);
- // System.out.println(System.currentTimeMillis() - start);
+ }
+ }
+ st = i + 1;
+ }
+ }
+ // System.out.println("Task " + kk + "Completed in " + (System.currentTimeMillis() - start));
+ measurements.addAll(mst.stream()
+ .collect(groupingBy(Measurement::station, collector))
+ .values());
+
+ }
+ catch (Exception e) {
+ // throw new RuntimeException(e);
+ System.out.println("");
+ }
+ }
+ };
}
}