1#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp"
2#include "duckdb/planner/expression.hpp"
3#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
4#include "duckdb/planner/expression/bound_reference_expression.hpp"
5#include "duckdb/common/algorithm.hpp"
6
7namespace duckdb {
8
9//! Shared information about a collection of distinct aggregates
10DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector<unique_ptr<Expression>> &aggregates,
11 vector<idx_t> indices)
12 : indices(std::move(indices)), aggregates(aggregates) {
13 table_count = CreateTableIndexMap();
14
15 const idx_t aggregate_count = aggregates.size();
16
17 total_child_count = 0;
18 for (idx_t i = 0; i < aggregate_count; i++) {
19 auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>();
20
21 if (!aggregate.IsDistinct()) {
22 continue;
23 }
24 total_child_count += aggregate.children.size();
25 }
26}
27
28//! Stateful data for the distinct aggregates
29
30DistinctAggregateState::DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client)
31 : child_executor(client) {
32
33 radix_states.resize(new_size: data.info.table_count);
34 distinct_output_chunks.resize(new_size: data.info.table_count);
35
36 idx_t aggregate_count = data.info.aggregates.size();
37 for (idx_t i = 0; i < aggregate_count; i++) {
38 auto &aggregate = data.info.aggregates[i]->Cast<BoundAggregateExpression>();
39
40 // Initialize the child executor and get the payload types for every aggregate
41 for (auto &child : aggregate.children) {
42 child_executor.AddExpression(expr: *child);
43 }
44 if (!aggregate.IsDistinct()) {
45 continue;
46 }
47 D_ASSERT(data.info.table_map.count(i));
48 idx_t table_idx = data.info.table_map.at(k: i);
49 if (data.radix_tables[table_idx] == nullptr) {
50 //! This table is unused because the aggregate shares its data with another
51 continue;
52 }
53
54 // Get the global sinkstate for the aggregate
55 auto &radix_table = *data.radix_tables[table_idx];
56 radix_states[table_idx] = radix_table.GetGlobalSinkState(context&: client);
57
58 // Fill the chunk_types (group_by + children)
59 vector<LogicalType> chunk_types;
60 for (auto &group_type : data.grouped_aggregate_data[table_idx]->group_types) {
61 chunk_types.push_back(x: group_type);
62 }
63
64 // This is used in Finalize to get the data from the radix table
65 distinct_output_chunks[table_idx] = make_uniq<DataChunk>();
66 distinct_output_chunks[table_idx]->Initialize(context&: client, types: chunk_types);
67 }
68}
69
70//! Persistent + shared (read-only) data for the distinct aggregates
71DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info)
72 : DistinctAggregateData(info, {}, nullptr) {
73}
74
75DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups,
76 const vector<unique_ptr<Expression>> *group_expressions)
77 : info(info) {
78 grouped_aggregate_data.resize(new_size: info.table_count);
79 radix_tables.resize(new_size: info.table_count);
80 grouping_sets.resize(new_size: info.table_count);
81
82 for (auto &i : info.indices) {
83 auto &aggregate = info.aggregates[i]->Cast<BoundAggregateExpression>();
84
85 D_ASSERT(info.table_map.count(i));
86 idx_t table_idx = info.table_map.at(k: i);
87 if (radix_tables[table_idx] != nullptr) {
88 //! This aggregate shares a table with another aggregate, and the table is already initialized
89 continue;
90 }
91 // The grouping set contains the indices of the chunk that correspond to the data vector
92 // that will be used to figure out in which bucket the payload should be put
93 auto &grouping_set = grouping_sets[table_idx];
94 //! Populate the group with the children of the aggregate
95 for (auto &group : groups) {
96 grouping_set.insert(x: group);
97 }
98 idx_t group_by_size = group_expressions ? group_expressions->size() : 0;
99 for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) {
100 grouping_set.insert(x: set_idx + group_by_size);
101 }
102 // Create the hashtable for the aggregate
103 grouped_aggregate_data[table_idx] = make_uniq<GroupedAggregateData>();
104 grouped_aggregate_data[table_idx]->InitializeDistinct(aggregate: info.aggregates[i], groups_p: group_expressions);
105 radix_tables[table_idx] =
106 make_uniq<RadixPartitionedHashTable>(args&: grouping_set, args&: *grouped_aggregate_data[table_idx]);
107
108 // Fill the chunk_types (only contains the payload of the distinct aggregates)
109 vector<LogicalType> chunk_types;
110 for (auto &child_p : aggregate.children) {
111 chunk_types.push_back(x: child_p->return_type);
112 }
113 }
114}
115
116using aggr_ref_t = reference<BoundAggregateExpression>;
117
118struct FindMatchingAggregate {
119 explicit FindMatchingAggregate(const aggr_ref_t &aggr) : aggr_r(aggr) {
120 }
121 bool operator()(const aggr_ref_t other_r) {
122 auto &other = other_r.get();
123 auto &aggr = aggr_r.get();
124 if (other.children.size() != aggr.children.size()) {
125 return false;
126 }
127 if (!Expression::Equals(left: aggr.filter, right: other.filter)) {
128 return false;
129 }
130 for (idx_t i = 0; i < aggr.children.size(); i++) {
131 auto &other_child = other.children[i]->Cast<BoundReferenceExpression>();
132 auto &aggr_child = aggr.children[i]->Cast<BoundReferenceExpression>();
133 if (other_child.index != aggr_child.index) {
134 return false;
135 }
136 }
137 return true;
138 }
139 const aggr_ref_t aggr_r;
140};
141
142idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() {
143 vector<aggr_ref_t> table_inputs;
144
145 D_ASSERT(table_map.empty());
146 for (auto &agg_idx : indices) {
147 D_ASSERT(agg_idx < aggregates.size());
148 auto &aggregate = aggregates[agg_idx]->Cast<BoundAggregateExpression>();
149
150 auto matching_inputs =
151 std::find_if(first: table_inputs.begin(), last: table_inputs.end(), pred: FindMatchingAggregate(std::ref(t&: aggregate)));
152 if (matching_inputs != table_inputs.end()) {
153 //! Assign the existing table to the aggregate
154 idx_t found_idx = std::distance(first: table_inputs.begin(), last: matching_inputs);
155 table_map[agg_idx] = found_idx;
156 continue;
157 }
158 //! Create a new table and assign its index to the aggregate
159 table_map[agg_idx] = table_inputs.size();
160 table_inputs.push_back(x: std::ref(t&: aggregate));
161 }
162 //! Every distinct aggregate needs to be assigned an index
163 D_ASSERT(table_map.size() == indices.size());
164 //! There can not be more tables than there are distinct aggregates
165 D_ASSERT(table_inputs.size() <= indices.size());
166
167 return table_inputs.size();
168}
169
170bool DistinctAggregateCollectionInfo::AnyDistinct() const {
171 return !indices.empty();
172}
173
174const unsafe_vector<idx_t> &DistinctAggregateCollectionInfo::Indices() const {
175 return this->indices;
176}
177
178static vector<idx_t> GetDistinctIndices(vector<unique_ptr<Expression>> &aggregates) {
179 vector<idx_t> distinct_indices;
180 for (idx_t i = 0; i < aggregates.size(); i++) {
181 auto &aggregate = aggregates[i];
182 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
183 if (aggr.IsDistinct()) {
184 distinct_indices.push_back(x: i);
185 }
186 }
187 return distinct_indices;
188}
189
190unique_ptr<DistinctAggregateCollectionInfo>
191DistinctAggregateCollectionInfo::Create(vector<unique_ptr<Expression>> &aggregates) {
192 vector<idx_t> indices = GetDistinctIndices(aggregates);
193 if (indices.empty()) {
194 return nullptr;
195 }
196 return make_uniq<DistinctAggregateCollectionInfo>(args&: aggregates, args: std::move(indices));
197}
198
199bool DistinctAggregateData::IsDistinct(idx_t index) const {
200 bool is_distinct = !radix_tables.empty() && info.table_map.count(x: index);
201#ifdef DEBUG
202 //! Make sure that if it is distinct, it's also in the indices
203 //! And if it's not distinct, that it's also not in the indices
204 bool found = false;
205 for (auto &idx : info.indices) {
206 if (idx == index) {
207 found = true;
208 break;
209 }
210 }
211 D_ASSERT(found == is_distinct);
212#endif
213 return is_distinct;
214}
215
216} // namespace duckdb
217