| 1 | #include "duckdb/optimizer/rule/move_constants.hpp" |
| 2 | |
| 3 | #include "duckdb/common/exception.hpp" |
| 4 | #include "duckdb/common/value_operations/value_operations.hpp" |
| 5 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
| 6 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 7 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 8 | #include "duckdb/optimizer/expression_rewriter.hpp" |
| 9 | |
| 10 | namespace duckdb { |
| 11 | |
| 12 | MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
| 13 | auto op = make_uniq<ComparisonExpressionMatcher>(); |
| 14 | op->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); |
| 15 | op->policy = SetMatcher::Policy::UNORDERED; |
| 16 | |
| 17 | auto arithmetic = make_uniq<FunctionExpressionMatcher>(); |
| 18 | // we handle multiplication, addition and subtraction because those are "easy" |
| 19 | // integer division makes the division case difficult |
| 20 | // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules |
| 21 | arithmetic->function = make_uniq<ManyFunctionMatcher>(args: unordered_set<string> {"+" , "-" , "*" }); |
| 22 | // we match only on integral numeric types |
| 23 | arithmetic->type = make_uniq<IntegerTypeMatcher>(); |
| 24 | arithmetic->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); |
| 25 | arithmetic->matchers.push_back(x: make_uniq<ExpressionMatcher>()); |
| 26 | arithmetic->policy = SetMatcher::Policy::SOME; |
| 27 | op->matchers.push_back(x: std::move(arithmetic)); |
| 28 | root = std::move(op); |
| 29 | } |
| 30 | |
| 31 | unique_ptr<Expression> MoveConstantsRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, |
| 32 | bool &changes_made, bool is_root) { |
| 33 | auto &comparison = bindings[0].get().Cast<BoundComparisonExpression>(); |
| 34 | auto &outer_constant = bindings[1].get().Cast<BoundConstantExpression>(); |
| 35 | auto &arithmetic = bindings[2].get().Cast<BoundFunctionExpression>(); |
| 36 | auto &inner_constant = bindings[3].get().Cast<BoundConstantExpression>(); |
| 37 | if (!TypeIsIntegral(type: arithmetic.return_type.InternalType())) { |
| 38 | return nullptr; |
| 39 | } |
| 40 | if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { |
| 41 | return make_uniq<BoundConstantExpression>(args: Value(comparison.return_type)); |
| 42 | } |
| 43 | auto &constant_type = outer_constant.return_type; |
| 44 | hugeint_t outer_value = IntegralValue::Get(value: outer_constant.value); |
| 45 | hugeint_t inner_value = IntegralValue::Get(value: inner_constant.value); |
| 46 | |
| 47 | idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; |
| 48 | auto &op_type = arithmetic.function.name; |
| 49 | if (op_type == "+" ) { |
| 50 | // [x + 1 COMP 10] OR [1 + x COMP 10] |
| 51 | // order does not matter in addition: |
| 52 | // simply change right side to 10-1 (outer_constant - inner_constant) |
| 53 | if (!Hugeint::SubtractInPlace(lhs&: outer_value, rhs: inner_value)) { |
| 54 | return nullptr; |
| 55 | } |
| 56 | auto result_value = Value::HUGEINT(value: outer_value); |
| 57 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
| 58 | if (comparison.type != ExpressionType::COMPARE_EQUAL) { |
| 59 | return nullptr; |
| 60 | } |
| 61 | // if the cast is not possible then the comparison is not possible |
| 62 | // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 |
| 63 | // since this is not possible we can remove the entire branch here |
| 64 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
| 65 | value: Value::BOOLEAN(value: false)); |
| 66 | } |
| 67 | outer_constant.value = std::move(result_value); |
| 68 | } else if (op_type == "-" ) { |
| 69 | // [x - 1 COMP 10] O R [1 - x COMP 10] |
| 70 | // order matters in subtraction: |
| 71 | if (arithmetic_child_index == 0) { |
| 72 | // [x - 1 COMP 10] |
| 73 | // change right side to 10+1 (outer_constant + inner_constant) |
| 74 | if (!Hugeint::AddInPlace(lhs&: outer_value, rhs: inner_value)) { |
| 75 | return nullptr; |
| 76 | } |
| 77 | auto result_value = Value::HUGEINT(value: outer_value); |
| 78 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
| 79 | // if the cast is not possible then an equality comparison is not possible |
| 80 | if (comparison.type != ExpressionType::COMPARE_EQUAL) { |
| 81 | return nullptr; |
| 82 | } |
| 83 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
| 84 | value: Value::BOOLEAN(value: false)); |
| 85 | } |
| 86 | outer_constant.value = std::move(result_value); |
| 87 | } else { |
| 88 | // [1 - x COMP 10] |
| 89 | // change right side to 1-10=-9 |
| 90 | if (!Hugeint::SubtractInPlace(lhs&: inner_value, rhs: outer_value)) { |
| 91 | return nullptr; |
| 92 | } |
| 93 | auto result_value = Value::HUGEINT(value: inner_value); |
| 94 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
| 95 | // if the cast is not possible then an equality comparison is not possible |
| 96 | if (comparison.type != ExpressionType::COMPARE_EQUAL) { |
| 97 | return nullptr; |
| 98 | } |
| 99 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
| 100 | value: Value::BOOLEAN(value: false)); |
| 101 | } |
| 102 | outer_constant.value = std::move(result_value); |
| 103 | // in this case, we should also flip the comparison |
| 104 | // e.g. if we have [4 - x < 2] then we should have [x > 2] |
| 105 | comparison.type = FlipComparisonExpression(type: comparison.type); |
| 106 | } |
| 107 | } else { |
| 108 | D_ASSERT(op_type == "*" ); |
| 109 | // [x * 2 COMP 10] OR [2 * x COMP 10] |
| 110 | // order does not matter in multiplication: |
| 111 | // change right side to 10/2 (outer_constant / inner_constant) |
| 112 | // but ONLY if outer_constant is cleanly divisible by the inner_constant |
| 113 | if (inner_value == 0) { |
| 114 | // x * 0, the result is either 0 or NULL |
| 115 | // we let the arithmetic_simplification rule take care of simplifying this first |
| 116 | return nullptr; |
| 117 | } |
| 118 | if (outer_value % inner_value != 0) { |
| 119 | // not cleanly divisible |
| 120 | bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; |
| 121 | bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; |
| 122 | if (is_equality || is_inequality) { |
| 123 | // we know the values are not equal |
| 124 | // the result will be either FALSE or NULL (if COMPARE_EQUAL) |
| 125 | // or TRUE or NULL (if COMPARE_NOTEQUAL) |
| 126 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
| 127 | value: Value::BOOLEAN(value: is_inequality)); |
| 128 | } else { |
| 129 | // not cleanly divisible and we are doing > >= < <=, skip the simplification for now |
| 130 | return nullptr; |
| 131 | } |
| 132 | } |
| 133 | if (inner_value < 0) { |
| 134 | // multiply by negative value, need to flip expression |
| 135 | comparison.type = FlipComparisonExpression(type: comparison.type); |
| 136 | } |
| 137 | // else divide the RHS by the LHS |
| 138 | // we need to do a range check on the cast even though we do a division |
| 139 | // because e.g. -128 / -1 = 128, which is out of range |
| 140 | auto result_value = Value::HUGEINT(value: outer_value / inner_value); |
| 141 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
| 142 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
| 143 | value: Value::BOOLEAN(value: false)); |
| 144 | } |
| 145 | outer_constant.value = std::move(result_value); |
| 146 | } |
| 147 | // replace left side with x |
| 148 | // first extract x from the arithmetic expression |
| 149 | auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); |
| 150 | // then place in the comparison |
| 151 | if (comparison.left.get() == &outer_constant) { |
| 152 | comparison.right = std::move(arithmetic_child); |
| 153 | } else { |
| 154 | comparison.left = std::move(arithmetic_child); |
| 155 | } |
| 156 | changes_made = true; |
| 157 | return nullptr; |
| 158 | } |
| 159 | |
| 160 | } // namespace duckdb |
| 161 | |