1 | #include "duckdb/planner/logical_operator_visitor.hpp" |
2 | |
3 | #include "duckdb/planner/expression/list.hpp" |
4 | #include "duckdb/planner/expression_iterator.hpp" |
5 | #include "duckdb/planner/operator/list.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | void LogicalOperatorVisitor::VisitOperator(LogicalOperator &op) { |
10 | VisitOperatorChildren(op); |
11 | VisitOperatorExpressions(op); |
12 | } |
13 | |
14 | void LogicalOperatorVisitor::VisitOperatorChildren(LogicalOperator &op) { |
15 | for (auto &child : op.children) { |
16 | VisitOperator(op&: *child); |
17 | } |
18 | } |
19 | |
20 | void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op, |
21 | const std::function<void(unique_ptr<Expression> *child)> &callback) { |
22 | switch (op.type) { |
23 | case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { |
24 | auto &get = op.Cast<LogicalExpressionGet>(); |
25 | for (auto &expr_list : get.expressions) { |
26 | for (auto &expr : expr_list) { |
27 | callback(&expr); |
28 | } |
29 | } |
30 | break; |
31 | } |
32 | case LogicalOperatorType::LOGICAL_ORDER_BY: { |
33 | auto &order = op.Cast<LogicalOrder>(); |
34 | for (auto &node : order.orders) { |
35 | callback(&node.expression); |
36 | } |
37 | break; |
38 | } |
39 | case LogicalOperatorType::LOGICAL_TOP_N: { |
40 | auto &order = op.Cast<LogicalTopN>(); |
41 | for (auto &node : order.orders) { |
42 | callback(&node.expression); |
43 | } |
44 | break; |
45 | } |
46 | case LogicalOperatorType::LOGICAL_DISTINCT: { |
47 | auto &distinct = op.Cast<LogicalDistinct>(); |
48 | for (auto &target : distinct.distinct_targets) { |
49 | callback(&target); |
50 | } |
51 | if (distinct.order_by) { |
52 | for (auto &order : distinct.order_by->orders) { |
53 | callback(&order.expression); |
54 | } |
55 | } |
56 | break; |
57 | } |
58 | case LogicalOperatorType::LOGICAL_INSERT: { |
59 | auto &insert = op.Cast<LogicalInsert>(); |
60 | if (insert.on_conflict_condition) { |
61 | callback(&insert.on_conflict_condition); |
62 | } |
63 | if (insert.do_update_condition) { |
64 | callback(&insert.do_update_condition); |
65 | } |
66 | break; |
67 | } |
68 | case LogicalOperatorType::LOGICAL_ASOF_JOIN: |
69 | case LogicalOperatorType::LOGICAL_DELIM_JOIN: |
70 | case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { |
71 | if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { |
72 | auto &delim_join = op.Cast<LogicalDelimJoin>(); |
73 | for (auto &expr : delim_join.duplicate_eliminated_columns) { |
74 | callback(&expr); |
75 | } |
76 | } |
77 | auto &join = op.Cast<LogicalComparisonJoin>(); |
78 | for (auto &cond : join.conditions) { |
79 | callback(&cond.left); |
80 | callback(&cond.right); |
81 | } |
82 | break; |
83 | } |
84 | case LogicalOperatorType::LOGICAL_ANY_JOIN: { |
85 | auto &join = op.Cast<LogicalAnyJoin>(); |
86 | callback(&join.condition); |
87 | break; |
88 | } |
89 | case LogicalOperatorType::LOGICAL_LIMIT: { |
90 | auto &limit = op.Cast<LogicalLimit>(); |
91 | if (limit.limit) { |
92 | callback(&limit.limit); |
93 | } |
94 | if (limit.offset) { |
95 | callback(&limit.offset); |
96 | } |
97 | break; |
98 | } |
99 | case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: { |
100 | auto &limit = op.Cast<LogicalLimitPercent>(); |
101 | if (limit.limit) { |
102 | callback(&limit.limit); |
103 | } |
104 | if (limit.offset) { |
105 | callback(&limit.offset); |
106 | } |
107 | break; |
108 | } |
109 | case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { |
110 | auto &aggr = op.Cast<LogicalAggregate>(); |
111 | for (auto &group : aggr.groups) { |
112 | callback(&group); |
113 | } |
114 | break; |
115 | } |
116 | default: |
117 | break; |
118 | } |
119 | for (auto &expression : op.expressions) { |
120 | callback(&expression); |
121 | } |
122 | } |
123 | |
124 | void LogicalOperatorVisitor::VisitOperatorExpressions(LogicalOperator &op) { |
125 | LogicalOperatorVisitor::EnumerateExpressions(op, callback: [&](unique_ptr<Expression> *child) { VisitExpression(expression: child); }); |
126 | } |
127 | |
128 | void LogicalOperatorVisitor::VisitExpression(unique_ptr<Expression> *expression) { |
129 | auto &expr = **expression; |
130 | unique_ptr<Expression> result; |
131 | switch (expr.GetExpressionClass()) { |
132 | case ExpressionClass::BOUND_AGGREGATE: |
133 | result = VisitReplace(expr&: expr.Cast<BoundAggregateExpression>(), expr_ptr: expression); |
134 | break; |
135 | case ExpressionClass::BOUND_BETWEEN: |
136 | result = VisitReplace(expr&: expr.Cast<BoundBetweenExpression>(), expr_ptr: expression); |
137 | break; |
138 | case ExpressionClass::BOUND_CASE: |
139 | result = VisitReplace(expr&: expr.Cast<BoundCaseExpression>(), expr_ptr: expression); |
140 | break; |
141 | case ExpressionClass::BOUND_CAST: |
142 | result = VisitReplace(expr&: expr.Cast<BoundCastExpression>(), expr_ptr: expression); |
143 | break; |
144 | case ExpressionClass::BOUND_COLUMN_REF: |
145 | result = VisitReplace(expr&: expr.Cast<BoundColumnRefExpression>(), expr_ptr: expression); |
146 | break; |
147 | case ExpressionClass::BOUND_COMPARISON: |
148 | result = VisitReplace(expr&: expr.Cast<BoundComparisonExpression>(), expr_ptr: expression); |
149 | break; |
150 | case ExpressionClass::BOUND_CONJUNCTION: |
151 | result = VisitReplace(expr&: expr.Cast<BoundConjunctionExpression>(), expr_ptr: expression); |
152 | break; |
153 | case ExpressionClass::BOUND_CONSTANT: |
154 | result = VisitReplace(expr&: expr.Cast<BoundConstantExpression>(), expr_ptr: expression); |
155 | break; |
156 | case ExpressionClass::BOUND_FUNCTION: |
157 | result = VisitReplace(expr&: expr.Cast<BoundFunctionExpression>(), expr_ptr: expression); |
158 | break; |
159 | case ExpressionClass::BOUND_SUBQUERY: |
160 | result = VisitReplace(expr&: expr.Cast<BoundSubqueryExpression>(), expr_ptr: expression); |
161 | break; |
162 | case ExpressionClass::BOUND_OPERATOR: |
163 | result = VisitReplace(expr&: expr.Cast<BoundOperatorExpression>(), expr_ptr: expression); |
164 | break; |
165 | case ExpressionClass::BOUND_PARAMETER: |
166 | result = VisitReplace(expr&: expr.Cast<BoundParameterExpression>(), expr_ptr: expression); |
167 | break; |
168 | case ExpressionClass::BOUND_REF: |
169 | result = VisitReplace(expr&: expr.Cast<BoundReferenceExpression>(), expr_ptr: expression); |
170 | break; |
171 | case ExpressionClass::BOUND_DEFAULT: |
172 | result = VisitReplace(expr&: expr.Cast<BoundDefaultExpression>(), expr_ptr: expression); |
173 | break; |
174 | case ExpressionClass::BOUND_WINDOW: |
175 | result = VisitReplace(expr&: expr.Cast<BoundWindowExpression>(), expr_ptr: expression); |
176 | break; |
177 | case ExpressionClass::BOUND_UNNEST: |
178 | result = VisitReplace(expr&: expr.Cast<BoundUnnestExpression>(), expr_ptr: expression); |
179 | break; |
180 | default: |
181 | throw InternalException("Unrecognized expression type in logical operator visitor" ); |
182 | } |
183 | if (result) { |
184 | *expression = std::move(result); |
185 | } else { |
186 | // visit the children of this node |
187 | VisitExpressionChildren(expression&: expr); |
188 | } |
189 | } |
190 | |
191 | void LogicalOperatorVisitor::VisitExpressionChildren(Expression &expr) { |
192 | ExpressionIterator::EnumerateChildren(expr, callback: [&](unique_ptr<Expression> &expr) { VisitExpression(expression: &expr); }); |
193 | } |
194 | |
195 | // these are all default methods that can be overriden |
196 | // we don't care about coverage here |
197 | // LCOV_EXCL_START |
198 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundAggregateExpression &expr, |
199 | unique_ptr<Expression> *expr_ptr) { |
200 | return nullptr; |
201 | } |
202 | |
203 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundBetweenExpression &expr, |
204 | unique_ptr<Expression> *expr_ptr) { |
205 | return nullptr; |
206 | } |
207 | |
208 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundCaseExpression &expr, |
209 | unique_ptr<Expression> *expr_ptr) { |
210 | return nullptr; |
211 | } |
212 | |
213 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundCastExpression &expr, |
214 | unique_ptr<Expression> *expr_ptr) { |
215 | return nullptr; |
216 | } |
217 | |
218 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundColumnRefExpression &expr, |
219 | unique_ptr<Expression> *expr_ptr) { |
220 | return nullptr; |
221 | } |
222 | |
223 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundComparisonExpression &expr, |
224 | unique_ptr<Expression> *expr_ptr) { |
225 | return nullptr; |
226 | } |
227 | |
228 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundConjunctionExpression &expr, |
229 | unique_ptr<Expression> *expr_ptr) { |
230 | return nullptr; |
231 | } |
232 | |
233 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundConstantExpression &expr, |
234 | unique_ptr<Expression> *expr_ptr) { |
235 | return nullptr; |
236 | } |
237 | |
238 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundDefaultExpression &expr, |
239 | unique_ptr<Expression> *expr_ptr) { |
240 | return nullptr; |
241 | } |
242 | |
243 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundFunctionExpression &expr, |
244 | unique_ptr<Expression> *expr_ptr) { |
245 | return nullptr; |
246 | } |
247 | |
248 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundOperatorExpression &expr, |
249 | unique_ptr<Expression> *expr_ptr) { |
250 | return nullptr; |
251 | } |
252 | |
253 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundParameterExpression &expr, |
254 | unique_ptr<Expression> *expr_ptr) { |
255 | return nullptr; |
256 | } |
257 | |
258 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundReferenceExpression &expr, |
259 | unique_ptr<Expression> *expr_ptr) { |
260 | return nullptr; |
261 | } |
262 | |
263 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundSubqueryExpression &expr, |
264 | unique_ptr<Expression> *expr_ptr) { |
265 | return nullptr; |
266 | } |
267 | |
268 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundWindowExpression &expr, |
269 | unique_ptr<Expression> *expr_ptr) { |
270 | return nullptr; |
271 | } |
272 | |
273 | unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundUnnestExpression &expr, |
274 | unique_ptr<Expression> *expr_ptr) { |
275 | return nullptr; |
276 | } |
277 | |
278 | // LCOV_EXCL_STOP |
279 | |
280 | } // namespace duckdb |
281 | |