1 | #include "duckdb/planner/expression_iterator.hpp" |
2 | |
3 | #include "duckdb/planner/bound_query_node.hpp" |
4 | #include "duckdb/planner/expression/list.hpp" |
5 | #include "duckdb/planner/query_node/bound_select_node.hpp" |
6 | #include "duckdb/planner/query_node/bound_set_operation_node.hpp" |
7 | #include "duckdb/planner/tableref/list.hpp" |
8 | |
9 | using namespace duckdb; |
10 | using namespace std; |
11 | |
12 | void ExpressionIterator::EnumerateChildren(const Expression &expr, function<void(const Expression &child)> callback) { |
13 | EnumerateChildren((Expression &)expr, [&](unique_ptr<Expression> child) -> unique_ptr<Expression> { |
14 | callback(*child); |
15 | return move(child); |
16 | }); |
17 | } |
18 | |
19 | void ExpressionIterator::EnumerateChildren(Expression &expr, std::function<void(Expression &child)> callback) { |
20 | EnumerateChildren(expr, [&](unique_ptr<Expression> child) -> unique_ptr<Expression> { |
21 | callback(*child); |
22 | return move(child); |
23 | }); |
24 | } |
25 | |
26 | void ExpressionIterator::EnumerateChildren(Expression &expr, |
27 | function<unique_ptr<Expression>(unique_ptr<Expression> child)> callback) { |
28 | switch (expr.expression_class) { |
29 | case ExpressionClass::BOUND_AGGREGATE: { |
30 | auto &aggr_expr = (BoundAggregateExpression &)expr; |
31 | for (auto &child : aggr_expr.children) { |
32 | child = callback(move(child)); |
33 | } |
34 | break; |
35 | } |
36 | case ExpressionClass::BOUND_BETWEEN: { |
37 | auto &between_expr = (BoundBetweenExpression &)expr; |
38 | between_expr.input = callback(move(between_expr.input)); |
39 | between_expr.lower = callback(move(between_expr.lower)); |
40 | between_expr.upper = callback(move(between_expr.upper)); |
41 | break; |
42 | } |
43 | case ExpressionClass::BOUND_CASE: { |
44 | auto &case_expr = (BoundCaseExpression &)expr; |
45 | case_expr.check = callback(move(case_expr.check)); |
46 | case_expr.result_if_true = callback(move(case_expr.result_if_true)); |
47 | case_expr.result_if_false = callback(move(case_expr.result_if_false)); |
48 | break; |
49 | } |
50 | case ExpressionClass::BOUND_CAST: { |
51 | auto &cast_expr = (BoundCastExpression &)expr; |
52 | cast_expr.child = callback(move(cast_expr.child)); |
53 | break; |
54 | } |
55 | case ExpressionClass::BOUND_COMPARISON: { |
56 | auto &comp_expr = (BoundComparisonExpression &)expr; |
57 | comp_expr.left = callback(move(comp_expr.left)); |
58 | comp_expr.right = callback(move(comp_expr.right)); |
59 | break; |
60 | } |
61 | case ExpressionClass::BOUND_CONJUNCTION: { |
62 | auto &conj_expr = (BoundConjunctionExpression &)expr; |
63 | for (auto &child : conj_expr.children) { |
64 | child = callback(move(child)); |
65 | } |
66 | break; |
67 | } |
68 | case ExpressionClass::BOUND_FUNCTION: { |
69 | auto &func_expr = (BoundFunctionExpression &)expr; |
70 | for (auto &child : func_expr.children) { |
71 | child = callback(move(child)); |
72 | } |
73 | break; |
74 | } |
75 | case ExpressionClass::BOUND_OPERATOR: { |
76 | auto &op_expr = (BoundOperatorExpression &)expr; |
77 | for (auto &child : op_expr.children) { |
78 | child = callback(move(child)); |
79 | } |
80 | break; |
81 | } |
82 | case ExpressionClass::BOUND_SUBQUERY: { |
83 | auto &subquery_expr = (BoundSubqueryExpression &)expr; |
84 | if (subquery_expr.child) { |
85 | subquery_expr.child = callback(move(subquery_expr.child)); |
86 | } |
87 | break; |
88 | } |
89 | case ExpressionClass::BOUND_WINDOW: { |
90 | auto &window_expr = (BoundWindowExpression &)expr; |
91 | for (auto &partition : window_expr.partitions) { |
92 | partition = callback(move(partition)); |
93 | } |
94 | for (auto &order : window_expr.orders) { |
95 | order.expression = callback(move(order.expression)); |
96 | } |
97 | for (auto &child : window_expr.children) { |
98 | child = callback(move(child)); |
99 | } |
100 | if (window_expr.offset_expr) { |
101 | window_expr.offset_expr = callback(move(window_expr.offset_expr)); |
102 | } |
103 | if (window_expr.default_expr) { |
104 | window_expr.default_expr = callback(move(window_expr.default_expr)); |
105 | } |
106 | break; |
107 | } |
108 | case ExpressionClass::BOUND_UNNEST: { |
109 | auto &unnest_expr = (BoundUnnestExpression &)expr; |
110 | unnest_expr.child = callback(move(unnest_expr.child)); |
111 | break; |
112 | } |
113 | case ExpressionClass::COMMON_SUBEXPRESSION: { |
114 | auto &cse_expr = (CommonSubExpression &)expr; |
115 | if (cse_expr.owned_child) { |
116 | cse_expr.owned_child = callback(move(cse_expr.owned_child)); |
117 | } |
118 | break; |
119 | } |
120 | case ExpressionClass::BOUND_COLUMN_REF: |
121 | case ExpressionClass::BOUND_CONSTANT: |
122 | case ExpressionClass::BOUND_DEFAULT: |
123 | case ExpressionClass::BOUND_PARAMETER: |
124 | case ExpressionClass::BOUND_REF: |
125 | // these node types have no children |
126 | break; |
127 | default: |
128 | // called on non BoundExpression type! |
129 | assert(0); |
130 | break; |
131 | } |
132 | } |
133 | |
134 | void ExpressionIterator::EnumerateExpression(unique_ptr<Expression> &expr, |
135 | std::function<void(Expression &child)> callback) { |
136 | if (!expr) { |
137 | return; |
138 | } |
139 | callback(*expr); |
140 | ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr<Expression> child) -> unique_ptr<Expression> { |
141 | EnumerateExpression(child, callback); |
142 | return move(child); |
143 | }); |
144 | } |
145 | |
146 | void ExpressionIterator::EnumerateTableRefChildren(BoundTableRef &ref, |
147 | std::function<void(Expression &child)> callback) { |
148 | switch (ref.type) { |
149 | case TableReferenceType::CROSS_PRODUCT: { |
150 | auto &bound_crossproduct = (BoundCrossProductRef &)ref; |
151 | EnumerateTableRefChildren(*bound_crossproduct.left, callback); |
152 | EnumerateTableRefChildren(*bound_crossproduct.right, callback); |
153 | break; |
154 | } |
155 | case TableReferenceType::JOIN: { |
156 | auto &bound_join = (BoundJoinRef &)ref; |
157 | EnumerateExpression(bound_join.condition, callback); |
158 | EnumerateTableRefChildren(*bound_join.left, callback); |
159 | EnumerateTableRefChildren(*bound_join.right, callback); |
160 | break; |
161 | } |
162 | case TableReferenceType::SUBQUERY: { |
163 | auto &bound_subquery = (BoundSubqueryRef &)ref; |
164 | EnumerateQueryNodeChildren(*bound_subquery.subquery, callback); |
165 | break; |
166 | } |
167 | default: |
168 | assert(ref.type == TableReferenceType::TABLE_FUNCTION || ref.type == TableReferenceType::BASE_TABLE || |
169 | ref.type == TableReferenceType::EMPTY); |
170 | break; |
171 | } |
172 | } |
173 | |
174 | void ExpressionIterator::EnumerateQueryNodeChildren(BoundQueryNode &node, |
175 | std::function<void(Expression &child)> callback) { |
176 | switch (node.type) { |
177 | case QueryNodeType::SET_OPERATION_NODE: { |
178 | auto &bound_setop = (BoundSetOperationNode &)node; |
179 | EnumerateQueryNodeChildren(*bound_setop.left, callback); |
180 | EnumerateQueryNodeChildren(*bound_setop.right, callback); |
181 | break; |
182 | } |
183 | default: |
184 | assert(node.type == QueryNodeType::SELECT_NODE); |
185 | auto &bound_select = (BoundSelectNode &)node; |
186 | for (idx_t i = 0; i < bound_select.select_list.size(); i++) { |
187 | EnumerateExpression(bound_select.select_list[i], callback); |
188 | } |
189 | EnumerateExpression(bound_select.where_clause, callback); |
190 | for (idx_t i = 0; i < bound_select.groups.size(); i++) { |
191 | EnumerateExpression(bound_select.groups[i], callback); |
192 | } |
193 | EnumerateExpression(bound_select.having, callback); |
194 | for (idx_t i = 0; i < bound_select.aggregates.size(); i++) { |
195 | EnumerateExpression(bound_select.aggregates[i], callback); |
196 | } |
197 | for (idx_t i = 0; i < bound_select.windows.size(); i++) { |
198 | EnumerateExpression(bound_select.windows[i], callback); |
199 | } |
200 | if (bound_select.from_table) { |
201 | EnumerateTableRefChildren(*bound_select.from_table, callback); |
202 | } |
203 | break; |
204 | } |
205 | for (idx_t i = 0; i < node.modifiers.size(); i++) { |
206 | switch (node.modifiers[i]->type) { |
207 | case ResultModifierType::DISTINCT_MODIFIER: |
208 | for (auto &expr : ((BoundDistinctModifier &)*node.modifiers[i]).target_distincts) { |
209 | EnumerateExpression(expr, callback); |
210 | } |
211 | break; |
212 | case ResultModifierType::ORDER_MODIFIER: |
213 | for (auto &order : ((BoundOrderModifier &)*node.modifiers[i]).orders) { |
214 | EnumerateExpression(order.expression, callback); |
215 | } |
216 | break; |
217 | default: |
218 | break; |
219 | } |
220 | } |
221 | } |
222 | |