1#include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp"
2
3namespace duckdb {
4
5idx_t GroupedAggregateData::GroupCount() const {
6 return groups.size();
7}
8
9const vector<vector<idx_t>> &GroupedAggregateData::GetGroupingFunctions() const {
10 return grouping_functions;
11}
12
13void GroupedAggregateData::InitializeGroupby(vector<unique_ptr<Expression>> groups,
14 vector<unique_ptr<Expression>> expressions,
15 vector<unsafe_vector<idx_t>> grouping_functions) {
16 InitializeGroupbyGroups(groups: std::move(groups));
17 vector<LogicalType> payload_types_filters;
18
19 SetGroupingFunctions(grouping_functions);
20
21 filter_count = 0;
22 for (auto &expr : expressions) {
23 D_ASSERT(expr->expression_class == ExpressionClass::BOUND_AGGREGATE);
24 D_ASSERT(expr->IsAggregate());
25 auto &aggr = expr->Cast<BoundAggregateExpression>();
26 bindings.push_back(x: &aggr);
27
28 aggregate_return_types.push_back(x: aggr.return_type);
29 for (auto &child : aggr.children) {
30 payload_types.push_back(x: child->return_type);
31 }
32 if (aggr.filter) {
33 filter_count++;
34 payload_types_filters.push_back(x: aggr.filter->return_type);
35 }
36 if (!aggr.function.combine) {
37 throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name);
38 }
39 aggregates.push_back(x: std::move(expr));
40 }
41 for (const auto &pay_filters : payload_types_filters) {
42 payload_types.push_back(x: pay_filters);
43 }
44}
45
46void GroupedAggregateData::InitializeDistinct(const unique_ptr<Expression> &aggregate,
47 const vector<unique_ptr<Expression>> *groups_p) {
48 auto &aggr = aggregate->Cast<BoundAggregateExpression>();
49 D_ASSERT(aggr.IsDistinct());
50
51 // Add the (empty in ungrouped case) groups of the aggregates
52 InitializeDistinctGroups(groups: groups_p);
53
54 // bindings.push_back(&aggr);
55 filter_count = 0;
56 aggregate_return_types.push_back(x: aggr.return_type);
57 for (idx_t i = 0; i < aggr.children.size(); i++) {
58 auto &child = aggr.children[i];
59 group_types.push_back(x: child->return_type);
60 groups.push_back(x: child->Copy());
61 payload_types.push_back(x: child->return_type);
62 if (aggr.filter) {
63 filter_count++;
64 }
65 }
66 if (!aggr.function.combine) {
67 throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name);
68 }
69}
70
71void GroupedAggregateData::InitializeDistinctGroups(const vector<unique_ptr<Expression>> *groups_p) {
72 if (!groups_p) {
73 return;
74 }
75 for (auto &expr : *groups_p) {
76 group_types.push_back(x: expr->return_type);
77 groups.push_back(x: expr->Copy());
78 }
79}
80
81void GroupedAggregateData::InitializeGroupbyGroups(vector<unique_ptr<Expression>> groups) {
82 // Add all the expressions of the group by clause
83 for (auto &expr : groups) {
84 group_types.push_back(x: expr->return_type);
85 }
86 this->groups = std::move(groups);
87}
88
89void GroupedAggregateData::SetGroupingFunctions(vector<unsafe_vector<idx_t>> &functions) {
90 grouping_functions.reserve(n: functions.size());
91 for (idx_t i = 0; i < functions.size(); i++) {
92 grouping_functions.push_back(x: std::move(functions[i]));
93 }
94}
95
96} // namespace duckdb
97