1#include "duckdb/optimizer/in_clause_rewriter.hpp"
2#include "duckdb/optimizer/optimizer.hpp"
3#include "duckdb/planner/binder.hpp"
4#include "duckdb/planner/expression/bound_comparison_expression.hpp"
5#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
6#include "duckdb/planner/expression/bound_operator_expression.hpp"
7#include "duckdb/planner/operator/logical_column_data_get.hpp"
8#include "duckdb/planner/operator/logical_comparison_join.hpp"
9#include "duckdb/execution/expression_executor.hpp"
10
11namespace duckdb {
12
13unique_ptr<LogicalOperator> InClauseRewriter::Rewrite(unique_ptr<LogicalOperator> op) {
14 if (op->children.size() == 1) {
15 root = std::move(op->children[0]);
16 VisitOperatorExpressions(op&: *op);
17 op->children[0] = std::move(root);
18 }
19
20 for (auto &child : op->children) {
21 child = Rewrite(op: std::move(child));
22 }
23 return op;
24}
25
26unique_ptr<Expression> InClauseRewriter::VisitReplace(BoundOperatorExpression &expr, unique_ptr<Expression> *expr_ptr) {
27 if (expr.type != ExpressionType::COMPARE_IN && expr.type != ExpressionType::COMPARE_NOT_IN) {
28 return nullptr;
29 }
30 D_ASSERT(root);
31 auto in_type = expr.children[0]->return_type;
32 bool is_regular_in = expr.type == ExpressionType::COMPARE_IN;
33 bool all_scalar = true;
34 // IN clause with many children: try to generate a mark join that replaces this IN expression
35 // we can only do this if the expressions in the expression list are scalar
36 for (idx_t i = 1; i < expr.children.size(); i++) {
37 if (!expr.children[i]->IsFoldable()) {
38 // non-scalar expression
39 all_scalar = false;
40 }
41 }
42 if (expr.children.size() == 2) {
43 // only one child
44 // IN: turn into X = 1
45 // NOT IN: turn into X <> 1
46 return make_uniq<BoundComparisonExpression>(args: is_regular_in ? ExpressionType::COMPARE_EQUAL
47 : ExpressionType::COMPARE_NOTEQUAL,
48 args: std::move(expr.children[0]), args: std::move(expr.children[1]));
49 }
50 if (expr.children.size() < 6 || !all_scalar) {
51 // low amount of children or not all scalar
52 // IN: turn into (X = 1 OR X = 2 OR X = 3...)
53 // NOT IN: turn into (X <> 1 AND X <> 2 AND X <> 3 ...)
54 auto conjunction = make_uniq<BoundConjunctionExpression>(args: is_regular_in ? ExpressionType::CONJUNCTION_OR
55 : ExpressionType::CONJUNCTION_AND);
56 for (idx_t i = 1; i < expr.children.size(); i++) {
57 conjunction->children.push_back(x: make_uniq<BoundComparisonExpression>(
58 args: is_regular_in ? ExpressionType::COMPARE_EQUAL : ExpressionType::COMPARE_NOTEQUAL,
59 args: expr.children[0]->Copy(), args: std::move(expr.children[i])));
60 }
61 return std::move(conjunction);
62 }
63 // IN clause with many constant children
64 // generate a mark join that replaces this IN expression
65 // first generate a ColumnDataCollection from the set of expressions
66 vector<LogicalType> types = {in_type};
67 auto collection = make_uniq<ColumnDataCollection>(args&: context, args&: types);
68 ColumnDataAppendState append_state;
69 collection->InitializeAppend(state&: append_state);
70
71 DataChunk chunk;
72 chunk.Initialize(context, types);
73 for (idx_t i = 1; i < expr.children.size(); i++) {
74 // resolve this expression to a constant
75 auto value = ExpressionExecutor::EvaluateScalar(context, expr: *expr.children[i]);
76 idx_t index = chunk.size();
77 chunk.SetCardinality(chunk.size() + 1);
78 chunk.SetValue(col_idx: 0, index, val: value);
79 if (chunk.size() == STANDARD_VECTOR_SIZE || i + 1 == expr.children.size()) {
80 // chunk full: append to chunk collection
81 collection->Append(state&: append_state, new_chunk&: chunk);
82 chunk.Reset();
83 }
84 }
85 // now generate a ChunkGet that scans this collection
86 auto chunk_index = optimizer.binder.GenerateTableIndex();
87 auto chunk_scan = make_uniq<LogicalColumnDataGet>(args&: chunk_index, args&: types, args: std::move(collection));
88
89 // then we generate the MARK join with the chunk scan on the RHS
90 auto join = make_uniq<LogicalComparisonJoin>(args: JoinType::MARK);
91 join->mark_index = chunk_index;
92 join->AddChild(child: std::move(root));
93 join->AddChild(child: std::move(chunk_scan));
94 // create the JOIN condition
95 JoinCondition cond;
96 cond.left = std::move(expr.children[0]);
97
98 cond.right = make_uniq<BoundColumnRefExpression>(args&: in_type, args: ColumnBinding(chunk_index, 0));
99 cond.comparison = ExpressionType::COMPARE_EQUAL;
100 join->conditions.push_back(x: std::move(cond));
101 root = std::move(join);
102
103 // we replace the original subquery with a BoundColumnRefExpression referring to the mark column
104 unique_ptr<Expression> result =
105 make_uniq<BoundColumnRefExpression>(args: "IN (...)", args: LogicalType::BOOLEAN, args: ColumnBinding(chunk_index, 0));
106 if (!is_regular_in) {
107 // NOT IN: invert
108 auto invert = make_uniq<BoundOperatorExpression>(args: ExpressionType::OPERATOR_NOT, args: LogicalType::BOOLEAN);
109 invert->children.push_back(x: std::move(result));
110 result = std::move(invert);
111 }
112 return result;
113}
114
115} // namespace duckdb
116