1 | #include "duckdb/execution/perfect_aggregate_hashtable.hpp" |
2 | #include "duckdb/execution/expression_executor.hpp" |
3 | #include "duckdb/common/row_operations/row_operations.hpp" |
4 | |
5 | namespace duckdb { |
6 | |
7 | PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, Allocator &allocator, |
8 | const vector<LogicalType> &group_types_p, |
9 | vector<LogicalType> payload_types_p, |
10 | vector<AggregateObject> aggregate_objects_p, |
11 | vector<Value> group_minima_p, vector<idx_t> required_bits_p) |
12 | : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), |
13 | addresses(LogicalType::POINTER), required_bits(std::move(required_bits_p)), total_required_bits(0), |
14 | group_minima(std::move(group_minima_p)), sel(STANDARD_VECTOR_SIZE), aggregate_allocator(allocator) { |
15 | for (auto &group_bits : required_bits) { |
16 | total_required_bits += group_bits; |
17 | } |
18 | // the total amount of groups we allocate space for is 2^required_bits |
19 | total_groups = (uint64_t)1 << total_required_bits; |
20 | // we don't need to store the groups in a perfect hash table, since the group keys can be deduced by their location |
21 | grouping_columns = group_types_p.size(); |
22 | layout.Initialize(aggregates_p: std::move(aggregate_objects_p)); |
23 | tuple_size = layout.GetRowWidth(); |
24 | |
25 | // allocate and null initialize the data |
26 | owned_data = make_unsafe_uniq_array<data_t>(n: tuple_size * total_groups); |
27 | data = owned_data.get(); |
28 | |
29 | // set up the empty payloads for every tuple, and initialize the "occupied" flag to false |
30 | group_is_set = make_unsafe_uniq_array<bool>(n: total_groups); |
31 | memset(s: group_is_set.get(), c: 0, n: total_groups * sizeof(bool)); |
32 | |
33 | // initialize the hash table for each entry |
34 | auto address_data = FlatVector::GetData<uintptr_t>(vector&: addresses); |
35 | idx_t init_count = 0; |
36 | for (idx_t i = 0; i < total_groups; i++) { |
37 | address_data[init_count] = uintptr_t(data) + (tuple_size * i); |
38 | init_count++; |
39 | if (init_count == STANDARD_VECTOR_SIZE) { |
40 | RowOperations::InitializeStates(layout, addresses, sel: *FlatVector::IncrementalSelectionVector(), count: init_count); |
41 | init_count = 0; |
42 | } |
43 | } |
44 | RowOperations::InitializeStates(layout, addresses, sel: *FlatVector::IncrementalSelectionVector(), count: init_count); |
45 | } |
46 | |
47 | PerfectAggregateHashTable::~PerfectAggregateHashTable() { |
48 | Destroy(); |
49 | } |
50 | |
51 | template <class T> |
52 | static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value &min, uintptr_t *address_data, |
53 | idx_t current_shift, idx_t count) { |
54 | auto data = UnifiedVectorFormat::GetData<T>(group_data); |
55 | auto min_val = min.GetValueUnsafe<T>(); |
56 | if (!group_data.validity.AllValid()) { |
57 | for (idx_t i = 0; i < count; i++) { |
58 | auto index = group_data.sel->get_index(idx: i); |
59 | // check if the value is NULL |
60 | // NULL groups are considered as "0" in the hash table |
61 | // that is to say, they have no effect on the position of the element (because 0 << shift is 0) |
62 | // we only need to handle non-null values here |
63 | if (group_data.validity.RowIsValid(row_idx: index)) { |
64 | D_ASSERT(data[index] >= min_val); |
65 | uintptr_t adjusted_value = (data[index] - min_val) + 1; |
66 | address_data[i] += adjusted_value << current_shift; |
67 | } |
68 | } |
69 | } else { |
70 | // no null values: we can directly compute the addresses |
71 | for (idx_t i = 0; i < count; i++) { |
72 | auto index = group_data.sel->get_index(idx: i); |
73 | uintptr_t adjusted_value = (data[index] - min_val) + 1; |
74 | address_data[i] += adjusted_value << current_shift; |
75 | } |
76 | } |
77 | } |
78 | |
79 | static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift, idx_t count) { |
80 | UnifiedVectorFormat vdata; |
81 | group.ToUnifiedFormat(count, data&: vdata); |
82 | |
83 | switch (group.GetType().InternalType()) { |
84 | case PhysicalType::INT8: |
85 | ComputeGroupLocationTemplated<int8_t>(group_data&: vdata, min, address_data, current_shift, count); |
86 | break; |
87 | case PhysicalType::INT16: |
88 | ComputeGroupLocationTemplated<int16_t>(group_data&: vdata, min, address_data, current_shift, count); |
89 | break; |
90 | case PhysicalType::INT32: |
91 | ComputeGroupLocationTemplated<int32_t>(group_data&: vdata, min, address_data, current_shift, count); |
92 | break; |
93 | case PhysicalType::INT64: |
94 | ComputeGroupLocationTemplated<int64_t>(group_data&: vdata, min, address_data, current_shift, count); |
95 | break; |
96 | default: |
97 | throw InternalException("Unsupported group type for perfect aggregate hash table" ); |
98 | } |
99 | } |
100 | |
101 | void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) { |
102 | // first we need to find the location in the HT of each of the groups |
103 | auto address_data = FlatVector::GetData<uintptr_t>(vector&: addresses); |
104 | // zero-initialize the address data |
105 | memset(s: address_data, c: 0, n: groups.size() * sizeof(uintptr_t)); |
106 | D_ASSERT(groups.ColumnCount() == group_minima.size()); |
107 | |
108 | // then compute the actual group location by iterating over each of the groups |
109 | idx_t current_shift = total_required_bits; |
110 | for (idx_t i = 0; i < groups.ColumnCount(); i++) { |
111 | current_shift -= required_bits[i]; |
112 | ComputeGroupLocation(group&: groups.data[i], min&: group_minima[i], address_data, current_shift, count: groups.size()); |
113 | } |
114 | // now we have the HT entry number for every tuple |
115 | // compute the actual pointer to the data by adding it to the base HT pointer and multiplying by the tuple size |
116 | for (idx_t i = 0; i < groups.size(); i++) { |
117 | const auto group = address_data[i]; |
118 | D_ASSERT(group < total_groups); |
119 | group_is_set[group] = true; |
120 | address_data[i] = uintptr_t(data) + group * tuple_size; |
121 | } |
122 | |
123 | // after finding the group location we update the aggregates |
124 | idx_t payload_idx = 0; |
125 | auto &aggregates = layout.GetAggregates(); |
126 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
127 | for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { |
128 | auto &aggregate = aggregates[aggr_idx]; |
129 | auto input_count = (idx_t)aggregate.child_count; |
130 | if (aggregate.filter) { |
131 | RowOperations::UpdateFilteredStates(state&: row_state, filter_data&: filter_set.GetFilterData(aggr_idx), aggr&: aggregate, addresses, |
132 | payload, arg_idx: payload_idx); |
133 | } else { |
134 | RowOperations::UpdateStates(state&: row_state, aggr&: aggregate, addresses, payload, arg_idx: payload_idx, count: payload.size()); |
135 | } |
136 | // move to the next aggregate |
137 | payload_idx += input_count; |
138 | VectorOperations::AddInPlace(left&: addresses, delta: aggregate.payload_size, count: payload.size()); |
139 | } |
140 | } |
141 | |
142 | void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) { |
143 | D_ASSERT(total_groups == other.total_groups); |
144 | D_ASSERT(tuple_size == other.tuple_size); |
145 | |
146 | Vector source_addresses(LogicalType::POINTER); |
147 | Vector target_addresses(LogicalType::POINTER); |
148 | auto source_addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: source_addresses); |
149 | auto target_addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: target_addresses); |
150 | |
151 | // iterate over all entries of both hash tables and call combine for all entries that can be combined |
152 | data_ptr_t source_ptr = other.data; |
153 | data_ptr_t target_ptr = data; |
154 | idx_t combine_count = 0; |
155 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
156 | for (idx_t i = 0; i < total_groups; i++) { |
157 | auto has_entry_source = other.group_is_set[i]; |
158 | // we only have any work to do if the source has an entry for this group |
159 | if (has_entry_source) { |
160 | group_is_set[i] = true; |
161 | source_addresses_ptr[combine_count] = source_ptr; |
162 | target_addresses_ptr[combine_count] = target_ptr; |
163 | combine_count++; |
164 | if (combine_count == STANDARD_VECTOR_SIZE) { |
165 | RowOperations::CombineStates(state&: row_state, layout, sources&: source_addresses, targets&: target_addresses, count: combine_count); |
166 | combine_count = 0; |
167 | } |
168 | } |
169 | source_ptr += tuple_size; |
170 | target_ptr += tuple_size; |
171 | } |
172 | RowOperations::CombineStates(state&: row_state, layout, sources&: source_addresses, targets&: target_addresses, count: combine_count); |
173 | } |
174 | |
175 | template <class T> |
176 | static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, idx_t mask, idx_t shift, |
177 | idx_t entry_count, Vector &result) { |
178 | auto data = FlatVector::GetData<T>(result); |
179 | auto &validity_mask = FlatVector::Validity(vector&: result); |
180 | auto min_data = min.GetValueUnsafe<T>(); |
181 | for (idx_t i = 0; i < entry_count; i++) { |
182 | // extract the value of this group from the total group index |
183 | auto group_index = (group_values[i] >> shift) & mask; |
184 | if (group_index == 0) { |
185 | // if it is 0, the value is NULL |
186 | validity_mask.SetInvalid(i); |
187 | } else { |
188 | // otherwise we add the value (minus 1) to the min value |
189 | data[i] = min_data + group_index - 1; |
190 | } |
191 | } |
192 | } |
193 | |
194 | static void ReconstructGroupVector(uint32_t group_values[], Value &min, idx_t required_bits, idx_t shift, |
195 | idx_t entry_count, Vector &result) { |
196 | // construct the mask for this entry |
197 | idx_t mask = ((uint64_t)1 << required_bits) - 1; |
198 | switch (result.GetType().InternalType()) { |
199 | case PhysicalType::INT8: |
200 | ReconstructGroupVectorTemplated<int8_t>(group_values, min, mask, shift, entry_count, result); |
201 | break; |
202 | case PhysicalType::INT16: |
203 | ReconstructGroupVectorTemplated<int16_t>(group_values, min, mask, shift, entry_count, result); |
204 | break; |
205 | case PhysicalType::INT32: |
206 | ReconstructGroupVectorTemplated<int32_t>(group_values, min, mask, shift, entry_count, result); |
207 | break; |
208 | case PhysicalType::INT64: |
209 | ReconstructGroupVectorTemplated<int64_t>(group_values, min, mask, shift, entry_count, result); |
210 | break; |
211 | default: |
212 | throw InternalException("Invalid type for perfect aggregate HT group" ); |
213 | } |
214 | } |
215 | |
216 | void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) { |
217 | auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
218 | uint32_t group_values[STANDARD_VECTOR_SIZE]; |
219 | |
220 | // iterate over the HT until we either have exhausted the entire HT, or |
221 | idx_t entry_count = 0; |
222 | for (; scan_position < total_groups; scan_position++) { |
223 | if (group_is_set[scan_position]) { |
224 | // this group is set: add it to the set of groups to extract |
225 | data_pointers[entry_count] = data + tuple_size * scan_position; |
226 | group_values[entry_count] = scan_position; |
227 | entry_count++; |
228 | if (entry_count == STANDARD_VECTOR_SIZE) { |
229 | scan_position++; |
230 | break; |
231 | } |
232 | } |
233 | } |
234 | if (entry_count == 0) { |
235 | // no entries found |
236 | return; |
237 | } |
238 | // first reconstruct the groups from the group index |
239 | idx_t shift = total_required_bits; |
240 | for (idx_t i = 0; i < grouping_columns; i++) { |
241 | shift -= required_bits[i]; |
242 | ReconstructGroupVector(group_values, min&: group_minima[i], required_bits: required_bits[i], shift, entry_count, result&: result.data[i]); |
243 | } |
244 | // then construct the payloads |
245 | result.SetCardinality(entry_count); |
246 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
247 | RowOperations::FinalizeStates(state&: row_state, layout, addresses, result, aggr_idx: grouping_columns); |
248 | } |
249 | |
250 | void PerfectAggregateHashTable::Destroy() { |
251 | // check if there is any destructor to call |
252 | bool has_destructor = false; |
253 | for (auto &aggr : layout.GetAggregates()) { |
254 | if (aggr.function.destructor) { |
255 | has_destructor = true; |
256 | } |
257 | } |
258 | if (!has_destructor) { |
259 | return; |
260 | } |
261 | // there are aggregates with destructors: loop over the hash table |
262 | // and call the destructor method for each of the aggregates |
263 | auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
264 | idx_t count = 0; |
265 | |
266 | // iterate over all initialised slots of the hash table |
267 | RowOperationsState row_state(aggregate_allocator.GetAllocator()); |
268 | data_ptr_t payload_ptr = data; |
269 | for (idx_t i = 0; i < total_groups; i++) { |
270 | if (group_is_set[i]) { |
271 | data_pointers[count++] = payload_ptr; |
272 | if (count == STANDARD_VECTOR_SIZE) { |
273 | RowOperations::DestroyStates(state&: row_state, layout, addresses, count); |
274 | count = 0; |
275 | } |
276 | } |
277 | payload_ptr += tuple_size; |
278 | } |
279 | RowOperations::DestroyStates(state&: row_state, layout, addresses, count); |
280 | } |
281 | |
282 | } // namespace duckdb |
283 | |