1#include "duckdb/optimizer/in_clause_rewriter.hpp"
2#include "duckdb/optimizer/optimizer.hpp"
3#include "duckdb/planner/binder.hpp"
4#include "duckdb/planner/expression/bound_comparison_expression.hpp"
5#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
6#include "duckdb/planner/expression/bound_operator_expression.hpp"
7#include "duckdb/planner/operator/logical_chunk_get.hpp"
8#include "duckdb/planner/operator/logical_comparison_join.hpp"
9#include "duckdb/execution/expression_executor.hpp"
10
11namespace duckdb {
12
13unique_ptr<LogicalOperator> InClauseRewriter::Rewrite(unique_ptr<LogicalOperator> op) {
14 if (op->children.size() == 1) {
15 root = move(op->children[0]);
16 VisitOperatorExpressions(*op);
17 op->children[0] = move(root);
18 }
19
20 for (auto &child : op->children) {
21 child = Rewrite(move(child));
22 }
23 return op;
24}
25
26unique_ptr<Expression> InClauseRewriter::VisitReplace(BoundOperatorExpression &expr, unique_ptr<Expression> *expr_ptr) {
27 if (expr.type != ExpressionType::COMPARE_IN && expr.type != ExpressionType::COMPARE_NOT_IN) {
28 return nullptr;
29 }
30 assert(root);
31 auto in_type = expr.children[0]->return_type;
32 bool is_regular_in = expr.type == ExpressionType::COMPARE_IN;
33 bool all_scalar = true;
34 // IN clause with many children: try to generate a mark join that replaces this IN expression
35 // we can only do this if the expressions in the expression list are scalar
36 for (idx_t i = 1; i < expr.children.size(); i++) {
37 assert(expr.children[i]->return_type == in_type);
38 if (!expr.children[i]->IsFoldable()) {
39 // non-scalar expression
40 all_scalar = false;
41 }
42 }
43 if (expr.children.size() == 2) {
44 // only one child
45 // IN: turn into X = 1
46 // NOT IN: turn into X <> 1
47 return make_unique<BoundComparisonExpression>(is_regular_in ? ExpressionType::COMPARE_EQUAL
48 : ExpressionType::COMPARE_NOTEQUAL,
49 move(expr.children[0]), move(expr.children[1]));
50 }
51 if (expr.children.size() < 6 || !all_scalar) {
52 // low amount of children or not all scalar
53 // IN: turn into (X = 1 OR X = 2 OR X = 3...)
54 // NOT IN: turn into (X <> 1 AND X <> 2 AND X <> 3 ...)
55 auto conjunction = make_unique<BoundConjunctionExpression>(is_regular_in ? ExpressionType::CONJUNCTION_OR
56 : ExpressionType::CONJUNCTION_AND);
57 for (idx_t i = 1; i < expr.children.size(); i++) {
58 conjunction->children.push_back(make_unique<BoundComparisonExpression>(
59 is_regular_in ? ExpressionType::COMPARE_EQUAL : ExpressionType::COMPARE_NOTEQUAL,
60 expr.children[0]->Copy(), move(expr.children[i])));
61 }
62 return move(conjunction);
63 }
64 // IN clause with many constant children
65 // generate a mark join that replaces this IN expression
66 // first generate a ChunkCollection from the set of expressions
67 vector<TypeId> types = {in_type};
68 auto collection = make_unique<ChunkCollection>();
69 DataChunk chunk;
70 chunk.Initialize(types);
71 for (idx_t i = 1; i < expr.children.size(); i++) {
72 // reoslve this expression to a constant
73 auto value = ExpressionExecutor::EvaluateScalar(*expr.children[i]);
74 idx_t index = chunk.size();
75 chunk.SetCardinality(chunk.size() + 1);
76 chunk.SetValue(0, index, value);
77 if (chunk.size() == STANDARD_VECTOR_SIZE || i + 1 == expr.children.size()) {
78 // chunk full: append to chunk collection
79 collection->Append(chunk);
80 chunk.Reset();
81 }
82 }
83 // now generate a ChunkGet that scans this collection
84 auto chunk_index = optimizer.binder.GenerateTableIndex();
85 auto chunk_scan = make_unique<LogicalChunkGet>(chunk_index, types, move(collection));
86
87 // then we generate the MARK join with the chunk scan on the RHS
88 auto join = make_unique<LogicalComparisonJoin>(JoinType::MARK);
89 join->mark_index = chunk_index;
90 join->AddChild(move(root));
91 join->AddChild(move(chunk_scan));
92 // create the JOIN condition
93 JoinCondition cond;
94 cond.left = move(expr.children[0]);
95
96 cond.right = make_unique<BoundColumnRefExpression>(in_type, ColumnBinding(chunk_index, 0));
97 cond.comparison = ExpressionType::COMPARE_EQUAL;
98 join->conditions.push_back(move(cond));
99 root = move(join);
100
101 // we replace the original subquery with a BoundColumnRefExpression refering to the mark column
102 unique_ptr<Expression> result =
103 make_unique<BoundColumnRefExpression>("IN (...)", TypeId::BOOL, ColumnBinding(chunk_index, 0));
104 if (!is_regular_in) {
105 // NOT IN: invert
106 auto invert = make_unique<BoundOperatorExpression>(ExpressionType::OPERATOR_NOT, TypeId::BOOL);
107 invert->children.push_back(move(result));
108 result = move(invert);
109 }
110 return result;
111}
112
113} // namespace duckdb
114