1 | #include "duckdb/execution/window_segment_tree.hpp" |
2 | |
3 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
4 | #include "duckdb/common/algorithm.hpp" |
5 | #include "duckdb/common/helper.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | //===--------------------------------------------------------------------===// |
10 | // WindowAggregateState |
11 | //===--------------------------------------------------------------------===// |
12 | |
13 | WindowAggregateState::WindowAggregateState(AggregateObject aggr, const LogicalType &result_type_p) |
14 | : aggr(std::move(aggr)), result_type(result_type_p), state(aggr.function.state_size()), |
15 | statev(Value::POINTER(value: CastPointerToValue(src: state.data()))), |
16 | statep(Value::POINTER(value: CastPointerToValue(src: state.data()))) { |
17 | statev.SetVectorType(VectorType::FLAT_VECTOR); // Prevent conversion of results to constants |
18 | } |
19 | |
20 | WindowAggregateState::~WindowAggregateState() { |
21 | } |
22 | |
23 | void WindowAggregateState::AggregateInit() { |
24 | aggr.function.initialize(state.data()); |
25 | } |
26 | |
27 | void WindowAggregateState::AggegateFinal(Vector &result, idx_t rid) { |
28 | AggregateInputData aggr_input_data(aggr.GetFunctionData(), Allocator::DefaultAllocator()); |
29 | aggr.function.finalize(statev, aggr_input_data, result, 1, rid); |
30 | |
31 | if (aggr.function.destructor) { |
32 | aggr.function.destructor(statev, aggr_input_data, 1); |
33 | } |
34 | } |
35 | |
36 | void WindowAggregateState::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { |
37 | } |
38 | |
39 | void WindowAggregateState::Finalize() { |
40 | } |
41 | |
42 | void WindowAggregateState::Compute(Vector &result, idx_t rid, idx_t start, idx_t end) { |
43 | } |
44 | |
45 | //===--------------------------------------------------------------------===// |
46 | // WindowConstantAggregate |
47 | //===--------------------------------------------------------------------===// |
48 | |
49 | WindowConstantAggregate::WindowConstantAggregate(AggregateObject aggr, const LogicalType &result_type, |
50 | const ValidityMask &partition_mask, const idx_t count) |
51 | : WindowAggregateState(std::move(aggr), result_type), partition(0), row(0) { |
52 | |
53 | // Locate the partition boundaries |
54 | idx_t start = 0; |
55 | if (partition_mask.AllValid()) { |
56 | partition_offsets.emplace_back(args: 0); |
57 | } else { |
58 | idx_t entry_idx; |
59 | idx_t shift; |
60 | while (start < count) { |
61 | partition_mask.GetEntryIndex(row_idx: start, entry_idx, idx_in_entry&: shift); |
62 | |
63 | // If start is aligned with the start of a block, |
64 | // and the block is blank, then skip forward one block. |
65 | const auto block = partition_mask.GetValidityEntry(entry_idx); |
66 | if (partition_mask.NoneValid(entry: block) && !shift) { |
67 | start += ValidityMask::BITS_PER_VALUE; |
68 | continue; |
69 | } |
70 | |
71 | // Loop over the block |
72 | for (; shift < ValidityMask::BITS_PER_VALUE && start < count; ++shift, ++start) { |
73 | if (partition_mask.RowIsValid(entry: block, idx_in_entry: shift)) { |
74 | partition_offsets.emplace_back(args&: start); |
75 | } |
76 | } |
77 | } |
78 | } |
79 | |
80 | // Initialise the vector for caching the results |
81 | results = make_uniq<Vector>(args: result_type, args: partition_offsets.size()); |
82 | partition_offsets.emplace_back(args: count); |
83 | |
84 | // Start the first aggregate |
85 | AggregateInit(); |
86 | } |
87 | |
88 | void WindowConstantAggregate::Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) { |
89 | const auto chunk_begin = row; |
90 | const auto chunk_end = chunk_begin + payload_chunk.size(); |
91 | |
92 | if (!inputs.ColumnCount() && payload_chunk.ColumnCount()) { |
93 | inputs.Initialize(allocator&: Allocator::DefaultAllocator(), types: payload_chunk.GetTypes()); |
94 | } |
95 | |
96 | AggregateInputData aggr_input_data(aggr.GetFunctionData(), Allocator::DefaultAllocator()); |
97 | idx_t begin = 0; |
98 | idx_t filter_idx = 0; |
99 | auto partition_end = partition_offsets[partition + 1]; |
100 | while (row < chunk_end) { |
101 | if (row == partition_end) { |
102 | AggegateFinal(result&: *results, rid: partition++); |
103 | AggregateInit(); |
104 | partition_end = partition_offsets[partition + 1]; |
105 | } |
106 | partition_end = MinValue(a: partition_end, b: chunk_end); |
107 | auto end = partition_end - chunk_begin; |
108 | |
109 | inputs.Reset(); |
110 | if (filter_sel) { |
111 | // Slice to any filtered rows in [begin, end) |
112 | SelectionVector sel; |
113 | |
114 | // Find the first value in [begin, end) |
115 | for (; filter_idx < filtered; ++filter_idx) { |
116 | auto idx = filter_sel->get_index(idx: filter_idx); |
117 | if (idx >= begin) { |
118 | break; |
119 | } |
120 | } |
121 | |
122 | // Find the first value in [end, filtered) |
123 | sel.Initialize(sel: filter_sel->data() + filter_idx); |
124 | idx_t nsel = 0; |
125 | for (; filter_idx < filtered; ++filter_idx, ++nsel) { |
126 | auto idx = filter_sel->get_index(idx: filter_idx); |
127 | if (idx >= end) { |
128 | break; |
129 | } |
130 | } |
131 | |
132 | if (nsel != inputs.size()) { |
133 | inputs.Slice(other&: payload_chunk, sel, count: nsel); |
134 | } |
135 | } else { |
136 | // Slice to [begin, end) |
137 | if (begin) { |
138 | for (idx_t c = 0; c < payload_chunk.ColumnCount(); ++c) { |
139 | inputs.data[c].Slice(other&: payload_chunk.data[c], offset: begin, end); |
140 | } |
141 | } else { |
142 | inputs.Reference(chunk&: payload_chunk); |
143 | } |
144 | inputs.SetCardinality(end - begin); |
145 | } |
146 | |
147 | // Aggregate the filtered rows into a single state |
148 | const auto count = inputs.size(); |
149 | if (aggr.function.simple_update) { |
150 | aggr.function.simple_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), state.data(), count); |
151 | } else { |
152 | aggr.function.update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, count); |
153 | } |
154 | |
155 | // Skip filtered rows too! |
156 | row += end - begin; |
157 | begin = end; |
158 | } |
159 | } |
160 | |
161 | void WindowConstantAggregate::Finalize() { |
162 | AggegateFinal(result&: *results, rid: partition++); |
163 | |
164 | partition = 0; |
165 | row = 0; |
166 | } |
167 | |
168 | void WindowConstantAggregate::Compute(Vector &target, idx_t rid, idx_t start, idx_t end) { |
169 | // Find the partition containing [start, end) |
170 | while (start < partition_offsets[partition] || partition_offsets[partition + 1] <= start) { |
171 | ++partition; |
172 | } |
173 | D_ASSERT(partition_offsets[partition] <= start); |
174 | D_ASSERT(partition + 1 < partition_offsets.size()); |
175 | D_ASSERT(end <= partition_offsets[partition + 1]); |
176 | |
177 | // Copy the value |
178 | VectorOperations::Copy(source: *results, target, source_count: partition + 1, source_offset: partition, target_offset: rid); |
179 | } |
180 | |
181 | //===--------------------------------------------------------------------===// |
182 | // WindowSegmentTree |
183 | //===--------------------------------------------------------------------===// |
184 | WindowSegmentTree::WindowSegmentTree(AggregateObject aggr_p, const LogicalType &result_type_p, DataChunk *input, |
185 | const ValidityMask &filter_mask_p, WindowAggregationMode mode_p) |
186 | : aggr(std::move(aggr_p)), result_type(result_type_p), state(aggr.function.state_size()), |
187 | statep(Value::POINTER(value: CastPointerToValue(src: state.data()))), frame(0, 0), |
188 | statev(Value::POINTER(value: CastPointerToValue(src: state.data()))), internal_nodes(0), input_ref(input), |
189 | filter_mask(filter_mask_p), mode(mode_p) { |
190 | statep.Flatten(count: input->size()); |
191 | statev.SetVectorType(VectorType::FLAT_VECTOR); // Prevent conversion of results to constants |
192 | |
193 | if (input_ref && input_ref->ColumnCount() > 0) { |
194 | filter_sel.Initialize(count: input->size()); |
195 | inputs.Initialize(allocator&: Allocator::DefaultAllocator(), types: input_ref->GetTypes()); |
196 | // if we have a frame-by-frame method, share the single state |
197 | if (aggr.function.window && UseWindowAPI()) { |
198 | AggregateInit(); |
199 | inputs.Reference(chunk&: *input_ref); |
200 | } else { |
201 | inputs.SetCapacity(*input_ref); |
202 | if (aggr.function.combine && UseCombineAPI()) { |
203 | ConstructTree(); |
204 | } |
205 | } |
206 | } |
207 | } |
208 | |
209 | WindowSegmentTree::~WindowSegmentTree() { |
210 | if (!aggr.function.destructor) { |
211 | // nothing to destroy |
212 | return; |
213 | } |
214 | AggregateInputData aggr_input_data(aggr.GetFunctionData(), Allocator::DefaultAllocator()); |
215 | // call the destructor for all the intermediate states |
216 | data_ptr_t address_data[STANDARD_VECTOR_SIZE]; |
217 | Vector addresses(LogicalType::POINTER, data_ptr_cast(src: address_data)); |
218 | idx_t count = 0; |
219 | for (idx_t i = 0; i < internal_nodes; i++) { |
220 | address_data[count++] = data_ptr_t(levels_flat_native.get() + i * state.size()); |
221 | if (count == STANDARD_VECTOR_SIZE) { |
222 | aggr.function.destructor(addresses, aggr_input_data, count); |
223 | count = 0; |
224 | } |
225 | } |
226 | if (count > 0) { |
227 | aggr.function.destructor(addresses, aggr_input_data, count); |
228 | } |
229 | |
230 | if (aggr.function.window && UseWindowAPI()) { |
231 | aggr.function.destructor(statev, aggr_input_data, 1); |
232 | } |
233 | } |
234 | |
235 | void WindowSegmentTree::AggregateInit() { |
236 | aggr.function.initialize(state.data()); |
237 | } |
238 | |
239 | void WindowSegmentTree::AggegateFinal(Vector &result, idx_t rid) { |
240 | AggregateInputData aggr_input_data(aggr.GetFunctionData(), Allocator::DefaultAllocator()); |
241 | aggr.function.finalize(statev, aggr_input_data, result, 1, rid); |
242 | |
243 | if (aggr.function.destructor) { |
244 | aggr.function.destructor(statev, aggr_input_data, 1); |
245 | } |
246 | } |
247 | |
248 | void WindowSegmentTree::(idx_t begin, idx_t end) { |
249 | const auto size = end - begin; |
250 | |
251 | auto &chunk = *input_ref; |
252 | const auto input_count = input_ref->ColumnCount(); |
253 | inputs.SetCardinality(size); |
254 | for (idx_t i = 0; i < input_count; ++i) { |
255 | auto &v = inputs.data[i]; |
256 | auto &vec = chunk.data[i]; |
257 | v.Slice(other&: vec, offset: begin, end); |
258 | v.Verify(count: size); |
259 | } |
260 | |
261 | // Slice to any filtered rows |
262 | if (!filter_mask.AllValid()) { |
263 | idx_t filtered = 0; |
264 | for (idx_t i = begin; i < end; ++i) { |
265 | if (filter_mask.RowIsValid(row_idx: i)) { |
266 | filter_sel.set_index(idx: filtered++, loc: i - begin); |
267 | } |
268 | } |
269 | if (filtered != inputs.size()) { |
270 | inputs.Slice(sel_vector: filter_sel, count: filtered); |
271 | } |
272 | } |
273 | } |
274 | |
275 | void WindowSegmentTree::WindowSegmentValue(idx_t l_idx, idx_t begin, idx_t end) { |
276 | D_ASSERT(begin <= end); |
277 | if (begin == end || inputs.ColumnCount() == 0) { |
278 | return; |
279 | } |
280 | |
281 | const auto count = end - begin; |
282 | Vector s(statep, 0, count); |
283 | if (l_idx == 0) { |
284 | ExtractFrame(begin, end); |
285 | AggregateInputData aggr_input_data(aggr.GetFunctionData(), Allocator::DefaultAllocator()); |
286 | D_ASSERT(!inputs.data.empty()); |
287 | aggr.function.update(&inputs.data[0], aggr_input_data, input_ref->ColumnCount(), s, inputs.size()); |
288 | } else { |
289 | // find out where the states begin |
290 | data_ptr_t begin_ptr = levels_flat_native.get() + state.size() * (begin + levels_flat_start[l_idx - 1]); |
291 | // set up a vector of pointers that point towards the set of states |
292 | Vector v(LogicalType::POINTER, count); |
293 | auto pdata = FlatVector::GetData<data_ptr_t>(vector&: v); |
294 | for (idx_t i = 0; i < count; i++) { |
295 | pdata[i] = begin_ptr + i * state.size(); |
296 | } |
297 | v.Verify(count); |
298 | AggregateInputData aggr_input_data(aggr.GetFunctionData(), Allocator::DefaultAllocator()); |
299 | aggr.function.combine(v, s, aggr_input_data, count); |
300 | } |
301 | } |
302 | |
303 | void WindowSegmentTree::ConstructTree() { |
304 | D_ASSERT(input_ref); |
305 | D_ASSERT(inputs.ColumnCount() > 0); |
306 | |
307 | // compute space required to store internal nodes of segment tree |
308 | internal_nodes = 0; |
309 | idx_t level_nodes = input_ref->size(); |
310 | do { |
311 | level_nodes = (level_nodes + (TREE_FANOUT - 1)) / TREE_FANOUT; |
312 | internal_nodes += level_nodes; |
313 | } while (level_nodes > 1); |
314 | levels_flat_native = make_unsafe_uniq_array<data_t>(n: internal_nodes * state.size()); |
315 | levels_flat_start.push_back(x: 0); |
316 | |
317 | idx_t levels_flat_offset = 0; |
318 | idx_t level_current = 0; |
319 | // level 0 is data itself |
320 | idx_t level_size; |
321 | // iterate over the levels of the segment tree |
322 | while ((level_size = (level_current == 0 ? input_ref->size() |
323 | : levels_flat_offset - levels_flat_start[level_current - 1])) > 1) { |
324 | for (idx_t pos = 0; pos < level_size; pos += TREE_FANOUT) { |
325 | // compute the aggregate for this entry in the segment tree |
326 | AggregateInit(); |
327 | WindowSegmentValue(l_idx: level_current, begin: pos, end: MinValue(a: level_size, b: pos + TREE_FANOUT)); |
328 | |
329 | memcpy(dest: levels_flat_native.get() + (levels_flat_offset * state.size()), src: state.data(), n: state.size()); |
330 | |
331 | levels_flat_offset++; |
332 | } |
333 | |
334 | levels_flat_start.push_back(x: levels_flat_offset); |
335 | level_current++; |
336 | } |
337 | |
338 | // Corner case: single element in the window |
339 | if (levels_flat_offset == 0) { |
340 | aggr.function.initialize(levels_flat_native.get()); |
341 | } |
342 | } |
343 | |
344 | void WindowSegmentTree::Compute(Vector &result, idx_t rid, idx_t begin, idx_t end) { |
345 | D_ASSERT(input_ref); |
346 | |
347 | // If we have a window function, use that |
348 | if (aggr.function.window && UseWindowAPI()) { |
349 | // Frame boundaries |
350 | auto prev = frame; |
351 | frame = FrameBounds(begin, end); |
352 | |
353 | // Extract the range |
354 | AggregateInputData aggr_input_data(aggr.GetFunctionData(), Allocator::DefaultAllocator()); |
355 | aggr.function.window(input_ref->data.data(), filter_mask, aggr_input_data, inputs.ColumnCount(), state.data(), |
356 | frame, prev, result, rid, 0); |
357 | return; |
358 | } |
359 | |
360 | AggregateInit(); |
361 | |
362 | // Aggregate everything at once if we can't combine states |
363 | if (!aggr.function.combine || !UseCombineAPI()) { |
364 | WindowSegmentValue(l_idx: 0, begin, end); |
365 | AggegateFinal(result, rid); |
366 | return; |
367 | } |
368 | |
369 | for (idx_t l_idx = 0; l_idx < levels_flat_start.size() + 1; l_idx++) { |
370 | idx_t parent_begin = begin / TREE_FANOUT; |
371 | idx_t parent_end = end / TREE_FANOUT; |
372 | if (parent_begin == parent_end) { |
373 | WindowSegmentValue(l_idx, begin, end); |
374 | break; |
375 | } |
376 | idx_t group_begin = parent_begin * TREE_FANOUT; |
377 | if (begin != group_begin) { |
378 | WindowSegmentValue(l_idx, begin, end: group_begin + TREE_FANOUT); |
379 | parent_begin++; |
380 | } |
381 | idx_t group_end = parent_end * TREE_FANOUT; |
382 | if (end != group_end) { |
383 | WindowSegmentValue(l_idx, begin: group_end, end); |
384 | } |
385 | begin = parent_begin; |
386 | end = parent_end; |
387 | } |
388 | |
389 | AggegateFinal(result, rid); |
390 | } |
391 | |
392 | } // namespace duckdb |
393 | |