1#include "duckdb/planner/expression_iterator.hpp"
2
3#include "duckdb/planner/bound_query_node.hpp"
4#include "duckdb/planner/expression/list.hpp"
5#include "duckdb/planner/query_node/bound_select_node.hpp"
6#include "duckdb/planner/query_node/bound_set_operation_node.hpp"
7#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp"
8#include "duckdb/planner/tableref/list.hpp"
9
10namespace duckdb {
11
12void ExpressionIterator::EnumerateChildren(const Expression &expr,
13 const std::function<void(const Expression &child)> &callback) {
14 EnumerateChildren(expression&: (Expression &)expr, callback: [&](unique_ptr<Expression> &child) { callback(*child); });
15}
16
17void ExpressionIterator::EnumerateChildren(Expression &expr, const std::function<void(Expression &child)> &callback) {
18 EnumerateChildren(expression&: expr, callback: [&](unique_ptr<Expression> &child) { callback(*child); });
19}
20
21void ExpressionIterator::EnumerateChildren(Expression &expr,
22 const std::function<void(unique_ptr<Expression> &child)> &callback) {
23 switch (expr.expression_class) {
24 case ExpressionClass::BOUND_AGGREGATE: {
25 auto &aggr_expr = expr.Cast<BoundAggregateExpression>();
26 for (auto &child : aggr_expr.children) {
27 callback(child);
28 }
29 if (aggr_expr.filter) {
30 callback(aggr_expr.filter);
31 }
32 if (aggr_expr.order_bys) {
33 for (auto &order : aggr_expr.order_bys->orders) {
34 callback(order.expression);
35 }
36 }
37 break;
38 }
39 case ExpressionClass::BOUND_BETWEEN: {
40 auto &between_expr = expr.Cast<BoundBetweenExpression>();
41 callback(between_expr.input);
42 callback(between_expr.lower);
43 callback(between_expr.upper);
44 break;
45 }
46 case ExpressionClass::BOUND_CASE: {
47 auto &case_expr = expr.Cast<BoundCaseExpression>();
48 for (auto &case_check : case_expr.case_checks) {
49 callback(case_check.when_expr);
50 callback(case_check.then_expr);
51 }
52 callback(case_expr.else_expr);
53 break;
54 }
55 case ExpressionClass::BOUND_CAST: {
56 auto &cast_expr = expr.Cast<BoundCastExpression>();
57 callback(cast_expr.child);
58 break;
59 }
60 case ExpressionClass::BOUND_COMPARISON: {
61 auto &comp_expr = expr.Cast<BoundComparisonExpression>();
62 callback(comp_expr.left);
63 callback(comp_expr.right);
64 break;
65 }
66 case ExpressionClass::BOUND_CONJUNCTION: {
67 auto &conj_expr = expr.Cast<BoundConjunctionExpression>();
68 for (auto &child : conj_expr.children) {
69 callback(child);
70 }
71 break;
72 }
73 case ExpressionClass::BOUND_FUNCTION: {
74 auto &func_expr = expr.Cast<BoundFunctionExpression>();
75 for (auto &child : func_expr.children) {
76 callback(child);
77 }
78 break;
79 }
80 case ExpressionClass::BOUND_OPERATOR: {
81 auto &op_expr = expr.Cast<BoundOperatorExpression>();
82 for (auto &child : op_expr.children) {
83 callback(child);
84 }
85 break;
86 }
87 case ExpressionClass::BOUND_SUBQUERY: {
88 auto &subquery_expr = expr.Cast<BoundSubqueryExpression>();
89 if (subquery_expr.child) {
90 callback(subquery_expr.child);
91 }
92 break;
93 }
94 case ExpressionClass::BOUND_WINDOW: {
95 auto &window_expr = expr.Cast<BoundWindowExpression>();
96 for (auto &partition : window_expr.partitions) {
97 callback(partition);
98 }
99 for (auto &order : window_expr.orders) {
100 callback(order.expression);
101 }
102 for (auto &child : window_expr.children) {
103 callback(child);
104 }
105 if (window_expr.filter_expr) {
106 callback(window_expr.filter_expr);
107 }
108 if (window_expr.start_expr) {
109 callback(window_expr.start_expr);
110 }
111 if (window_expr.end_expr) {
112 callback(window_expr.end_expr);
113 }
114 if (window_expr.offset_expr) {
115 callback(window_expr.offset_expr);
116 }
117 if (window_expr.default_expr) {
118 callback(window_expr.default_expr);
119 }
120 break;
121 }
122 case ExpressionClass::BOUND_UNNEST: {
123 auto &unnest_expr = expr.Cast<BoundUnnestExpression>();
124 callback(unnest_expr.child);
125 break;
126 }
127 case ExpressionClass::BOUND_COLUMN_REF:
128 case ExpressionClass::BOUND_LAMBDA_REF:
129 case ExpressionClass::BOUND_CONSTANT:
130 case ExpressionClass::BOUND_DEFAULT:
131 case ExpressionClass::BOUND_PARAMETER:
132 case ExpressionClass::BOUND_REF:
133 // these node types have no children
134 break;
135 default:
136 throw InternalException("ExpressionIterator used on unbound expression");
137 }
138}
139
140void ExpressionIterator::EnumerateExpression(unique_ptr<Expression> &expr,
141 const std::function<void(Expression &child)> &callback) {
142 if (!expr) {
143 return;
144 }
145 callback(*expr);
146 ExpressionIterator::EnumerateChildren(expr&: *expr,
147 callback: [&](unique_ptr<Expression> &child) { EnumerateExpression(expr&: child, callback); });
148}
149
150void ExpressionIterator::EnumerateTableRefChildren(BoundTableRef &ref,
151 const std::function<void(Expression &child)> &callback) {
152 switch (ref.type) {
153 case TableReferenceType::EXPRESSION_LIST: {
154 auto &bound_expr_list = ref.Cast<BoundExpressionListRef>();
155 for (auto &expr_list : bound_expr_list.values) {
156 for (auto &expr : expr_list) {
157 EnumerateExpression(expr, callback);
158 }
159 }
160 break;
161 }
162 case TableReferenceType::JOIN: {
163 auto &bound_join = ref.Cast<BoundJoinRef>();
164 if (bound_join.condition) {
165 EnumerateExpression(expr&: bound_join.condition, callback);
166 }
167 EnumerateTableRefChildren(ref&: *bound_join.left, callback);
168 EnumerateTableRefChildren(ref&: *bound_join.right, callback);
169 break;
170 }
171 case TableReferenceType::SUBQUERY: {
172 auto &bound_subquery = ref.Cast<BoundSubqueryRef>();
173 EnumerateQueryNodeChildren(node&: *bound_subquery.subquery, callback);
174 break;
175 }
176 case TableReferenceType::TABLE_FUNCTION:
177 case TableReferenceType::EMPTY:
178 case TableReferenceType::BASE_TABLE:
179 case TableReferenceType::CTE:
180 break;
181 default:
182 throw NotImplementedException("Unimplemented table reference type in ExpressionIterator");
183 }
184}
185
186void ExpressionIterator::EnumerateQueryNodeChildren(BoundQueryNode &node,
187 const std::function<void(Expression &child)> &callback) {
188 switch (node.type) {
189 case QueryNodeType::SET_OPERATION_NODE: {
190 auto &bound_setop = node.Cast<BoundSetOperationNode>();
191 EnumerateQueryNodeChildren(node&: *bound_setop.left, callback);
192 EnumerateQueryNodeChildren(node&: *bound_setop.right, callback);
193 break;
194 }
195 case QueryNodeType::RECURSIVE_CTE_NODE: {
196 auto &cte_node = node.Cast<BoundRecursiveCTENode>();
197 EnumerateQueryNodeChildren(node&: *cte_node.left, callback);
198 EnumerateQueryNodeChildren(node&: *cte_node.right, callback);
199 break;
200 }
201 case QueryNodeType::SELECT_NODE: {
202 auto &bound_select = node.Cast<BoundSelectNode>();
203 for (auto &expr : bound_select.select_list) {
204 EnumerateExpression(expr, callback);
205 }
206 EnumerateExpression(expr&: bound_select.where_clause, callback);
207 for (auto &expr : bound_select.groups.group_expressions) {
208 EnumerateExpression(expr, callback);
209 }
210 EnumerateExpression(expr&: bound_select.having, callback);
211 for (auto &expr : bound_select.aggregates) {
212 EnumerateExpression(expr, callback);
213 }
214 for (auto &entry : bound_select.unnests) {
215 for (auto &expr : entry.second.expressions) {
216 EnumerateExpression(expr, callback);
217 }
218 }
219 for (auto &expr : bound_select.windows) {
220 EnumerateExpression(expr, callback);
221 }
222 if (bound_select.from_table) {
223 EnumerateTableRefChildren(ref&: *bound_select.from_table, callback);
224 }
225 break;
226 }
227 default:
228 throw NotImplementedException("Unimplemented query node in ExpressionIterator");
229 }
230 for (idx_t i = 0; i < node.modifiers.size(); i++) {
231 switch (node.modifiers[i]->type) {
232 case ResultModifierType::DISTINCT_MODIFIER:
233 for (auto &expr : node.modifiers[i]->Cast<BoundDistinctModifier>().target_distincts) {
234 EnumerateExpression(expr, callback);
235 }
236 break;
237 case ResultModifierType::ORDER_MODIFIER:
238 for (auto &order : node.modifiers[i]->Cast<BoundOrderModifier>().orders) {
239 EnumerateExpression(expr&: order.expression, callback);
240 }
241 break;
242 default:
243 break;
244 }
245 }
246}
247
248} // namespace duckdb
249