1#include "duckdb/execution/aggregate_hashtable.hpp"
2
3#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
4#include "duckdb/common/algorithm.hpp"
5#include "duckdb/common/exception.hpp"
6#include "duckdb/common/radix_partitioning.hpp"
7#include "duckdb/common/row_operations/row_operations.hpp"
8#include "duckdb/common/types/null_value.hpp"
9#include "duckdb/common/types/row/tuple_data_iterator.hpp"
10#include "duckdb/common/vector_operations/unary_executor.hpp"
11#include "duckdb/common/vector_operations/vector_operations.hpp"
12#include "duckdb/execution/expression_executor.hpp"
13#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
14#include "duckdb/storage/buffer_manager.hpp"
15
16#include <cmath>
17
18namespace duckdb {
19
20using ValidityBytes = TupleDataLayout::ValidityBytes;
21
22GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator,
23 vector<LogicalType> group_types, vector<LogicalType> payload_types,
24 const vector<BoundAggregateExpression *> &bindings,
25 HtEntryType entry_type, idx_t initial_capacity)
26 : GroupedAggregateHashTable(context, allocator, std::move(group_types), std::move(payload_types),
27 AggregateObject::CreateAggregateObjects(bindings), entry_type, initial_capacity) {
28}
29
30GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator,
31 vector<LogicalType> group_types)
32 : GroupedAggregateHashTable(context, allocator, std::move(group_types), {}, vector<AggregateObject>()) {
33}
34
35AggregateHTAppendState::AggregateHTAppendState()
36 : ht_offsets(LogicalTypeId::BIGINT), hash_salts(LogicalTypeId::SMALLINT),
37 group_compare_vector(STANDARD_VECTOR_SIZE), no_match_vector(STANDARD_VECTOR_SIZE),
38 empty_vector(STANDARD_VECTOR_SIZE), new_groups(STANDARD_VECTOR_SIZE), addresses(LogicalType::POINTER),
39 chunk_state_initialized(false) {
40}
41
42GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator,
43 vector<LogicalType> group_types_p,
44 vector<LogicalType> payload_types_p,
45 vector<AggregateObject> aggregate_objects_p,
46 HtEntryType entry_type, idx_t initial_capacity)
47 : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)),
48 entry_type(entry_type), capacity(0), is_finalized(false),
49 aggregate_allocator(make_shared<ArenaAllocator>(args&: allocator)) {
50 // Append hash column to the end and initialise the row layout
51 group_types_p.emplace_back(args: LogicalType::HASH);
52 layout.Initialize(types_p: std::move(group_types_p), aggregates_p: std::move(aggregate_objects_p));
53 tuple_size = layout.GetRowWidth();
54 tuples_per_block = Storage::BLOCK_SIZE / tuple_size;
55
56 // HT layout
57 hash_offset = layout.GetOffsets()[layout.ColumnCount() - 1];
58 data_collection = make_uniq<TupleDataCollection>(args&: buffer_manager, args&: layout);
59 data_collection->InitializeAppend(pin_state&: td_pin_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED);
60
61 hashes_hdl = buffer_manager.Allocate(block_size: Storage::BLOCK_SIZE);
62 hashes_hdl_ptr = hashes_hdl.Ptr();
63
64 switch (entry_type) {
65 case HtEntryType::HT_WIDTH_64: {
66 hash_prefix_shift = (HASH_WIDTH - sizeof(aggr_ht_entry_64::salt)) * 8;
67 Resize<aggr_ht_entry_64>(size: initial_capacity);
68 break;
69 }
70 case HtEntryType::HT_WIDTH_32: {
71 hash_prefix_shift = (HASH_WIDTH - sizeof(aggr_ht_entry_32::salt)) * 8;
72 Resize<aggr_ht_entry_32>(size: initial_capacity);
73 break;
74 }
75 default:
76 throw InternalException("Unknown HT entry width");
77 }
78
79 predicates.resize(new_size: layout.ColumnCount() - 1, x: ExpressionType::COMPARE_EQUAL);
80}
81
82GroupedAggregateHashTable::~GroupedAggregateHashTable() {
83 Destroy();
84}
85
86void GroupedAggregateHashTable::Destroy() {
87 if (data_collection->Count() == 0) {
88 return;
89 }
90
91 // Check if there is an aggregate with a destructor
92 bool has_destructor = false;
93 for (auto &aggr : layout.GetAggregates()) {
94 if (aggr.function.destructor) {
95 has_destructor = true;
96 }
97 }
98 if (!has_destructor) {
99 return;
100 }
101
102 // There are aggregates with destructors: Call the destructor for each of the aggregates
103 RowOperationsState state(aggregate_allocator->GetAllocator());
104 TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false);
105 auto &row_locations = iterator.GetChunkState().row_locations;
106 do {
107 RowOperations::DestroyStates(state, layout, addresses&: row_locations, count: iterator.GetCurrentChunkCount());
108 } while (iterator.Next());
109 data_collection->Reset();
110}
111
112template <class ENTRY>
113void GroupedAggregateHashTable::VerifyInternal() {
114 auto hashes_ptr = (ENTRY *)hashes_hdl_ptr;
115 idx_t count = 0;
116 for (idx_t i = 0; i < capacity; i++) {
117 if (hashes_ptr[i].page_nr > 0) {
118 D_ASSERT(hashes_ptr[i].page_offset < tuples_per_block);
119 D_ASSERT(hashes_ptr[i].page_nr <= payload_hds_ptrs.size());
120 auto ptr = payload_hds_ptrs[hashes_ptr[i].page_nr - 1] + ((hashes_ptr[i].page_offset) * tuple_size);
121 auto hash = Load<hash_t>(ptr + hash_offset);
122 D_ASSERT((hashes_ptr[i].salt) == (hash >> hash_prefix_shift));
123
124 count++;
125 }
126 }
127 (void)count;
128 D_ASSERT(count == Count());
129}
130
131idx_t GroupedAggregateHashTable::InitialCapacity() {
132 return STANDARD_VECTOR_SIZE * 2ULL;
133}
134
135idx_t GroupedAggregateHashTable::GetMaxCapacity(HtEntryType entry_type, idx_t tuple_size) {
136 idx_t max_pages;
137 idx_t max_tuples;
138
139 switch (entry_type) {
140 case HtEntryType::HT_WIDTH_32:
141 max_pages = NumericLimits<uint8_t>::Maximum();
142 max_tuples = NumericLimits<uint16_t>::Maximum();
143 break;
144 case HtEntryType::HT_WIDTH_64:
145 max_pages = NumericLimits<uint32_t>::Maximum();
146 max_tuples = NumericLimits<uint16_t>::Maximum();
147 break;
148 default:
149 throw InternalException("Unsupported hash table width");
150 }
151
152 return max_pages * MinValue(a: max_tuples, b: (idx_t)Storage::BLOCK_SIZE / tuple_size);
153}
154
155idx_t GroupedAggregateHashTable::MaxCapacity() {
156 return GetMaxCapacity(entry_type, tuple_size);
157}
158
159void GroupedAggregateHashTable::Verify() {
160#ifdef DEBUG
161 switch (entry_type) {
162 case HtEntryType::HT_WIDTH_32:
163 VerifyInternal<aggr_ht_entry_32>();
164 break;
165 case HtEntryType::HT_WIDTH_64:
166 VerifyInternal<aggr_ht_entry_64>();
167 break;
168 }
169#endif
170}
171
172template <class ENTRY>
173void GroupedAggregateHashTable::Resize(idx_t size) {
174 D_ASSERT(!is_finalized);
175 D_ASSERT(size >= STANDARD_VECTOR_SIZE);
176 D_ASSERT(IsPowerOfTwo(size));
177
178 if (size < capacity) {
179 throw InternalException("Cannot downsize a hash table!");
180 }
181 capacity = size;
182
183 bitmask = capacity - 1;
184 const auto byte_size = capacity * sizeof(ENTRY);
185 if (byte_size > (idx_t)Storage::BLOCK_SIZE) {
186 hashes_hdl = buffer_manager.Allocate(block_size: byte_size);
187 hashes_hdl_ptr = hashes_hdl.Ptr();
188 }
189 memset(s: hashes_hdl_ptr, c: 0, n: byte_size);
190
191 if (Count() != 0) {
192 D_ASSERT(!payload_hds_ptrs.empty());
193 auto hashes_arr = (ENTRY *)hashes_hdl_ptr;
194
195 idx_t block_id = 0;
196 auto block_pointer = payload_hds_ptrs[block_id];
197 auto block_end = block_pointer + tuples_per_block * tuple_size;
198
199 TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::ALREADY_PINNED, false);
200 const auto row_locations = iterator.GetRowLocations();
201 do {
202 for (idx_t i = 0; i < iterator.GetCurrentChunkCount(); i++) {
203 const auto &row_location = row_locations[i];
204 if (row_location > block_end || row_location < block_pointer) {
205 block_id++;
206 D_ASSERT(block_id < payload_hds_ptrs.size());
207 block_pointer = payload_hds_ptrs[block_id];
208 block_end = block_pointer + tuples_per_block * tuple_size;
209 }
210 D_ASSERT(row_location >= block_pointer && row_location < block_end);
211 D_ASSERT((row_location - block_pointer) % tuple_size == 0);
212
213 const auto hash = Load<hash_t>(ptr: row_location + hash_offset);
214 D_ASSERT((hash & bitmask) == (hash % capacity));
215 D_ASSERT(hash >> hash_prefix_shift <= NumericLimits<uint16_t>::Maximum());
216
217 auto entry_idx = (idx_t)hash & bitmask;
218 while (hashes_arr[entry_idx].page_nr > 0) {
219 entry_idx++;
220 if (entry_idx >= capacity) {
221 entry_idx = 0;
222 }
223 }
224
225 auto &ht_entry = hashes_arr[entry_idx];
226 D_ASSERT(!ht_entry.page_nr);
227 ht_entry.salt = hash >> hash_prefix_shift;
228 ht_entry.page_nr = block_id + 1;
229 ht_entry.page_offset = (row_location - block_pointer) / tuple_size;
230 }
231 } while (iterator.Next());
232 }
233
234 Verify();
235}
236
237idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, DataChunk &payload,
238 AggregateType filter) {
239 unsafe_vector<idx_t> aggregate_filter;
240
241 auto &aggregates = layout.GetAggregates();
242 for (idx_t i = 0; i < aggregates.size(); i++) {
243 auto &aggregate = aggregates[i];
244 if (aggregate.aggr_type == filter) {
245 aggregate_filter.push_back(x: i);
246 }
247 }
248 return AddChunk(state, groups, payload, filter: aggregate_filter);
249}
250
251idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, DataChunk &payload,
252 const unsafe_vector<idx_t> &filter) {
253 Vector hashes(LogicalType::HASH);
254 groups.Hash(result&: hashes);
255
256 return AddChunk(state, groups, group_hashes&: hashes, payload, filter);
257}
258
259idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, Vector &group_hashes,
260 DataChunk &payload, const unsafe_vector<idx_t> &filter) {
261 D_ASSERT(!is_finalized);
262 if (groups.size() == 0) {
263 return 0;
264 }
265
266#ifdef DEBUG
267 D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount());
268 for (idx_t i = 0; i < groups.ColumnCount(); i++) {
269 D_ASSERT(groups.GetTypes()[i] == layout.GetTypes()[i]);
270 }
271#endif
272
273 auto new_group_count = FindOrCreateGroups(state, groups, group_hashes, addresses_out&: state.addresses, new_groups_out&: state.new_groups);
274 VectorOperations::AddInPlace(left&: state.addresses, delta: layout.GetAggrOffset(), count: payload.size());
275
276 // Now every cell has an entry, update the aggregates
277 auto &aggregates = layout.GetAggregates();
278 idx_t filter_idx = 0;
279 idx_t payload_idx = 0;
280 RowOperationsState row_state(aggregate_allocator->GetAllocator());
281 for (idx_t i = 0; i < aggregates.size(); i++) {
282 auto &aggr = aggregates[i];
283 if (filter_idx >= filter.size() || i < filter[filter_idx]) {
284 // Skip all the aggregates that are not in the filter
285 payload_idx += aggr.child_count;
286 VectorOperations::AddInPlace(left&: state.addresses, delta: aggr.payload_size, count: payload.size());
287 continue;
288 }
289 D_ASSERT(i == filter[filter_idx]);
290
291 if (aggr.aggr_type != AggregateType::DISTINCT && aggr.filter) {
292 RowOperations::UpdateFilteredStates(state&: row_state, filter_data&: filter_set.GetFilterData(aggr_idx: i), aggr, addresses&: state.addresses, payload,
293 arg_idx: payload_idx);
294 } else {
295 RowOperations::UpdateStates(state&: row_state, aggr, addresses&: state.addresses, payload, arg_idx: payload_idx, count: payload.size());
296 }
297
298 // Move to the next aggregate
299 payload_idx += aggr.child_count;
300 VectorOperations::AddInPlace(left&: state.addresses, delta: aggr.payload_size, count: payload.size());
301 filter_idx++;
302 }
303
304 Verify();
305 return new_group_count;
306}
307
308void GroupedAggregateHashTable::FetchAggregates(DataChunk &groups, DataChunk &result) {
309 groups.Verify();
310 D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount());
311 for (idx_t i = 0; i < result.ColumnCount(); i++) {
312 D_ASSERT(result.data[i].GetType() == payload_types[i]);
313 }
314 result.SetCardinality(groups);
315 if (groups.size() == 0) {
316 return;
317 }
318
319 // find the groups associated with the addresses
320 // FIXME: this should not use the FindOrCreateGroups, creating them is unnecessary
321 AggregateHTAppendState append_state;
322 Vector addresses(LogicalType::POINTER);
323 FindOrCreateGroups(state&: append_state, groups, addresses_out&: addresses);
324 // now fetch the aggregates
325 RowOperationsState row_state(aggregate_allocator->GetAllocator());
326 RowOperations::FinalizeStates(state&: row_state, layout, addresses, result, aggr_idx: 0);
327}
328
329idx_t GroupedAggregateHashTable::ResizeThreshold() {
330 return capacity / LOAD_FACTOR;
331}
332
333template <class ENTRY>
334idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(AggregateHTAppendState &state, DataChunk &groups,
335 Vector &group_hashes_v, Vector &addresses_v,
336 SelectionVector &new_groups_out) {
337 D_ASSERT(!is_finalized);
338 D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount());
339 D_ASSERT(group_hashes_v.GetType() == LogicalType::HASH);
340 D_ASSERT(state.ht_offsets.GetVectorType() == VectorType::FLAT_VECTOR);
341 D_ASSERT(state.ht_offsets.GetType() == LogicalType::BIGINT);
342 D_ASSERT(addresses_v.GetType() == LogicalType::POINTER);
343 D_ASSERT(state.hash_salts.GetType() == LogicalType::SMALLINT);
344
345 if (Count() + groups.size() > MaxCapacity()) {
346 throw InternalException("Hash table capacity reached");
347 }
348
349 // Resize at 50% capacity, also need to fit the entire vector
350 if (capacity - Count() <= groups.size() || Count() > ResizeThreshold()) {
351 Verify();
352 Resize<ENTRY>(capacity * 2);
353 }
354 D_ASSERT(capacity - Count() >= groups.size()); // we need to be able to fit at least one vector of data
355
356 group_hashes_v.Flatten(count: groups.size());
357 auto group_hashes = FlatVector::GetData<hash_t>(vector&: group_hashes_v);
358
359 addresses_v.Flatten(count: groups.size());
360 auto addresses = FlatVector::GetData<data_ptr_t>(vector&: addresses_v);
361
362 // Compute the entry in the table based on the hash using a modulo,
363 // and precompute the hash salts for faster comparison below
364 auto ht_offsets_ptr = FlatVector::GetData<uint64_t>(vector&: state.ht_offsets);
365 auto hash_salts_ptr = FlatVector::GetData<uint16_t>(vector&: state.hash_salts);
366 for (idx_t r = 0; r < groups.size(); r++) {
367 auto element = group_hashes[r];
368 D_ASSERT((element & bitmask) == (element % capacity));
369 ht_offsets_ptr[r] = element & bitmask;
370 hash_salts_ptr[r] = element >> hash_prefix_shift;
371 }
372 // we start out with all entries [0, 1, 2, ..., groups.size()]
373 const SelectionVector *sel_vector = FlatVector::IncrementalSelectionVector();
374
375 // Make a chunk that references the groups and the hashes and convert to unified format
376 if (state.group_chunk.ColumnCount() == 0) {
377 state.group_chunk.InitializeEmpty(types: layout.GetTypes());
378 }
379 D_ASSERT(state.group_chunk.ColumnCount() == layout.GetTypes().size());
380 for (idx_t grp_idx = 0; grp_idx < groups.ColumnCount(); grp_idx++) {
381 state.group_chunk.data[grp_idx].Reference(other&: groups.data[grp_idx]);
382 }
383 state.group_chunk.data[groups.ColumnCount()].Reference(other&: group_hashes_v);
384 state.group_chunk.SetCardinality(groups);
385
386 // convert all vectors to unified format
387 if (!state.chunk_state_initialized) {
388 data_collection->InitializeAppend(chunk_state&: state.chunk_state);
389 state.chunk_state_initialized = true;
390 }
391 TupleDataCollection::ToUnifiedFormat(chunk_state&: state.chunk_state, new_chunk&: state.group_chunk);
392 if (!state.group_data) {
393 state.group_data = make_unsafe_uniq_array<UnifiedVectorFormat>(n: state.group_chunk.ColumnCount());
394 }
395 TupleDataCollection::GetVectorData(chunk_state: state.chunk_state, result: state.group_data.get());
396
397 idx_t new_group_count = 0;
398 idx_t remaining_entries = groups.size();
399 while (remaining_entries > 0) {
400 idx_t new_entry_count = 0;
401 idx_t need_compare_count = 0;
402 idx_t no_match_count = 0;
403
404 // For each remaining entry, figure out whether or not it belongs to a full or empty group
405 for (idx_t i = 0; i < remaining_entries; i++) {
406 const idx_t index = sel_vector->get_index(idx: i);
407 auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]);
408 if (ht_entry.page_nr == 0) { // Cell is unoccupied (we use page number 0 as a "unused marker")
409 D_ASSERT(group_hashes[index] >> hash_prefix_shift <= NumericLimits<uint16_t>::Maximum());
410 D_ASSERT(payload_hds_ptrs.size() < NumericLimits<uint32_t>::Maximum());
411
412 // Set page nr to 1 for now to mark it as occupied (will be corrected later) and set the salt
413 ht_entry.page_nr = 1;
414 ht_entry.salt = group_hashes[index] >> hash_prefix_shift;
415
416 // Update selection lists for outer loops
417 state.empty_vector.set_index(idx: new_entry_count++, loc: index);
418 new_groups_out.set_index(idx: new_group_count++, loc: index);
419 } else { // Cell is occupied: Compare salts
420 if (ht_entry.salt == hash_salts_ptr[index]) {
421 state.group_compare_vector.set_index(idx: need_compare_count++, loc: index);
422 } else {
423 state.no_match_vector.set_index(idx: no_match_count++, loc: index);
424 }
425 }
426 }
427
428 if (new_entry_count != 0) {
429 // Append everything that belongs to an empty group
430 data_collection->AppendUnified(pin_state&: td_pin_state, chunk_state&: state.chunk_state, new_chunk&: state.group_chunk, append_sel: state.empty_vector,
431 append_count: new_entry_count);
432 RowOperations::InitializeStates(layout, addresses&: state.chunk_state.row_locations,
433 sel: *FlatVector::IncrementalSelectionVector(), count: new_entry_count);
434
435 // Get the pointers to the (possibly) newly created blocks of the data collection
436 idx_t block_id = payload_hds_ptrs.empty() ? 0 : payload_hds_ptrs.size() - 1;
437 UpdateBlockPointers();
438 auto block_pointer = payload_hds_ptrs[block_id];
439 auto block_end = block_pointer + tuples_per_block * tuple_size;
440
441 // Set the page nrs/offsets in the 1st part of the HT now that the data has been appended
442 const auto row_locations = FlatVector::GetData<data_ptr_t>(vector&: state.chunk_state.row_locations);
443 for (idx_t new_entry_idx = 0; new_entry_idx < new_entry_count; new_entry_idx++) {
444 const auto &row_location = row_locations[new_entry_idx];
445 if (row_location > block_end || row_location < block_pointer) {
446 block_id++;
447 D_ASSERT(block_id < payload_hds_ptrs.size());
448 block_pointer = payload_hds_ptrs[block_id];
449 block_end = block_pointer + tuples_per_block * tuple_size;
450 }
451 D_ASSERT(row_location >= block_pointer && row_location < block_end);
452 D_ASSERT((row_location - block_pointer) % tuple_size == 0);
453 const auto index = state.empty_vector.get_index(idx: new_entry_idx);
454 auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]);
455 ht_entry.page_nr = block_id + 1;
456 ht_entry.page_offset = (row_location - block_pointer) / tuple_size;
457 addresses[index] = row_location;
458 }
459 }
460
461 if (need_compare_count != 0) {
462 // Get the pointers to the rows that need to be compared
463 for (idx_t need_compare_idx = 0; need_compare_idx < need_compare_count; need_compare_idx++) {
464 const auto index = state.group_compare_vector.get_index(idx: need_compare_idx);
465 const auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]);
466 auto page_ptr = payload_hds_ptrs[ht_entry.page_nr - 1];
467 auto page_offset = ht_entry.page_offset * tuple_size;
468 addresses[index] = page_ptr + page_offset;
469 }
470
471 // Perform group comparisons
472 RowOperations::Match(columns&: state.group_chunk, col_data: state.group_data.get(), layout, rows&: addresses_v, predicates,
473 sel&: state.group_compare_vector, count: need_compare_count, no_match: &state.no_match_vector,
474 no_match_count);
475 }
476
477 // Linear probing: each of the entries that do not match move to the next entry in the HT
478 for (idx_t i = 0; i < no_match_count; i++) {
479 idx_t index = state.no_match_vector.get_index(idx: i);
480 ht_offsets_ptr[index]++;
481 if (ht_offsets_ptr[index] >= capacity) {
482 ht_offsets_ptr[index] = 0;
483 }
484 }
485 sel_vector = &state.no_match_vector;
486 remaining_entries = no_match_count;
487 }
488
489 return new_group_count;
490}
491
492void GroupedAggregateHashTable::UpdateBlockPointers() {
493 for (const auto &id_and_handle : td_pin_state.row_handles) {
494 const auto &id = id_and_handle.first;
495 const auto &handle = id_and_handle.second;
496 if (payload_hds_ptrs.empty() || id > payload_hds_ptrs.size() - 1) {
497 payload_hds_ptrs.resize(new_size: id + 1);
498 }
499 payload_hds_ptrs[id] = handle.Ptr();
500 }
501}
502
503// this is to support distinct aggregations where we need to record whether we
504// have already seen a value for a group
505idx_t GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups,
506 Vector &group_hashes, Vector &addresses_out,
507 SelectionVector &new_groups_out) {
508 switch (entry_type) {
509 case HtEntryType::HT_WIDTH_64:
510 return FindOrCreateGroupsInternal<aggr_ht_entry_64>(state, groups, group_hashes_v&: group_hashes, addresses_v&: addresses_out, new_groups_out);
511 case HtEntryType::HT_WIDTH_32:
512 return FindOrCreateGroupsInternal<aggr_ht_entry_32>(state, groups, group_hashes_v&: group_hashes, addresses_v&: addresses_out, new_groups_out);
513 default:
514 throw InternalException("Unknown HT entry width");
515 }
516}
517
518void GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups,
519 Vector &addresses) {
520 // create a dummy new_groups sel vector
521 FindOrCreateGroups(state, groups, addresses_out&: addresses, new_groups_out&: state.new_groups);
522}
523
524idx_t GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups,
525 Vector &addresses_out, SelectionVector &new_groups_out) {
526 Vector hashes(LogicalType::HASH);
527 groups.Hash(result&: hashes);
528 return FindOrCreateGroups(state, groups, group_hashes&: hashes, addresses_out, new_groups_out);
529}
530
531struct FlushMoveState {
532 explicit FlushMoveState(TupleDataCollection &collection_p)
533 : collection(collection_p), hashes(LogicalType::HASH), group_addresses(LogicalType::POINTER),
534 new_groups_sel(STANDARD_VECTOR_SIZE) {
535 const auto &layout = collection.GetLayout();
536 vector<column_t> column_ids;
537 column_ids.reserve(n: layout.ColumnCount() - 1);
538 for (idx_t col_idx = 0; col_idx < layout.ColumnCount() - 1; col_idx++) {
539 column_ids.emplace_back(args&: col_idx);
540 }
541 // FIXME DESTROY_AFTER_DONE if we make it possible to pass a selection vector to RowOperations::DestroyStates?
542 collection.InitializeScan(state&: scan_state, column_ids, properties: TupleDataPinProperties::UNPIN_AFTER_DONE);
543 collection.InitializeScanChunk(state&: scan_state, chunk&: groups);
544 hash_col_idx = layout.ColumnCount() - 1;
545 }
546
547 bool Scan();
548
549 TupleDataCollection &collection;
550 TupleDataScanState scan_state;
551 DataChunk groups;
552
553 idx_t hash_col_idx;
554 Vector hashes;
555
556 AggregateHTAppendState append_state;
557 Vector group_addresses;
558 SelectionVector new_groups_sel;
559};
560
561bool FlushMoveState::Scan() {
562 if (collection.Scan(state&: scan_state, result&: groups)) {
563 collection.Gather(row_locations&: scan_state.chunk_state.row_locations, sel: *FlatVector::IncrementalSelectionVector(),
564 scan_count: groups.size(), column_id: hash_col_idx, result&: hashes, target_sel: *FlatVector::IncrementalSelectionVector());
565 return true;
566 }
567
568 collection.FinalizePinState(pin_state&: scan_state.pin_state);
569 return false;
570}
571
572void GroupedAggregateHashTable::Combine(GroupedAggregateHashTable &other) {
573 D_ASSERT(!is_finalized);
574
575 D_ASSERT(other.layout.GetAggrWidth() == layout.GetAggrWidth());
576 D_ASSERT(other.layout.GetDataWidth() == layout.GetDataWidth());
577 D_ASSERT(other.layout.GetRowWidth() == layout.GetRowWidth());
578
579 if (other.Count() == 0) {
580 return;
581 }
582
583 FlushMoveState state(*other.data_collection);
584 RowOperationsState row_state(aggregate_allocator->GetAllocator());
585 while (state.Scan()) {
586 FindOrCreateGroups(state&: state.append_state, groups&: state.groups, group_hashes&: state.hashes, addresses_out&: state.group_addresses, new_groups_out&: state.new_groups_sel);
587 RowOperations::CombineStates(state&: row_state, layout, sources&: state.scan_state.chunk_state.row_locations,
588 targets&: state.group_addresses, count: state.groups.size());
589 }
590
591 Verify();
592}
593
594void GroupedAggregateHashTable::Partition(vector<GroupedAggregateHashTable *> &partition_hts, idx_t radix_bits) {
595 const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits);
596 D_ASSERT(partition_hts.size() == num_partitions);
597
598 // Partition the data
599 auto partitioned_data =
600 make_uniq<RadixPartitionedTupleData>(args&: buffer_manager, args&: layout, args&: radix_bits, args: layout.ColumnCount() - 1);
601 partitioned_data->Partition(source&: *data_collection, properties: TupleDataPinProperties::KEEP_EVERYTHING_PINNED);
602 D_ASSERT(partitioned_data->GetPartitions().size() == num_partitions);
603
604 // Move the partitioned data collections to the partitioned hash tables and initialize the 1st part of the HT
605 auto &partitions = partitioned_data->GetPartitions();
606 for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) {
607 auto &partition_ht = *partition_hts[partition_idx];
608 partition_ht.data_collection = std::move(partitions[partition_idx]);
609 partition_ht.aggregate_allocator = aggregate_allocator;
610 partition_ht.InitializeFirstPart();
611 partition_ht.Verify();
612 }
613}
614
615void GroupedAggregateHashTable::InitializeFirstPart() {
616 data_collection->GetBlockPointers(block_pointers&: payload_hds_ptrs);
617 auto size = MaxValue<idx_t>(a: NextPowerOfTwo(v: Count() * 2L), b: capacity);
618 switch (entry_type) {
619 case HtEntryType::HT_WIDTH_64:
620 Resize<aggr_ht_entry_64>(size);
621 break;
622 case HtEntryType::HT_WIDTH_32:
623 Resize<aggr_ht_entry_32>(size);
624 break;
625 default:
626 throw InternalException("Unknown HT entry width");
627 }
628}
629
630idx_t GroupedAggregateHashTable::Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate,
631 DataChunk &result) {
632 data_collection->Scan(gstate, lstate, result);
633
634 RowOperationsState row_state(aggregate_allocator->GetAllocator());
635 const auto group_cols = layout.ColumnCount() - 1;
636 RowOperations::FinalizeStates(state&: row_state, layout, addresses&: lstate.chunk_state.row_locations, result, aggr_idx: group_cols);
637
638 return result.size();
639}
640
641void GroupedAggregateHashTable::Finalize() {
642 if (is_finalized) {
643 return;
644 }
645
646 // Early release hashes (not needed for partition/scan) and data collection (will be pinned again when scanning)
647 hashes_hdl.Destroy();
648 data_collection->FinalizePinState(pin_state&: td_pin_state);
649 data_collection->Unpin();
650
651 is_finalized = true;
652}
653
654} // namespace duckdb
655