| 1 | #include "duckdb/execution/perfect_aggregate_hashtable.hpp" |
| 2 | #include "duckdb/execution/expression_executor.hpp" |
| 3 | #include "duckdb/common/row_operations/row_operations.hpp" |
| 4 | |
| 5 | namespace duckdb { |
| 6 | |
| 7 | PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, Allocator &allocator, |
| 8 | const vector<LogicalType> &group_types_p, |
| 9 | vector<LogicalType> payload_types_p, |
| 10 | vector<AggregateObject> aggregate_objects_p, |
| 11 | vector<Value> group_minima_p, vector<idx_t> required_bits_p) |
| 12 | : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), |
| 13 | addresses(LogicalType::POINTER), required_bits(std::move(required_bits_p)), total_required_bits(0), |
| 14 | group_minima(std::move(group_minima_p)), sel(STANDARD_VECTOR_SIZE), aggregate_allocator(allocator) { |
| 15 | for (auto &group_bits : required_bits) { |
| 16 | total_required_bits += group_bits; |
| 17 | } |
| 18 | // the total amount of groups we allocate space for is 2^required_bits |
| 19 | total_groups = (uint64_t)1 << total_required_bits; |
| 20 | // we don't need to store the groups in a perfect hash table, since the group keys can be deduced by their location |
| 21 | grouping_columns = group_types_p.size(); |
| 22 | layout.Initialize(aggregates_p: std::move(aggregate_objects_p)); |
| 23 | tuple_size = layout.GetRowWidth(); |
| 24 | |
| 25 | // allocate and null initialize the data |
| 26 | owned_data = make_unsafe_uniq_array<data_t>(n: tuple_size * total_groups); |
| 27 | data = owned_data.get(); |
| 28 | |
| 29 | // set up the empty payloads for every tuple, and initialize the "occupied" flag to false |
| 30 | group_is_set = make_unsafe_uniq_array<bool>(n: total_groups); |
| 31 | memset(s: group_is_set.get(), c: 0, n: total_groups * sizeof(bool)); |
| 32 | |
| 33 | // initialize the hash table for each entry |
| 34 | auto address_data = FlatVector::GetData<uintptr_t>(vector&: addresses); |
| 35 | idx_t init_count = 0; |
| 36 | for (idx_t i = 0; i < total_groups; i++) { |
| 37 | address_data[init_count] = uintptr_t(data) + (tuple_size * i); |
| 38 | init_count++; |
| 39 | if (init_count == STANDARD_VECTOR_SIZE) { |
| 40 | RowOperations::InitializeStates(layout, addresses, sel: *FlatVector::IncrementalSelectionVector(), count: init_count); |
| 41 | init_count = 0; |
| 42 | } |
| 43 | } |
| 44 | RowOperations::InitializeStates(layout, addresses, sel: *FlatVector::IncrementalSelectionVector(), count: init_count); |
| 45 | } |
| 46 | |
| 47 | PerfectAggregateHashTable::~PerfectAggregateHashTable() { |
| 48 | Destroy(); |
| 49 | } |
| 50 | |
| 51 | template <class T> |
| 52 | static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value &min, uintptr_t *address_data, |
| 53 | idx_t current_shift, idx_t count) { |
| 54 | auto data = UnifiedVectorFormat::GetData<T>(group_data); |
| 55 | auto min_val = min.GetValueUnsafe<T>(); |
| 56 | if (!group_data.validity.AllValid()) { |
| 57 | for (idx_t i = 0; i < count; i++) { |
| 58 | auto index = group_data.sel->get_index(idx: i); |
| 59 | // check if the value is NULL |
| 60 | // NULL groups are considered as "0" in the hash table |
| 61 | // that is to say, they have no effect on the position of the element (because 0 << shift is 0) |
| 62 | // we only need to handle non-null values here |
| 63 | if (group_data.validity.RowIsValid(row_idx: index)) { |
| 64 | D_ASSERT(data[index] >= min_val); |
| 65 | uintptr_t adjusted_value = (data[index] - min_val) + 1; |
| 66 | address_data[i] += adjusted_value << current_shift; |
| 67 | } |
| 68 | } |
| 69 | } else { |
| 70 | // no null values: we can directly compute the addresses |
| 71 | for (idx_t i = 0; i < count; i++) { |
| 72 | auto index = group_data.sel->get_index(idx: i); |
| 73 | uintptr_t adjusted_value = (data[index] - min_val) + 1; |
| 74 | address_data[i] += adjusted_value << current_shift; |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift, idx_t count) { |
| 80 | UnifiedVectorFormat vdata; |
| 81 | group.ToUnifiedFormat(count, data&: vdata); |
| 82 | |
| 83 | switch (group.GetType().InternalType()) { |
| 84 | case PhysicalType::INT8: |
| 85 | ComputeGroupLocationTemplated<int8_t>(group_data&: vdata, min, address_data, current_shift, count); |
| 86 | break; |
| 87 | case PhysicalType::INT16: |
| 88 | ComputeGroupLocationTemplated<int16_t>(group_data&: vdata, min, address_data, current_shift, count); |
| 89 | break; |
| 90 | case PhysicalType::INT32: |
| 91 | ComputeGroupLocationTemplated<int32_t>(group_data&: vdata, min, address_data, current_shift, count); |
| 92 | break; |
| 93 | case PhysicalType::INT64: |
| 94 | ComputeGroupLocationTemplated<int64_t>(group_data&: vdata, min, address_data, current_shift, count); |
| 95 | break; |
| 96 | default: |
| 97 | throw InternalException("Unsupported group type for perfect aggregate hash table" ); |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) { |
| 102 | // first we need to find the location in the HT of each of the groups |
| 103 | auto address_data = FlatVector::GetData<uintptr_t>(vector&: addresses); |
| 104 | // zero-initialize the address data |
| 105 | memset(s: address_data, c: 0, n: groups.size() * sizeof(uintptr_t)); |
| 106 | D_ASSERT(groups.ColumnCount() == group_minima.size()); |
| 107 | |
| 108 | // then compute the actual group location by iterating over each of the groups |
| 109 | idx_t current_shift = total_required_bits; |
| 110 | for (idx_t i = 0; i < groups.ColumnCount(); i++) { |
| 111 | current_shift -= required_bits[i]; |
| 112 | ComputeGroupLocation(group&: groups.data[i], min&: group_minima[i], address_data, current_shift, count: groups.size()); |
| 113 | } |
| 114 | // now we have the HT entry number for every tuple |
| 115 | // compute the actual pointer to the data by adding it to the base HT pointer and multiplying by the tuple size |
| 116 | for (idx_t i = 0; i < groups.size(); i++) { |
| 117 | const auto group = address_data[i]; |
| 118 | D_ASSERT(group < total_groups); |
| 119 | group_is_set[group] = true; |
| 120 | address_data[i] = uintptr_t(data) + group * tuple_size; |
| 121 | } |
| 122 | |
| 123 | // after finding the group location we update the aggregates |
| 124 | idx_t payload_idx = 0; |
| 125 | auto &aggregates = layout.GetAggregates(); |
| 126 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
| 127 | for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { |
| 128 | auto &aggregate = aggregates[aggr_idx]; |
| 129 | auto input_count = (idx_t)aggregate.child_count; |
| 130 | if (aggregate.filter) { |
| 131 | RowOperations::UpdateFilteredStates(state&: row_state, filter_data&: filter_set.GetFilterData(aggr_idx), aggr&: aggregate, addresses, |
| 132 | payload, arg_idx: payload_idx); |
| 133 | } else { |
| 134 | RowOperations::UpdateStates(state&: row_state, aggr&: aggregate, addresses, payload, arg_idx: payload_idx, count: payload.size()); |
| 135 | } |
| 136 | // move to the next aggregate |
| 137 | payload_idx += input_count; |
| 138 | VectorOperations::AddInPlace(left&: addresses, delta: aggregate.payload_size, count: payload.size()); |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) { |
| 143 | D_ASSERT(total_groups == other.total_groups); |
| 144 | D_ASSERT(tuple_size == other.tuple_size); |
| 145 | |
| 146 | Vector source_addresses(LogicalType::POINTER); |
| 147 | Vector target_addresses(LogicalType::POINTER); |
| 148 | auto source_addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: source_addresses); |
| 149 | auto target_addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: target_addresses); |
| 150 | |
| 151 | // iterate over all entries of both hash tables and call combine for all entries that can be combined |
| 152 | data_ptr_t source_ptr = other.data; |
| 153 | data_ptr_t target_ptr = data; |
| 154 | idx_t combine_count = 0; |
| 155 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
| 156 | for (idx_t i = 0; i < total_groups; i++) { |
| 157 | auto has_entry_source = other.group_is_set[i]; |
| 158 | // we only have any work to do if the source has an entry for this group |
| 159 | if (has_entry_source) { |
| 160 | group_is_set[i] = true; |
| 161 | source_addresses_ptr[combine_count] = source_ptr; |
| 162 | target_addresses_ptr[combine_count] = target_ptr; |
| 163 | combine_count++; |
| 164 | if (combine_count == STANDARD_VECTOR_SIZE) { |
| 165 | RowOperations::CombineStates(state&: row_state, layout, sources&: source_addresses, targets&: target_addresses, count: combine_count); |
| 166 | combine_count = 0; |
| 167 | } |
| 168 | } |
| 169 | source_ptr += tuple_size; |
| 170 | target_ptr += tuple_size; |
| 171 | } |
| 172 | RowOperations::CombineStates(state&: row_state, layout, sources&: source_addresses, targets&: target_addresses, count: combine_count); |
| 173 | } |
| 174 | |
| 175 | template <class T> |
| 176 | static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, idx_t mask, idx_t shift, |
| 177 | idx_t entry_count, Vector &result) { |
| 178 | auto data = FlatVector::GetData<T>(result); |
| 179 | auto &validity_mask = FlatVector::Validity(vector&: result); |
| 180 | auto min_data = min.GetValueUnsafe<T>(); |
| 181 | for (idx_t i = 0; i < entry_count; i++) { |
| 182 | // extract the value of this group from the total group index |
| 183 | auto group_index = (group_values[i] >> shift) & mask; |
| 184 | if (group_index == 0) { |
| 185 | // if it is 0, the value is NULL |
| 186 | validity_mask.SetInvalid(i); |
| 187 | } else { |
| 188 | // otherwise we add the value (minus 1) to the min value |
| 189 | data[i] = min_data + group_index - 1; |
| 190 | } |
| 191 | } |
| 192 | } |
| 193 | |
| 194 | static void ReconstructGroupVector(uint32_t group_values[], Value &min, idx_t required_bits, idx_t shift, |
| 195 | idx_t entry_count, Vector &result) { |
| 196 | // construct the mask for this entry |
| 197 | idx_t mask = ((uint64_t)1 << required_bits) - 1; |
| 198 | switch (result.GetType().InternalType()) { |
| 199 | case PhysicalType::INT8: |
| 200 | ReconstructGroupVectorTemplated<int8_t>(group_values, min, mask, shift, entry_count, result); |
| 201 | break; |
| 202 | case PhysicalType::INT16: |
| 203 | ReconstructGroupVectorTemplated<int16_t>(group_values, min, mask, shift, entry_count, result); |
| 204 | break; |
| 205 | case PhysicalType::INT32: |
| 206 | ReconstructGroupVectorTemplated<int32_t>(group_values, min, mask, shift, entry_count, result); |
| 207 | break; |
| 208 | case PhysicalType::INT64: |
| 209 | ReconstructGroupVectorTemplated<int64_t>(group_values, min, mask, shift, entry_count, result); |
| 210 | break; |
| 211 | default: |
| 212 | throw InternalException("Invalid type for perfect aggregate HT group" ); |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) { |
| 217 | auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
| 218 | uint32_t group_values[STANDARD_VECTOR_SIZE]; |
| 219 | |
| 220 | // iterate over the HT until we either have exhausted the entire HT, or |
| 221 | idx_t entry_count = 0; |
| 222 | for (; scan_position < total_groups; scan_position++) { |
| 223 | if (group_is_set[scan_position]) { |
| 224 | // this group is set: add it to the set of groups to extract |
| 225 | data_pointers[entry_count] = data + tuple_size * scan_position; |
| 226 | group_values[entry_count] = scan_position; |
| 227 | entry_count++; |
| 228 | if (entry_count == STANDARD_VECTOR_SIZE) { |
| 229 | scan_position++; |
| 230 | break; |
| 231 | } |
| 232 | } |
| 233 | } |
| 234 | if (entry_count == 0) { |
| 235 | // no entries found |
| 236 | return; |
| 237 | } |
| 238 | // first reconstruct the groups from the group index |
| 239 | idx_t shift = total_required_bits; |
| 240 | for (idx_t i = 0; i < grouping_columns; i++) { |
| 241 | shift -= required_bits[i]; |
| 242 | ReconstructGroupVector(group_values, min&: group_minima[i], required_bits: required_bits[i], shift, entry_count, result&: result.data[i]); |
| 243 | } |
| 244 | // then construct the payloads |
| 245 | result.SetCardinality(entry_count); |
| 246 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
| 247 | RowOperations::FinalizeStates(state&: row_state, layout, addresses, result, aggr_idx: grouping_columns); |
| 248 | } |
| 249 | |
| 250 | void PerfectAggregateHashTable::Destroy() { |
| 251 | // check if there is any destructor to call |
| 252 | bool has_destructor = false; |
| 253 | for (auto &aggr : layout.GetAggregates()) { |
| 254 | if (aggr.function.destructor) { |
| 255 | has_destructor = true; |
| 256 | } |
| 257 | } |
| 258 | if (!has_destructor) { |
| 259 | return; |
| 260 | } |
| 261 | // there are aggregates with destructors: loop over the hash table |
| 262 | // and call the destructor method for each of the aggregates |
| 263 | auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
| 264 | idx_t count = 0; |
| 265 | |
| 266 | // iterate over all initialised slots of the hash table |
| 267 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
| 268 | data_ptr_t payload_ptr = data; |
| 269 | for (idx_t i = 0; i < total_groups; i++) { |
| 270 | if (group_is_set[i]) { |
| 271 | data_pointers[count++] = payload_ptr; |
| 272 | if (count == STANDARD_VECTOR_SIZE) { |
| 273 | RowOperations::DestroyStates(state&: row_state, layout, addresses, count); |
| 274 | count = 0; |
| 275 | } |
| 276 | } |
| 277 | payload_ptr += tuple_size; |
| 278 | } |
| 279 | RowOperations::DestroyStates(state&: row_state, layout, addresses, count); |
| 280 | } |
| 281 | |
| 282 | } // namespace duckdb |
| 283 | |