1 | #include "duckdb/execution/aggregate_hashtable.hpp" |
2 | |
3 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
4 | #include "duckdb/common/algorithm.hpp" |
5 | #include "duckdb/common/exception.hpp" |
6 | #include "duckdb/common/radix_partitioning.hpp" |
7 | #include "duckdb/common/row_operations/row_operations.hpp" |
8 | #include "duckdb/common/types/null_value.hpp" |
9 | #include "duckdb/common/types/row/tuple_data_iterator.hpp" |
10 | #include "duckdb/common/vector_operations/unary_executor.hpp" |
11 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
12 | #include "duckdb/execution/expression_executor.hpp" |
13 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
14 | #include "duckdb/storage/buffer_manager.hpp" |
15 | |
16 | #include <cmath> |
17 | |
18 | namespace duckdb { |
19 | |
20 | using ValidityBytes = TupleDataLayout::ValidityBytes; |
21 | |
22 | GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, |
23 | vector<LogicalType> group_types, vector<LogicalType> payload_types, |
24 | const vector<BoundAggregateExpression *> &bindings, |
25 | HtEntryType entry_type, idx_t initial_capacity) |
26 | : GroupedAggregateHashTable(context, allocator, std::move(group_types), std::move(payload_types), |
27 | AggregateObject::CreateAggregateObjects(bindings), entry_type, initial_capacity) { |
28 | } |
29 | |
30 | GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, |
31 | vector<LogicalType> group_types) |
32 | : GroupedAggregateHashTable(context, allocator, std::move(group_types), {}, vector<AggregateObject>()) { |
33 | } |
34 | |
35 | AggregateHTAppendState::AggregateHTAppendState() |
36 | : ht_offsets(LogicalTypeId::BIGINT), hash_salts(LogicalTypeId::SMALLINT), |
37 | group_compare_vector(STANDARD_VECTOR_SIZE), no_match_vector(STANDARD_VECTOR_SIZE), |
38 | empty_vector(STANDARD_VECTOR_SIZE), new_groups(STANDARD_VECTOR_SIZE), addresses(LogicalType::POINTER), |
39 | chunk_state_initialized(false) { |
40 | } |
41 | |
42 | GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, |
43 | vector<LogicalType> group_types_p, |
44 | vector<LogicalType> payload_types_p, |
45 | vector<AggregateObject> aggregate_objects_p, |
46 | HtEntryType entry_type, idx_t initial_capacity) |
47 | : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), |
48 | entry_type(entry_type), capacity(0), is_finalized(false), |
49 | aggregate_allocator(make_shared<ArenaAllocator>(args&: allocator)) { |
50 | // Append hash column to the end and initialise the row layout |
51 | group_types_p.emplace_back(args: LogicalType::HASH); |
52 | layout.Initialize(types_p: std::move(group_types_p), aggregates_p: std::move(aggregate_objects_p)); |
53 | tuple_size = layout.GetRowWidth(); |
54 | tuples_per_block = Storage::BLOCK_SIZE / tuple_size; |
55 | |
56 | // HT layout |
57 | hash_offset = layout.GetOffsets()[layout.ColumnCount() - 1]; |
58 | data_collection = make_uniq<TupleDataCollection>(args&: buffer_manager, args&: layout); |
59 | data_collection->InitializeAppend(pin_state&: td_pin_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED); |
60 | |
61 | hashes_hdl = buffer_manager.Allocate(block_size: Storage::BLOCK_SIZE); |
62 | hashes_hdl_ptr = hashes_hdl.Ptr(); |
63 | |
64 | switch (entry_type) { |
65 | case HtEntryType::HT_WIDTH_64: { |
66 | hash_prefix_shift = (HASH_WIDTH - sizeof(aggr_ht_entry_64::salt)) * 8; |
67 | Resize<aggr_ht_entry_64>(size: initial_capacity); |
68 | break; |
69 | } |
70 | case HtEntryType::HT_WIDTH_32: { |
71 | hash_prefix_shift = (HASH_WIDTH - sizeof(aggr_ht_entry_32::salt)) * 8; |
72 | Resize<aggr_ht_entry_32>(size: initial_capacity); |
73 | break; |
74 | } |
75 | default: |
76 | throw InternalException("Unknown HT entry width" ); |
77 | } |
78 | |
79 | predicates.resize(new_size: layout.ColumnCount() - 1, x: ExpressionType::COMPARE_EQUAL); |
80 | } |
81 | |
82 | GroupedAggregateHashTable::~GroupedAggregateHashTable() { |
83 | Destroy(); |
84 | } |
85 | |
86 | void GroupedAggregateHashTable::Destroy() { |
87 | if (data_collection->Count() == 0) { |
88 | return; |
89 | } |
90 | |
91 | // Check if there is an aggregate with a destructor |
92 | bool has_destructor = false; |
93 | for (auto &aggr : layout.GetAggregates()) { |
94 | if (aggr.function.destructor) { |
95 | has_destructor = true; |
96 | } |
97 | } |
98 | if (!has_destructor) { |
99 | return; |
100 | } |
101 | |
102 | // There are aggregates with destructors: Call the destructor for each of the aggregates |
103 | RowOperationsState state(aggregate_allocator->GetAllocator()); |
104 | TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); |
105 | auto &row_locations = iterator.GetChunkState().row_locations; |
106 | do { |
107 | RowOperations::DestroyStates(state, layout, addresses&: row_locations, count: iterator.GetCurrentChunkCount()); |
108 | } while (iterator.Next()); |
109 | data_collection->Reset(); |
110 | } |
111 | |
112 | template <class ENTRY> |
113 | void GroupedAggregateHashTable::VerifyInternal() { |
114 | auto hashes_ptr = (ENTRY *)hashes_hdl_ptr; |
115 | idx_t count = 0; |
116 | for (idx_t i = 0; i < capacity; i++) { |
117 | if (hashes_ptr[i].page_nr > 0) { |
118 | D_ASSERT(hashes_ptr[i].page_offset < tuples_per_block); |
119 | D_ASSERT(hashes_ptr[i].page_nr <= payload_hds_ptrs.size()); |
120 | auto ptr = payload_hds_ptrs[hashes_ptr[i].page_nr - 1] + ((hashes_ptr[i].page_offset) * tuple_size); |
121 | auto hash = Load<hash_t>(ptr + hash_offset); |
122 | D_ASSERT((hashes_ptr[i].salt) == (hash >> hash_prefix_shift)); |
123 | |
124 | count++; |
125 | } |
126 | } |
127 | (void)count; |
128 | D_ASSERT(count == Count()); |
129 | } |
130 | |
131 | idx_t GroupedAggregateHashTable::InitialCapacity() { |
132 | return STANDARD_VECTOR_SIZE * 2ULL; |
133 | } |
134 | |
135 | idx_t GroupedAggregateHashTable::GetMaxCapacity(HtEntryType entry_type, idx_t tuple_size) { |
136 | idx_t max_pages; |
137 | idx_t max_tuples; |
138 | |
139 | switch (entry_type) { |
140 | case HtEntryType::HT_WIDTH_32: |
141 | max_pages = NumericLimits<uint8_t>::Maximum(); |
142 | max_tuples = NumericLimits<uint16_t>::Maximum(); |
143 | break; |
144 | case HtEntryType::HT_WIDTH_64: |
145 | max_pages = NumericLimits<uint32_t>::Maximum(); |
146 | max_tuples = NumericLimits<uint16_t>::Maximum(); |
147 | break; |
148 | default: |
149 | throw InternalException("Unsupported hash table width" ); |
150 | } |
151 | |
152 | return max_pages * MinValue(a: max_tuples, b: (idx_t)Storage::BLOCK_SIZE / tuple_size); |
153 | } |
154 | |
155 | idx_t GroupedAggregateHashTable::MaxCapacity() { |
156 | return GetMaxCapacity(entry_type, tuple_size); |
157 | } |
158 | |
159 | void GroupedAggregateHashTable::Verify() { |
160 | #ifdef DEBUG |
161 | switch (entry_type) { |
162 | case HtEntryType::HT_WIDTH_32: |
163 | VerifyInternal<aggr_ht_entry_32>(); |
164 | break; |
165 | case HtEntryType::HT_WIDTH_64: |
166 | VerifyInternal<aggr_ht_entry_64>(); |
167 | break; |
168 | } |
169 | #endif |
170 | } |
171 | |
172 | template <class ENTRY> |
173 | void GroupedAggregateHashTable::Resize(idx_t size) { |
174 | D_ASSERT(!is_finalized); |
175 | D_ASSERT(size >= STANDARD_VECTOR_SIZE); |
176 | D_ASSERT(IsPowerOfTwo(size)); |
177 | |
178 | if (size < capacity) { |
179 | throw InternalException("Cannot downsize a hash table!" ); |
180 | } |
181 | capacity = size; |
182 | |
183 | bitmask = capacity - 1; |
184 | const auto byte_size = capacity * sizeof(ENTRY); |
185 | if (byte_size > (idx_t)Storage::BLOCK_SIZE) { |
186 | hashes_hdl = buffer_manager.Allocate(block_size: byte_size); |
187 | hashes_hdl_ptr = hashes_hdl.Ptr(); |
188 | } |
189 | memset(s: hashes_hdl_ptr, c: 0, n: byte_size); |
190 | |
191 | if (Count() != 0) { |
192 | D_ASSERT(!payload_hds_ptrs.empty()); |
193 | auto hashes_arr = (ENTRY *)hashes_hdl_ptr; |
194 | |
195 | idx_t block_id = 0; |
196 | auto block_pointer = payload_hds_ptrs[block_id]; |
197 | auto block_end = block_pointer + tuples_per_block * tuple_size; |
198 | |
199 | TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::ALREADY_PINNED, false); |
200 | const auto row_locations = iterator.GetRowLocations(); |
201 | do { |
202 | for (idx_t i = 0; i < iterator.GetCurrentChunkCount(); i++) { |
203 | const auto &row_location = row_locations[i]; |
204 | if (row_location > block_end || row_location < block_pointer) { |
205 | block_id++; |
206 | D_ASSERT(block_id < payload_hds_ptrs.size()); |
207 | block_pointer = payload_hds_ptrs[block_id]; |
208 | block_end = block_pointer + tuples_per_block * tuple_size; |
209 | } |
210 | D_ASSERT(row_location >= block_pointer && row_location < block_end); |
211 | D_ASSERT((row_location - block_pointer) % tuple_size == 0); |
212 | |
213 | const auto hash = Load<hash_t>(ptr: row_location + hash_offset); |
214 | D_ASSERT((hash & bitmask) == (hash % capacity)); |
215 | D_ASSERT(hash >> hash_prefix_shift <= NumericLimits<uint16_t>::Maximum()); |
216 | |
217 | auto entry_idx = (idx_t)hash & bitmask; |
218 | while (hashes_arr[entry_idx].page_nr > 0) { |
219 | entry_idx++; |
220 | if (entry_idx >= capacity) { |
221 | entry_idx = 0; |
222 | } |
223 | } |
224 | |
225 | auto &ht_entry = hashes_arr[entry_idx]; |
226 | D_ASSERT(!ht_entry.page_nr); |
227 | ht_entry.salt = hash >> hash_prefix_shift; |
228 | ht_entry.page_nr = block_id + 1; |
229 | ht_entry.page_offset = (row_location - block_pointer) / tuple_size; |
230 | } |
231 | } while (iterator.Next()); |
232 | } |
233 | |
234 | Verify(); |
235 | } |
236 | |
237 | idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, DataChunk &payload, |
238 | AggregateType filter) { |
239 | unsafe_vector<idx_t> aggregate_filter; |
240 | |
241 | auto &aggregates = layout.GetAggregates(); |
242 | for (idx_t i = 0; i < aggregates.size(); i++) { |
243 | auto &aggregate = aggregates[i]; |
244 | if (aggregate.aggr_type == filter) { |
245 | aggregate_filter.push_back(x: i); |
246 | } |
247 | } |
248 | return AddChunk(state, groups, payload, filter: aggregate_filter); |
249 | } |
250 | |
251 | idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, DataChunk &payload, |
252 | const unsafe_vector<idx_t> &filter) { |
253 | Vector hashes(LogicalType::HASH); |
254 | groups.Hash(result&: hashes); |
255 | |
256 | return AddChunk(state, groups, group_hashes&: hashes, payload, filter); |
257 | } |
258 | |
259 | idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, Vector &group_hashes, |
260 | DataChunk &payload, const unsafe_vector<idx_t> &filter) { |
261 | D_ASSERT(!is_finalized); |
262 | if (groups.size() == 0) { |
263 | return 0; |
264 | } |
265 | |
266 | #ifdef DEBUG |
267 | D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); |
268 | for (idx_t i = 0; i < groups.ColumnCount(); i++) { |
269 | D_ASSERT(groups.GetTypes()[i] == layout.GetTypes()[i]); |
270 | } |
271 | #endif |
272 | |
273 | auto new_group_count = FindOrCreateGroups(state, groups, group_hashes, addresses_out&: state.addresses, new_groups_out&: state.new_groups); |
274 | VectorOperations::AddInPlace(left&: state.addresses, delta: layout.GetAggrOffset(), count: payload.size()); |
275 | |
276 | // Now every cell has an entry, update the aggregates |
277 | auto &aggregates = layout.GetAggregates(); |
278 | idx_t filter_idx = 0; |
279 | idx_t payload_idx = 0; |
280 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
281 | for (idx_t i = 0; i < aggregates.size(); i++) { |
282 | auto &aggr = aggregates[i]; |
283 | if (filter_idx >= filter.size() || i < filter[filter_idx]) { |
284 | // Skip all the aggregates that are not in the filter |
285 | payload_idx += aggr.child_count; |
286 | VectorOperations::AddInPlace(left&: state.addresses, delta: aggr.payload_size, count: payload.size()); |
287 | continue; |
288 | } |
289 | D_ASSERT(i == filter[filter_idx]); |
290 | |
291 | if (aggr.aggr_type != AggregateType::DISTINCT && aggr.filter) { |
292 | RowOperations::UpdateFilteredStates(state&: row_state, filter_data&: filter_set.GetFilterData(aggr_idx: i), aggr, addresses&: state.addresses, payload, |
293 | arg_idx: payload_idx); |
294 | } else { |
295 | RowOperations::UpdateStates(state&: row_state, aggr, addresses&: state.addresses, payload, arg_idx: payload_idx, count: payload.size()); |
296 | } |
297 | |
298 | // Move to the next aggregate |
299 | payload_idx += aggr.child_count; |
300 | VectorOperations::AddInPlace(left&: state.addresses, delta: aggr.payload_size, count: payload.size()); |
301 | filter_idx++; |
302 | } |
303 | |
304 | Verify(); |
305 | return new_group_count; |
306 | } |
307 | |
308 | void GroupedAggregateHashTable::FetchAggregates(DataChunk &groups, DataChunk &result) { |
309 | groups.Verify(); |
310 | D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); |
311 | for (idx_t i = 0; i < result.ColumnCount(); i++) { |
312 | D_ASSERT(result.data[i].GetType() == payload_types[i]); |
313 | } |
314 | result.SetCardinality(groups); |
315 | if (groups.size() == 0) { |
316 | return; |
317 | } |
318 | |
319 | // find the groups associated with the addresses |
320 | // FIXME: this should not use the FindOrCreateGroups, creating them is unnecessary |
321 | AggregateHTAppendState append_state; |
322 | Vector addresses(LogicalType::POINTER); |
323 | FindOrCreateGroups(state&: append_state, groups, addresses_out&: addresses); |
324 | // now fetch the aggregates |
325 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
326 | RowOperations::FinalizeStates(state&: row_state, layout, addresses, result, aggr_idx: 0); |
327 | } |
328 | |
329 | idx_t GroupedAggregateHashTable::ResizeThreshold() { |
330 | return capacity / LOAD_FACTOR; |
331 | } |
332 | |
333 | template <class ENTRY> |
334 | idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(AggregateHTAppendState &state, DataChunk &groups, |
335 | Vector &group_hashes_v, Vector &addresses_v, |
336 | SelectionVector &new_groups_out) { |
337 | D_ASSERT(!is_finalized); |
338 | D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); |
339 | D_ASSERT(group_hashes_v.GetType() == LogicalType::HASH); |
340 | D_ASSERT(state.ht_offsets.GetVectorType() == VectorType::FLAT_VECTOR); |
341 | D_ASSERT(state.ht_offsets.GetType() == LogicalType::BIGINT); |
342 | D_ASSERT(addresses_v.GetType() == LogicalType::POINTER); |
343 | D_ASSERT(state.hash_salts.GetType() == LogicalType::SMALLINT); |
344 | |
345 | if (Count() + groups.size() > MaxCapacity()) { |
346 | throw InternalException("Hash table capacity reached" ); |
347 | } |
348 | |
349 | // Resize at 50% capacity, also need to fit the entire vector |
350 | if (capacity - Count() <= groups.size() || Count() > ResizeThreshold()) { |
351 | Verify(); |
352 | Resize<ENTRY>(capacity * 2); |
353 | } |
354 | D_ASSERT(capacity - Count() >= groups.size()); // we need to be able to fit at least one vector of data |
355 | |
356 | group_hashes_v.Flatten(count: groups.size()); |
357 | auto group_hashes = FlatVector::GetData<hash_t>(vector&: group_hashes_v); |
358 | |
359 | addresses_v.Flatten(count: groups.size()); |
360 | auto addresses = FlatVector::GetData<data_ptr_t>(vector&: addresses_v); |
361 | |
362 | // Compute the entry in the table based on the hash using a modulo, |
363 | // and precompute the hash salts for faster comparison below |
364 | auto ht_offsets_ptr = FlatVector::GetData<uint64_t>(vector&: state.ht_offsets); |
365 | auto hash_salts_ptr = FlatVector::GetData<uint16_t>(vector&: state.hash_salts); |
366 | for (idx_t r = 0; r < groups.size(); r++) { |
367 | auto element = group_hashes[r]; |
368 | D_ASSERT((element & bitmask) == (element % capacity)); |
369 | ht_offsets_ptr[r] = element & bitmask; |
370 | hash_salts_ptr[r] = element >> hash_prefix_shift; |
371 | } |
372 | // we start out with all entries [0, 1, 2, ..., groups.size()] |
373 | const SelectionVector *sel_vector = FlatVector::IncrementalSelectionVector(); |
374 | |
375 | // Make a chunk that references the groups and the hashes and convert to unified format |
376 | if (state.group_chunk.ColumnCount() == 0) { |
377 | state.group_chunk.InitializeEmpty(types: layout.GetTypes()); |
378 | } |
379 | D_ASSERT(state.group_chunk.ColumnCount() == layout.GetTypes().size()); |
380 | for (idx_t grp_idx = 0; grp_idx < groups.ColumnCount(); grp_idx++) { |
381 | state.group_chunk.data[grp_idx].Reference(other&: groups.data[grp_idx]); |
382 | } |
383 | state.group_chunk.data[groups.ColumnCount()].Reference(other&: group_hashes_v); |
384 | state.group_chunk.SetCardinality(groups); |
385 | |
386 | // convert all vectors to unified format |
387 | if (!state.chunk_state_initialized) { |
388 | data_collection->InitializeAppend(chunk_state&: state.chunk_state); |
389 | state.chunk_state_initialized = true; |
390 | } |
391 | TupleDataCollection::ToUnifiedFormat(chunk_state&: state.chunk_state, new_chunk&: state.group_chunk); |
392 | if (!state.group_data) { |
393 | state.group_data = make_unsafe_uniq_array<UnifiedVectorFormat>(n: state.group_chunk.ColumnCount()); |
394 | } |
395 | TupleDataCollection::GetVectorData(chunk_state: state.chunk_state, result: state.group_data.get()); |
396 | |
397 | idx_t new_group_count = 0; |
398 | idx_t remaining_entries = groups.size(); |
399 | while (remaining_entries > 0) { |
400 | idx_t new_entry_count = 0; |
401 | idx_t need_compare_count = 0; |
402 | idx_t no_match_count = 0; |
403 | |
404 | // For each remaining entry, figure out whether or not it belongs to a full or empty group |
405 | for (idx_t i = 0; i < remaining_entries; i++) { |
406 | const idx_t index = sel_vector->get_index(idx: i); |
407 | auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]); |
408 | if (ht_entry.page_nr == 0) { // Cell is unoccupied (we use page number 0 as a "unused marker") |
409 | D_ASSERT(group_hashes[index] >> hash_prefix_shift <= NumericLimits<uint16_t>::Maximum()); |
410 | D_ASSERT(payload_hds_ptrs.size() < NumericLimits<uint32_t>::Maximum()); |
411 | |
412 | // Set page nr to 1 for now to mark it as occupied (will be corrected later) and set the salt |
413 | ht_entry.page_nr = 1; |
414 | ht_entry.salt = group_hashes[index] >> hash_prefix_shift; |
415 | |
416 | // Update selection lists for outer loops |
417 | state.empty_vector.set_index(idx: new_entry_count++, loc: index); |
418 | new_groups_out.set_index(idx: new_group_count++, loc: index); |
419 | } else { // Cell is occupied: Compare salts |
420 | if (ht_entry.salt == hash_salts_ptr[index]) { |
421 | state.group_compare_vector.set_index(idx: need_compare_count++, loc: index); |
422 | } else { |
423 | state.no_match_vector.set_index(idx: no_match_count++, loc: index); |
424 | } |
425 | } |
426 | } |
427 | |
428 | if (new_entry_count != 0) { |
429 | // Append everything that belongs to an empty group |
430 | data_collection->AppendUnified(pin_state&: td_pin_state, chunk_state&: state.chunk_state, new_chunk&: state.group_chunk, append_sel: state.empty_vector, |
431 | append_count: new_entry_count); |
432 | RowOperations::InitializeStates(layout, addresses&: state.chunk_state.row_locations, |
433 | sel: *FlatVector::IncrementalSelectionVector(), count: new_entry_count); |
434 | |
435 | // Get the pointers to the (possibly) newly created blocks of the data collection |
436 | idx_t block_id = payload_hds_ptrs.empty() ? 0 : payload_hds_ptrs.size() - 1; |
437 | UpdateBlockPointers(); |
438 | auto block_pointer = payload_hds_ptrs[block_id]; |
439 | auto block_end = block_pointer + tuples_per_block * tuple_size; |
440 | |
441 | // Set the page nrs/offsets in the 1st part of the HT now that the data has been appended |
442 | const auto row_locations = FlatVector::GetData<data_ptr_t>(vector&: state.chunk_state.row_locations); |
443 | for (idx_t new_entry_idx = 0; new_entry_idx < new_entry_count; new_entry_idx++) { |
444 | const auto &row_location = row_locations[new_entry_idx]; |
445 | if (row_location > block_end || row_location < block_pointer) { |
446 | block_id++; |
447 | D_ASSERT(block_id < payload_hds_ptrs.size()); |
448 | block_pointer = payload_hds_ptrs[block_id]; |
449 | block_end = block_pointer + tuples_per_block * tuple_size; |
450 | } |
451 | D_ASSERT(row_location >= block_pointer && row_location < block_end); |
452 | D_ASSERT((row_location - block_pointer) % tuple_size == 0); |
453 | const auto index = state.empty_vector.get_index(idx: new_entry_idx); |
454 | auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]); |
455 | ht_entry.page_nr = block_id + 1; |
456 | ht_entry.page_offset = (row_location - block_pointer) / tuple_size; |
457 | addresses[index] = row_location; |
458 | } |
459 | } |
460 | |
461 | if (need_compare_count != 0) { |
462 | // Get the pointers to the rows that need to be compared |
463 | for (idx_t need_compare_idx = 0; need_compare_idx < need_compare_count; need_compare_idx++) { |
464 | const auto index = state.group_compare_vector.get_index(idx: need_compare_idx); |
465 | const auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]); |
466 | auto page_ptr = payload_hds_ptrs[ht_entry.page_nr - 1]; |
467 | auto page_offset = ht_entry.page_offset * tuple_size; |
468 | addresses[index] = page_ptr + page_offset; |
469 | } |
470 | |
471 | // Perform group comparisons |
472 | RowOperations::Match(columns&: state.group_chunk, col_data: state.group_data.get(), layout, rows&: addresses_v, predicates, |
473 | sel&: state.group_compare_vector, count: need_compare_count, no_match: &state.no_match_vector, |
474 | no_match_count); |
475 | } |
476 | |
477 | // Linear probing: each of the entries that do not match move to the next entry in the HT |
478 | for (idx_t i = 0; i < no_match_count; i++) { |
479 | idx_t index = state.no_match_vector.get_index(idx: i); |
480 | ht_offsets_ptr[index]++; |
481 | if (ht_offsets_ptr[index] >= capacity) { |
482 | ht_offsets_ptr[index] = 0; |
483 | } |
484 | } |
485 | sel_vector = &state.no_match_vector; |
486 | remaining_entries = no_match_count; |
487 | } |
488 | |
489 | return new_group_count; |
490 | } |
491 | |
492 | void GroupedAggregateHashTable::UpdateBlockPointers() { |
493 | for (const auto &id_and_handle : td_pin_state.row_handles) { |
494 | const auto &id = id_and_handle.first; |
495 | const auto &handle = id_and_handle.second; |
496 | if (payload_hds_ptrs.empty() || id > payload_hds_ptrs.size() - 1) { |
497 | payload_hds_ptrs.resize(new_size: id + 1); |
498 | } |
499 | payload_hds_ptrs[id] = handle.Ptr(); |
500 | } |
501 | } |
502 | |
503 | // this is to support distinct aggregations where we need to record whether we |
504 | // have already seen a value for a group |
505 | idx_t GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups, |
506 | Vector &group_hashes, Vector &addresses_out, |
507 | SelectionVector &new_groups_out) { |
508 | switch (entry_type) { |
509 | case HtEntryType::HT_WIDTH_64: |
510 | return FindOrCreateGroupsInternal<aggr_ht_entry_64>(state, groups, group_hashes_v&: group_hashes, addresses_v&: addresses_out, new_groups_out); |
511 | case HtEntryType::HT_WIDTH_32: |
512 | return FindOrCreateGroupsInternal<aggr_ht_entry_32>(state, groups, group_hashes_v&: group_hashes, addresses_v&: addresses_out, new_groups_out); |
513 | default: |
514 | throw InternalException("Unknown HT entry width" ); |
515 | } |
516 | } |
517 | |
518 | void GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups, |
519 | Vector &addresses) { |
520 | // create a dummy new_groups sel vector |
521 | FindOrCreateGroups(state, groups, addresses_out&: addresses, new_groups_out&: state.new_groups); |
522 | } |
523 | |
524 | idx_t GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups, |
525 | Vector &addresses_out, SelectionVector &new_groups_out) { |
526 | Vector hashes(LogicalType::HASH); |
527 | groups.Hash(result&: hashes); |
528 | return FindOrCreateGroups(state, groups, group_hashes&: hashes, addresses_out, new_groups_out); |
529 | } |
530 | |
531 | struct FlushMoveState { |
532 | explicit FlushMoveState(TupleDataCollection &collection_p) |
533 | : collection(collection_p), hashes(LogicalType::HASH), group_addresses(LogicalType::POINTER), |
534 | new_groups_sel(STANDARD_VECTOR_SIZE) { |
535 | const auto &layout = collection.GetLayout(); |
536 | vector<column_t> column_ids; |
537 | column_ids.reserve(n: layout.ColumnCount() - 1); |
538 | for (idx_t col_idx = 0; col_idx < layout.ColumnCount() - 1; col_idx++) { |
539 | column_ids.emplace_back(args&: col_idx); |
540 | } |
541 | // FIXME DESTROY_AFTER_DONE if we make it possible to pass a selection vector to RowOperations::DestroyStates? |
542 | collection.InitializeScan(state&: scan_state, column_ids, properties: TupleDataPinProperties::UNPIN_AFTER_DONE); |
543 | collection.InitializeScanChunk(state&: scan_state, chunk&: groups); |
544 | hash_col_idx = layout.ColumnCount() - 1; |
545 | } |
546 | |
547 | bool Scan(); |
548 | |
549 | TupleDataCollection &collection; |
550 | TupleDataScanState scan_state; |
551 | DataChunk groups; |
552 | |
553 | idx_t hash_col_idx; |
554 | Vector hashes; |
555 | |
556 | AggregateHTAppendState append_state; |
557 | Vector group_addresses; |
558 | SelectionVector new_groups_sel; |
559 | }; |
560 | |
561 | bool FlushMoveState::Scan() { |
562 | if (collection.Scan(state&: scan_state, result&: groups)) { |
563 | collection.Gather(row_locations&: scan_state.chunk_state.row_locations, sel: *FlatVector::IncrementalSelectionVector(), |
564 | scan_count: groups.size(), column_id: hash_col_idx, result&: hashes, target_sel: *FlatVector::IncrementalSelectionVector()); |
565 | return true; |
566 | } |
567 | |
568 | collection.FinalizePinState(pin_state&: scan_state.pin_state); |
569 | return false; |
570 | } |
571 | |
572 | void GroupedAggregateHashTable::Combine(GroupedAggregateHashTable &other) { |
573 | D_ASSERT(!is_finalized); |
574 | |
575 | D_ASSERT(other.layout.GetAggrWidth() == layout.GetAggrWidth()); |
576 | D_ASSERT(other.layout.GetDataWidth() == layout.GetDataWidth()); |
577 | D_ASSERT(other.layout.GetRowWidth() == layout.GetRowWidth()); |
578 | |
579 | if (other.Count() == 0) { |
580 | return; |
581 | } |
582 | |
583 | FlushMoveState state(*other.data_collection); |
584 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
585 | while (state.Scan()) { |
586 | FindOrCreateGroups(state&: state.append_state, groups&: state.groups, group_hashes&: state.hashes, addresses_out&: state.group_addresses, new_groups_out&: state.new_groups_sel); |
587 | RowOperations::CombineStates(state&: row_state, layout, sources&: state.scan_state.chunk_state.row_locations, |
588 | targets&: state.group_addresses, count: state.groups.size()); |
589 | } |
590 | |
591 | Verify(); |
592 | } |
593 | |
594 | void GroupedAggregateHashTable::Partition(vector<GroupedAggregateHashTable *> &partition_hts, idx_t radix_bits) { |
595 | const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); |
596 | D_ASSERT(partition_hts.size() == num_partitions); |
597 | |
598 | // Partition the data |
599 | auto partitioned_data = |
600 | make_uniq<RadixPartitionedTupleData>(args&: buffer_manager, args&: layout, args&: radix_bits, args: layout.ColumnCount() - 1); |
601 | partitioned_data->Partition(source&: *data_collection, properties: TupleDataPinProperties::KEEP_EVERYTHING_PINNED); |
602 | D_ASSERT(partitioned_data->GetPartitions().size() == num_partitions); |
603 | |
604 | // Move the partitioned data collections to the partitioned hash tables and initialize the 1st part of the HT |
605 | auto &partitions = partitioned_data->GetPartitions(); |
606 | for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { |
607 | auto &partition_ht = *partition_hts[partition_idx]; |
608 | partition_ht.data_collection = std::move(partitions[partition_idx]); |
609 | partition_ht.aggregate_allocator = aggregate_allocator; |
610 | partition_ht.InitializeFirstPart(); |
611 | partition_ht.Verify(); |
612 | } |
613 | } |
614 | |
615 | void GroupedAggregateHashTable::InitializeFirstPart() { |
616 | data_collection->GetBlockPointers(block_pointers&: payload_hds_ptrs); |
617 | auto size = MaxValue<idx_t>(a: NextPowerOfTwo(v: Count() * 2L), b: capacity); |
618 | switch (entry_type) { |
619 | case HtEntryType::HT_WIDTH_64: |
620 | Resize<aggr_ht_entry_64>(size); |
621 | break; |
622 | case HtEntryType::HT_WIDTH_32: |
623 | Resize<aggr_ht_entry_32>(size); |
624 | break; |
625 | default: |
626 | throw InternalException("Unknown HT entry width" ); |
627 | } |
628 | } |
629 | |
630 | idx_t GroupedAggregateHashTable::Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, |
631 | DataChunk &result) { |
632 | data_collection->Scan(gstate, lstate, result); |
633 | |
634 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
635 | const auto group_cols = layout.ColumnCount() - 1; |
636 | RowOperations::FinalizeStates(state&: row_state, layout, addresses&: lstate.chunk_state.row_locations, result, aggr_idx: group_cols); |
637 | |
638 | return result.size(); |
639 | } |
640 | |
641 | void GroupedAggregateHashTable::Finalize() { |
642 | if (is_finalized) { |
643 | return; |
644 | } |
645 | |
646 | // Early release hashes (not needed for partition/scan) and data collection (will be pinned again when scanning) |
647 | hashes_hdl.Destroy(); |
648 | data_collection->FinalizePinState(pin_state&: td_pin_state); |
649 | data_collection->Unpin(); |
650 | |
651 | is_finalized = true; |
652 | } |
653 | |
654 | } // namespace duckdb |
655 | |