| 1 | #include "duckdb/parser/parsed_expression.hpp" | 
| 2 | #include "duckdb/parser/transformer.hpp" | 
| 3 | #include "duckdb/parser/query_node/select_node.hpp" | 
| 4 | #include "duckdb/parser/expression_map.hpp" | 
| 5 | #include "duckdb/parser/expression/function_expression.hpp" | 
| 6 |  | 
| 7 | namespace duckdb { | 
| 8 |  | 
| 9 | static void CheckGroupingSetMax(idx_t count) { | 
| 10 | 	static constexpr const idx_t MAX_GROUPING_SETS = 65535; | 
| 11 | 	if (count > MAX_GROUPING_SETS) { | 
| 12 | 		throw ParserException("Maximum grouping set count of %d exceeded" , MAX_GROUPING_SETS); | 
| 13 | 	} | 
| 14 | } | 
| 15 |  | 
| 16 | static void CheckGroupingSetCubes(idx_t current_count, idx_t cube_count) { | 
| 17 | 	idx_t combinations = 1; | 
| 18 | 	for (idx_t i = 0; i < cube_count; i++) { | 
| 19 | 		combinations *= 2; | 
| 20 | 		CheckGroupingSetMax(count: current_count + combinations); | 
| 21 | 	} | 
| 22 | } | 
| 23 |  | 
| 24 | struct GroupingExpressionMap { | 
| 25 | 	parsed_expression_map_t<idx_t> map; | 
| 26 | }; | 
| 27 |  | 
| 28 | static GroupingSet VectorToGroupingSet(vector<idx_t> &indexes) { | 
| 29 | 	GroupingSet result; | 
| 30 | 	for (idx_t i = 0; i < indexes.size(); i++) { | 
| 31 | 		result.insert(x: indexes[i]); | 
| 32 | 	} | 
| 33 | 	return result; | 
| 34 | } | 
| 35 |  | 
| 36 | static void MergeGroupingSet(GroupingSet &result, GroupingSet &other) { | 
| 37 | 	CheckGroupingSetMax(count: result.size() + other.size()); | 
| 38 | 	result.insert(first: other.begin(), last: other.end()); | 
| 39 | } | 
| 40 |  | 
| 41 | void Transformer::AddGroupByExpression(unique_ptr<ParsedExpression> expression, GroupingExpressionMap &map, | 
| 42 |                                        GroupByNode &result, vector<idx_t> &result_set) { | 
| 43 | 	if (expression->type == ExpressionType::FUNCTION) { | 
| 44 | 		auto &func = expression->Cast<FunctionExpression>(); | 
| 45 | 		if (func.function_name == "row" ) { | 
| 46 | 			for (auto &child : func.children) { | 
| 47 | 				AddGroupByExpression(expression: std::move(child), map, result, result_set); | 
| 48 | 			} | 
| 49 | 			return; | 
| 50 | 		} | 
| 51 | 	} | 
| 52 | 	auto entry = map.map.find(x: *expression); | 
| 53 | 	idx_t result_idx; | 
| 54 | 	if (entry == map.map.end()) { | 
| 55 | 		result_idx = result.group_expressions.size(); | 
| 56 | 		map.map[*expression] = result_idx; | 
| 57 | 		result.group_expressions.push_back(x: std::move(expression)); | 
| 58 | 	} else { | 
| 59 | 		result_idx = entry->second; | 
| 60 | 	} | 
| 61 | 	result_set.push_back(x: result_idx); | 
| 62 | } | 
| 63 |  | 
| 64 | static void AddCubeSets(const GroupingSet ¤t_set, vector<GroupingSet> &result_set, | 
| 65 |                         vector<GroupingSet> &result_sets, idx_t start_idx = 0) { | 
| 66 | 	CheckGroupingSetMax(count: result_sets.size()); | 
| 67 | 	result_sets.push_back(x: current_set); | 
| 68 | 	for (idx_t k = start_idx; k < result_set.size(); k++) { | 
| 69 | 		auto child_set = current_set; | 
| 70 | 		MergeGroupingSet(result&: child_set, other&: result_set[k]); | 
| 71 | 		AddCubeSets(current_set: child_set, result_set, result_sets, start_idx: k + 1); | 
| 72 | 	} | 
| 73 | } | 
| 74 |  | 
| 75 | void Transformer::TransformGroupByExpression(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, | 
| 76 |                                              GroupByNode &result, vector<idx_t> &indexes) { | 
| 77 | 	auto expression = TransformExpression(node&: n); | 
| 78 | 	AddGroupByExpression(expression: std::move(expression), map, result, result_set&: indexes); | 
| 79 | } | 
| 80 |  | 
| 81 | // If one GROUPING SETS clause is nested inside another, | 
| 82 | // the effect is the same as if all the elements of the inner clause had been written directly in the outer clause. | 
| 83 | void Transformer::TransformGroupByNode(duckdb_libpgquery::PGNode &n, GroupingExpressionMap &map, SelectNode &result, | 
| 84 |                                        vector<GroupingSet> &result_sets) { | 
| 85 | 	if (n.type == duckdb_libpgquery::T_PGGroupingSet) { | 
| 86 | 		auto &grouping_set = PGCast<duckdb_libpgquery::PGGroupingSet>(node&: n); | 
| 87 | 		switch (grouping_set.kind) { | 
| 88 | 		case duckdb_libpgquery::GROUPING_SET_EMPTY: | 
| 89 | 			result_sets.emplace_back(); | 
| 90 | 			break; | 
| 91 | 		case duckdb_libpgquery::GROUPING_SET_ALL: { | 
| 92 | 			result.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; | 
| 93 | 			break; | 
| 94 | 		} | 
| 95 | 		case duckdb_libpgquery::GROUPING_SET_SETS: { | 
| 96 | 			for (auto node = grouping_set.content->head; node; node = node->next) { | 
| 97 | 				auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); | 
| 98 | 				TransformGroupByNode(n&: *pg_node, map, result, result_sets); | 
| 99 | 			} | 
| 100 | 			break; | 
| 101 | 		} | 
| 102 | 		case duckdb_libpgquery::GROUPING_SET_ROLLUP: { | 
| 103 | 			vector<GroupingSet> rollup_sets; | 
| 104 | 			for (auto node = grouping_set.content->head; node; node = node->next) { | 
| 105 | 				auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); | 
| 106 | 				vector<idx_t> rollup_set; | 
| 107 | 				TransformGroupByExpression(n&: *pg_node, map, result&: result.groups, indexes&: rollup_set); | 
| 108 | 				rollup_sets.push_back(x: VectorToGroupingSet(indexes&: rollup_set)); | 
| 109 | 			} | 
| 110 | 			// generate the subsets of the rollup set and add them to the grouping sets | 
| 111 | 			GroupingSet current_set; | 
| 112 | 			result_sets.push_back(x: current_set); | 
| 113 | 			for (idx_t i = 0; i < rollup_sets.size(); i++) { | 
| 114 | 				MergeGroupingSet(result&: current_set, other&: rollup_sets[i]); | 
| 115 | 				result_sets.push_back(x: current_set); | 
| 116 | 			} | 
| 117 | 			break; | 
| 118 | 		} | 
| 119 | 		case duckdb_libpgquery::GROUPING_SET_CUBE: { | 
| 120 | 			vector<GroupingSet> cube_sets; | 
| 121 | 			for (auto node = grouping_set.content->head; node; node = node->next) { | 
| 122 | 				auto pg_node = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); | 
| 123 | 				vector<idx_t> cube_set; | 
| 124 | 				TransformGroupByExpression(n&: *pg_node, map, result&: result.groups, indexes&: cube_set); | 
| 125 | 				cube_sets.push_back(x: VectorToGroupingSet(indexes&: cube_set)); | 
| 126 | 			} | 
| 127 | 			// generate the subsets of the rollup set and add them to the grouping sets | 
| 128 | 			CheckGroupingSetCubes(current_count: result_sets.size(), cube_count: cube_sets.size()); | 
| 129 |  | 
| 130 | 			GroupingSet current_set; | 
| 131 | 			AddCubeSets(current_set, result_set&: cube_sets, result_sets, start_idx: 0); | 
| 132 | 			break; | 
| 133 | 		} | 
| 134 | 		default: | 
| 135 | 			throw InternalException("Unsupported GROUPING SET type %d" , grouping_set.kind); | 
| 136 | 		} | 
| 137 | 	} else { | 
| 138 | 		vector<idx_t> indexes; | 
| 139 | 		TransformGroupByExpression(n, map, result&: result.groups, indexes); | 
| 140 | 		result_sets.push_back(x: VectorToGroupingSet(indexes)); | 
| 141 | 	} | 
| 142 | } | 
| 143 |  | 
| 144 | // If multiple grouping items are specified in a single GROUP BY clause, | 
| 145 | // then the final list of grouping sets is the cross product of the individual items. | 
| 146 | bool Transformer::TransformGroupBy(optional_ptr<duckdb_libpgquery::PGList> group, SelectNode &select_node) { | 
| 147 | 	if (!group) { | 
| 148 | 		return false; | 
| 149 | 	} | 
| 150 | 	auto &result = select_node.groups; | 
| 151 | 	GroupingExpressionMap map; | 
| 152 | 	for (auto node = group->head; node != nullptr; node = node->next) { | 
| 153 | 		auto n = PGPointerCast<duckdb_libpgquery::PGNode>(ptr: node->data.ptr_value); | 
| 154 | 		vector<GroupingSet> result_sets; | 
| 155 | 		TransformGroupByNode(n&: *n, map, result&: select_node, result_sets); | 
| 156 | 		CheckGroupingSetMax(count: result_sets.size()); | 
| 157 | 		if (result.grouping_sets.empty()) { | 
| 158 | 			// no grouping sets yet: use the current set of grouping sets | 
| 159 | 			result.grouping_sets = std::move(result_sets); | 
| 160 | 		} else { | 
| 161 | 			// compute the cross product | 
| 162 | 			vector<GroupingSet> new_sets; | 
| 163 | 			idx_t grouping_set_count = result.grouping_sets.size() * result_sets.size(); | 
| 164 | 			CheckGroupingSetMax(count: grouping_set_count); | 
| 165 | 			new_sets.reserve(n: grouping_set_count); | 
| 166 | 			for (idx_t current_idx = 0; current_idx < result.grouping_sets.size(); current_idx++) { | 
| 167 | 				auto ¤t_set = result.grouping_sets[current_idx]; | 
| 168 | 				for (idx_t new_idx = 0; new_idx < result_sets.size(); new_idx++) { | 
| 169 | 					auto &new_set = result_sets[new_idx]; | 
| 170 | 					GroupingSet set; | 
| 171 | 					set.insert(first: current_set.begin(), last: current_set.end()); | 
| 172 | 					set.insert(first: new_set.begin(), last: new_set.end()); | 
| 173 | 					new_sets.push_back(x: std::move(set)); | 
| 174 | 				} | 
| 175 | 			} | 
| 176 | 			result.grouping_sets = std::move(new_sets); | 
| 177 | 		} | 
| 178 | 	} | 
| 179 | 	if (result.group_expressions.size() == 1 && result.grouping_sets.size() == 1 && | 
| 180 | 	    ExpressionIsEmptyStar(expr&: *result.group_expressions[0])) { | 
| 181 | 		// GROUP BY * | 
| 182 | 		result.group_expressions.clear(); | 
| 183 | 		result.grouping_sets.clear(); | 
| 184 | 		select_node.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; | 
| 185 | 	} | 
| 186 | 	return true; | 
| 187 | } | 
| 188 |  | 
| 189 | } // namespace duckdb | 
| 190 |  |