1 | #include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" |
2 | #include "duckdb/planner/expression.hpp" |
3 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
4 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
5 | #include "duckdb/common/algorithm.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | //! Shared information about a collection of distinct aggregates |
10 | DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector<unique_ptr<Expression>> &aggregates, |
11 | vector<idx_t> indices) |
12 | : indices(std::move(indices)), aggregates(aggregates) { |
13 | table_count = CreateTableIndexMap(); |
14 | |
15 | const idx_t aggregate_count = aggregates.size(); |
16 | |
17 | total_child_count = 0; |
18 | for (idx_t i = 0; i < aggregate_count; i++) { |
19 | auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>(); |
20 | |
21 | if (!aggregate.IsDistinct()) { |
22 | continue; |
23 | } |
24 | total_child_count += aggregate.children.size(); |
25 | } |
26 | } |
27 | |
28 | //! Stateful data for the distinct aggregates |
29 | |
30 | DistinctAggregateState::DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client) |
31 | : child_executor(client) { |
32 | |
33 | radix_states.resize(new_size: data.info.table_count); |
34 | distinct_output_chunks.resize(new_size: data.info.table_count); |
35 | |
36 | idx_t aggregate_count = data.info.aggregates.size(); |
37 | for (idx_t i = 0; i < aggregate_count; i++) { |
38 | auto &aggregate = data.info.aggregates[i]->Cast<BoundAggregateExpression>(); |
39 | |
40 | // Initialize the child executor and get the payload types for every aggregate |
41 | for (auto &child : aggregate.children) { |
42 | child_executor.AddExpression(expr: *child); |
43 | } |
44 | if (!aggregate.IsDistinct()) { |
45 | continue; |
46 | } |
47 | D_ASSERT(data.info.table_map.count(i)); |
48 | idx_t table_idx = data.info.table_map.at(k: i); |
49 | if (data.radix_tables[table_idx] == nullptr) { |
50 | //! This table is unused because the aggregate shares its data with another |
51 | continue; |
52 | } |
53 | |
54 | // Get the global sinkstate for the aggregate |
55 | auto &radix_table = *data.radix_tables[table_idx]; |
56 | radix_states[table_idx] = radix_table.GetGlobalSinkState(context&: client); |
57 | |
58 | // Fill the chunk_types (group_by + children) |
59 | vector<LogicalType> chunk_types; |
60 | for (auto &group_type : data.grouped_aggregate_data[table_idx]->group_types) { |
61 | chunk_types.push_back(x: group_type); |
62 | } |
63 | |
64 | // This is used in Finalize to get the data from the radix table |
65 | distinct_output_chunks[table_idx] = make_uniq<DataChunk>(); |
66 | distinct_output_chunks[table_idx]->Initialize(context&: client, types: chunk_types); |
67 | } |
68 | } |
69 | |
70 | //! Persistent + shared (read-only) data for the distinct aggregates |
71 | DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info) |
72 | : DistinctAggregateData(info, {}, nullptr) { |
73 | } |
74 | |
75 | DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups, |
76 | const vector<unique_ptr<Expression>> *group_expressions) |
77 | : info(info) { |
78 | grouped_aggregate_data.resize(new_size: info.table_count); |
79 | radix_tables.resize(new_size: info.table_count); |
80 | grouping_sets.resize(new_size: info.table_count); |
81 | |
82 | for (auto &i : info.indices) { |
83 | auto &aggregate = info.aggregates[i]->Cast<BoundAggregateExpression>(); |
84 | |
85 | D_ASSERT(info.table_map.count(i)); |
86 | idx_t table_idx = info.table_map.at(k: i); |
87 | if (radix_tables[table_idx] != nullptr) { |
88 | //! This aggregate shares a table with another aggregate, and the table is already initialized |
89 | continue; |
90 | } |
91 | // The grouping set contains the indices of the chunk that correspond to the data vector |
92 | // that will be used to figure out in which bucket the payload should be put |
93 | auto &grouping_set = grouping_sets[table_idx]; |
94 | //! Populate the group with the children of the aggregate |
95 | for (auto &group : groups) { |
96 | grouping_set.insert(x: group); |
97 | } |
98 | idx_t group_by_size = group_expressions ? group_expressions->size() : 0; |
99 | for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) { |
100 | grouping_set.insert(x: set_idx + group_by_size); |
101 | } |
102 | // Create the hashtable for the aggregate |
103 | grouped_aggregate_data[table_idx] = make_uniq<GroupedAggregateData>(); |
104 | grouped_aggregate_data[table_idx]->InitializeDistinct(aggregate: info.aggregates[i], groups_p: group_expressions); |
105 | radix_tables[table_idx] = |
106 | make_uniq<RadixPartitionedHashTable>(args&: grouping_set, args&: *grouped_aggregate_data[table_idx]); |
107 | |
108 | // Fill the chunk_types (only contains the payload of the distinct aggregates) |
109 | vector<LogicalType> chunk_types; |
110 | for (auto &child_p : aggregate.children) { |
111 | chunk_types.push_back(x: child_p->return_type); |
112 | } |
113 | } |
114 | } |
115 | |
116 | using aggr_ref_t = reference<BoundAggregateExpression>; |
117 | |
118 | struct FindMatchingAggregate { |
119 | explicit FindMatchingAggregate(const aggr_ref_t &aggr) : aggr_r(aggr) { |
120 | } |
121 | bool operator()(const aggr_ref_t other_r) { |
122 | auto &other = other_r.get(); |
123 | auto &aggr = aggr_r.get(); |
124 | if (other.children.size() != aggr.children.size()) { |
125 | return false; |
126 | } |
127 | if (!Expression::Equals(left: aggr.filter, right: other.filter)) { |
128 | return false; |
129 | } |
130 | for (idx_t i = 0; i < aggr.children.size(); i++) { |
131 | auto &other_child = other.children[i]->Cast<BoundReferenceExpression>(); |
132 | auto &aggr_child = aggr.children[i]->Cast<BoundReferenceExpression>(); |
133 | if (other_child.index != aggr_child.index) { |
134 | return false; |
135 | } |
136 | } |
137 | return true; |
138 | } |
139 | const aggr_ref_t aggr_r; |
140 | }; |
141 | |
142 | idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() { |
143 | vector<aggr_ref_t> table_inputs; |
144 | |
145 | D_ASSERT(table_map.empty()); |
146 | for (auto &agg_idx : indices) { |
147 | D_ASSERT(agg_idx < aggregates.size()); |
148 | auto &aggregate = aggregates[agg_idx]->Cast<BoundAggregateExpression>(); |
149 | |
150 | auto matching_inputs = |
151 | std::find_if(first: table_inputs.begin(), last: table_inputs.end(), pred: FindMatchingAggregate(std::ref(t&: aggregate))); |
152 | if (matching_inputs != table_inputs.end()) { |
153 | //! Assign the existing table to the aggregate |
154 | idx_t found_idx = std::distance(first: table_inputs.begin(), last: matching_inputs); |
155 | table_map[agg_idx] = found_idx; |
156 | continue; |
157 | } |
158 | //! Create a new table and assign its index to the aggregate |
159 | table_map[agg_idx] = table_inputs.size(); |
160 | table_inputs.push_back(x: std::ref(t&: aggregate)); |
161 | } |
162 | //! Every distinct aggregate needs to be assigned an index |
163 | D_ASSERT(table_map.size() == indices.size()); |
164 | //! There can not be more tables than there are distinct aggregates |
165 | D_ASSERT(table_inputs.size() <= indices.size()); |
166 | |
167 | return table_inputs.size(); |
168 | } |
169 | |
170 | bool DistinctAggregateCollectionInfo::AnyDistinct() const { |
171 | return !indices.empty(); |
172 | } |
173 | |
174 | const unsafe_vector<idx_t> &DistinctAggregateCollectionInfo::Indices() const { |
175 | return this->indices; |
176 | } |
177 | |
178 | static vector<idx_t> GetDistinctIndices(vector<unique_ptr<Expression>> &aggregates) { |
179 | vector<idx_t> distinct_indices; |
180 | for (idx_t i = 0; i < aggregates.size(); i++) { |
181 | auto &aggregate = aggregates[i]; |
182 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
183 | if (aggr.IsDistinct()) { |
184 | distinct_indices.push_back(x: i); |
185 | } |
186 | } |
187 | return distinct_indices; |
188 | } |
189 | |
190 | unique_ptr<DistinctAggregateCollectionInfo> |
191 | DistinctAggregateCollectionInfo::Create(vector<unique_ptr<Expression>> &aggregates) { |
192 | vector<idx_t> indices = GetDistinctIndices(aggregates); |
193 | if (indices.empty()) { |
194 | return nullptr; |
195 | } |
196 | return make_uniq<DistinctAggregateCollectionInfo>(args&: aggregates, args: std::move(indices)); |
197 | } |
198 | |
199 | bool DistinctAggregateData::IsDistinct(idx_t index) const { |
200 | bool is_distinct = !radix_tables.empty() && info.table_map.count(x: index); |
201 | #ifdef DEBUG |
202 | //! Make sure that if it is distinct, it's also in the indices |
203 | //! And if it's not distinct, that it's also not in the indices |
204 | bool found = false; |
205 | for (auto &idx : info.indices) { |
206 | if (idx == index) { |
207 | found = true; |
208 | break; |
209 | } |
210 | } |
211 | D_ASSERT(found == is_distinct); |
212 | #endif |
213 | return is_distinct; |
214 | } |
215 | |
216 | } // namespace duckdb |
217 | |