1 | #include "duckdb/optimizer/deliminator.hpp" |
2 | |
3 | #include "duckdb/optimizer/join_order/join_order_optimizer.hpp" |
4 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_conjunction_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_operator_expression.hpp" |
8 | #include "duckdb/planner/operator/logical_aggregate.hpp" |
9 | #include "duckdb/planner/operator/logical_delim_get.hpp" |
10 | #include "duckdb/planner/operator/logical_delim_join.hpp" |
11 | #include "duckdb/planner/operator/logical_filter.hpp" |
12 | |
13 | namespace duckdb { |
14 | |
15 | class DeliminatorPlanUpdater : LogicalOperatorVisitor { |
16 | public: |
17 | explicit DeliminatorPlanUpdater(ClientContext &context) : context(context) { |
18 | } |
19 | //! Update the plan after a DelimGet has been removed |
20 | void VisitOperator(LogicalOperator &op) override; |
21 | void VisitExpression(unique_ptr<Expression> *expression) override; |
22 | |
23 | public: |
24 | ClientContext &context; |
25 | |
26 | expression_map_t<Expression *> expr_map; |
27 | column_binding_map_t<bool> projection_map; |
28 | column_binding_map_t<Expression *> reverse_proj_or_agg_map; |
29 | unique_ptr<LogicalOperator> temp_ptr; |
30 | }; |
31 | |
32 | static idx_t DelimGetCount(LogicalOperator &op) { |
33 | if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) { |
34 | return 1; |
35 | } |
36 | idx_t child_count = 0; |
37 | for (auto &child : op.children) { |
38 | child_count += DelimGetCount(op&: *child); |
39 | } |
40 | return child_count; |
41 | } |
42 | |
43 | static bool IsEqualityJoinCondition(JoinCondition &cond) { |
44 | switch (cond.comparison) { |
45 | case ExpressionType::COMPARE_EQUAL: |
46 | case ExpressionType::COMPARE_NOT_DISTINCT_FROM: |
47 | return true; |
48 | default: |
49 | return false; |
50 | } |
51 | } |
52 | |
53 | static bool InequalityDelimJoinCanBeEliminated(JoinType &join_type) { |
54 | switch (join_type) { |
55 | case JoinType::ANTI: |
56 | case JoinType::MARK: |
57 | case JoinType::SEMI: |
58 | case JoinType::SINGLE: |
59 | return true; |
60 | default: |
61 | return false; |
62 | } |
63 | } |
64 | |
65 | void DeliminatorPlanUpdater::VisitOperator(LogicalOperator &op) { |
66 | VisitOperatorChildren(op); |
67 | VisitOperatorExpressions(op); |
68 | if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN && DelimGetCount(op) == 0) { |
69 | auto &delim_join = op.Cast<LogicalDelimJoin>(); |
70 | auto &decs = delim_join.duplicate_eliminated_columns; |
71 | for (auto &cond : delim_join.conditions) { |
72 | if (!IsEqualityJoinCondition(cond)) { |
73 | continue; |
74 | } |
75 | auto rhs = cond.right.get(); |
76 | while (rhs->type == ExpressionType::OPERATOR_CAST) { |
77 | auto &cast = rhs->Cast<BoundCastExpression>(); |
78 | rhs = cast.child.get(); |
79 | } |
80 | if (rhs->type != ExpressionType::BOUND_COLUMN_REF) { |
81 | throw InternalException("Error in Deliminator: expected a bound column reference" ); |
82 | } |
83 | auto &colref = rhs->Cast<BoundColumnRefExpression>(); |
84 | if (projection_map.find(x: colref.binding) != projection_map.end()) { |
85 | // value on the right is a projection of removed DelimGet |
86 | for (idx_t i = 0; i < decs.size(); i++) { |
87 | if (decs[i]->Equals(other: *cond.left)) { |
88 | // the value on the left no longer needs to be a duplicate-eliminated column |
89 | decs.erase(position: decs.begin() + i); |
90 | break; |
91 | } |
92 | } |
93 | // whether we applied an IS NOT NULL filter |
94 | cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; |
95 | } |
96 | } |
97 | // change type if there are no more duplicate-eliminated columns |
98 | if (decs.empty()) { |
99 | delim_join.type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; |
100 | // sub-plans with DelimGets are not re-orderable (yet), however, we removed all DelimGet of this DelimJoin |
101 | // the DelimGets are on the RHS of the DelimJoin, so we can call the JoinOrderOptimizer on the RHS now |
102 | JoinOrderOptimizer optimizer(context); |
103 | delim_join.children[1] = optimizer.Optimize(plan: std::move(delim_join.children[1])); |
104 | } |
105 | } |
106 | } |
107 | |
108 | void DeliminatorPlanUpdater::VisitExpression(unique_ptr<Expression> *expression) { |
109 | auto &expr = **expression; |
110 | auto entry = expr_map.find(x: expr); |
111 | if (entry != expr_map.end()) { |
112 | *expression = entry->second->Copy(); |
113 | } else { |
114 | VisitExpressionChildren(expression&: **expression); |
115 | } |
116 | } |
117 | |
118 | unique_ptr<LogicalOperator> Deliminator::Optimize(unique_ptr<LogicalOperator> op) { |
119 | vector<unique_ptr<LogicalOperator> *> candidates; |
120 | FindCandidates(op_ptr: &op, candidates); |
121 | |
122 | for (auto &candidate : candidates) { |
123 | DeliminatorPlanUpdater updater(context); |
124 | if (RemoveCandidate(plan: &op, candidate, updater)) { |
125 | updater.VisitOperator(op&: *op); |
126 | } |
127 | } |
128 | return op; |
129 | } |
130 | |
131 | void Deliminator::FindCandidates(unique_ptr<LogicalOperator> *op_ptr, |
132 | vector<unique_ptr<LogicalOperator> *> &candidates) { |
133 | auto op = op_ptr->get(); |
134 | // search children before adding, so the deepest candidates get added first |
135 | for (auto &child : op->children) { |
136 | FindCandidates(op_ptr: &child, candidates); |
137 | } |
138 | // search for projection/aggregate |
139 | if (op->type != LogicalOperatorType::LOGICAL_PROJECTION && |
140 | op->type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { |
141 | return; |
142 | } |
143 | // followed by a join |
144 | if (op->children[0]->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { |
145 | return; |
146 | } |
147 | auto &join = *op->children[0]; |
148 | // with a DelimGet as a direct child (left or right) |
149 | if (join.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET || |
150 | join.children[1]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { |
151 | candidates.push_back(x: op_ptr); |
152 | return; |
153 | } |
154 | // or a filter followed by a DelimGet (left) |
155 | if (join.children[0]->type == LogicalOperatorType::LOGICAL_FILTER && |
156 | join.children[0]->children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { |
157 | candidates.push_back(x: op_ptr); |
158 | return; |
159 | } |
160 | // filter followed by a DelimGet (right) |
161 | if (join.children[1]->type == LogicalOperatorType::LOGICAL_FILTER && |
162 | join.children[1]->children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { |
163 | candidates.push_back(x: op_ptr); |
164 | return; |
165 | } |
166 | } |
167 | |
168 | static bool OperatorIsDelimGet(LogicalOperator &op) { |
169 | if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) { |
170 | return true; |
171 | } |
172 | if (op.type == LogicalOperatorType::LOGICAL_FILTER && |
173 | op.children[0]->type == LogicalOperatorType::LOGICAL_DELIM_GET) { |
174 | return true; |
175 | } |
176 | return false; |
177 | } |
178 | |
179 | static bool ChildJoinTypeCanBeDeliminated(JoinType &join_type) { |
180 | switch (join_type) { |
181 | case JoinType::INNER: |
182 | case JoinType::SEMI: |
183 | return true; |
184 | default: |
185 | return false; |
186 | } |
187 | } |
188 | |
189 | bool Deliminator::RemoveCandidate(unique_ptr<LogicalOperator> *plan, unique_ptr<LogicalOperator> *candidate, |
190 | DeliminatorPlanUpdater &updater) { |
191 | auto &proj_or_agg = **candidate; |
192 | auto &join = proj_or_agg.children[0]->Cast<LogicalComparisonJoin>(); |
193 | if (!ChildJoinTypeCanBeDeliminated(join_type&: join.join_type)) { |
194 | return false; |
195 | } |
196 | |
197 | // get the index (left or right) of the DelimGet side of the join |
198 | idx_t delim_idx = OperatorIsDelimGet(op&: *join.children[0]) ? 0 : 1; |
199 | D_ASSERT(OperatorIsDelimGet(*join.children[delim_idx])); |
200 | // get the filter (if any) |
201 | optional_ptr<LogicalFilter> filter; |
202 | if (join.children[delim_idx]->type == LogicalOperatorType::LOGICAL_FILTER) { |
203 | filter = &join.children[delim_idx]->Cast<LogicalFilter>(); |
204 | } |
205 | auto &delim_get = (filter ? filter->children[0] : join.children[delim_idx])->Cast<LogicalDelimGet>(); |
206 | if (join.conditions.size() != delim_get.chunk_types.size()) { |
207 | // joining with DelimGet adds new information |
208 | return false; |
209 | } |
210 | // check if joining with the DelimGet is redundant, and collect relevant column information |
211 | bool all_equality_conditions = true; |
212 | vector<reference<Expression>> nulls_are_not_equal_exprs; |
213 | for (auto &cond : join.conditions) { |
214 | all_equality_conditions = all_equality_conditions && IsEqualityJoinCondition(cond); |
215 | auto &delim_side = delim_idx == 0 ? *cond.left : *cond.right; |
216 | auto &other_side = delim_idx == 0 ? *cond.right : *cond.left; |
217 | if (delim_side.type != ExpressionType::BOUND_COLUMN_REF) { |
218 | // non-colref e.g. expression -(4, 1) in 4-i=j where i is from DelimGet |
219 | // FIXME: might be possible to also eliminate these |
220 | return false; |
221 | } |
222 | updater.expr_map[delim_side] = &other_side; |
223 | if (cond.comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { |
224 | nulls_are_not_equal_exprs.push_back(x: other_side); |
225 | } |
226 | } |
227 | |
228 | // removed DelimGet columns are assigned a new ColumnBinding by Projection/Aggregation, keep track here |
229 | if (proj_or_agg.type == LogicalOperatorType::LOGICAL_PROJECTION) { |
230 | for (auto &cb : proj_or_agg.GetColumnBindings()) { |
231 | updater.projection_map[cb] = true; |
232 | updater.reverse_proj_or_agg_map[cb] = proj_or_agg.expressions[cb.column_index].get(); |
233 | for (auto &expr : nulls_are_not_equal_exprs) { |
234 | if (proj_or_agg.expressions[cb.column_index]->Equals(other: expr.get())) { |
235 | updater.projection_map[cb] = false; |
236 | break; |
237 | } |
238 | } |
239 | } |
240 | } else { |
241 | auto &agg = proj_or_agg.Cast<LogicalAggregate>(); |
242 | |
243 | // Create a vector of all exprs in the agg |
244 | vector<Expression *> all_agg_exprs; |
245 | all_agg_exprs.reserve(n: agg.groups.size() + agg.expressions.size()); |
246 | for (auto &expr : agg.groups) { |
247 | all_agg_exprs.push_back(x: expr.get()); |
248 | } |
249 | for (auto &expr : agg.expressions) { |
250 | all_agg_exprs.push_back(x: expr.get()); |
251 | } |
252 | |
253 | for (auto &cb : agg.GetColumnBindings()) { |
254 | updater.projection_map[cb] = true; |
255 | updater.reverse_proj_or_agg_map[cb] = all_agg_exprs[cb.column_index]; |
256 | for (auto &expr : nulls_are_not_equal_exprs) { |
257 | if ((cb.table_index == agg.group_index && agg.groups[cb.column_index]->Equals(other: expr.get())) || |
258 | (cb.table_index == agg.aggregate_index && agg.expressions[cb.column_index]->Equals(other: expr.get()))) { |
259 | updater.projection_map[cb] = false; |
260 | break; |
261 | } |
262 | } |
263 | } |
264 | } |
265 | |
266 | if (!all_equality_conditions) { |
267 | // we can get rid of an inequality join with a DelimGet, but only under specific circumstances |
268 | if (!RemoveInequalityCandidate(plan, candidate, updater)) { |
269 | return false; |
270 | } |
271 | } |
272 | |
273 | // make a filter if needed |
274 | if (!nulls_are_not_equal_exprs.empty() || filter != nullptr) { |
275 | auto filter_op = make_uniq<LogicalFilter>(); |
276 | if (!nulls_are_not_equal_exprs.empty()) { |
277 | // add an IS NOT NULL filter that was implicitly in JoinCondition::null_values_are_equal |
278 | for (auto &expr : nulls_are_not_equal_exprs) { |
279 | auto is_not_null_expr = |
280 | make_uniq<BoundOperatorExpression>(args: ExpressionType::OPERATOR_IS_NOT_NULL, args: LogicalType::BOOLEAN); |
281 | is_not_null_expr->children.push_back(x: expr.get().Copy()); |
282 | filter_op->expressions.push_back(x: std::move(is_not_null_expr)); |
283 | } |
284 | } |
285 | if (filter != nullptr) { |
286 | for (auto &expr : filter->expressions) { |
287 | filter_op->expressions.push_back(x: std::move(expr)); |
288 | } |
289 | } |
290 | filter_op->children.push_back(x: std::move(join.children[1 - delim_idx])); |
291 | join.children[1 - delim_idx] = std::move(filter_op); |
292 | } |
293 | // temporarily save deleted operator so its expressions are still available |
294 | updater.temp_ptr = std::move(proj_or_agg.children[0]); |
295 | // replace the redundant join |
296 | proj_or_agg.children[0] = std::move(join.children[1 - delim_idx]); |
297 | return true; |
298 | } |
299 | |
300 | static void GetDelimJoins(LogicalOperator &op, vector<LogicalOperator *> &delim_joins) { |
301 | for (auto &child : op.children) { |
302 | GetDelimJoins(op&: *child, delim_joins); |
303 | } |
304 | if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { |
305 | delim_joins.push_back(x: &op); |
306 | } |
307 | } |
308 | |
309 | static bool HasChild(LogicalOperator *haystack, LogicalOperator *needle, idx_t &side) { |
310 | if (haystack == needle) { |
311 | return true; |
312 | } |
313 | for (idx_t i = 0; i < haystack->children.size(); i++) { |
314 | auto &child = haystack->children[i]; |
315 | idx_t dummy_side; |
316 | if (HasChild(haystack: child.get(), needle, side&: dummy_side)) { |
317 | side = i; |
318 | return true; |
319 | } |
320 | } |
321 | return false; |
322 | } |
323 | |
324 | bool Deliminator::RemoveInequalityCandidate(unique_ptr<LogicalOperator> *plan, unique_ptr<LogicalOperator> *candidate, |
325 | DeliminatorPlanUpdater &updater) { |
326 | auto &proj_or_agg = **candidate; |
327 | // first, we find a DelimJoin in "plan" that has only one DelimGet as a child, which is in "candidate" |
328 | if (DelimGetCount(op&: proj_or_agg) != 1) { |
329 | // the candidate therefore must have only a single DelimGet in its children |
330 | return false; |
331 | } |
332 | |
333 | vector<LogicalOperator *> delim_joins; |
334 | GetDelimJoins(op&: **plan, delim_joins); |
335 | |
336 | LogicalOperator *parent = nullptr; |
337 | idx_t parent_delim_get_side = 0; |
338 | for (auto dj : delim_joins) { |
339 | D_ASSERT(dj->type == LogicalOperatorType::LOGICAL_DELIM_JOIN); |
340 | if (!HasChild(haystack: dj, needle: &proj_or_agg, side&: parent_delim_get_side)) { |
341 | continue; |
342 | } |
343 | // we found a parent DelimJoin |
344 | if (DelimGetCount(op&: *dj) != 1) { |
345 | // it has more than one DelimGet children |
346 | continue; |
347 | } |
348 | |
349 | // we can only remove inequality join with a DelimGet if the parent DelimJoin has one of these join types |
350 | auto &delim_join = dj->Cast<LogicalDelimJoin>(); |
351 | if (!InequalityDelimJoinCanBeEliminated(join_type&: delim_join.join_type)) { |
352 | continue; |
353 | } |
354 | |
355 | parent = dj; |
356 | break; |
357 | } |
358 | if (!parent) { |
359 | return false; |
360 | } |
361 | |
362 | // we found the parent delim join, and we may be able to remove the child DelimGet join |
363 | // but we need to make sure that their conditions refer to exactly the same columns |
364 | auto &parent_delim_join = parent->Cast<LogicalDelimJoin>(); |
365 | auto &join = proj_or_agg.children[0]->Cast<LogicalComparisonJoin>(); |
366 | if (parent_delim_join.conditions.size() != join.conditions.size()) { |
367 | // different number of conditions, can't replace |
368 | return false; |
369 | } |
370 | |
371 | // we can only do this optimization under the following conditions: |
372 | // 1. all join expressions coming from the DelimGet side are colrefs |
373 | // 2. these expressions refer to colrefs coming from the proj/agg on top of the child DelimGet join |
374 | // 3. the expression (before it was proj/agg) can be found in the conditions of the child DelimGet join |
375 | for (auto &parent_cond : parent_delim_join.conditions) { |
376 | auto &parent_expr = parent_delim_get_side == 0 ? parent_cond.left : parent_cond.right; |
377 | if (parent_expr->type != ExpressionType::BOUND_COLUMN_REF) { |
378 | // can only deal with colrefs |
379 | return false; |
380 | } |
381 | auto &parent_colref = parent_expr->Cast<BoundColumnRefExpression>(); |
382 | auto it = updater.reverse_proj_or_agg_map.find(x: parent_colref.binding); |
383 | if (it == updater.reverse_proj_or_agg_map.end()) { |
384 | // refers to a column that was not in the child DelimGet join |
385 | return false; |
386 | } |
387 | // try to find the corresponding child condition |
388 | // TODO: can be more flexible - allow CAST |
389 | auto &child_expr = *it->second; |
390 | bool found = false; |
391 | for (auto &child_cond : join.conditions) { |
392 | if (child_cond.left->Equals(other: child_expr) || child_cond.right->Equals(other: child_expr)) { |
393 | found = true; |
394 | break; |
395 | } |
396 | } |
397 | if (!found) { |
398 | // could not find the mapped expression in the child condition expressions |
399 | return false; |
400 | } |
401 | } |
402 | |
403 | // TODO: we cannot perform the optimization here because our pure inequality joins don't implement |
404 | // JoinType::SINGLE yet |
405 | if (parent_delim_join.join_type == JoinType::SINGLE) { |
406 | bool has_one_equality = false; |
407 | for (auto &cond : join.conditions) { |
408 | has_one_equality = has_one_equality || IsEqualityJoinCondition(cond); |
409 | } |
410 | if (!has_one_equality) { |
411 | return false; |
412 | } |
413 | } |
414 | |
415 | // we are now sure that we can remove the child DelimGet join, so we basically do the same loop as above |
416 | // this time without checks because we already did them, and replace the expressions |
417 | for (auto &parent_cond : parent_delim_join.conditions) { |
418 | auto &parent_expr = parent_delim_get_side == 0 ? parent_cond.left : parent_cond.right; |
419 | auto &parent_colref = parent_expr->Cast<BoundColumnRefExpression>(); |
420 | auto it = updater.reverse_proj_or_agg_map.find(x: parent_colref.binding); |
421 | auto &child_expr = *it->second; |
422 | for (auto &child_cond : join.conditions) { |
423 | if (!child_cond.left->Equals(other: child_expr) && !child_cond.right->Equals(other: child_expr)) { |
424 | continue; |
425 | } |
426 | parent_expr = make_uniq<BoundColumnRefExpression>(args&: parent_expr->alias, args&: parent_expr->return_type, args: it->first); |
427 | parent_cond.comparison = |
428 | parent_delim_get_side == 0 ? child_cond.comparison : FlipComparisonExpression(type: child_cond.comparison); |
429 | break; |
430 | } |
431 | } |
432 | |
433 | // no longer needs to be a delim join |
434 | parent_delim_join.duplicate_eliminated_columns.clear(); |
435 | parent_delim_join.type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; |
436 | |
437 | return true; |
438 | } |
439 | |
440 | } // namespace duckdb |
441 | |