1 | #include "duckdb/parser/parsed_expression.hpp" |
2 | #include "duckdb/parser/transformer.hpp" |
3 | #include "duckdb/parser/query_node/select_node.hpp" |
4 | #include "duckdb/parser/expression_map.hpp" |
5 | #include "duckdb/parser/expression/function_expression.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | static void CheckGroupingSetMax(idx_t count) { |
10 | static constexpr const idx_t MAX_GROUPING_SETS = 65535; |
11 | if (count > MAX_GROUPING_SETS) { |
12 | throw ParserException("Maximum grouping set count of %d exceeded" , MAX_GROUPING_SETS); |
13 | } |
14 | } |
15 | |
16 | static void CheckGroupingSetCubes(idx_t current_count, idx_t cube_count) { |
17 | idx_t combinations = 1; |
18 | for (idx_t i = 0; i < cube_count; i++) { |
19 | combinations *= 2; |
20 | CheckGroupingSetMax(count: current_count + combinations); |
21 | } |
22 | } |
23 | |
24 | struct GroupingExpressionMap { |
25 | parsed_expression_map_t<idx_t> map; |
26 | }; |
27 | |
28 | static GroupingSet VectorToGroupingSet(vector<idx_t> &indexes) { |
29 | GroupingSet result; |
30 | for (idx_t i = 0; i < indexes.size(); i++) { |
31 | result.insert(x: indexes[i]); |
32 | } |
33 | return result; |
34 | } |
35 | |
36 | static void MergeGroupingSet(GroupingSet &result, GroupingSet &other) { |
37 | CheckGroupingSetMax(count: result.size() + other.size()); |
38 | result.insert(first: other.begin(), last: other.end()); |
39 | } |
40 | |
41 | void Transformer::AddGroupByExpression(unique_ptr<ParsedExpression> expression, GroupingExpressionMap &map, |
42 | GroupByNode &result, vector<idx_t> &result_set) { |
43 | if (expression->type == ExpressionType::FUNCTION) { |
44 | auto &func = expression->Cast<FunctionExpression>(); |
45 | if (func.function_name == "row" ) { |
46 | for (auto &child : func.children) { |
47 | AddGroupByExpression(expression: std::move(child), map, result, result_set); |
48 | } |
49 | return; |
50 | } |
51 | } |
52 | auto entry = map.map.find(x: *expression); |
53 | idx_t result_idx; |
54 | if (entry == map.map.end()) { |
55 | result_idx = result.group_expressions.size(); |
56 | map.map[*expression] = result_idx; |
57 | result.group_expressions.push_back(x: std::move(expression)); |
58 | } else { |
59 | result_idx = entry->second; |
60 | } |
61 | result_set.push_back(x: result_idx); |
62 | } |
63 | |
64 | static void AddCubeSets(const GroupingSet ¤t_set, vector<GroupingSet> &result_set, |
65 | vector<GroupingSet> &result_sets, idx_t start_idx = 0) { |
66 | CheckGroupingSetMax(count: result_sets.size()); |
67 | result_sets.push_back(x: current_set); |
68 | for (idx_t k = start_idx; k < result_set.size(); k++) { |
69 | auto child_set = current_set; |
70 | MergeGroupingSet(result&: child_set, other&: result_set[k]); |
71 | AddCubeSets(current_set: child_set, result_set, result_sets, start_idx: k + 1); |
72 | } |
73 | } |
74 | |
75 | void Transformer::TransformGroupByExpression(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, |
76 | GroupByNode &result, vector<idx_t> &indexes) { |
77 | auto expression = TransformExpression(node&: n); |
78 | AddGroupByExpression(expression: std::move(expression), map, result, result_set&: indexes); |
79 | } |
80 | |
81 | // If one GROUPING SETS clause is nested inside another, |
82 | // the effect is the same as if all the elements of the inner clause had been written directly in the outer clause. |
83 | void Transformer::TransformGroupByNode(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, SelectNode &result, |
84 | vector<GroupingSet> &result_sets) { |
85 | if (n.type == duckdb_libpgquery::T_PGGroupingSet) { |
86 | auto &grouping_set = PGCast<duckdb_libpgquery::PGGroupingSet>(node&: n); |
87 | switch (grouping_set.kind) { |
88 | case duckdb_libpgquery::GROUPING_SET_EMPTY: |
89 | result_sets.emplace_back(); |
90 | break; |
91 | case duckdb_libpgquery::GROUPING_SET_ALL: { |
92 | result.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; |
93 | break; |
94 | } |
95 | case duckdb_libpgquery::GROUPING_SET_SETS: { |
96 | for (auto node = grouping_set.content->head; node; node = node->next) { |
97 | auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); |
98 | TransformGroupByNode(n&: *pg_node, map, result, result_sets); |
99 | } |
100 | break; |
101 | } |
102 | case duckdb_libpgquery::GROUPING_SET_ROLLUP: { |
103 | vector<GroupingSet> rollup_sets; |
104 | for (auto node = grouping_set.content->head; node; node = node->next) { |
105 | auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); |
106 | vector<idx_t> rollup_set; |
107 | TransformGroupByExpression(n&: *pg_node, map, result&: result.groups, indexes&: rollup_set); |
108 | rollup_sets.push_back(x: VectorToGroupingSet(indexes&: rollup_set)); |
109 | } |
110 | // generate the subsets of the rollup set and add them to the grouping sets |
111 | GroupingSet current_set; |
112 | result_sets.push_back(x: current_set); |
113 | for (idx_t i = 0; i < rollup_sets.size(); i++) { |
114 | MergeGroupingSet(result&: current_set, other&: rollup_sets[i]); |
115 | result_sets.push_back(x: current_set); |
116 | } |
117 | break; |
118 | } |
119 | case duckdb_libpgquery::GROUPING_SET_CUBE: { |
120 | vector<GroupingSet> cube_sets; |
121 | for (auto node = grouping_set.content->head; node; node = node->next) { |
122 | auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); |
123 | vector<idx_t> cube_set; |
124 | TransformGroupByExpression(n&: *pg_node, map, result&: result.groups, indexes&: cube_set); |
125 | cube_sets.push_back(x: VectorToGroupingSet(indexes&: cube_set)); |
126 | } |
127 | // generate the subsets of the rollup set and add them to the grouping sets |
128 | CheckGroupingSetCubes(current_count: result_sets.size(), cube_count: cube_sets.size()); |
129 | |
130 | GroupingSet current_set; |
131 | AddCubeSets(current_set, result_set&: cube_sets, result_sets, start_idx: 0); |
132 | break; |
133 | } |
134 | default: |
135 | throw InternalException("Unsupported GROUPING SET type %d" , grouping_set.kind); |
136 | } |
137 | } else { |
138 | vector<idx_t> indexes; |
139 | TransformGroupByExpression(n, map, result&: result.groups, indexes); |
140 | result_sets.push_back(x: VectorToGroupingSet(indexes)); |
141 | } |
142 | } |
143 | |
144 | // If multiple grouping items are specified in a single GROUP BY clause, |
145 | // then the final list of grouping sets is the cross product of the individual items. |
146 | bool Transformer::TransformGroupBy(optional_ptr<duckdb_libpgquery::PGList> group, SelectNode &select_node) { |
147 | if (!group) { |
148 | return false; |
149 | } |
150 | auto &result = select_node.groups; |
151 | GroupingExpressionMap map; |
152 | for (auto node = group->head; node != nullptr; node = node->next) { |
153 | auto n = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); |
154 | vector<GroupingSet> result_sets; |
155 | TransformGroupByNode(n&: *n, map, result&: select_node, result_sets); |
156 | CheckGroupingSetMax(count: result_sets.size()); |
157 | if (result.grouping_sets.empty()) { |
158 | // no grouping sets yet: use the current set of grouping sets |
159 | result.grouping_sets = std::move(result_sets); |
160 | } else { |
161 | // compute the cross product |
162 | vector<GroupingSet> new_sets; |
163 | idx_t grouping_set_count = result.grouping_sets.size() * result_sets.size(); |
164 | CheckGroupingSetMax(count: grouping_set_count); |
165 | new_sets.reserve(n: grouping_set_count); |
166 | for (idx_t current_idx = 0; current_idx < result.grouping_sets.size(); current_idx++) { |
167 | auto ¤t_set = result.grouping_sets[current_idx]; |
168 | for (idx_t new_idx = 0; new_idx < result_sets.size(); new_idx++) { |
169 | auto &new_set = result_sets[new_idx]; |
170 | GroupingSet set; |
171 | set.insert(first: current_set.begin(), last: current_set.end()); |
172 | set.insert(first: new_set.begin(), last: new_set.end()); |
173 | new_sets.push_back(x: std::move(set)); |
174 | } |
175 | } |
176 | result.grouping_sets = std::move(new_sets); |
177 | } |
178 | } |
179 | if (result.group_expressions.size() == 1 && result.grouping_sets.size() == 1 && |
180 | ExpressionIsEmptyStar(expr&: *result.group_expressions[0])) { |
181 | // GROUP BY * |
182 | result.group_expressions.clear(); |
183 | result.grouping_sets.clear(); |
184 | select_node.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; |
185 | } |
186 | return true; |
187 | } |
188 | |
189 | } // namespace duckdb |
190 | |