1 | #include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp" |
---|---|
2 | |
3 | namespace duckdb { |
4 | |
5 | idx_t GroupedAggregateData::GroupCount() const { |
6 | return groups.size(); |
7 | } |
8 | |
9 | const vector<vector<idx_t>> &GroupedAggregateData::GetGroupingFunctions() const { |
10 | return grouping_functions; |
11 | } |
12 | |
13 | void GroupedAggregateData::InitializeGroupby(vector<unique_ptr<Expression>> groups, |
14 | vector<unique_ptr<Expression>> expressions, |
15 | vector<unsafe_vector<idx_t>> grouping_functions) { |
16 | InitializeGroupbyGroups(groups: std::move(groups)); |
17 | vector<LogicalType> payload_types_filters; |
18 | |
19 | SetGroupingFunctions(grouping_functions); |
20 | |
21 | filter_count = 0; |
22 | for (auto &expr : expressions) { |
23 | D_ASSERT(expr->expression_class == ExpressionClass::BOUND_AGGREGATE); |
24 | D_ASSERT(expr->IsAggregate()); |
25 | auto &aggr = expr->Cast<BoundAggregateExpression>(); |
26 | bindings.push_back(x: &aggr); |
27 | |
28 | aggregate_return_types.push_back(x: aggr.return_type); |
29 | for (auto &child : aggr.children) { |
30 | payload_types.push_back(x: child->return_type); |
31 | } |
32 | if (aggr.filter) { |
33 | filter_count++; |
34 | payload_types_filters.push_back(x: aggr.filter->return_type); |
35 | } |
36 | if (!aggr.function.combine) { |
37 | throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); |
38 | } |
39 | aggregates.push_back(x: std::move(expr)); |
40 | } |
41 | for (const auto &pay_filters : payload_types_filters) { |
42 | payload_types.push_back(x: pay_filters); |
43 | } |
44 | } |
45 | |
46 | void GroupedAggregateData::InitializeDistinct(const unique_ptr<Expression> &aggregate, |
47 | const vector<unique_ptr<Expression>> *groups_p) { |
48 | auto &aggr = aggregate->Cast<BoundAggregateExpression>(); |
49 | D_ASSERT(aggr.IsDistinct()); |
50 | |
51 | // Add the (empty in ungrouped case) groups of the aggregates |
52 | InitializeDistinctGroups(groups: groups_p); |
53 | |
54 | // bindings.push_back(&aggr); |
55 | filter_count = 0; |
56 | aggregate_return_types.push_back(x: aggr.return_type); |
57 | for (idx_t i = 0; i < aggr.children.size(); i++) { |
58 | auto &child = aggr.children[i]; |
59 | group_types.push_back(x: child->return_type); |
60 | groups.push_back(x: child->Copy()); |
61 | payload_types.push_back(x: child->return_type); |
62 | if (aggr.filter) { |
63 | filter_count++; |
64 | } |
65 | } |
66 | if (!aggr.function.combine) { |
67 | throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); |
68 | } |
69 | } |
70 | |
71 | void GroupedAggregateData::InitializeDistinctGroups(const vector<unique_ptr<Expression>> *groups_p) { |
72 | if (!groups_p) { |
73 | return; |
74 | } |
75 | for (auto &expr : *groups_p) { |
76 | group_types.push_back(x: expr->return_type); |
77 | groups.push_back(x: expr->Copy()); |
78 | } |
79 | } |
80 | |
81 | void GroupedAggregateData::InitializeGroupbyGroups(vector<unique_ptr<Expression>> groups) { |
82 | // Add all the expressions of the group by clause |
83 | for (auto &expr : groups) { |
84 | group_types.push_back(x: expr->return_type); |
85 | } |
86 | this->groups = std::move(groups); |
87 | } |
88 | |
89 | void GroupedAggregateData::SetGroupingFunctions(vector<unsafe_vector<idx_t>> &functions) { |
90 | grouping_functions.reserve(n: functions.size()); |
91 | for (idx_t i = 0; i < functions.size(); i++) { |
92 | grouping_functions.push_back(x: std::move(functions[i])); |
93 | } |
94 | } |
95 | |
96 | } // namespace duckdb |
97 |