1 | #include "duckdb/planner/expression_iterator.hpp" |
2 | |
3 | #include "duckdb/planner/bound_query_node.hpp" |
4 | #include "duckdb/planner/expression/list.hpp" |
5 | #include "duckdb/planner/query_node/bound_select_node.hpp" |
6 | #include "duckdb/planner/query_node/bound_set_operation_node.hpp" |
7 | #include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" |
8 | #include "duckdb/planner/tableref/list.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | void ExpressionIterator::EnumerateChildren(const Expression &expr, |
13 | const std::function<void(const Expression &child)> &callback) { |
14 | EnumerateChildren(expression&: (Expression &)expr, callback: [&](unique_ptr<Expression> &child) { callback(*child); }); |
15 | } |
16 | |
17 | void ExpressionIterator::EnumerateChildren(Expression &expr, const std::function<void(Expression &child)> &callback) { |
18 | EnumerateChildren(expression&: expr, callback: [&](unique_ptr<Expression> &child) { callback(*child); }); |
19 | } |
20 | |
21 | void ExpressionIterator::EnumerateChildren(Expression &expr, |
22 | const std::function<void(unique_ptr<Expression> &child)> &callback) { |
23 | switch (expr.expression_class) { |
24 | case ExpressionClass::BOUND_AGGREGATE: { |
25 | auto &aggr_expr = expr.Cast<BoundAggregateExpression>(); |
26 | for (auto &child : aggr_expr.children) { |
27 | callback(child); |
28 | } |
29 | if (aggr_expr.filter) { |
30 | callback(aggr_expr.filter); |
31 | } |
32 | if (aggr_expr.order_bys) { |
33 | for (auto &order : aggr_expr.order_bys->orders) { |
34 | callback(order.expression); |
35 | } |
36 | } |
37 | break; |
38 | } |
39 | case ExpressionClass::BOUND_BETWEEN: { |
40 | auto &between_expr = expr.Cast<BoundBetweenExpression>(); |
41 | callback(between_expr.input); |
42 | callback(between_expr.lower); |
43 | callback(between_expr.upper); |
44 | break; |
45 | } |
46 | case ExpressionClass::BOUND_CASE: { |
47 | auto &case_expr = expr.Cast<BoundCaseExpression>(); |
48 | for (auto &case_check : case_expr.case_checks) { |
49 | callback(case_check.when_expr); |
50 | callback(case_check.then_expr); |
51 | } |
52 | callback(case_expr.else_expr); |
53 | break; |
54 | } |
55 | case ExpressionClass::BOUND_CAST: { |
56 | auto &cast_expr = expr.Cast<BoundCastExpression>(); |
57 | callback(cast_expr.child); |
58 | break; |
59 | } |
60 | case ExpressionClass::BOUND_COMPARISON: { |
61 | auto &comp_expr = expr.Cast<BoundComparisonExpression>(); |
62 | callback(comp_expr.left); |
63 | callback(comp_expr.right); |
64 | break; |
65 | } |
66 | case ExpressionClass::BOUND_CONJUNCTION: { |
67 | auto &conj_expr = expr.Cast<BoundConjunctionExpression>(); |
68 | for (auto &child : conj_expr.children) { |
69 | callback(child); |
70 | } |
71 | break; |
72 | } |
73 | case ExpressionClass::BOUND_FUNCTION: { |
74 | auto &func_expr = expr.Cast<BoundFunctionExpression>(); |
75 | for (auto &child : func_expr.children) { |
76 | callback(child); |
77 | } |
78 | break; |
79 | } |
80 | case ExpressionClass::BOUND_OPERATOR: { |
81 | auto &op_expr = expr.Cast<BoundOperatorExpression>(); |
82 | for (auto &child : op_expr.children) { |
83 | callback(child); |
84 | } |
85 | break; |
86 | } |
87 | case ExpressionClass::BOUND_SUBQUERY: { |
88 | auto &subquery_expr = expr.Cast<BoundSubqueryExpression>(); |
89 | if (subquery_expr.child) { |
90 | callback(subquery_expr.child); |
91 | } |
92 | break; |
93 | } |
94 | case ExpressionClass::BOUND_WINDOW: { |
95 | auto &window_expr = expr.Cast<BoundWindowExpression>(); |
96 | for (auto &partition : window_expr.partitions) { |
97 | callback(partition); |
98 | } |
99 | for (auto &order : window_expr.orders) { |
100 | callback(order.expression); |
101 | } |
102 | for (auto &child : window_expr.children) { |
103 | callback(child); |
104 | } |
105 | if (window_expr.filter_expr) { |
106 | callback(window_expr.filter_expr); |
107 | } |
108 | if (window_expr.start_expr) { |
109 | callback(window_expr.start_expr); |
110 | } |
111 | if (window_expr.end_expr) { |
112 | callback(window_expr.end_expr); |
113 | } |
114 | if (window_expr.offset_expr) { |
115 | callback(window_expr.offset_expr); |
116 | } |
117 | if (window_expr.default_expr) { |
118 | callback(window_expr.default_expr); |
119 | } |
120 | break; |
121 | } |
122 | case ExpressionClass::BOUND_UNNEST: { |
123 | auto &unnest_expr = expr.Cast<BoundUnnestExpression>(); |
124 | callback(unnest_expr.child); |
125 | break; |
126 | } |
127 | case ExpressionClass::BOUND_COLUMN_REF: |
128 | case ExpressionClass::BOUND_LAMBDA_REF: |
129 | case ExpressionClass::BOUND_CONSTANT: |
130 | case ExpressionClass::BOUND_DEFAULT: |
131 | case ExpressionClass::BOUND_PARAMETER: |
132 | case ExpressionClass::BOUND_REF: |
133 | // these node types have no children |
134 | break; |
135 | default: |
136 | throw InternalException("ExpressionIterator used on unbound expression" ); |
137 | } |
138 | } |
139 | |
140 | void ExpressionIterator::EnumerateExpression(unique_ptr<Expression> &expr, |
141 | const std::function<void(Expression &child)> &callback) { |
142 | if (!expr) { |
143 | return; |
144 | } |
145 | callback(*expr); |
146 | ExpressionIterator::EnumerateChildren(expr&: *expr, |
147 | callback: [&](unique_ptr<Expression> &child) { EnumerateExpression(expr&: child, callback); }); |
148 | } |
149 | |
150 | void ExpressionIterator::EnumerateTableRefChildren(BoundTableRef &ref, |
151 | const std::function<void(Expression &child)> &callback) { |
152 | switch (ref.type) { |
153 | case TableReferenceType::EXPRESSION_LIST: { |
154 | auto &bound_expr_list = ref.Cast<BoundExpressionListRef>(); |
155 | for (auto &expr_list : bound_expr_list.values) { |
156 | for (auto &expr : expr_list) { |
157 | EnumerateExpression(expr, callback); |
158 | } |
159 | } |
160 | break; |
161 | } |
162 | case TableReferenceType::JOIN: { |
163 | auto &bound_join = ref.Cast<BoundJoinRef>(); |
164 | if (bound_join.condition) { |
165 | EnumerateExpression(expr&: bound_join.condition, callback); |
166 | } |
167 | EnumerateTableRefChildren(ref&: *bound_join.left, callback); |
168 | EnumerateTableRefChildren(ref&: *bound_join.right, callback); |
169 | break; |
170 | } |
171 | case TableReferenceType::SUBQUERY: { |
172 | auto &bound_subquery = ref.Cast<BoundSubqueryRef>(); |
173 | EnumerateQueryNodeChildren(node&: *bound_subquery.subquery, callback); |
174 | break; |
175 | } |
176 | case TableReferenceType::TABLE_FUNCTION: |
177 | case TableReferenceType::EMPTY: |
178 | case TableReferenceType::BASE_TABLE: |
179 | case TableReferenceType::CTE: |
180 | break; |
181 | default: |
182 | throw NotImplementedException("Unimplemented table reference type in ExpressionIterator" ); |
183 | } |
184 | } |
185 | |
186 | void ExpressionIterator::EnumerateQueryNodeChildren(BoundQueryNode &node, |
187 | const std::function<void(Expression &child)> &callback) { |
188 | switch (node.type) { |
189 | case QueryNodeType::SET_OPERATION_NODE: { |
190 | auto &bound_setop = node.Cast<BoundSetOperationNode>(); |
191 | EnumerateQueryNodeChildren(node&: *bound_setop.left, callback); |
192 | EnumerateQueryNodeChildren(node&: *bound_setop.right, callback); |
193 | break; |
194 | } |
195 | case QueryNodeType::RECURSIVE_CTE_NODE: { |
196 | auto &cte_node = node.Cast<BoundRecursiveCTENode>(); |
197 | EnumerateQueryNodeChildren(node&: *cte_node.left, callback); |
198 | EnumerateQueryNodeChildren(node&: *cte_node.right, callback); |
199 | break; |
200 | } |
201 | case QueryNodeType::SELECT_NODE: { |
202 | auto &bound_select = node.Cast<BoundSelectNode>(); |
203 | for (auto &expr : bound_select.select_list) { |
204 | EnumerateExpression(expr, callback); |
205 | } |
206 | EnumerateExpression(expr&: bound_select.where_clause, callback); |
207 | for (auto &expr : bound_select.groups.group_expressions) { |
208 | EnumerateExpression(expr, callback); |
209 | } |
210 | EnumerateExpression(expr&: bound_select.having, callback); |
211 | for (auto &expr : bound_select.aggregates) { |
212 | EnumerateExpression(expr, callback); |
213 | } |
214 | for (auto &entry : bound_select.unnests) { |
215 | for (auto &expr : entry.second.expressions) { |
216 | EnumerateExpression(expr, callback); |
217 | } |
218 | } |
219 | for (auto &expr : bound_select.windows) { |
220 | EnumerateExpression(expr, callback); |
221 | } |
222 | if (bound_select.from_table) { |
223 | EnumerateTableRefChildren(ref&: *bound_select.from_table, callback); |
224 | } |
225 | break; |
226 | } |
227 | default: |
228 | throw NotImplementedException("Unimplemented query node in ExpressionIterator" ); |
229 | } |
230 | for (idx_t i = 0; i < node.modifiers.size(); i++) { |
231 | switch (node.modifiers[i]->type) { |
232 | case ResultModifierType::DISTINCT_MODIFIER: |
233 | for (auto &expr : node.modifiers[i]->Cast<BoundDistinctModifier>().target_distincts) { |
234 | EnumerateExpression(expr, callback); |
235 | } |
236 | break; |
237 | case ResultModifierType::ORDER_MODIFIER: |
238 | for (auto &order : node.modifiers[i]->Cast<BoundOrderModifier>().orders) { |
239 | EnumerateExpression(expr&: order.expression, callback); |
240 | } |
241 | break; |
242 | default: |
243 | break; |
244 | } |
245 | } |
246 | } |
247 | |
248 | } // namespace duckdb |
249 | |