1#include "duckdb/parser/parsed_expression.hpp"
2#include "duckdb/parser/transformer.hpp"
3#include "duckdb/parser/query_node/select_node.hpp"
4#include "duckdb/parser/expression_map.hpp"
5#include "duckdb/parser/expression/function_expression.hpp"
6
7namespace duckdb {
8
9static void CheckGroupingSetMax(idx_t count) {
10 static constexpr const idx_t MAX_GROUPING_SETS = 65535;
11 if (count > MAX_GROUPING_SETS) {
12 throw ParserException("Maximum grouping set count of %d exceeded", MAX_GROUPING_SETS);
13 }
14}
15
16static void CheckGroupingSetCubes(idx_t current_count, idx_t cube_count) {
17 idx_t combinations = 1;
18 for (idx_t i = 0; i < cube_count; i++) {
19 combinations *= 2;
20 CheckGroupingSetMax(count: current_count + combinations);
21 }
22}
23
24struct GroupingExpressionMap {
25 parsed_expression_map_t<idx_t> map;
26};
27
28static GroupingSet VectorToGroupingSet(vector<idx_t> &indexes) {
29 GroupingSet result;
30 for (idx_t i = 0; i < indexes.size(); i++) {
31 result.insert(x: indexes[i]);
32 }
33 return result;
34}
35
36static void MergeGroupingSet(GroupingSet &result, GroupingSet &other) {
37 CheckGroupingSetMax(count: result.size() + other.size());
38 result.insert(first: other.begin(), last: other.end());
39}
40
41void Transformer::AddGroupByExpression(unique_ptr<ParsedExpression> expression, GroupingExpressionMap &map,
42 GroupByNode &result, vector<idx_t> &result_set) {
43 if (expression->type == ExpressionType::FUNCTION) {
44 auto &func = expression->Cast<FunctionExpression>();
45 if (func.function_name == "row") {
46 for (auto &child : func.children) {
47 AddGroupByExpression(expression: std::move(child), map, result, result_set);
48 }
49 return;
50 }
51 }
52 auto entry = map.map.find(x: *expression);
53 idx_t result_idx;
54 if (entry == map.map.end()) {
55 result_idx = result.group_expressions.size();
56 map.map[*expression] = result_idx;
57 result.group_expressions.push_back(x: std::move(expression));
58 } else {
59 result_idx = entry->second;
60 }
61 result_set.push_back(x: result_idx);
62}
63
64static void AddCubeSets(const GroupingSet &current_set, vector<GroupingSet> &result_set,
65 vector<GroupingSet> &result_sets, idx_t start_idx = 0) {
66 CheckGroupingSetMax(count: result_sets.size());
67 result_sets.push_back(x: current_set);
68 for (idx_t k = start_idx; k < result_set.size(); k++) {
69 auto child_set = current_set;
70 MergeGroupingSet(result&: child_set, other&: result_set[k]);
71 AddCubeSets(current_set: child_set, result_set, result_sets, start_idx: k + 1);
72 }
73}
74
75void Transformer::TransformGroupByExpression(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map,
76 GroupByNode &result, vector<idx_t> &indexes) {
77 auto expression = TransformExpression(node&: n);
78 AddGroupByExpression(expression: std::move(expression), map, result, result_set&: indexes);
79}
80
81// If one GROUPING SETS clause is nested inside another,
82// the effect is the same as if all the elements of the inner clause had been written directly in the outer clause.
83void Transformer::TransformGroupByNode(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, SelectNode &result,
84 vector<GroupingSet> &result_sets) {
85 if (n.type == duckdb_libpgquery::T_PGGroupingSet) {
86 auto &grouping_set = PGCast<duckdb_libpgquery::PGGroupingSet>(node&: n);
87 switch (grouping_set.kind) {
88 case duckdb_libpgquery::GROUPING_SET_EMPTY:
89 result_sets.emplace_back();
90 break;
91 case duckdb_libpgquery::GROUPING_SET_ALL: {
92 result.aggregate_handling = AggregateHandling::FORCE_AGGREGATES;
93 break;
94 }
95 case duckdb_libpgquery::GROUPING_SET_SETS: {
96 for (auto node = grouping_set.content->head; node; node = node->next) {
97 auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value);
98 TransformGroupByNode(n&: *pg_node, map, result, result_sets);
99 }
100 break;
101 }
102 case duckdb_libpgquery::GROUPING_SET_ROLLUP: {
103 vector<GroupingSet> rollup_sets;
104 for (auto node = grouping_set.content->head; node; node = node->next) {
105 auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value);
106 vector<idx_t> rollup_set;
107 TransformGroupByExpression(n&: *pg_node, map, result&: result.groups, indexes&: rollup_set);
108 rollup_sets.push_back(x: VectorToGroupingSet(indexes&: rollup_set));
109 }
110 // generate the subsets of the rollup set and add them to the grouping sets
111 GroupingSet current_set;
112 result_sets.push_back(x: current_set);
113 for (idx_t i = 0; i < rollup_sets.size(); i++) {
114 MergeGroupingSet(result&: current_set, other&: rollup_sets[i]);
115 result_sets.push_back(x: current_set);
116 }
117 break;
118 }
119 case duckdb_libpgquery::GROUPING_SET_CUBE: {
120 vector<GroupingSet> cube_sets;
121 for (auto node = grouping_set.content->head; node; node = node->next) {
122 auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value);
123 vector<idx_t> cube_set;
124 TransformGroupByExpression(n&: *pg_node, map, result&: result.groups, indexes&: cube_set);
125 cube_sets.push_back(x: VectorToGroupingSet(indexes&: cube_set));
126 }
127 // generate the subsets of the rollup set and add them to the grouping sets
128 CheckGroupingSetCubes(current_count: result_sets.size(), cube_count: cube_sets.size());
129
130 GroupingSet current_set;
131 AddCubeSets(current_set, result_set&: cube_sets, result_sets, start_idx: 0);
132 break;
133 }
134 default:
135 throw InternalException("Unsupported GROUPING SET type %d", grouping_set.kind);
136 }
137 } else {
138 vector<idx_t> indexes;
139 TransformGroupByExpression(n, map, result&: result.groups, indexes);
140 result_sets.push_back(x: VectorToGroupingSet(indexes));
141 }
142}
143
144// If multiple grouping items are specified in a single GROUP BY clause,
145// then the final list of grouping sets is the cross product of the individual items.
146bool Transformer::TransformGroupBy(optional_ptr<duckdb_libpgquery::PGList> group, SelectNode &select_node) {
147 if (!group) {
148 return false;
149 }
150 auto &result = select_node.groups;
151 GroupingExpressionMap map;
152 for (auto node = group->head; node != nullptr; node = node->next) {
153 auto n = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value);
154 vector<GroupingSet> result_sets;
155 TransformGroupByNode(n&: *n, map, result&: select_node, result_sets);
156 CheckGroupingSetMax(count: result_sets.size());
157 if (result.grouping_sets.empty()) {
158 // no grouping sets yet: use the current set of grouping sets
159 result.grouping_sets = std::move(result_sets);
160 } else {
161 // compute the cross product
162 vector<GroupingSet> new_sets;
163 idx_t grouping_set_count = result.grouping_sets.size() * result_sets.size();
164 CheckGroupingSetMax(count: grouping_set_count);
165 new_sets.reserve(n: grouping_set_count);
166 for (idx_t current_idx = 0; current_idx < result.grouping_sets.size(); current_idx++) {
167 auto &current_set = result.grouping_sets[current_idx];
168 for (idx_t new_idx = 0; new_idx < result_sets.size(); new_idx++) {
169 auto &new_set = result_sets[new_idx];
170 GroupingSet set;
171 set.insert(first: current_set.begin(), last: current_set.end());
172 set.insert(first: new_set.begin(), last: new_set.end());
173 new_sets.push_back(x: std::move(set));
174 }
175 }
176 result.grouping_sets = std::move(new_sets);
177 }
178 }
179 if (result.group_expressions.size() == 1 && result.grouping_sets.size() == 1 &&
180 ExpressionIsEmptyStar(expr&: *result.group_expressions[0])) {
181 // GROUP BY *
182 result.group_expressions.clear();
183 result.grouping_sets.clear();
184 select_node.aggregate_handling = AggregateHandling::FORCE_AGGREGATES;
185 }
186 return true;
187}
188
189} // namespace duckdb
190