1 | #include "duckdb/optimizer/rule/move_constants.hpp" |
2 | |
3 | #include "duckdb/common/exception.hpp" |
4 | #include "duckdb/common/value_operations/value_operations.hpp" |
5 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
8 | |
9 | using namespace duckdb; |
10 | using namespace std; |
11 | |
12 | MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
13 | auto op = make_unique<ComparisonExpressionMatcher>(); |
14 | op->matchers.push_back(make_unique<ConstantExpressionMatcher>()); |
15 | op->policy = SetMatcher::Policy::UNORDERED; |
16 | |
17 | auto arithmetic = make_unique<FunctionExpressionMatcher>(); |
18 | // we handle multiplication, addition and subtraction because those are "easy" |
19 | // integer division makes the division case difficult |
20 | // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules |
21 | arithmetic->function = make_unique<ManyFunctionMatcher>(unordered_set<string>{"+" , "-" , "*" }); |
22 | // we match only on integral numeric types |
23 | arithmetic->type = make_unique<IntegerTypeMatcher>(); |
24 | arithmetic->matchers.push_back(make_unique<ConstantExpressionMatcher>()); |
25 | arithmetic->matchers.push_back(make_unique<ExpressionMatcher>()); |
26 | arithmetic->policy = SetMatcher::Policy::SOME; |
27 | op->matchers.push_back(move(arithmetic)); |
28 | root = move(op); |
29 | } |
30 | |
31 | unique_ptr<Expression> MoveConstantsRule::Apply(LogicalOperator &op, vector<Expression *> &bindings, |
32 | bool &changes_made) { |
33 | auto comparison = (BoundComparisonExpression *)bindings[0]; |
34 | auto outer_constant = (BoundConstantExpression *)bindings[1]; |
35 | auto arithmetic = (BoundFunctionExpression *)bindings[2]; |
36 | auto inner_constant = (BoundConstantExpression *)bindings[3]; |
37 | |
38 | int arithmetic_child_index = arithmetic->children[0].get() == inner_constant ? 1 : 0; |
39 | auto &op_type = arithmetic->function.name; |
40 | if (op_type == "+" ) { |
41 | // [x + 1 COMP 10] OR [1 + x COMP 10] |
42 | // order does not matter in addition: |
43 | // simply change right side to 10-1 (outer_constant - inner_constant) |
44 | outer_constant->value = outer_constant->value - inner_constant->value; |
45 | } else if (op_type == "-" ) { |
46 | // [x - 1 COMP 10] O R [1 - x COMP 10] |
47 | // order matters in subtraction: |
48 | if (arithmetic_child_index == 0) { |
49 | // [x - 1 COMP 10] |
50 | // change right side to 10+1 (outer_constant + inner_constant) |
51 | outer_constant->value = outer_constant->value + inner_constant->value; |
52 | } else { |
53 | // [1 - x COMP 10] |
54 | // change right side to 1-10=-9 |
55 | outer_constant->value = inner_constant->value - outer_constant->value; |
56 | // in this case, we should also flip the comparison |
57 | // e.g. if we have [4 - x < 2] then we should have [x > 2] |
58 | comparison->type = FlipComparisionExpression(comparison->type); |
59 | } |
60 | } else { |
61 | assert(op_type == "*" ); |
62 | // [x * 2 COMP 10] OR [2 * x COMP 10] |
63 | // order does not matter in multiplication: |
64 | // change right side to 10/2 (outer_constant / inner_constant) |
65 | // but ONLY if outer_constant is cleanly divisible by the inner_constant |
66 | if (inner_constant->value == 0) { |
67 | // x * 0, the result is either 0 or NULL |
68 | // thus the final result will be either [TRUE, FALSE] or [NULL], depending |
69 | // on if 0 matches the comparison criteria with the RHS |
70 | // for now we don't fold, but we can fold to "ConstantOrNull" |
71 | return nullptr; |
72 | } |
73 | if (ValueOperations::Modulo(outer_constant->value, inner_constant->value) != 0) { |
74 | // not cleanly divisible, the result will be either FALSE or NULL |
75 | // for now, we don't do anything |
76 | return nullptr; |
77 | } |
78 | if (inner_constant->value < 0) { |
79 | // multiply by negative value, need to flip expression |
80 | comparison->type = FlipComparisionExpression(comparison->type); |
81 | } |
82 | // else divide the RHS by the LHS |
83 | outer_constant->value = outer_constant->value / inner_constant->value; |
84 | } |
85 | // replace left side with x |
86 | // first extract x from the arithmetic expression |
87 | auto arithmetic_child = move(arithmetic->children[arithmetic_child_index]); |
88 | // then place in the comparison |
89 | if (comparison->left.get() == outer_constant) { |
90 | comparison->right = move(arithmetic_child); |
91 | } else { |
92 | comparison->left = move(arithmetic_child); |
93 | } |
94 | changes_made = true; |
95 | return nullptr; |
96 | } |
97 | |