1#include "duckdb/optimizer/deliminator.hpp"
2
3#include "duckdb/optimizer/join_order/join_order_optimizer.hpp"
4#include "duckdb/planner/expression/bound_cast_expression.hpp"
5#include "duckdb/planner/expression/bound_columnref_expression.hpp"
6#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
7#include "duckdb/planner/expression/bound_operator_expression.hpp"
8#include "duckdb/planner/operator/logical_aggregate.hpp"
9#include "duckdb/planner/operator/logical_delim_get.hpp"
10#include "duckdb/planner/operator/logical_delim_join.hpp"
11#include "duckdb/planner/operator/logical_filter.hpp"
12
13namespace duckdb {
14
15class DeliminatorPlanUpdater : LogicalOperatorVisitor {
16public:
17 explicit DeliminatorPlanUpdater(ClientContext &context) : context(context) {
18 }
19 //! Update the plan after a DelimGet has been removed
20 void VisitOperator(LogicalOperator &op) override;
21 void VisitExpression(unique_ptr<Expression> *expression) override;
22
23public:
24 ClientContext &context;
25
26 expression_map_t<Expression *> expr_map;
27 column_binding_map_t<bool> projection_map;
28 column_binding_map_t<Expression *> reverse_proj_or_agg_map;
29 unique_ptr<LogicalOperator> temp_ptr;
30};
31
32static idx_t DelimGetCount(LogicalOperator &op) {
33 if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) {
34 return 1;
35 }
36 idx_t child_count = 0;
37 for (auto &child : op.children) {
38 child_count += DelimGetCount(op&: *child);
39 }
40 return child_count;
41}
42
43static bool IsEqualityJoinCondition(JoinCondition &cond) {
44 switch (cond.comparison) {
45 case ExpressionType::COMPARE_EQUAL:
46 case ExpressionType::COMPARE_NOT_DISTINCT_FROM:
47 return true;
48 default:
49 return false;
50 }
51}
52
53static bool InequalityDelimJoinCanBeEliminated(JoinType &join_type) {
54 switch (join_type) {
55 case JoinType::ANTI:
56 case JoinType::MARK:
57 case JoinType::SEMI:
58 case JoinType::SINGLE:
59 return true;
60 default:
61 return false;
62 }
63}
64
65void DeliminatorPlanUpdater::VisitOperator(LogicalOperator &op) {
66 VisitOperatorChildren(op);
67 VisitOperatorExpressions(op);
68 if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN && DelimGetCount(op) == 0) {
69 auto &delim_join = op.Cast<LogicalDelimJoin>();
70 auto &decs = delim_join.duplicate_eliminated_columns;
71 for (auto &cond : delim_join.conditions) {
72 if (!IsEqualityJoinCondition(cond)) {
73 continue;
74 }
75 auto rhs = cond.right.get();
76 while (rhs->type == ExpressionType::OPERATOR_CAST) {
77 auto &cast = rhs->Cast<BoundCastExpression>();
78 rhs = cast.child.get();
79 }
80 if (rhs->type != ExpressionType::BOUND_COLUMN_REF) {
81 throw InternalException("Error in Deliminator: expected a bound column reference");
82 }
83 auto &colref = rhs->Cast<BoundColumnRefExpression>();
84 if (projection_map.find(x: colref.binding) != projection_map.end()) {
85 // value on the right is a projection of removed DelimGet
86 for (idx_t i = 0; i < decs.size(); i++) {
87 if (decs[i]->Equals(other: *cond.left)) {
88 // the value on the left no longer needs to be a duplicate-eliminated column
89 decs.erase(position: decs.begin() + i);
90 break;
91 }
92 }
93 // whether we applied an IS NOT NULL filter
94 cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM;
95 }
96 }
97 // change type if there are no more duplicate-eliminated columns
98 if (decs.empty()) {
99 delim_join.type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN;
100 // sub-plans with DelimGets are not re-orderable (yet), however, we removed all DelimGet of this DelimJoin
101 // the DelimGets are on the RHS of the DelimJoin, so we can call the JoinOrderOptimizer on the RHS now
102 JoinOrderOptimizer optimizer(context);
103 delim_join.children[1] = optimizer.Optimize(plan: std::move(delim_join.children[1]));
104 }
105 }
106}
107
108void DeliminatorPlanUpdater::VisitExpression(unique_ptr<Expression> *expression) {
109 auto &expr = **expression;
110 auto entry = expr_map.find(x: expr);
111 if (entry != expr_map.end()) {
112 *expression = entry->second->Copy();
113 } else {
114 VisitExpressionChildren(expression&: **expression);
115 }
116}
117
118unique_ptr<LogicalOperator> Deliminator::Optimize(unique_ptr<LogicalOperator> op) {
119 vector<unique_ptr<LogicalOperator> *> candidates;
120 FindCandidates(op_ptr: &op, candidates);
121
122 for (auto &candidate : candidates) {
123 DeliminatorPlanUpdater updater(context);
124 if (RemoveCandidate(plan: &op, candidate, updater)) {
125 updater.VisitOperator(op&: *op);
126 }
127 }
128 return op;
129}
130
131void Deliminator::FindCandidates(unique_ptr<LogicalOperator> *op_ptr,
132 vector<unique_ptr<LogicalOperator> *> &candidates) {
133 auto op = op_ptr->get();
134 // search children before adding, so the deepest candidates get added first
135 for (auto &child : op->children) {
136 FindCandidates(op_ptr: &child, candidates);
137 }
138 // search for projection/aggregate
139 if (op->type != LogicalOperatorType::LOGICAL_PROJECTION &&
140 op->type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) {
141 return;
142 }
143 // followed by a join
144 if (op->children[0]->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) {
145 return;
146 }
147 auto &join = *op->children[0];
148 // with a DelimGet as a direct child (left or right)
149 if (join.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET ||
150 join.children[1]->type == LogicalOperatorType::LOGICAL_DELIM_GET) {
151 candidates.push_back(x: op_ptr);
152 return;
153 }
154 // or a filter followed by a DelimGet (left)
155 if (join.children[0]->type == LogicalOperatorType::LOGICAL_FILTER &&
156 join.children[0]->children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) {
157 candidates.push_back(x: op_ptr);
158 return;
159 }
160 // filter followed by a DelimGet (right)
161 if (join.children[1]->type == LogicalOperatorType::LOGICAL_FILTER &&
162 join.children[1]->children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) {
163 candidates.push_back(x: op_ptr);
164 return;
165 }
166}
167
168static bool OperatorIsDelimGet(LogicalOperator &op) {
169 if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) {
170 return true;
171 }
172 if (op.type == LogicalOperatorType::LOGICAL_FILTER &&
173 op.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) {
174 return true;
175 }
176 return false;
177}
178
179static bool ChildJoinTypeCanBeDeliminated(JoinType &join_type) {
180 switch (join_type) {
181 case JoinType::INNER:
182 case JoinType::SEMI:
183 return true;
184 default:
185 return false;
186 }
187}
188
189bool Deliminator::RemoveCandidate(unique_ptr<LogicalOperator> *plan, unique_ptr<LogicalOperator> *candidate,
190 DeliminatorPlanUpdater &updater) {
191 auto &proj_or_agg = **candidate;
192 auto &join = proj_or_agg.children[0]->Cast<LogicalComparisonJoin>();
193 if (!ChildJoinTypeCanBeDeliminated(join_type&: join.join_type)) {
194 return false;
195 }
196
197 // get the index (left or right) of the DelimGet side of the join
198 idx_t delim_idx = OperatorIsDelimGet(op&: *join.children[0]) ? 0 : 1;
199 D_ASSERT(OperatorIsDelimGet(*join.children[delim_idx]));
200 // get the filter (if any)
201 optional_ptr<LogicalFilter> filter;
202 if (join.children[delim_idx]->type == LogicalOperatorType::LOGICAL_FILTER) {
203 filter = &join.children[delim_idx]->Cast<LogicalFilter>();
204 }
205 auto &delim_get = (filter ? filter->children[0] : join.children[delim_idx])->Cast<LogicalDelimGet>();
206 if (join.conditions.size() != delim_get.chunk_types.size()) {
207 // joining with DelimGet adds new information
208 return false;
209 }
210 // check if joining with the DelimGet is redundant, and collect relevant column information
211 bool all_equality_conditions = true;
212 vector<reference<Expression>> nulls_are_not_equal_exprs;
213 for (auto &cond : join.conditions) {
214 all_equality_conditions = all_equality_conditions && IsEqualityJoinCondition(cond);
215 auto &delim_side = delim_idx == 0 ? *cond.left : *cond.right;
216 auto &other_side = delim_idx == 0 ? *cond.right : *cond.left;
217 if (delim_side.type != ExpressionType::BOUND_COLUMN_REF) {
218 // non-colref e.g. expression -(4, 1) in 4-i=j where i is from DelimGet
219 // FIXME: might be possible to also eliminate these
220 return false;
221 }
222 updater.expr_map[delim_side] = &other_side;
223 if (cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) {
224 nulls_are_not_equal_exprs.push_back(x: other_side);
225 }
226 }
227
228 // removed DelimGet columns are assigned a new ColumnBinding by Projection/Aggregation, keep track here
229 if (proj_or_agg.type == LogicalOperatorType::LOGICAL_PROJECTION) {
230 for (auto &cb : proj_or_agg.GetColumnBindings()) {
231 updater.projection_map[cb] = true;
232 updater.reverse_proj_or_agg_map[cb] = proj_or_agg.expressions[cb.column_index].get();
233 for (auto &expr : nulls_are_not_equal_exprs) {
234 if (proj_or_agg.expressions[cb.column_index]->Equals(other: expr.get())) {
235 updater.projection_map[cb] = false;
236 break;
237 }
238 }
239 }
240 } else {
241 auto &agg = proj_or_agg.Cast<LogicalAggregate>();
242
243 // Create a vector of all exprs in the agg
244 vector<Expression *> all_agg_exprs;
245 all_agg_exprs.reserve(n: agg.groups.size() + agg.expressions.size());
246 for (auto &expr : agg.groups) {
247 all_agg_exprs.push_back(x: expr.get());
248 }
249 for (auto &expr : agg.expressions) {
250 all_agg_exprs.push_back(x: expr.get());
251 }
252
253 for (auto &cb : agg.GetColumnBindings()) {
254 updater.projection_map[cb] = true;
255 updater.reverse_proj_or_agg_map[cb] = all_agg_exprs[cb.column_index];
256 for (auto &expr : nulls_are_not_equal_exprs) {
257 if ((cb.table_index == agg.group_index && agg.groups[cb.column_index]->Equals(other: expr.get())) ||
258 (cb.table_index == agg.aggregate_index && agg.expressions[cb.column_index]->Equals(other: expr.get()))) {
259 updater.projection_map[cb] = false;
260 break;
261 }
262 }
263 }
264 }
265
266 if (!all_equality_conditions) {
267 // we can get rid of an inequality join with a DelimGet, but only under specific circumstances
268 if (!RemoveInequalityCandidate(plan, candidate, updater)) {
269 return false;
270 }
271 }
272
273 // make a filter if needed
274 if (!nulls_are_not_equal_exprs.empty() || filter != nullptr) {
275 auto filter_op = make_uniq<LogicalFilter>();
276 if (!nulls_are_not_equal_exprs.empty()) {
277 // add an IS NOT NULL filter that was implicitly in JoinCondition::null_values_are_equal
278 for (auto &expr : nulls_are_not_equal_exprs) {
279 auto is_not_null_expr =
280 make_uniq<BoundOperatorExpression>(args: ExpressionType::OPERATOR_IS_NOT_NULL, args: LogicalType::BOOLEAN);
281 is_not_null_expr->children.push_back(x: expr.get().Copy());
282 filter_op->expressions.push_back(x: std::move(is_not_null_expr));
283 }
284 }
285 if (filter != nullptr) {
286 for (auto &expr : filter->expressions) {
287 filter_op->expressions.push_back(x: std::move(expr));
288 }
289 }
290 filter_op->children.push_back(x: std::move(join.children[1 - delim_idx]));
291 join.children[1 - delim_idx] = std::move(filter_op);
292 }
293 // temporarily save deleted operator so its expressions are still available
294 updater.temp_ptr = std::move(proj_or_agg.children[0]);
295 // replace the redundant join
296 proj_or_agg.children[0] = std::move(join.children[1 - delim_idx]);
297 return true;
298}
299
300static void GetDelimJoins(LogicalOperator &op, vector<LogicalOperator *> &delim_joins) {
301 for (auto &child : op.children) {
302 GetDelimJoins(op&: *child, delim_joins);
303 }
304 if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) {
305 delim_joins.push_back(x: &op);
306 }
307}
308
309static bool HasChild(LogicalOperator *haystack, LogicalOperator *needle, idx_t &side) {
310 if (haystack == needle) {
311 return true;
312 }
313 for (idx_t i = 0; i < haystack->children.size(); i++) {
314 auto &child = haystack->children[i];
315 idx_t dummy_side;
316 if (HasChild(haystack: child.get(), needle, side&: dummy_side)) {
317 side = i;
318 return true;
319 }
320 }
321 return false;
322}
323
324bool Deliminator::RemoveInequalityCandidate(unique_ptr<LogicalOperator> *plan, unique_ptr<LogicalOperator> *candidate,
325 DeliminatorPlanUpdater &updater) {
326 auto &proj_or_agg = **candidate;
327 // first, we find a DelimJoin in "plan" that has only one DelimGet as a child, which is in "candidate"
328 if (DelimGetCount(op&: proj_or_agg) != 1) {
329 // the candidate therefore must have only a single DelimGet in its children
330 return false;
331 }
332
333 vector<LogicalOperator *> delim_joins;
334 GetDelimJoins(op&: **plan, delim_joins);
335
336 LogicalOperator *parent = nullptr;
337 idx_t parent_delim_get_side = 0;
338 for (auto dj : delim_joins) {
339 D_ASSERT(dj->type == LogicalOperatorType::LOGICAL_DELIM_JOIN);
340 if (!HasChild(haystack: dj, needle: &proj_or_agg, side&: parent_delim_get_side)) {
341 continue;
342 }
343 // we found a parent DelimJoin
344 if (DelimGetCount(op&: *dj) != 1) {
345 // it has more than one DelimGet children
346 continue;
347 }
348
349 // we can only remove inequality join with a DelimGet if the parent DelimJoin has one of these join types
350 auto &delim_join = dj->Cast<LogicalDelimJoin>();
351 if (!InequalityDelimJoinCanBeEliminated(join_type&: delim_join.join_type)) {
352 continue;
353 }
354
355 parent = dj;
356 break;
357 }
358 if (!parent) {
359 return false;
360 }
361
362 // we found the parent delim join, and we may be able to remove the child DelimGet join
363 // but we need to make sure that their conditions refer to exactly the same columns
364 auto &parent_delim_join = parent->Cast<LogicalDelimJoin>();
365 auto &join = proj_or_agg.children[0]->Cast<LogicalComparisonJoin>();
366 if (parent_delim_join.conditions.size() != join.conditions.size()) {
367 // different number of conditions, can't replace
368 return false;
369 }
370
371 // we can only do this optimization under the following conditions:
372 // 1. all join expressions coming from the DelimGet side are colrefs
373 // 2. these expressions refer to colrefs coming from the proj/agg on top of the child DelimGet join
374 // 3. the expression (before it was proj/agg) can be found in the conditions of the child DelimGet join
375 for (auto &parent_cond : parent_delim_join.conditions) {
376 auto &parent_expr = parent_delim_get_side == 0 ? parent_cond.left : parent_cond.right;
377 if (parent_expr->type != ExpressionType::BOUND_COLUMN_REF) {
378 // can only deal with colrefs
379 return false;
380 }
381 auto &parent_colref = parent_expr->Cast<BoundColumnRefExpression>();
382 auto it = updater.reverse_proj_or_agg_map.find(x: parent_colref.binding);
383 if (it == updater.reverse_proj_or_agg_map.end()) {
384 // refers to a column that was not in the child DelimGet join
385 return false;
386 }
387 // try to find the corresponding child condition
388 // TODO: can be more flexible - allow CAST
389 auto &child_expr = *it->second;
390 bool found = false;
391 for (auto &child_cond : join.conditions) {
392 if (child_cond.left->Equals(other: child_expr) || child_cond.right->Equals(other: child_expr)) {
393 found = true;
394 break;
395 }
396 }
397 if (!found) {
398 // could not find the mapped expression in the child condition expressions
399 return false;
400 }
401 }
402
403 // TODO: we cannot perform the optimization here because our pure inequality joins don't implement
404 // JoinType::SINGLE yet
405 if (parent_delim_join.join_type == JoinType::SINGLE) {
406 bool has_one_equality = false;
407 for (auto &cond : join.conditions) {
408 has_one_equality = has_one_equality || IsEqualityJoinCondition(cond);
409 }
410 if (!has_one_equality) {
411 return false;
412 }
413 }
414
415 // we are now sure that we can remove the child DelimGet join, so we basically do the same loop as above
416 // this time without checks because we already did them, and replace the expressions
417 for (auto &parent_cond : parent_delim_join.conditions) {
418 auto &parent_expr = parent_delim_get_side == 0 ? parent_cond.left : parent_cond.right;
419 auto &parent_colref = parent_expr->Cast<BoundColumnRefExpression>();
420 auto it = updater.reverse_proj_or_agg_map.find(x: parent_colref.binding);
421 auto &child_expr = *it->second;
422 for (auto &child_cond : join.conditions) {
423 if (!child_cond.left->Equals(other: child_expr) && !child_cond.right->Equals(other: child_expr)) {
424 continue;
425 }
426 parent_expr = make_uniq<BoundColumnRefExpression>(args&: parent_expr->alias, args&: parent_expr->return_type, args: it->first);
427 parent_cond.comparison =
428 parent_delim_get_side == 0 ? child_cond.comparison : FlipComparisonExpression(type: child_cond.comparison);
429 break;
430 }
431 }
432
433 // no longer needs to be a delim join
434 parent_delim_join.duplicate_eliminated_columns.clear();
435 parent_delim_join.type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN;
436
437 return true;
438}
439
440} // namespace duckdb
441