| 1 | #include "duckdb/optimizer/rule/move_constants.hpp" |
| 2 | |
| 3 | #include "duckdb/common/exception.hpp" |
| 4 | #include "duckdb/common/value_operations/value_operations.hpp" |
| 5 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
| 6 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 7 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 8 | |
| 9 | using namespace duckdb; |
| 10 | using namespace std; |
| 11 | |
| 12 | MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
| 13 | auto op = make_unique<ComparisonExpressionMatcher>(); |
| 14 | op->matchers.push_back(make_unique<ConstantExpressionMatcher>()); |
| 15 | op->policy = SetMatcher::Policy::UNORDERED; |
| 16 | |
| 17 | auto arithmetic = make_unique<FunctionExpressionMatcher>(); |
| 18 | // we handle multiplication, addition and subtraction because those are "easy" |
| 19 | // integer division makes the division case difficult |
| 20 | // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules |
| 21 | arithmetic->function = make_unique<ManyFunctionMatcher>(unordered_set<string>{"+" , "-" , "*" }); |
| 22 | // we match only on integral numeric types |
| 23 | arithmetic->type = make_unique<IntegerTypeMatcher>(); |
| 24 | arithmetic->matchers.push_back(make_unique<ConstantExpressionMatcher>()); |
| 25 | arithmetic->matchers.push_back(make_unique<ExpressionMatcher>()); |
| 26 | arithmetic->policy = SetMatcher::Policy::SOME; |
| 27 | op->matchers.push_back(move(arithmetic)); |
| 28 | root = move(op); |
| 29 | } |
| 30 | |
| 31 | unique_ptr<Expression> MoveConstantsRule::Apply(LogicalOperator &op, vector<Expression *> &bindings, |
| 32 | bool &changes_made) { |
| 33 | auto comparison = (BoundComparisonExpression *)bindings[0]; |
| 34 | auto outer_constant = (BoundConstantExpression *)bindings[1]; |
| 35 | auto arithmetic = (BoundFunctionExpression *)bindings[2]; |
| 36 | auto inner_constant = (BoundConstantExpression *)bindings[3]; |
| 37 | |
| 38 | int arithmetic_child_index = arithmetic->children[0].get() == inner_constant ? 1 : 0; |
| 39 | auto &op_type = arithmetic->function.name; |
| 40 | if (op_type == "+" ) { |
| 41 | // [x + 1 COMP 10] OR [1 + x COMP 10] |
| 42 | // order does not matter in addition: |
| 43 | // simply change right side to 10-1 (outer_constant - inner_constant) |
| 44 | outer_constant->value = outer_constant->value - inner_constant->value; |
| 45 | } else if (op_type == "-" ) { |
| 46 | // [x - 1 COMP 10] O R [1 - x COMP 10] |
| 47 | // order matters in subtraction: |
| 48 | if (arithmetic_child_index == 0) { |
| 49 | // [x - 1 COMP 10] |
| 50 | // change right side to 10+1 (outer_constant + inner_constant) |
| 51 | outer_constant->value = outer_constant->value + inner_constant->value; |
| 52 | } else { |
| 53 | // [1 - x COMP 10] |
| 54 | // change right side to 1-10=-9 |
| 55 | outer_constant->value = inner_constant->value - outer_constant->value; |
| 56 | // in this case, we should also flip the comparison |
| 57 | // e.g. if we have [4 - x < 2] then we should have [x > 2] |
| 58 | comparison->type = FlipComparisionExpression(comparison->type); |
| 59 | } |
| 60 | } else { |
| 61 | assert(op_type == "*" ); |
| 62 | // [x * 2 COMP 10] OR [2 * x COMP 10] |
| 63 | // order does not matter in multiplication: |
| 64 | // change right side to 10/2 (outer_constant / inner_constant) |
| 65 | // but ONLY if outer_constant is cleanly divisible by the inner_constant |
| 66 | if (inner_constant->value == 0) { |
| 67 | // x * 0, the result is either 0 or NULL |
| 68 | // thus the final result will be either [TRUE, FALSE] or [NULL], depending |
| 69 | // on if 0 matches the comparison criteria with the RHS |
| 70 | // for now we don't fold, but we can fold to "ConstantOrNull" |
| 71 | return nullptr; |
| 72 | } |
| 73 | if (ValueOperations::Modulo(outer_constant->value, inner_constant->value) != 0) { |
| 74 | // not cleanly divisible, the result will be either FALSE or NULL |
| 75 | // for now, we don't do anything |
| 76 | return nullptr; |
| 77 | } |
| 78 | if (inner_constant->value < 0) { |
| 79 | // multiply by negative value, need to flip expression |
| 80 | comparison->type = FlipComparisionExpression(comparison->type); |
| 81 | } |
| 82 | // else divide the RHS by the LHS |
| 83 | outer_constant->value = outer_constant->value / inner_constant->value; |
| 84 | } |
| 85 | // replace left side with x |
| 86 | // first extract x from the arithmetic expression |
| 87 | auto arithmetic_child = move(arithmetic->children[arithmetic_child_index]); |
| 88 | // then place in the comparison |
| 89 | if (comparison->left.get() == outer_constant) { |
| 90 | comparison->right = move(arithmetic_child); |
| 91 | } else { |
| 92 | comparison->left = move(arithmetic_child); |
| 93 | } |
| 94 | changes_made = true; |
| 95 | return nullptr; |
| 96 | } |
| 97 | |