| 1 | #include "duckdb/execution/aggregate_hashtable.hpp" |
| 2 | |
| 3 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
| 4 | #include "duckdb/common/algorithm.hpp" |
| 5 | #include "duckdb/common/exception.hpp" |
| 6 | #include "duckdb/common/radix_partitioning.hpp" |
| 7 | #include "duckdb/common/row_operations/row_operations.hpp" |
| 8 | #include "duckdb/common/types/null_value.hpp" |
| 9 | #include "duckdb/common/types/row/tuple_data_iterator.hpp" |
| 10 | #include "duckdb/common/vector_operations/unary_executor.hpp" |
| 11 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 12 | #include "duckdb/execution/expression_executor.hpp" |
| 13 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 14 | #include "duckdb/storage/buffer_manager.hpp" |
| 15 | |
| 16 | #include <cmath> |
| 17 | |
| 18 | namespace duckdb { |
| 19 | |
| 20 | using ValidityBytes = TupleDataLayout::ValidityBytes; |
| 21 | |
| 22 | GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, |
| 23 | vector<LogicalType> group_types, vector<LogicalType> payload_types, |
| 24 | const vector<BoundAggregateExpression *> &bindings, |
| 25 | HtEntryType entry_type, idx_t initial_capacity) |
| 26 | : GroupedAggregateHashTable(context, allocator, std::move(group_types), std::move(payload_types), |
| 27 | AggregateObject::CreateAggregateObjects(bindings), entry_type, initial_capacity) { |
| 28 | } |
| 29 | |
| 30 | GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, |
| 31 | vector<LogicalType> group_types) |
| 32 | : GroupedAggregateHashTable(context, allocator, std::move(group_types), {}, vector<AggregateObject>()) { |
| 33 | } |
| 34 | |
| 35 | AggregateHTAppendState::AggregateHTAppendState() |
| 36 | : ht_offsets(LogicalTypeId::BIGINT), hash_salts(LogicalTypeId::SMALLINT), |
| 37 | group_compare_vector(STANDARD_VECTOR_SIZE), no_match_vector(STANDARD_VECTOR_SIZE), |
| 38 | empty_vector(STANDARD_VECTOR_SIZE), new_groups(STANDARD_VECTOR_SIZE), addresses(LogicalType::POINTER), |
| 39 | chunk_state_initialized(false) { |
| 40 | } |
| 41 | |
| 42 | GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, |
| 43 | vector<LogicalType> group_types_p, |
| 44 | vector<LogicalType> payload_types_p, |
| 45 | vector<AggregateObject> aggregate_objects_p, |
| 46 | HtEntryType entry_type, idx_t initial_capacity) |
| 47 | : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), |
| 48 | entry_type(entry_type), capacity(0), is_finalized(false), |
| 49 | aggregate_allocator(make_shared<ArenaAllocator>(args&: allocator)) { |
| 50 | // Append hash column to the end and initialise the row layout |
| 51 | group_types_p.emplace_back(args: LogicalType::HASH); |
| 52 | layout.Initialize(types_p: std::move(group_types_p), aggregates_p: std::move(aggregate_objects_p)); |
| 53 | tuple_size = layout.GetRowWidth(); |
| 54 | tuples_per_block = Storage::BLOCK_SIZE / tuple_size; |
| 55 | |
| 56 | // HT layout |
| 57 | hash_offset = layout.GetOffsets()[layout.ColumnCount() - 1]; |
| 58 | data_collection = make_uniq<TupleDataCollection>(args&: buffer_manager, args&: layout); |
| 59 | data_collection->InitializeAppend(pin_state&: td_pin_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED); |
| 60 | |
| 61 | hashes_hdl = buffer_manager.Allocate(block_size: Storage::BLOCK_SIZE); |
| 62 | hashes_hdl_ptr = hashes_hdl.Ptr(); |
| 63 | |
| 64 | switch (entry_type) { |
| 65 | case HtEntryType::HT_WIDTH_64: { |
| 66 | hash_prefix_shift = (HASH_WIDTH - sizeof(aggr_ht_entry_64::salt)) * 8; |
| 67 | Resize<aggr_ht_entry_64>(size: initial_capacity); |
| 68 | break; |
| 69 | } |
| 70 | case HtEntryType::HT_WIDTH_32: { |
| 71 | hash_prefix_shift = (HASH_WIDTH - sizeof(aggr_ht_entry_32::salt)) * 8; |
| 72 | Resize<aggr_ht_entry_32>(size: initial_capacity); |
| 73 | break; |
| 74 | } |
| 75 | default: |
| 76 | throw InternalException("Unknown HT entry width" ); |
| 77 | } |
| 78 | |
| 79 | predicates.resize(new_size: layout.ColumnCount() - 1, x: ExpressionType::COMPARE_EQUAL); |
| 80 | } |
| 81 | |
| 82 | GroupedAggregateHashTable::~GroupedAggregateHashTable() { |
| 83 | Destroy(); |
| 84 | } |
| 85 | |
| 86 | void GroupedAggregateHashTable::Destroy() { |
| 87 | if (data_collection->Count() == 0) { |
| 88 | return; |
| 89 | } |
| 90 | |
| 91 | // Check if there is an aggregate with a destructor |
| 92 | bool has_destructor = false; |
| 93 | for (auto &aggr : layout.GetAggregates()) { |
| 94 | if (aggr.function.destructor) { |
| 95 | has_destructor = true; |
| 96 | } |
| 97 | } |
| 98 | if (!has_destructor) { |
| 99 | return; |
| 100 | } |
| 101 | |
| 102 | // There are aggregates with destructors: Call the destructor for each of the aggregates |
| 103 | RowOperationsState state(aggregate_allocator->GetAllocator()); |
| 104 | TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); |
| 105 | auto &row_locations = iterator.GetChunkState().row_locations; |
| 106 | do { |
| 107 | RowOperations::DestroyStates(state, layout, addresses&: row_locations, count: iterator.GetCurrentChunkCount()); |
| 108 | } while (iterator.Next()); |
| 109 | data_collection->Reset(); |
| 110 | } |
| 111 | |
| 112 | template <class ENTRY> |
| 113 | void GroupedAggregateHashTable::VerifyInternal() { |
| 114 | auto hashes_ptr = (ENTRY *)hashes_hdl_ptr; |
| 115 | idx_t count = 0; |
| 116 | for (idx_t i = 0; i < capacity; i++) { |
| 117 | if (hashes_ptr[i].page_nr > 0) { |
| 118 | D_ASSERT(hashes_ptr[i].page_offset < tuples_per_block); |
| 119 | D_ASSERT(hashes_ptr[i].page_nr <= payload_hds_ptrs.size()); |
| 120 | auto ptr = payload_hds_ptrs[hashes_ptr[i].page_nr - 1] + ((hashes_ptr[i].page_offset) * tuple_size); |
| 121 | auto hash = Load<hash_t>(ptr + hash_offset); |
| 122 | D_ASSERT((hashes_ptr[i].salt) == (hash >> hash_prefix_shift)); |
| 123 | |
| 124 | count++; |
| 125 | } |
| 126 | } |
| 127 | (void)count; |
| 128 | D_ASSERT(count == Count()); |
| 129 | } |
| 130 | |
| 131 | idx_t GroupedAggregateHashTable::InitialCapacity() { |
| 132 | return STANDARD_VECTOR_SIZE * 2ULL; |
| 133 | } |
| 134 | |
| 135 | idx_t GroupedAggregateHashTable::GetMaxCapacity(HtEntryType entry_type, idx_t tuple_size) { |
| 136 | idx_t max_pages; |
| 137 | idx_t max_tuples; |
| 138 | |
| 139 | switch (entry_type) { |
| 140 | case HtEntryType::HT_WIDTH_32: |
| 141 | max_pages = NumericLimits<uint8_t>::Maximum(); |
| 142 | max_tuples = NumericLimits<uint16_t>::Maximum(); |
| 143 | break; |
| 144 | case HtEntryType::HT_WIDTH_64: |
| 145 | max_pages = NumericLimits<uint32_t>::Maximum(); |
| 146 | max_tuples = NumericLimits<uint16_t>::Maximum(); |
| 147 | break; |
| 148 | default: |
| 149 | throw InternalException("Unsupported hash table width" ); |
| 150 | } |
| 151 | |
| 152 | return max_pages * MinValue(a: max_tuples, b: (idx_t)Storage::BLOCK_SIZE / tuple_size); |
| 153 | } |
| 154 | |
| 155 | idx_t GroupedAggregateHashTable::MaxCapacity() { |
| 156 | return GetMaxCapacity(entry_type, tuple_size); |
| 157 | } |
| 158 | |
| 159 | void GroupedAggregateHashTable::Verify() { |
| 160 | #ifdef DEBUG |
| 161 | switch (entry_type) { |
| 162 | case HtEntryType::HT_WIDTH_32: |
| 163 | VerifyInternal<aggr_ht_entry_32>(); |
| 164 | break; |
| 165 | case HtEntryType::HT_WIDTH_64: |
| 166 | VerifyInternal<aggr_ht_entry_64>(); |
| 167 | break; |
| 168 | } |
| 169 | #endif |
| 170 | } |
| 171 | |
| 172 | template <class ENTRY> |
| 173 | void GroupedAggregateHashTable::Resize(idx_t size) { |
| 174 | D_ASSERT(!is_finalized); |
| 175 | D_ASSERT(size >= STANDARD_VECTOR_SIZE); |
| 176 | D_ASSERT(IsPowerOfTwo(size)); |
| 177 | |
| 178 | if (size < capacity) { |
| 179 | throw InternalException("Cannot downsize a hash table!" ); |
| 180 | } |
| 181 | capacity = size; |
| 182 | |
| 183 | bitmask = capacity - 1; |
| 184 | const auto byte_size = capacity * sizeof(ENTRY); |
| 185 | if (byte_size > (idx_t)Storage::BLOCK_SIZE) { |
| 186 | hashes_hdl = buffer_manager.Allocate(block_size: byte_size); |
| 187 | hashes_hdl_ptr = hashes_hdl.Ptr(); |
| 188 | } |
| 189 | memset(s: hashes_hdl_ptr, c: 0, n: byte_size); |
| 190 | |
| 191 | if (Count() != 0) { |
| 192 | D_ASSERT(!payload_hds_ptrs.empty()); |
| 193 | auto hashes_arr = (ENTRY *)hashes_hdl_ptr; |
| 194 | |
| 195 | idx_t block_id = 0; |
| 196 | auto block_pointer = payload_hds_ptrs[block_id]; |
| 197 | auto block_end = block_pointer + tuples_per_block * tuple_size; |
| 198 | |
| 199 | TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::ALREADY_PINNED, false); |
| 200 | const auto row_locations = iterator.GetRowLocations(); |
| 201 | do { |
| 202 | for (idx_t i = 0; i < iterator.GetCurrentChunkCount(); i++) { |
| 203 | const auto &row_location = row_locations[i]; |
| 204 | if (row_location > block_end || row_location < block_pointer) { |
| 205 | block_id++; |
| 206 | D_ASSERT(block_id < payload_hds_ptrs.size()); |
| 207 | block_pointer = payload_hds_ptrs[block_id]; |
| 208 | block_end = block_pointer + tuples_per_block * tuple_size; |
| 209 | } |
| 210 | D_ASSERT(row_location >= block_pointer && row_location < block_end); |
| 211 | D_ASSERT((row_location - block_pointer) % tuple_size == 0); |
| 212 | |
| 213 | const auto hash = Load<hash_t>(ptr: row_location + hash_offset); |
| 214 | D_ASSERT((hash & bitmask) == (hash % capacity)); |
| 215 | D_ASSERT(hash >> hash_prefix_shift <= NumericLimits<uint16_t>::Maximum()); |
| 216 | |
| 217 | auto entry_idx = (idx_t)hash & bitmask; |
| 218 | while (hashes_arr[entry_idx].page_nr > 0) { |
| 219 | entry_idx++; |
| 220 | if (entry_idx >= capacity) { |
| 221 | entry_idx = 0; |
| 222 | } |
| 223 | } |
| 224 | |
| 225 | auto &ht_entry = hashes_arr[entry_idx]; |
| 226 | D_ASSERT(!ht_entry.page_nr); |
| 227 | ht_entry.salt = hash >> hash_prefix_shift; |
| 228 | ht_entry.page_nr = block_id + 1; |
| 229 | ht_entry.page_offset = (row_location - block_pointer) / tuple_size; |
| 230 | } |
| 231 | } while (iterator.Next()); |
| 232 | } |
| 233 | |
| 234 | Verify(); |
| 235 | } |
| 236 | |
| 237 | idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, DataChunk &payload, |
| 238 | AggregateType filter) { |
| 239 | unsafe_vector<idx_t> aggregate_filter; |
| 240 | |
| 241 | auto &aggregates = layout.GetAggregates(); |
| 242 | for (idx_t i = 0; i < aggregates.size(); i++) { |
| 243 | auto &aggregate = aggregates[i]; |
| 244 | if (aggregate.aggr_type == filter) { |
| 245 | aggregate_filter.push_back(x: i); |
| 246 | } |
| 247 | } |
| 248 | return AddChunk(state, groups, payload, filter: aggregate_filter); |
| 249 | } |
| 250 | |
| 251 | idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, DataChunk &payload, |
| 252 | const unsafe_vector<idx_t> &filter) { |
| 253 | Vector hashes(LogicalType::HASH); |
| 254 | groups.Hash(result&: hashes); |
| 255 | |
| 256 | return AddChunk(state, groups, group_hashes&: hashes, payload, filter); |
| 257 | } |
| 258 | |
| 259 | idx_t GroupedAggregateHashTable::AddChunk(AggregateHTAppendState &state, DataChunk &groups, Vector &group_hashes, |
| 260 | DataChunk &payload, const unsafe_vector<idx_t> &filter) { |
| 261 | D_ASSERT(!is_finalized); |
| 262 | if (groups.size() == 0) { |
| 263 | return 0; |
| 264 | } |
| 265 | |
| 266 | #ifdef DEBUG |
| 267 | D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); |
| 268 | for (idx_t i = 0; i < groups.ColumnCount(); i++) { |
| 269 | D_ASSERT(groups.GetTypes()[i] == layout.GetTypes()[i]); |
| 270 | } |
| 271 | #endif |
| 272 | |
| 273 | auto new_group_count = FindOrCreateGroups(state, groups, group_hashes, addresses_out&: state.addresses, new_groups_out&: state.new_groups); |
| 274 | VectorOperations::AddInPlace(left&: state.addresses, delta: layout.GetAggrOffset(), count: payload.size()); |
| 275 | |
| 276 | // Now every cell has an entry, update the aggregates |
| 277 | auto &aggregates = layout.GetAggregates(); |
| 278 | idx_t filter_idx = 0; |
| 279 | idx_t payload_idx = 0; |
| 280 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
| 281 | for (idx_t i = 0; i < aggregates.size(); i++) { |
| 282 | auto &aggr = aggregates[i]; |
| 283 | if (filter_idx >= filter.size() || i < filter[filter_idx]) { |
| 284 | // Skip all the aggregates that are not in the filter |
| 285 | payload_idx += aggr.child_count; |
| 286 | VectorOperations::AddInPlace(left&: state.addresses, delta: aggr.payload_size, count: payload.size()); |
| 287 | continue; |
| 288 | } |
| 289 | D_ASSERT(i == filter[filter_idx]); |
| 290 | |
| 291 | if (aggr.aggr_type != AggregateType::DISTINCT && aggr.filter) { |
| 292 | RowOperations::UpdateFilteredStates(state&: row_state, filter_data&: filter_set.GetFilterData(aggr_idx: i), aggr, addresses&: state.addresses, payload, |
| 293 | arg_idx: payload_idx); |
| 294 | } else { |
| 295 | RowOperations::UpdateStates(state&: row_state, aggr, addresses&: state.addresses, payload, arg_idx: payload_idx, count: payload.size()); |
| 296 | } |
| 297 | |
| 298 | // Move to the next aggregate |
| 299 | payload_idx += aggr.child_count; |
| 300 | VectorOperations::AddInPlace(left&: state.addresses, delta: aggr.payload_size, count: payload.size()); |
| 301 | filter_idx++; |
| 302 | } |
| 303 | |
| 304 | Verify(); |
| 305 | return new_group_count; |
| 306 | } |
| 307 | |
| 308 | void GroupedAggregateHashTable::FetchAggregates(DataChunk &groups, DataChunk &result) { |
| 309 | groups.Verify(); |
| 310 | D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); |
| 311 | for (idx_t i = 0; i < result.ColumnCount(); i++) { |
| 312 | D_ASSERT(result.data[i].GetType() == payload_types[i]); |
| 313 | } |
| 314 | result.SetCardinality(groups); |
| 315 | if (groups.size() == 0) { |
| 316 | return; |
| 317 | } |
| 318 | |
| 319 | // find the groups associated with the addresses |
| 320 | // FIXME: this should not use the FindOrCreateGroups, creating them is unnecessary |
| 321 | AggregateHTAppendState append_state; |
| 322 | Vector addresses(LogicalType::POINTER); |
| 323 | FindOrCreateGroups(state&: append_state, groups, addresses_out&: addresses); |
| 324 | // now fetch the aggregates |
| 325 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
| 326 | RowOperations::FinalizeStates(state&: row_state, layout, addresses, result, aggr_idx: 0); |
| 327 | } |
| 328 | |
| 329 | idx_t GroupedAggregateHashTable::ResizeThreshold() { |
| 330 | return capacity / LOAD_FACTOR; |
| 331 | } |
| 332 | |
| 333 | template <class ENTRY> |
| 334 | idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(AggregateHTAppendState &state, DataChunk &groups, |
| 335 | Vector &group_hashes_v, Vector &addresses_v, |
| 336 | SelectionVector &new_groups_out) { |
| 337 | D_ASSERT(!is_finalized); |
| 338 | D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); |
| 339 | D_ASSERT(group_hashes_v.GetType() == LogicalType::HASH); |
| 340 | D_ASSERT(state.ht_offsets.GetVectorType() == VectorType::FLAT_VECTOR); |
| 341 | D_ASSERT(state.ht_offsets.GetType() == LogicalType::BIGINT); |
| 342 | D_ASSERT(addresses_v.GetType() == LogicalType::POINTER); |
| 343 | D_ASSERT(state.hash_salts.GetType() == LogicalType::SMALLINT); |
| 344 | |
| 345 | if (Count() + groups.size() > MaxCapacity()) { |
| 346 | throw InternalException("Hash table capacity reached" ); |
| 347 | } |
| 348 | |
| 349 | // Resize at 50% capacity, also need to fit the entire vector |
| 350 | if (capacity - Count() <= groups.size() || Count() > ResizeThreshold()) { |
| 351 | Verify(); |
| 352 | Resize<ENTRY>(capacity * 2); |
| 353 | } |
| 354 | D_ASSERT(capacity - Count() >= groups.size()); // we need to be able to fit at least one vector of data |
| 355 | |
| 356 | group_hashes_v.Flatten(count: groups.size()); |
| 357 | auto group_hashes = FlatVector::GetData<hash_t>(vector&: group_hashes_v); |
| 358 | |
| 359 | addresses_v.Flatten(count: groups.size()); |
| 360 | auto addresses = FlatVector::GetData<data_ptr_t>(vector&: addresses_v); |
| 361 | |
| 362 | // Compute the entry in the table based on the hash using a modulo, |
| 363 | // and precompute the hash salts for faster comparison below |
| 364 | auto ht_offsets_ptr = FlatVector::GetData<uint64_t>(vector&: state.ht_offsets); |
| 365 | auto hash_salts_ptr = FlatVector::GetData<uint16_t>(vector&: state.hash_salts); |
| 366 | for (idx_t r = 0; r < groups.size(); r++) { |
| 367 | auto element = group_hashes[r]; |
| 368 | D_ASSERT((element & bitmask) == (element % capacity)); |
| 369 | ht_offsets_ptr[r] = element & bitmask; |
| 370 | hash_salts_ptr[r] = element >> hash_prefix_shift; |
| 371 | } |
| 372 | // we start out with all entries [0, 1, 2, ..., groups.size()] |
| 373 | const SelectionVector *sel_vector = FlatVector::IncrementalSelectionVector(); |
| 374 | |
| 375 | // Make a chunk that references the groups and the hashes and convert to unified format |
| 376 | if (state.group_chunk.ColumnCount() == 0) { |
| 377 | state.group_chunk.InitializeEmpty(types: layout.GetTypes()); |
| 378 | } |
| 379 | D_ASSERT(state.group_chunk.ColumnCount() == layout.GetTypes().size()); |
| 380 | for (idx_t grp_idx = 0; grp_idx < groups.ColumnCount(); grp_idx++) { |
| 381 | state.group_chunk.data[grp_idx].Reference(other&: groups.data[grp_idx]); |
| 382 | } |
| 383 | state.group_chunk.data[groups.ColumnCount()].Reference(other&: group_hashes_v); |
| 384 | state.group_chunk.SetCardinality(groups); |
| 385 | |
| 386 | // convert all vectors to unified format |
| 387 | if (!state.chunk_state_initialized) { |
| 388 | data_collection->InitializeAppend(chunk_state&: state.chunk_state); |
| 389 | state.chunk_state_initialized = true; |
| 390 | } |
| 391 | TupleDataCollection::ToUnifiedFormat(chunk_state&: state.chunk_state, new_chunk&: state.group_chunk); |
| 392 | if (!state.group_data) { |
| 393 | state.group_data = make_unsafe_uniq_array<UnifiedVectorFormat>(n: state.group_chunk.ColumnCount()); |
| 394 | } |
| 395 | TupleDataCollection::GetVectorData(chunk_state: state.chunk_state, result: state.group_data.get()); |
| 396 | |
| 397 | idx_t new_group_count = 0; |
| 398 | idx_t remaining_entries = groups.size(); |
| 399 | while (remaining_entries > 0) { |
| 400 | idx_t new_entry_count = 0; |
| 401 | idx_t need_compare_count = 0; |
| 402 | idx_t no_match_count = 0; |
| 403 | |
| 404 | // For each remaining entry, figure out whether or not it belongs to a full or empty group |
| 405 | for (idx_t i = 0; i < remaining_entries; i++) { |
| 406 | const idx_t index = sel_vector->get_index(idx: i); |
| 407 | auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]); |
| 408 | if (ht_entry.page_nr == 0) { // Cell is unoccupied (we use page number 0 as a "unused marker") |
| 409 | D_ASSERT(group_hashes[index] >> hash_prefix_shift <= NumericLimits<uint16_t>::Maximum()); |
| 410 | D_ASSERT(payload_hds_ptrs.size() < NumericLimits<uint32_t>::Maximum()); |
| 411 | |
| 412 | // Set page nr to 1 for now to mark it as occupied (will be corrected later) and set the salt |
| 413 | ht_entry.page_nr = 1; |
| 414 | ht_entry.salt = group_hashes[index] >> hash_prefix_shift; |
| 415 | |
| 416 | // Update selection lists for outer loops |
| 417 | state.empty_vector.set_index(idx: new_entry_count++, loc: index); |
| 418 | new_groups_out.set_index(idx: new_group_count++, loc: index); |
| 419 | } else { // Cell is occupied: Compare salts |
| 420 | if (ht_entry.salt == hash_salts_ptr[index]) { |
| 421 | state.group_compare_vector.set_index(idx: need_compare_count++, loc: index); |
| 422 | } else { |
| 423 | state.no_match_vector.set_index(idx: no_match_count++, loc: index); |
| 424 | } |
| 425 | } |
| 426 | } |
| 427 | |
| 428 | if (new_entry_count != 0) { |
| 429 | // Append everything that belongs to an empty group |
| 430 | data_collection->AppendUnified(pin_state&: td_pin_state, chunk_state&: state.chunk_state, new_chunk&: state.group_chunk, append_sel: state.empty_vector, |
| 431 | append_count: new_entry_count); |
| 432 | RowOperations::InitializeStates(layout, addresses&: state.chunk_state.row_locations, |
| 433 | sel: *FlatVector::IncrementalSelectionVector(), count: new_entry_count); |
| 434 | |
| 435 | // Get the pointers to the (possibly) newly created blocks of the data collection |
| 436 | idx_t block_id = payload_hds_ptrs.empty() ? 0 : payload_hds_ptrs.size() - 1; |
| 437 | UpdateBlockPointers(); |
| 438 | auto block_pointer = payload_hds_ptrs[block_id]; |
| 439 | auto block_end = block_pointer + tuples_per_block * tuple_size; |
| 440 | |
| 441 | // Set the page nrs/offsets in the 1st part of the HT now that the data has been appended |
| 442 | const auto row_locations = FlatVector::GetData<data_ptr_t>(vector&: state.chunk_state.row_locations); |
| 443 | for (idx_t new_entry_idx = 0; new_entry_idx < new_entry_count; new_entry_idx++) { |
| 444 | const auto &row_location = row_locations[new_entry_idx]; |
| 445 | if (row_location > block_end || row_location < block_pointer) { |
| 446 | block_id++; |
| 447 | D_ASSERT(block_id < payload_hds_ptrs.size()); |
| 448 | block_pointer = payload_hds_ptrs[block_id]; |
| 449 | block_end = block_pointer + tuples_per_block * tuple_size; |
| 450 | } |
| 451 | D_ASSERT(row_location >= block_pointer && row_location < block_end); |
| 452 | D_ASSERT((row_location - block_pointer) % tuple_size == 0); |
| 453 | const auto index = state.empty_vector.get_index(idx: new_entry_idx); |
| 454 | auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]); |
| 455 | ht_entry.page_nr = block_id + 1; |
| 456 | ht_entry.page_offset = (row_location - block_pointer) / tuple_size; |
| 457 | addresses[index] = row_location; |
| 458 | } |
| 459 | } |
| 460 | |
| 461 | if (need_compare_count != 0) { |
| 462 | // Get the pointers to the rows that need to be compared |
| 463 | for (idx_t need_compare_idx = 0; need_compare_idx < need_compare_count; need_compare_idx++) { |
| 464 | const auto index = state.group_compare_vector.get_index(idx: need_compare_idx); |
| 465 | const auto &ht_entry = *(((ENTRY *)this->hashes_hdl_ptr) + ht_offsets_ptr[index]); |
| 466 | auto page_ptr = payload_hds_ptrs[ht_entry.page_nr - 1]; |
| 467 | auto page_offset = ht_entry.page_offset * tuple_size; |
| 468 | addresses[index] = page_ptr + page_offset; |
| 469 | } |
| 470 | |
| 471 | // Perform group comparisons |
| 472 | RowOperations::Match(columns&: state.group_chunk, col_data: state.group_data.get(), layout, rows&: addresses_v, predicates, |
| 473 | sel&: state.group_compare_vector, count: need_compare_count, no_match: &state.no_match_vector, |
| 474 | no_match_count); |
| 475 | } |
| 476 | |
| 477 | // Linear probing: each of the entries that do not match move to the next entry in the HT |
| 478 | for (idx_t i = 0; i < no_match_count; i++) { |
| 479 | idx_t index = state.no_match_vector.get_index(idx: i); |
| 480 | ht_offsets_ptr[index]++; |
| 481 | if (ht_offsets_ptr[index] >= capacity) { |
| 482 | ht_offsets_ptr[index] = 0; |
| 483 | } |
| 484 | } |
| 485 | sel_vector = &state.no_match_vector; |
| 486 | remaining_entries = no_match_count; |
| 487 | } |
| 488 | |
| 489 | return new_group_count; |
| 490 | } |
| 491 | |
| 492 | void GroupedAggregateHashTable::UpdateBlockPointers() { |
| 493 | for (const auto &id_and_handle : td_pin_state.row_handles) { |
| 494 | const auto &id = id_and_handle.first; |
| 495 | const auto &handle = id_and_handle.second; |
| 496 | if (payload_hds_ptrs.empty() || id > payload_hds_ptrs.size() - 1) { |
| 497 | payload_hds_ptrs.resize(new_size: id + 1); |
| 498 | } |
| 499 | payload_hds_ptrs[id] = handle.Ptr(); |
| 500 | } |
| 501 | } |
| 502 | |
| 503 | // this is to support distinct aggregations where we need to record whether we |
| 504 | // have already seen a value for a group |
| 505 | idx_t GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups, |
| 506 | Vector &group_hashes, Vector &addresses_out, |
| 507 | SelectionVector &new_groups_out) { |
| 508 | switch (entry_type) { |
| 509 | case HtEntryType::HT_WIDTH_64: |
| 510 | return FindOrCreateGroupsInternal<aggr_ht_entry_64>(state, groups, group_hashes_v&: group_hashes, addresses_v&: addresses_out, new_groups_out); |
| 511 | case HtEntryType::HT_WIDTH_32: |
| 512 | return FindOrCreateGroupsInternal<aggr_ht_entry_32>(state, groups, group_hashes_v&: group_hashes, addresses_v&: addresses_out, new_groups_out); |
| 513 | default: |
| 514 | throw InternalException("Unknown HT entry width" ); |
| 515 | } |
| 516 | } |
| 517 | |
| 518 | void GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups, |
| 519 | Vector &addresses) { |
| 520 | // create a dummy new_groups sel vector |
| 521 | FindOrCreateGroups(state, groups, addresses_out&: addresses, new_groups_out&: state.new_groups); |
| 522 | } |
| 523 | |
| 524 | idx_t GroupedAggregateHashTable::FindOrCreateGroups(AggregateHTAppendState &state, DataChunk &groups, |
| 525 | Vector &addresses_out, SelectionVector &new_groups_out) { |
| 526 | Vector hashes(LogicalType::HASH); |
| 527 | groups.Hash(result&: hashes); |
| 528 | return FindOrCreateGroups(state, groups, group_hashes&: hashes, addresses_out, new_groups_out); |
| 529 | } |
| 530 | |
| 531 | struct FlushMoveState { |
| 532 | explicit FlushMoveState(TupleDataCollection &collection_p) |
| 533 | : collection(collection_p), hashes(LogicalType::HASH), group_addresses(LogicalType::POINTER), |
| 534 | new_groups_sel(STANDARD_VECTOR_SIZE) { |
| 535 | const auto &layout = collection.GetLayout(); |
| 536 | vector<column_t> column_ids; |
| 537 | column_ids.reserve(n: layout.ColumnCount() - 1); |
| 538 | for (idx_t col_idx = 0; col_idx < layout.ColumnCount() - 1; col_idx++) { |
| 539 | column_ids.emplace_back(args&: col_idx); |
| 540 | } |
| 541 | // FIXME DESTROY_AFTER_DONE if we make it possible to pass a selection vector to RowOperations::DestroyStates? |
| 542 | collection.InitializeScan(state&: scan_state, column_ids, properties: TupleDataPinProperties::UNPIN_AFTER_DONE); |
| 543 | collection.InitializeScanChunk(state&: scan_state, chunk&: groups); |
| 544 | hash_col_idx = layout.ColumnCount() - 1; |
| 545 | } |
| 546 | |
| 547 | bool Scan(); |
| 548 | |
| 549 | TupleDataCollection &collection; |
| 550 | TupleDataScanState scan_state; |
| 551 | DataChunk groups; |
| 552 | |
| 553 | idx_t hash_col_idx; |
| 554 | Vector hashes; |
| 555 | |
| 556 | AggregateHTAppendState append_state; |
| 557 | Vector group_addresses; |
| 558 | SelectionVector new_groups_sel; |
| 559 | }; |
| 560 | |
| 561 | bool FlushMoveState::Scan() { |
| 562 | if (collection.Scan(state&: scan_state, result&: groups)) { |
| 563 | collection.Gather(row_locations&: scan_state.chunk_state.row_locations, sel: *FlatVector::IncrementalSelectionVector(), |
| 564 | scan_count: groups.size(), column_id: hash_col_idx, result&: hashes, target_sel: *FlatVector::IncrementalSelectionVector()); |
| 565 | return true; |
| 566 | } |
| 567 | |
| 568 | collection.FinalizePinState(pin_state&: scan_state.pin_state); |
| 569 | return false; |
| 570 | } |
| 571 | |
| 572 | void GroupedAggregateHashTable::Combine(GroupedAggregateHashTable &other) { |
| 573 | D_ASSERT(!is_finalized); |
| 574 | |
| 575 | D_ASSERT(other.layout.GetAggrWidth() == layout.GetAggrWidth()); |
| 576 | D_ASSERT(other.layout.GetDataWidth() == layout.GetDataWidth()); |
| 577 | D_ASSERT(other.layout.GetRowWidth() == layout.GetRowWidth()); |
| 578 | |
| 579 | if (other.Count() == 0) { |
| 580 | return; |
| 581 | } |
| 582 | |
| 583 | FlushMoveState state(*other.data_collection); |
| 584 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
| 585 | while (state.Scan()) { |
| 586 | FindOrCreateGroups(state&: state.append_state, groups&: state.groups, group_hashes&: state.hashes, addresses_out&: state.group_addresses, new_groups_out&: state.new_groups_sel); |
| 587 | RowOperations::CombineStates(state&: row_state, layout, sources&: state.scan_state.chunk_state.row_locations, |
| 588 | targets&: state.group_addresses, count: state.groups.size()); |
| 589 | } |
| 590 | |
| 591 | Verify(); |
| 592 | } |
| 593 | |
| 594 | void GroupedAggregateHashTable::Partition(vector<GroupedAggregateHashTable *> &partition_hts, idx_t radix_bits) { |
| 595 | const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); |
| 596 | D_ASSERT(partition_hts.size() == num_partitions); |
| 597 | |
| 598 | // Partition the data |
| 599 | auto partitioned_data = |
| 600 | make_uniq<RadixPartitionedTupleData>(args&: buffer_manager, args&: layout, args&: radix_bits, args: layout.ColumnCount() - 1); |
| 601 | partitioned_data->Partition(source&: *data_collection, properties: TupleDataPinProperties::KEEP_EVERYTHING_PINNED); |
| 602 | D_ASSERT(partitioned_data->GetPartitions().size() == num_partitions); |
| 603 | |
| 604 | // Move the partitioned data collections to the partitioned hash tables and initialize the 1st part of the HT |
| 605 | auto &partitions = partitioned_data->GetPartitions(); |
| 606 | for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { |
| 607 | auto &partition_ht = *partition_hts[partition_idx]; |
| 608 | partition_ht.data_collection = std::move(partitions[partition_idx]); |
| 609 | partition_ht.aggregate_allocator = aggregate_allocator; |
| 610 | partition_ht.InitializeFirstPart(); |
| 611 | partition_ht.Verify(); |
| 612 | } |
| 613 | } |
| 614 | |
| 615 | void GroupedAggregateHashTable::InitializeFirstPart() { |
| 616 | data_collection->GetBlockPointers(block_pointers&: payload_hds_ptrs); |
| 617 | auto size = MaxValue<idx_t>(a: NextPowerOfTwo(v: Count() * 2L), b: capacity); |
| 618 | switch (entry_type) { |
| 619 | case HtEntryType::HT_WIDTH_64: |
| 620 | Resize<aggr_ht_entry_64>(size); |
| 621 | break; |
| 622 | case HtEntryType::HT_WIDTH_32: |
| 623 | Resize<aggr_ht_entry_32>(size); |
| 624 | break; |
| 625 | default: |
| 626 | throw InternalException("Unknown HT entry width" ); |
| 627 | } |
| 628 | } |
| 629 | |
| 630 | idx_t GroupedAggregateHashTable::Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, |
| 631 | DataChunk &result) { |
| 632 | data_collection->Scan(gstate, lstate, result); |
| 633 | |
| 634 | RowOperationsState row_state(aggregate_allocator->GetAllocator()); |
| 635 | const auto group_cols = layout.ColumnCount() - 1; |
| 636 | RowOperations::FinalizeStates(state&: row_state, layout, addresses&: lstate.chunk_state.row_locations, result, aggr_idx: group_cols); |
| 637 | |
| 638 | return result.size(); |
| 639 | } |
| 640 | |
| 641 | void GroupedAggregateHashTable::Finalize() { |
| 642 | if (is_finalized) { |
| 643 | return; |
| 644 | } |
| 645 | |
| 646 | // Early release hashes (not needed for partition/scan) and data collection (will be pinned again when scanning) |
| 647 | hashes_hdl.Destroy(); |
| 648 | data_collection->FinalizePinState(pin_state&: td_pin_state); |
| 649 | data_collection->Unpin(); |
| 650 | |
| 651 | is_finalized = true; |
| 652 | } |
| 653 | |
| 654 | } // namespace duckdb |
| 655 | |