1#include "duckdb/execution/perfect_aggregate_hashtable.hpp"
2#include "duckdb/execution/expression_executor.hpp"
3#include "duckdb/common/row_operations/row_operations.hpp"
4
5namespace duckdb {
6
7PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, Allocator &allocator,
8 const vector<LogicalType> &group_types_p,
9 vector<LogicalType> payload_types_p,
10 vector<AggregateObject> aggregate_objects_p,
11 vector<Value> group_minima_p, vector<idx_t> required_bits_p)
12 : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)),
13 addresses(LogicalType::POINTER), required_bits(std::move(required_bits_p)), total_required_bits(0),
14 group_minima(std::move(group_minima_p)), sel(STANDARD_VECTOR_SIZE), aggregate_allocator(allocator) {
15 for (auto &group_bits : required_bits) {
16 total_required_bits += group_bits;
17 }
18 // the total amount of groups we allocate space for is 2^required_bits
19 total_groups = (uint64_t)1 << total_required_bits;
20 // we don't need to store the groups in a perfect hash table, since the group keys can be deduced by their location
21 grouping_columns = group_types_p.size();
22 layout.Initialize(aggregates_p: std::move(aggregate_objects_p));
23 tuple_size = layout.GetRowWidth();
24
25 // allocate and null initialize the data
26 owned_data = make_unsafe_uniq_array<data_t>(n: tuple_size * total_groups);
27 data = owned_data.get();
28
29 // set up the empty payloads for every tuple, and initialize the "occupied" flag to false
30 group_is_set = make_unsafe_uniq_array<bool>(n: total_groups);
31 memset(s: group_is_set.get(), c: 0, n: total_groups * sizeof(bool));
32
33 // initialize the hash table for each entry
34 auto address_data = FlatVector::GetData<uintptr_t>(vector&: addresses);
35 idx_t init_count = 0;
36 for (idx_t i = 0; i < total_groups; i++) {
37 address_data[init_count] = uintptr_t(data) + (tuple_size * i);
38 init_count++;
39 if (init_count == STANDARD_VECTOR_SIZE) {
40 RowOperations::InitializeStates(layout, addresses, sel: *FlatVector::IncrementalSelectionVector(), count: init_count);
41 init_count = 0;
42 }
43 }
44 RowOperations::InitializeStates(layout, addresses, sel: *FlatVector::IncrementalSelectionVector(), count: init_count);
45}
46
47PerfectAggregateHashTable::~PerfectAggregateHashTable() {
48 Destroy();
49}
50
51template <class T>
52static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value &min, uintptr_t *address_data,
53 idx_t current_shift, idx_t count) {
54 auto data = UnifiedVectorFormat::GetData<T>(group_data);
55 auto min_val = min.GetValueUnsafe<T>();
56 if (!group_data.validity.AllValid()) {
57 for (idx_t i = 0; i < count; i++) {
58 auto index = group_data.sel->get_index(idx: i);
59 // check if the value is NULL
60 // NULL groups are considered as "0" in the hash table
61 // that is to say, they have no effect on the position of the element (because 0 << shift is 0)
62 // we only need to handle non-null values here
63 if (group_data.validity.RowIsValid(row_idx: index)) {
64 D_ASSERT(data[index] >= min_val);
65 uintptr_t adjusted_value = (data[index] - min_val) + 1;
66 address_data[i] += adjusted_value << current_shift;
67 }
68 }
69 } else {
70 // no null values: we can directly compute the addresses
71 for (idx_t i = 0; i < count; i++) {
72 auto index = group_data.sel->get_index(idx: i);
73 uintptr_t adjusted_value = (data[index] - min_val) + 1;
74 address_data[i] += adjusted_value << current_shift;
75 }
76 }
77}
78
79static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift, idx_t count) {
80 UnifiedVectorFormat vdata;
81 group.ToUnifiedFormat(count, data&: vdata);
82
83 switch (group.GetType().InternalType()) {
84 case PhysicalType::INT8:
85 ComputeGroupLocationTemplated<int8_t>(group_data&: vdata, min, address_data, current_shift, count);
86 break;
87 case PhysicalType::INT16:
88 ComputeGroupLocationTemplated<int16_t>(group_data&: vdata, min, address_data, current_shift, count);
89 break;
90 case PhysicalType::INT32:
91 ComputeGroupLocationTemplated<int32_t>(group_data&: vdata, min, address_data, current_shift, count);
92 break;
93 case PhysicalType::INT64:
94 ComputeGroupLocationTemplated<int64_t>(group_data&: vdata, min, address_data, current_shift, count);
95 break;
96 default:
97 throw InternalException("Unsupported group type for perfect aggregate hash table");
98 }
99}
100
101void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) {
102 // first we need to find the location in the HT of each of the groups
103 auto address_data = FlatVector::GetData<uintptr_t>(vector&: addresses);
104 // zero-initialize the address data
105 memset(s: address_data, c: 0, n: groups.size() * sizeof(uintptr_t));
106 D_ASSERT(groups.ColumnCount() == group_minima.size());
107
108 // then compute the actual group location by iterating over each of the groups
109 idx_t current_shift = total_required_bits;
110 for (idx_t i = 0; i < groups.ColumnCount(); i++) {
111 current_shift -= required_bits[i];
112 ComputeGroupLocation(group&: groups.data[i], min&: group_minima[i], address_data, current_shift, count: groups.size());
113 }
114 // now we have the HT entry number for every tuple
115 // compute the actual pointer to the data by adding it to the base HT pointer and multiplying by the tuple size
116 for (idx_t i = 0; i < groups.size(); i++) {
117 const auto group = address_data[i];
118 D_ASSERT(group < total_groups);
119 group_is_set[group] = true;
120 address_data[i] = uintptr_t(data) + group * tuple_size;
121 }
122
123 // after finding the group location we update the aggregates
124 idx_t payload_idx = 0;
125 auto &aggregates = layout.GetAggregates();
126 RowOperationsState row_state(aggregate_allocator.GetAllocator());
127 for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) {
128 auto &aggregate = aggregates[aggr_idx];
129 auto input_count = (idx_t)aggregate.child_count;
130 if (aggregate.filter) {
131 RowOperations::UpdateFilteredStates(state&: row_state, filter_data&: filter_set.GetFilterData(aggr_idx), aggr&: aggregate, addresses,
132 payload, arg_idx: payload_idx);
133 } else {
134 RowOperations::UpdateStates(state&: row_state, aggr&: aggregate, addresses, payload, arg_idx: payload_idx, count: payload.size());
135 }
136 // move to the next aggregate
137 payload_idx += input_count;
138 VectorOperations::AddInPlace(left&: addresses, delta: aggregate.payload_size, count: payload.size());
139 }
140}
141
142void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) {
143 D_ASSERT(total_groups == other.total_groups);
144 D_ASSERT(tuple_size == other.tuple_size);
145
146 Vector source_addresses(LogicalType::POINTER);
147 Vector target_addresses(LogicalType::POINTER);
148 auto source_addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: source_addresses);
149 auto target_addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: target_addresses);
150
151 // iterate over all entries of both hash tables and call combine for all entries that can be combined
152 data_ptr_t source_ptr = other.data;
153 data_ptr_t target_ptr = data;
154 idx_t combine_count = 0;
155 RowOperationsState row_state(aggregate_allocator.GetAllocator());
156 for (idx_t i = 0; i < total_groups; i++) {
157 auto has_entry_source = other.group_is_set[i];
158 // we only have any work to do if the source has an entry for this group
159 if (has_entry_source) {
160 group_is_set[i] = true;
161 source_addresses_ptr[combine_count] = source_ptr;
162 target_addresses_ptr[combine_count] = target_ptr;
163 combine_count++;
164 if (combine_count == STANDARD_VECTOR_SIZE) {
165 RowOperations::CombineStates(state&: row_state, layout, sources&: source_addresses, targets&: target_addresses, count: combine_count);
166 combine_count = 0;
167 }
168 }
169 source_ptr += tuple_size;
170 target_ptr += tuple_size;
171 }
172 RowOperations::CombineStates(state&: row_state, layout, sources&: source_addresses, targets&: target_addresses, count: combine_count);
173}
174
175template <class T>
176static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, idx_t mask, idx_t shift,
177 idx_t entry_count, Vector &result) {
178 auto data = FlatVector::GetData<T>(result);
179 auto &validity_mask = FlatVector::Validity(vector&: result);
180 auto min_data = min.GetValueUnsafe<T>();
181 for (idx_t i = 0; i < entry_count; i++) {
182 // extract the value of this group from the total group index
183 auto group_index = (group_values[i] >> shift) & mask;
184 if (group_index == 0) {
185 // if it is 0, the value is NULL
186 validity_mask.SetInvalid(i);
187 } else {
188 // otherwise we add the value (minus 1) to the min value
189 data[i] = min_data + group_index - 1;
190 }
191 }
192}
193
194static void ReconstructGroupVector(uint32_t group_values[], Value &min, idx_t required_bits, idx_t shift,
195 idx_t entry_count, Vector &result) {
196 // construct the mask for this entry
197 idx_t mask = ((uint64_t)1 << required_bits) - 1;
198 switch (result.GetType().InternalType()) {
199 case PhysicalType::INT8:
200 ReconstructGroupVectorTemplated<int8_t>(group_values, min, mask, shift, entry_count, result);
201 break;
202 case PhysicalType::INT16:
203 ReconstructGroupVectorTemplated<int16_t>(group_values, min, mask, shift, entry_count, result);
204 break;
205 case PhysicalType::INT32:
206 ReconstructGroupVectorTemplated<int32_t>(group_values, min, mask, shift, entry_count, result);
207 break;
208 case PhysicalType::INT64:
209 ReconstructGroupVectorTemplated<int64_t>(group_values, min, mask, shift, entry_count, result);
210 break;
211 default:
212 throw InternalException("Invalid type for perfect aggregate HT group");
213 }
214}
215
216void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) {
217 auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses);
218 uint32_t group_values[STANDARD_VECTOR_SIZE];
219
220 // iterate over the HT until we either have exhausted the entire HT, or
221 idx_t entry_count = 0;
222 for (; scan_position < total_groups; scan_position++) {
223 if (group_is_set[scan_position]) {
224 // this group is set: add it to the set of groups to extract
225 data_pointers[entry_count] = data + tuple_size * scan_position;
226 group_values[entry_count] = scan_position;
227 entry_count++;
228 if (entry_count == STANDARD_VECTOR_SIZE) {
229 scan_position++;
230 break;
231 }
232 }
233 }
234 if (entry_count == 0) {
235 // no entries found
236 return;
237 }
238 // first reconstruct the groups from the group index
239 idx_t shift = total_required_bits;
240 for (idx_t i = 0; i < grouping_columns; i++) {
241 shift -= required_bits[i];
242 ReconstructGroupVector(group_values, min&: group_minima[i], required_bits: required_bits[i], shift, entry_count, result&: result.data[i]);
243 }
244 // then construct the payloads
245 result.SetCardinality(entry_count);
246 RowOperationsState row_state(aggregate_allocator.GetAllocator());
247 RowOperations::FinalizeStates(state&: row_state, layout, addresses, result, aggr_idx: grouping_columns);
248}
249
250void PerfectAggregateHashTable::Destroy() {
251 // check if there is any destructor to call
252 bool has_destructor = false;
253 for (auto &aggr : layout.GetAggregates()) {
254 if (aggr.function.destructor) {
255 has_destructor = true;
256 }
257 }
258 if (!has_destructor) {
259 return;
260 }
261 // there are aggregates with destructors: loop over the hash table
262 // and call the destructor method for each of the aggregates
263 auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses);
264 idx_t count = 0;
265
266 // iterate over all initialised slots of the hash table
267 RowOperationsState row_state(aggregate_allocator.GetAllocator());
268 data_ptr_t payload_ptr = data;
269 for (idx_t i = 0; i < total_groups; i++) {
270 if (group_is_set[i]) {
271 data_pointers[count++] = payload_ptr;
272 if (count == STANDARD_VECTOR_SIZE) {
273 RowOperations::DestroyStates(state&: row_state, layout, addresses, count);
274 count = 0;
275 }
276 }
277 payload_ptr += tuple_size;
278 }
279 RowOperations::DestroyStates(state&: row_state, layout, addresses, count);
280}
281
282} // namespace duckdb
283