aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/dev/morling/onebrc/CalculateAverage_palmr.java
blob: 5dcfcde25e0a789e4ff8a543715245e2df7e1319 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
/*
 *  Copyright 2023 The original authors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package dev.morling.onebrc;

import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

public class CalculateAverage_palmr {
    private static final String FILE = "./measurements.txt";
    private static final int CHUNK_SIZE = 1024 * 1024 * 10; // Trial and error showed ~10MB to be a good size on our machine
    private static final int STATION_NAME_BUFFER_SIZE = 128;
    private static final int THREAD_COUNT = Math.min(8, Runtime.getRuntime().availableProcessors());
    private static final char SEPARATOR_CHAR = ';';
    private static final char END_OF_RECORD = '\n';
    private static final char MINUS_CHAR = '-';
    private static final char DECIMAL_POINT_CHAR = '.';

    public static void main(String[] args) throws IOException {

        final var file = new RandomAccessFile(FILE, "r");
        final var channel = file.getChannel();

        final TreeMap<String, MeasurementAggregator> results = StreamSupport.stream(ThreadChunk.chunk(file, THREAD_COUNT), true)
                .map(chunk -> parseChunk(chunk, channel))
                .flatMap(bakm -> bakm.getAsUnorderedList().stream())
                .collect(Collectors.toMap(m -> new String(m.stationNameBytes, StandardCharsets.UTF_8), m -> m, MeasurementAggregator::merge, TreeMap::new));
        System.out.println(results);
    }

    private record ThreadChunk(long startPoint, long endPoint, long size) {
        public static Spliterator<CalculateAverage_palmr.ThreadChunk> chunk(final RandomAccessFile file, final int chunkCount) throws IOException {
            final var fileSize = file.length();
            final var idealChunkSize = Math.max(CHUNK_SIZE, fileSize / THREAD_COUNT);
            final var chunks = new CalculateAverage_palmr.ThreadChunk[chunkCount];

            var validChunks = 0;
            var startPoint = 0L;
            for (int i = 0; i < chunkCount; i++) {
                var endPoint = Math.min(startPoint + idealChunkSize, fileSize);
                if (startPoint + idealChunkSize < fileSize)
                {
                    file.seek(endPoint);
                    while (endPoint++ < fileSize && file.readByte() != END_OF_RECORD) {
                        Thread.onSpinWait();
                    }
                }

                final var actualSize = endPoint - startPoint;
                if (actualSize > 1) {
                    chunks[i] = new CalculateAverage_palmr.ThreadChunk(startPoint, endPoint, actualSize);
                    startPoint += actualSize;
                    validChunks++;
                }
                else {
                    break;
                }
            }

            return Spliterators.spliterator(chunks, 0, validChunks,
                    Spliterator.ORDERED |
                            Spliterator.DISTINCT |
                            Spliterator.SORTED |
                            Spliterator.NONNULL |
                            Spliterator.IMMUTABLE |
                            Spliterator.CONCURRENT
            );
        }
    }

    private static ByteArrayKeyedMap parseChunk(ThreadChunk chunk, FileChannel channel) {
        final var state = new State();

        var offset = chunk.startPoint;
        while (offset < chunk.endPoint) {
            parseData(channel, state, offset, Math.min(CHUNK_SIZE, chunk.endPoint - offset));
            offset += CHUNK_SIZE;
        }

        return state.aggregators;
    }

    private static void parseData(final FileChannel channel,
                                  final State state,
                                  final long offset,
                                  final long bufferSize) {
        final ByteBuffer byteBuffer;
        try {
            byteBuffer = channel.map(FileChannel.MapMode.READ_ONLY, offset, bufferSize);

            while (byteBuffer.hasRemaining()) {
                final var currentChar = byteBuffer.get();

                if (currentChar == SEPARATOR_CHAR) {
                    state.parsingValue = true;
                }
                else if (currentChar == END_OF_RECORD) {
                    if (state.stationPointerEnd != 0) {
                        final var value = state.measurementValue * state.exponent;

                        MeasurementAggregator aggregator = state.aggregators.computeIfAbsent(state.stationBuffer, state.stationPointerEnd, state.signedHashCode);
                        aggregator.count++;
                        aggregator.min = Math.min(aggregator.min, value);
                        aggregator.max = Math.max(aggregator.max, value);
                        aggregator.sum += value;
                    }

                    // reset
                    state.reset();
                }
                else {
                    if (!state.parsingValue) {
                        state.stationBuffer[state.stationPointerEnd++] = currentChar;
                        state.signedHashCode = 31 * state.signedHashCode + (currentChar & 0xff);
                    }
                    else {
                        if (currentChar == MINUS_CHAR) {
                            state.exponent = -0.1;
                        }
                        else if (currentChar != DECIMAL_POINT_CHAR) {
                            state.measurementValue = state.measurementValue * 10 + (currentChar - '0');
                        }
                    }
                }
            }
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static final class State {
        ByteArrayKeyedMap aggregators = new ByteArrayKeyedMap();
        boolean parsingValue = false;
        byte[] stationBuffer = new byte[STATION_NAME_BUFFER_SIZE];
        int signedHashCode = 0;
        int stationPointerEnd = 0;
        double measurementValue = 0;
        double exponent = 0.1;

        public void reset() {
            parsingValue = false;
            signedHashCode = 0;
            stationPointerEnd = 0;
            measurementValue = 0;
            exponent = 0.1;
        }
    }

    private static class MeasurementAggregator {
        final byte[] stationNameBytes;
        final int stationNameHashCode;
        private double min = Double.POSITIVE_INFINITY;
        private double max = Double.NEGATIVE_INFINITY;
        private double sum;
        private long count;

        public MeasurementAggregator(final byte[] stationNameBytes, final int stationNameHashCode) {
            this.stationNameBytes = stationNameBytes;
            this.stationNameHashCode = stationNameHashCode;
        }

        public String toString() {
            return STR."\{round(min)}/\{round(sum / count)}/\{round(max)}";
        }

        private double round(final double value) {
            return Math.round(value * 10.0) / 10.0;
        }

        private MeasurementAggregator merge(final MeasurementAggregator b) {
            this.count += b.count;
            this.min = Math.min(this.min, b.min);
            this.max = Math.max(this.max, b.max);
            this.sum += b.sum;
            return this;
        }
    }

    /**
     * Very basic hash table implementation, only implementing computeIfAbsent since that's all the code needs.
     * It's sized to give minimal collisions with the example test set. this may not hold true if the stations list
     * changes, but it should still perform fairly well.
     * It uses Open Addressing, meaning it's just one array, rather Separate Chaining which is what the default java HashMap uses.
     * IT also uses Linear probing for collision resolution, which given the minimal collision count should hold up well.
     */
    private static class ByteArrayKeyedMap {
        private final int BUCKET_COUNT = 0xFFFF;
        private final MeasurementAggregator[] buckets = new MeasurementAggregator[BUCKET_COUNT + 1];
        private final List<MeasurementAggregator> compactUnorderedBuckets = new ArrayList<>(413);

        public MeasurementAggregator computeIfAbsent(final byte[] key, final int keyLength, final int keyHashCode) {
            var index = keyHashCode & BUCKET_COUNT;

            while (true) {
                MeasurementAggregator maybe = buckets[index];
                if (maybe != null) {
                    if (Arrays.equals(key, 0, keyLength, maybe.stationNameBytes, 0, maybe.stationNameBytes.length)) {
                        return maybe;
                    }
                    index++;
                    index &= BUCKET_COUNT;
                }
                else {
                    final var copiedKey = Arrays.copyOf(key, keyLength);
                    MeasurementAggregator measurementAggregator = new MeasurementAggregator(copiedKey, keyHashCode);
                    buckets[index] = measurementAggregator;
                    compactUnorderedBuckets.add(measurementAggregator);
                    return measurementAggregator;
                }
            }
        }

        public List<MeasurementAggregator> getAsUnorderedList() {
            return compactUnorderedBuckets;
        }
    }
}