1#include "duckdb/optimizer/rule/move_constants.hpp"
2
3#include "duckdb/common/exception.hpp"
4#include "duckdb/common/value_operations/value_operations.hpp"
5#include "duckdb/planner/expression/bound_comparison_expression.hpp"
6#include "duckdb/planner/expression/bound_constant_expression.hpp"
7#include "duckdb/planner/expression/bound_function_expression.hpp"
8
9using namespace duckdb;
10using namespace std;
11
12MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
13 auto op = make_unique<ComparisonExpressionMatcher>();
14 op->matchers.push_back(make_unique<ConstantExpressionMatcher>());
15 op->policy = SetMatcher::Policy::UNORDERED;
16
17 auto arithmetic = make_unique<FunctionExpressionMatcher>();
18 // we handle multiplication, addition and subtraction because those are "easy"
19 // integer division makes the division case difficult
20 // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules
21 arithmetic->function = make_unique<ManyFunctionMatcher>(unordered_set<string>{"+", "-", "*"});
22 // we match only on integral numeric types
23 arithmetic->type = make_unique<IntegerTypeMatcher>();
24 arithmetic->matchers.push_back(make_unique<ConstantExpressionMatcher>());
25 arithmetic->matchers.push_back(make_unique<ExpressionMatcher>());
26 arithmetic->policy = SetMatcher::Policy::SOME;
27 op->matchers.push_back(move(arithmetic));
28 root = move(op);
29}
30
31unique_ptr<Expression> MoveConstantsRule::Apply(LogicalOperator &op, vector<Expression *> &bindings,
32 bool &changes_made) {
33 auto comparison = (BoundComparisonExpression *)bindings[0];
34 auto outer_constant = (BoundConstantExpression *)bindings[1];
35 auto arithmetic = (BoundFunctionExpression *)bindings[2];
36 auto inner_constant = (BoundConstantExpression *)bindings[3];
37
38 int arithmetic_child_index = arithmetic->children[0].get() == inner_constant ? 1 : 0;
39 auto &op_type = arithmetic->function.name;
40 if (op_type == "+") {
41 // [x + 1 COMP 10] OR [1 + x COMP 10]
42 // order does not matter in addition:
43 // simply change right side to 10-1 (outer_constant - inner_constant)
44 outer_constant->value = outer_constant->value - inner_constant->value;
45 } else if (op_type == "-") {
46 // [x - 1 COMP 10] O R [1 - x COMP 10]
47 // order matters in subtraction:
48 if (arithmetic_child_index == 0) {
49 // [x - 1 COMP 10]
50 // change right side to 10+1 (outer_constant + inner_constant)
51 outer_constant->value = outer_constant->value + inner_constant->value;
52 } else {
53 // [1 - x COMP 10]
54 // change right side to 1-10=-9
55 outer_constant->value = inner_constant->value - outer_constant->value;
56 // in this case, we should also flip the comparison
57 // e.g. if we have [4 - x < 2] then we should have [x > 2]
58 comparison->type = FlipComparisionExpression(comparison->type);
59 }
60 } else {
61 assert(op_type == "*");
62 // [x * 2 COMP 10] OR [2 * x COMP 10]
63 // order does not matter in multiplication:
64 // change right side to 10/2 (outer_constant / inner_constant)
65 // but ONLY if outer_constant is cleanly divisible by the inner_constant
66 if (inner_constant->value == 0) {
67 // x * 0, the result is either 0 or NULL
68 // thus the final result will be either [TRUE, FALSE] or [NULL], depending
69 // on if 0 matches the comparison criteria with the RHS
70 // for now we don't fold, but we can fold to "ConstantOrNull"
71 return nullptr;
72 }
73 if (ValueOperations::Modulo(outer_constant->value, inner_constant->value) != 0) {
74 // not cleanly divisible, the result will be either FALSE or NULL
75 // for now, we don't do anything
76 return nullptr;
77 }
78 if (inner_constant->value < 0) {
79 // multiply by negative value, need to flip expression
80 comparison->type = FlipComparisionExpression(comparison->type);
81 }
82 // else divide the RHS by the LHS
83 outer_constant->value = outer_constant->value / inner_constant->value;
84 }
85 // replace left side with x
86 // first extract x from the arithmetic expression
87 auto arithmetic_child = move(arithmetic->children[arithmetic_child_index]);
88 // then place in the comparison
89 if (comparison->left.get() == outer_constant) {
90 comparison->right = move(arithmetic_child);
91 } else {
92 comparison->left = move(arithmetic_child);
93 }
94 changes_made = true;
95 return nullptr;
96}
97