1#include "duckdb/planner/logical_operator_visitor.hpp"
2
3#include "duckdb/planner/expression/list.hpp"
4#include "duckdb/planner/expression_iterator.hpp"
5#include "duckdb/planner/operator/list.hpp"
6
7namespace duckdb {
8
9void LogicalOperatorVisitor::VisitOperator(LogicalOperator &op) {
10 VisitOperatorChildren(op);
11 VisitOperatorExpressions(op);
12}
13
14void LogicalOperatorVisitor::VisitOperatorChildren(LogicalOperator &op) {
15 for (auto &child : op.children) {
16 VisitOperator(op&: *child);
17 }
18}
19
20void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op,
21 const std::function<void(unique_ptr<Expression> *child)> &callback) {
22 switch (op.type) {
23 case LogicalOperatorType::LOGICAL_EXPRESSION_GET: {
24 auto &get = op.Cast<LogicalExpressionGet>();
25 for (auto &expr_list : get.expressions) {
26 for (auto &expr : expr_list) {
27 callback(&expr);
28 }
29 }
30 break;
31 }
32 case LogicalOperatorType::LOGICAL_ORDER_BY: {
33 auto &order = op.Cast<LogicalOrder>();
34 for (auto &node : order.orders) {
35 callback(&node.expression);
36 }
37 break;
38 }
39 case LogicalOperatorType::LOGICAL_TOP_N: {
40 auto &order = op.Cast<LogicalTopN>();
41 for (auto &node : order.orders) {
42 callback(&node.expression);
43 }
44 break;
45 }
46 case LogicalOperatorType::LOGICAL_DISTINCT: {
47 auto &distinct = op.Cast<LogicalDistinct>();
48 for (auto &target : distinct.distinct_targets) {
49 callback(&target);
50 }
51 if (distinct.order_by) {
52 for (auto &order : distinct.order_by->orders) {
53 callback(&order.expression);
54 }
55 }
56 break;
57 }
58 case LogicalOperatorType::LOGICAL_INSERT: {
59 auto &insert = op.Cast<LogicalInsert>();
60 if (insert.on_conflict_condition) {
61 callback(&insert.on_conflict_condition);
62 }
63 if (insert.do_update_condition) {
64 callback(&insert.do_update_condition);
65 }
66 break;
67 }
68 case LogicalOperatorType::LOGICAL_ASOF_JOIN:
69 case LogicalOperatorType::LOGICAL_DELIM_JOIN:
70 case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: {
71 if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) {
72 auto &delim_join = op.Cast<LogicalDelimJoin>();
73 for (auto &expr : delim_join.duplicate_eliminated_columns) {
74 callback(&expr);
75 }
76 }
77 auto &join = op.Cast<LogicalComparisonJoin>();
78 for (auto &cond : join.conditions) {
79 callback(&cond.left);
80 callback(&cond.right);
81 }
82 break;
83 }
84 case LogicalOperatorType::LOGICAL_ANY_JOIN: {
85 auto &join = op.Cast<LogicalAnyJoin>();
86 callback(&join.condition);
87 break;
88 }
89 case LogicalOperatorType::LOGICAL_LIMIT: {
90 auto &limit = op.Cast<LogicalLimit>();
91 if (limit.limit) {
92 callback(&limit.limit);
93 }
94 if (limit.offset) {
95 callback(&limit.offset);
96 }
97 break;
98 }
99 case LogicalOperatorType::LOGICAL_LIMIT_PERCENT: {
100 auto &limit = op.Cast<LogicalLimitPercent>();
101 if (limit.limit) {
102 callback(&limit.limit);
103 }
104 if (limit.offset) {
105 callback(&limit.offset);
106 }
107 break;
108 }
109 case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: {
110 auto &aggr = op.Cast<LogicalAggregate>();
111 for (auto &group : aggr.groups) {
112 callback(&group);
113 }
114 break;
115 }
116 default:
117 break;
118 }
119 for (auto &expression : op.expressions) {
120 callback(&expression);
121 }
122}
123
124void LogicalOperatorVisitor::VisitOperatorExpressions(LogicalOperator &op) {
125 LogicalOperatorVisitor::EnumerateExpressions(op, callback: [&](unique_ptr<Expression> *child) { VisitExpression(expression: child); });
126}
127
128void LogicalOperatorVisitor::VisitExpression(unique_ptr<Expression> *expression) {
129 auto &expr = **expression;
130 unique_ptr<Expression> result;
131 switch (expr.GetExpressionClass()) {
132 case ExpressionClass::BOUND_AGGREGATE:
133 result = VisitReplace(expr&: expr.Cast<BoundAggregateExpression>(), expr_ptr: expression);
134 break;
135 case ExpressionClass::BOUND_BETWEEN:
136 result = VisitReplace(expr&: expr.Cast<BoundBetweenExpression>(), expr_ptr: expression);
137 break;
138 case ExpressionClass::BOUND_CASE:
139 result = VisitReplace(expr&: expr.Cast<BoundCaseExpression>(), expr_ptr: expression);
140 break;
141 case ExpressionClass::BOUND_CAST:
142 result = VisitReplace(expr&: expr.Cast<BoundCastExpression>(), expr_ptr: expression);
143 break;
144 case ExpressionClass::BOUND_COLUMN_REF:
145 result = VisitReplace(expr&: expr.Cast<BoundColumnRefExpression>(), expr_ptr: expression);
146 break;
147 case ExpressionClass::BOUND_COMPARISON:
148 result = VisitReplace(expr&: expr.Cast<BoundComparisonExpression>(), expr_ptr: expression);
149 break;
150 case ExpressionClass::BOUND_CONJUNCTION:
151 result = VisitReplace(expr&: expr.Cast<BoundConjunctionExpression>(), expr_ptr: expression);
152 break;
153 case ExpressionClass::BOUND_CONSTANT:
154 result = VisitReplace(expr&: expr.Cast<BoundConstantExpression>(), expr_ptr: expression);
155 break;
156 case ExpressionClass::BOUND_FUNCTION:
157 result = VisitReplace(expr&: expr.Cast<BoundFunctionExpression>(), expr_ptr: expression);
158 break;
159 case ExpressionClass::BOUND_SUBQUERY:
160 result = VisitReplace(expr&: expr.Cast<BoundSubqueryExpression>(), expr_ptr: expression);
161 break;
162 case ExpressionClass::BOUND_OPERATOR:
163 result = VisitReplace(expr&: expr.Cast<BoundOperatorExpression>(), expr_ptr: expression);
164 break;
165 case ExpressionClass::BOUND_PARAMETER:
166 result = VisitReplace(expr&: expr.Cast<BoundParameterExpression>(), expr_ptr: expression);
167 break;
168 case ExpressionClass::BOUND_REF:
169 result = VisitReplace(expr&: expr.Cast<BoundReferenceExpression>(), expr_ptr: expression);
170 break;
171 case ExpressionClass::BOUND_DEFAULT:
172 result = VisitReplace(expr&: expr.Cast<BoundDefaultExpression>(), expr_ptr: expression);
173 break;
174 case ExpressionClass::BOUND_WINDOW:
175 result = VisitReplace(expr&: expr.Cast<BoundWindowExpression>(), expr_ptr: expression);
176 break;
177 case ExpressionClass::BOUND_UNNEST:
178 result = VisitReplace(expr&: expr.Cast<BoundUnnestExpression>(), expr_ptr: expression);
179 break;
180 default:
181 throw InternalException("Unrecognized expression type in logical operator visitor");
182 }
183 if (result) {
184 *expression = std::move(result);
185 } else {
186 // visit the children of this node
187 VisitExpressionChildren(expression&: expr);
188 }
189}
190
191void LogicalOperatorVisitor::VisitExpressionChildren(Expression &expr) {
192 ExpressionIterator::EnumerateChildren(expr, callback: [&](unique_ptr<Expression> &expr) { VisitExpression(expression: &expr); });
193}
194
195// these are all default methods that can be overriden
196// we don't care about coverage here
197// LCOV_EXCL_START
198unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundAggregateExpression &expr,
199 unique_ptr<Expression> *expr_ptr) {
200 return nullptr;
201}
202
203unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundBetweenExpression &expr,
204 unique_ptr<Expression> *expr_ptr) {
205 return nullptr;
206}
207
208unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundCaseExpression &expr,
209 unique_ptr<Expression> *expr_ptr) {
210 return nullptr;
211}
212
213unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundCastExpression &expr,
214 unique_ptr<Expression> *expr_ptr) {
215 return nullptr;
216}
217
218unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundColumnRefExpression &expr,
219 unique_ptr<Expression> *expr_ptr) {
220 return nullptr;
221}
222
223unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundComparisonExpression &expr,
224 unique_ptr<Expression> *expr_ptr) {
225 return nullptr;
226}
227
228unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundConjunctionExpression &expr,
229 unique_ptr<Expression> *expr_ptr) {
230 return nullptr;
231}
232
233unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundConstantExpression &expr,
234 unique_ptr<Expression> *expr_ptr) {
235 return nullptr;
236}
237
238unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundDefaultExpression &expr,
239 unique_ptr<Expression> *expr_ptr) {
240 return nullptr;
241}
242
243unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundFunctionExpression &expr,
244 unique_ptr<Expression> *expr_ptr) {
245 return nullptr;
246}
247
248unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundOperatorExpression &expr,
249 unique_ptr<Expression> *expr_ptr) {
250 return nullptr;
251}
252
253unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundParameterExpression &expr,
254 unique_ptr<Expression> *expr_ptr) {
255 return nullptr;
256}
257
258unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundReferenceExpression &expr,
259 unique_ptr<Expression> *expr_ptr) {
260 return nullptr;
261}
262
263unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundSubqueryExpression &expr,
264 unique_ptr<Expression> *expr_ptr) {
265 return nullptr;
266}
267
268unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundWindowExpression &expr,
269 unique_ptr<Expression> *expr_ptr) {
270 return nullptr;
271}
272
273unique_ptr<Expression> LogicalOperatorVisitor::VisitReplace(BoundUnnestExpression &expr,
274 unique_ptr<Expression> *expr_ptr) {
275 return nullptr;
276}
277
278// LCOV_EXCL_STOP
279
280} // namespace duckdb
281