1#include "duckdb/planner/expression_iterator.hpp"
2
3#include "duckdb/planner/bound_query_node.hpp"
4#include "duckdb/planner/expression/list.hpp"
5#include "duckdb/planner/query_node/bound_select_node.hpp"
6#include "duckdb/planner/query_node/bound_set_operation_node.hpp"
7#include "duckdb/planner/tableref/list.hpp"
8
9using namespace duckdb;
10using namespace std;
11
12void ExpressionIterator::EnumerateChildren(const Expression &expr, function<void(const Expression &child)> callback) {
13 EnumerateChildren((Expression &)expr, [&](unique_ptr<Expression> child) -> unique_ptr<Expression> {
14 callback(*child);
15 return move(child);
16 });
17}
18
19void ExpressionIterator::EnumerateChildren(Expression &expr, std::function<void(Expression &child)> callback) {
20 EnumerateChildren(expr, [&](unique_ptr<Expression> child) -> unique_ptr<Expression> {
21 callback(*child);
22 return move(child);
23 });
24}
25
26void ExpressionIterator::EnumerateChildren(Expression &expr,
27 function<unique_ptr<Expression>(unique_ptr<Expression> child)> callback) {
28 switch (expr.expression_class) {
29 case ExpressionClass::BOUND_AGGREGATE: {
30 auto &aggr_expr = (BoundAggregateExpression &)expr;
31 for (auto &child : aggr_expr.children) {
32 child = callback(move(child));
33 }
34 break;
35 }
36 case ExpressionClass::BOUND_BETWEEN: {
37 auto &between_expr = (BoundBetweenExpression &)expr;
38 between_expr.input = callback(move(between_expr.input));
39 between_expr.lower = callback(move(between_expr.lower));
40 between_expr.upper = callback(move(between_expr.upper));
41 break;
42 }
43 case ExpressionClass::BOUND_CASE: {
44 auto &case_expr = (BoundCaseExpression &)expr;
45 case_expr.check = callback(move(case_expr.check));
46 case_expr.result_if_true = callback(move(case_expr.result_if_true));
47 case_expr.result_if_false = callback(move(case_expr.result_if_false));
48 break;
49 }
50 case ExpressionClass::BOUND_CAST: {
51 auto &cast_expr = (BoundCastExpression &)expr;
52 cast_expr.child = callback(move(cast_expr.child));
53 break;
54 }
55 case ExpressionClass::BOUND_COMPARISON: {
56 auto &comp_expr = (BoundComparisonExpression &)expr;
57 comp_expr.left = callback(move(comp_expr.left));
58 comp_expr.right = callback(move(comp_expr.right));
59 break;
60 }
61 case ExpressionClass::BOUND_CONJUNCTION: {
62 auto &conj_expr = (BoundConjunctionExpression &)expr;
63 for (auto &child : conj_expr.children) {
64 child = callback(move(child));
65 }
66 break;
67 }
68 case ExpressionClass::BOUND_FUNCTION: {
69 auto &func_expr = (BoundFunctionExpression &)expr;
70 for (auto &child : func_expr.children) {
71 child = callback(move(child));
72 }
73 break;
74 }
75 case ExpressionClass::BOUND_OPERATOR: {
76 auto &op_expr = (BoundOperatorExpression &)expr;
77 for (auto &child : op_expr.children) {
78 child = callback(move(child));
79 }
80 break;
81 }
82 case ExpressionClass::BOUND_SUBQUERY: {
83 auto &subquery_expr = (BoundSubqueryExpression &)expr;
84 if (subquery_expr.child) {
85 subquery_expr.child = callback(move(subquery_expr.child));
86 }
87 break;
88 }
89 case ExpressionClass::BOUND_WINDOW: {
90 auto &window_expr = (BoundWindowExpression &)expr;
91 for (auto &partition : window_expr.partitions) {
92 partition = callback(move(partition));
93 }
94 for (auto &order : window_expr.orders) {
95 order.expression = callback(move(order.expression));
96 }
97 for (auto &child : window_expr.children) {
98 child = callback(move(child));
99 }
100 if (window_expr.offset_expr) {
101 window_expr.offset_expr = callback(move(window_expr.offset_expr));
102 }
103 if (window_expr.default_expr) {
104 window_expr.default_expr = callback(move(window_expr.default_expr));
105 }
106 break;
107 }
108 case ExpressionClass::BOUND_UNNEST: {
109 auto &unnest_expr = (BoundUnnestExpression &)expr;
110 unnest_expr.child = callback(move(unnest_expr.child));
111 break;
112 }
113 case ExpressionClass::COMMON_SUBEXPRESSION: {
114 auto &cse_expr = (CommonSubExpression &)expr;
115 if (cse_expr.owned_child) {
116 cse_expr.owned_child = callback(move(cse_expr.owned_child));
117 }
118 break;
119 }
120 case ExpressionClass::BOUND_COLUMN_REF:
121 case ExpressionClass::BOUND_CONSTANT:
122 case ExpressionClass::BOUND_DEFAULT:
123 case ExpressionClass::BOUND_PARAMETER:
124 case ExpressionClass::BOUND_REF:
125 // these node types have no children
126 break;
127 default:
128 // called on non BoundExpression type!
129 assert(0);
130 break;
131 }
132}
133
134void ExpressionIterator::EnumerateExpression(unique_ptr<Expression> &expr,
135 std::function<void(Expression &child)> callback) {
136 if (!expr) {
137 return;
138 }
139 callback(*expr);
140 ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr<Expression> child) -> unique_ptr<Expression> {
141 EnumerateExpression(child, callback);
142 return move(child);
143 });
144}
145
146void ExpressionIterator::EnumerateTableRefChildren(BoundTableRef &ref,
147 std::function<void(Expression &child)> callback) {
148 switch (ref.type) {
149 case TableReferenceType::CROSS_PRODUCT: {
150 auto &bound_crossproduct = (BoundCrossProductRef &)ref;
151 EnumerateTableRefChildren(*bound_crossproduct.left, callback);
152 EnumerateTableRefChildren(*bound_crossproduct.right, callback);
153 break;
154 }
155 case TableReferenceType::JOIN: {
156 auto &bound_join = (BoundJoinRef &)ref;
157 EnumerateExpression(bound_join.condition, callback);
158 EnumerateTableRefChildren(*bound_join.left, callback);
159 EnumerateTableRefChildren(*bound_join.right, callback);
160 break;
161 }
162 case TableReferenceType::SUBQUERY: {
163 auto &bound_subquery = (BoundSubqueryRef &)ref;
164 EnumerateQueryNodeChildren(*bound_subquery.subquery, callback);
165 break;
166 }
167 default:
168 assert(ref.type == TableReferenceType::TABLE_FUNCTION || ref.type == TableReferenceType::BASE_TABLE ||
169 ref.type == TableReferenceType::EMPTY);
170 break;
171 }
172}
173
174void ExpressionIterator::EnumerateQueryNodeChildren(BoundQueryNode &node,
175 std::function<void(Expression &child)> callback) {
176 switch (node.type) {
177 case QueryNodeType::SET_OPERATION_NODE: {
178 auto &bound_setop = (BoundSetOperationNode &)node;
179 EnumerateQueryNodeChildren(*bound_setop.left, callback);
180 EnumerateQueryNodeChildren(*bound_setop.right, callback);
181 break;
182 }
183 default:
184 assert(node.type == QueryNodeType::SELECT_NODE);
185 auto &bound_select = (BoundSelectNode &)node;
186 for (idx_t i = 0; i < bound_select.select_list.size(); i++) {
187 EnumerateExpression(bound_select.select_list[i], callback);
188 }
189 EnumerateExpression(bound_select.where_clause, callback);
190 for (idx_t i = 0; i < bound_select.groups.size(); i++) {
191 EnumerateExpression(bound_select.groups[i], callback);
192 }
193 EnumerateExpression(bound_select.having, callback);
194 for (idx_t i = 0; i < bound_select.aggregates.size(); i++) {
195 EnumerateExpression(bound_select.aggregates[i], callback);
196 }
197 for (idx_t i = 0; i < bound_select.windows.size(); i++) {
198 EnumerateExpression(bound_select.windows[i], callback);
199 }
200 if (bound_select.from_table) {
201 EnumerateTableRefChildren(*bound_select.from_table, callback);
202 }
203 break;
204 }
205 for (idx_t i = 0; i < node.modifiers.size(); i++) {
206 switch (node.modifiers[i]->type) {
207 case ResultModifierType::DISTINCT_MODIFIER:
208 for (auto &expr : ((BoundDistinctModifier &)*node.modifiers[i]).target_distincts) {
209 EnumerateExpression(expr, callback);
210 }
211 break;
212 case ResultModifierType::ORDER_MODIFIER:
213 for (auto &order : ((BoundOrderModifier &)*node.modifiers[i]).orders) {
214 EnumerateExpression(order.expression, callback);
215 }
216 break;
217 default:
218 break;
219 }
220 }
221}
222