1#include "duckdb/parser/parsed_expression_iterator.hpp"
2
3#include "duckdb/parser/expression/list.hpp"
4#include "duckdb/parser/query_node.hpp"
5#include "duckdb/parser/query_node/recursive_cte_node.hpp"
6#include "duckdb/parser/query_node/select_node.hpp"
7#include "duckdb/parser/query_node/set_operation_node.hpp"
8#include "duckdb/parser/tableref/list.hpp"
9
10namespace duckdb {
11
12void ParsedExpressionIterator::EnumerateChildren(const ParsedExpression &expression,
13 const std::function<void(const ParsedExpression &child)> &callback) {
14 EnumerateChildren(expr&: (ParsedExpression &)expression, callback: [&](unique_ptr<ParsedExpression> &child) {
15 D_ASSERT(child);
16 callback(*child);
17 });
18}
19
20void ParsedExpressionIterator::EnumerateChildren(ParsedExpression &expr,
21 const std::function<void(ParsedExpression &child)> &callback) {
22 EnumerateChildren(expr, callback: [&](unique_ptr<ParsedExpression> &child) {
23 D_ASSERT(child);
24 callback(*child);
25 });
26}
27
28void ParsedExpressionIterator::EnumerateChildren(
29 ParsedExpression &expr, const std::function<void(unique_ptr<ParsedExpression> &child)> &callback) {
30 switch (expr.expression_class) {
31 case ExpressionClass::BETWEEN: {
32 auto &cast_expr = expr.Cast<BetweenExpression>();
33 callback(cast_expr.input);
34 callback(cast_expr.lower);
35 callback(cast_expr.upper);
36 break;
37 }
38 case ExpressionClass::CASE: {
39 auto &case_expr = expr.Cast<CaseExpression>();
40 for (auto &check : case_expr.case_checks) {
41 callback(check.when_expr);
42 callback(check.then_expr);
43 }
44 callback(case_expr.else_expr);
45 break;
46 }
47 case ExpressionClass::CAST: {
48 auto &cast_expr = expr.Cast<CastExpression>();
49 callback(cast_expr.child);
50 break;
51 }
52 case ExpressionClass::COLLATE: {
53 auto &cast_expr = expr.Cast<CollateExpression>();
54 callback(cast_expr.child);
55 break;
56 }
57 case ExpressionClass::COMPARISON: {
58 auto &comp_expr = expr.Cast<ComparisonExpression>();
59 callback(comp_expr.left);
60 callback(comp_expr.right);
61 break;
62 }
63 case ExpressionClass::CONJUNCTION: {
64 auto &conj_expr = expr.Cast<ConjunctionExpression>();
65 for (auto &child : conj_expr.children) {
66 callback(child);
67 }
68 break;
69 }
70
71 case ExpressionClass::FUNCTION: {
72 auto &func_expr = expr.Cast<FunctionExpression>();
73 for (auto &child : func_expr.children) {
74 callback(child);
75 }
76 if (func_expr.filter) {
77 callback(func_expr.filter);
78 }
79 if (func_expr.order_bys) {
80 for (auto &order : func_expr.order_bys->orders) {
81 callback(order.expression);
82 }
83 }
84 break;
85 }
86 case ExpressionClass::LAMBDA: {
87 auto &lambda_expr = expr.Cast<LambdaExpression>();
88 callback(lambda_expr.lhs);
89 callback(lambda_expr.expr);
90 break;
91 }
92 case ExpressionClass::OPERATOR: {
93 auto &op_expr = expr.Cast<OperatorExpression>();
94 for (auto &child : op_expr.children) {
95 callback(child);
96 }
97 break;
98 }
99 case ExpressionClass::STAR: {
100 auto &star_expr = expr.Cast<StarExpression>();
101 if (star_expr.expr) {
102 callback(star_expr.expr);
103 }
104 break;
105 }
106 case ExpressionClass::SUBQUERY: {
107 auto &subquery_expr = expr.Cast<SubqueryExpression>();
108 if (subquery_expr.child) {
109 callback(subquery_expr.child);
110 }
111 break;
112 }
113 case ExpressionClass::WINDOW: {
114 auto &window_expr = expr.Cast<WindowExpression>();
115 for (auto &partition : window_expr.partitions) {
116 callback(partition);
117 }
118 for (auto &order : window_expr.orders) {
119 callback(order.expression);
120 }
121 for (auto &child : window_expr.children) {
122 callback(child);
123 }
124 if (window_expr.filter_expr) {
125 callback(window_expr.filter_expr);
126 }
127 if (window_expr.start_expr) {
128 callback(window_expr.start_expr);
129 }
130 if (window_expr.end_expr) {
131 callback(window_expr.end_expr);
132 }
133 if (window_expr.offset_expr) {
134 callback(window_expr.offset_expr);
135 }
136 if (window_expr.default_expr) {
137 callback(window_expr.default_expr);
138 }
139 break;
140 }
141 case ExpressionClass::BOUND_EXPRESSION:
142 case ExpressionClass::COLUMN_REF:
143 case ExpressionClass::CONSTANT:
144 case ExpressionClass::DEFAULT:
145 case ExpressionClass::PARAMETER:
146 case ExpressionClass::POSITIONAL_REFERENCE:
147 // these node types have no children
148 break;
149 default:
150 // called on non ParsedExpression type!
151 throw NotImplementedException("Unimplemented expression class");
152 }
153}
154
155void ParsedExpressionIterator::EnumerateQueryNodeModifiers(
156 QueryNode &node, const std::function<void(unique_ptr<ParsedExpression> &child)> &callback) {
157
158 for (auto &modifier : node.modifiers) {
159 switch (modifier->type) {
160 case ResultModifierType::LIMIT_MODIFIER: {
161 auto &limit_modifier = modifier->Cast<LimitModifier>();
162 if (limit_modifier.limit) {
163 callback(limit_modifier.limit);
164 }
165 if (limit_modifier.offset) {
166 callback(limit_modifier.offset);
167 }
168 } break;
169
170 case ResultModifierType::LIMIT_PERCENT_MODIFIER: {
171 auto &limit_modifier = modifier->Cast<LimitPercentModifier>();
172 if (limit_modifier.limit) {
173 callback(limit_modifier.limit);
174 }
175 if (limit_modifier.offset) {
176 callback(limit_modifier.offset);
177 }
178 } break;
179
180 case ResultModifierType::ORDER_MODIFIER: {
181 auto &order_modifier = modifier->Cast<OrderModifier>();
182 for (auto &order : order_modifier.orders) {
183 callback(order.expression);
184 }
185 } break;
186
187 case ResultModifierType::DISTINCT_MODIFIER: {
188 auto &distinct_modifier = modifier->Cast<DistinctModifier>();
189 for (auto &target : distinct_modifier.distinct_on_targets) {
190 callback(target);
191 }
192 } break;
193
194 // do nothing
195 default:
196 break;
197 }
198 }
199}
200
201void ParsedExpressionIterator::EnumerateTableRefChildren(
202 TableRef &ref, const std::function<void(unique_ptr<ParsedExpression> &child)> &callback) {
203 switch (ref.type) {
204 case TableReferenceType::EXPRESSION_LIST: {
205 auto &el_ref = ref.Cast<ExpressionListRef>();
206 for (idx_t i = 0; i < el_ref.values.size(); i++) {
207 for (idx_t j = 0; j < el_ref.values[i].size(); j++) {
208 callback(el_ref.values[i][j]);
209 }
210 }
211 break;
212 }
213 case TableReferenceType::JOIN: {
214 auto &j_ref = ref.Cast<JoinRef>();
215 EnumerateTableRefChildren(ref&: *j_ref.left, callback);
216 EnumerateTableRefChildren(ref&: *j_ref.right, callback);
217 if (j_ref.condition) {
218 callback(j_ref.condition);
219 }
220 break;
221 }
222 case TableReferenceType::PIVOT: {
223 auto &p_ref = ref.Cast<PivotRef>();
224 EnumerateTableRefChildren(ref&: *p_ref.source, callback);
225 for (auto &aggr : p_ref.aggregates) {
226 callback(aggr);
227 }
228 break;
229 }
230 case TableReferenceType::SUBQUERY: {
231 auto &sq_ref = ref.Cast<SubqueryRef>();
232 EnumerateQueryNodeChildren(node&: *sq_ref.subquery->node, callback);
233 break;
234 }
235 case TableReferenceType::TABLE_FUNCTION: {
236 auto &tf_ref = ref.Cast<TableFunctionRef>();
237 callback(tf_ref.function);
238 break;
239 }
240 case TableReferenceType::BASE_TABLE:
241 case TableReferenceType::EMPTY:
242 // these TableRefs do not need to be unfolded
243 break;
244 case TableReferenceType::INVALID:
245 case TableReferenceType::CTE:
246 throw NotImplementedException("TableRef type not implemented for traversal");
247 }
248}
249
250void ParsedExpressionIterator::EnumerateQueryNodeChildren(
251 QueryNode &node, const std::function<void(unique_ptr<ParsedExpression> &child)> &callback) {
252 switch (node.type) {
253 case QueryNodeType::RECURSIVE_CTE_NODE: {
254 auto &rcte_node = node.Cast<RecursiveCTENode>();
255 EnumerateQueryNodeChildren(node&: *rcte_node.left, callback);
256 EnumerateQueryNodeChildren(node&: *rcte_node.right, callback);
257 break;
258 }
259 case QueryNodeType::SELECT_NODE: {
260 auto &sel_node = node.Cast<SelectNode>();
261 for (idx_t i = 0; i < sel_node.select_list.size(); i++) {
262 callback(sel_node.select_list[i]);
263 }
264 for (idx_t i = 0; i < sel_node.groups.group_expressions.size(); i++) {
265 callback(sel_node.groups.group_expressions[i]);
266 }
267 if (sel_node.where_clause) {
268 callback(sel_node.where_clause);
269 }
270 if (sel_node.having) {
271 callback(sel_node.having);
272 }
273 if (sel_node.qualify) {
274 callback(sel_node.qualify);
275 }
276
277 EnumerateTableRefChildren(ref&: *sel_node.from_table.get(), callback);
278 break;
279 }
280 case QueryNodeType::SET_OPERATION_NODE: {
281 auto &setop_node = node.Cast<SetOperationNode>();
282 EnumerateQueryNodeChildren(node&: *setop_node.left, callback);
283 EnumerateQueryNodeChildren(node&: *setop_node.right, callback);
284 break;
285 }
286 default:
287 throw NotImplementedException("QueryNode type not implemented for traversal");
288 }
289
290 if (!node.modifiers.empty()) {
291 EnumerateQueryNodeModifiers(node, callback);
292 }
293
294 for (auto &kv : node.cte_map.map) {
295 EnumerateQueryNodeChildren(node&: *kv.second->query->node, callback);
296 }
297}
298
299} // namespace duckdb
300