1 | #include "duckdb/optimizer/rule/move_constants.hpp" |
2 | |
3 | #include "duckdb/common/exception.hpp" |
4 | #include "duckdb/common/value_operations/value_operations.hpp" |
5 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
8 | #include "duckdb/optimizer/expression_rewriter.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
13 | auto op = make_uniq<ComparisonExpressionMatcher>(); |
14 | op->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); |
15 | op->policy = SetMatcher::Policy::UNORDERED; |
16 | |
17 | auto arithmetic = make_uniq<FunctionExpressionMatcher>(); |
18 | // we handle multiplication, addition and subtraction because those are "easy" |
19 | // integer division makes the division case difficult |
20 | // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules |
21 | arithmetic->function = make_uniq<ManyFunctionMatcher>(args: unordered_set<string> {"+" , "-" , "*" }); |
22 | // we match only on integral numeric types |
23 | arithmetic->type = make_uniq<IntegerTypeMatcher>(); |
24 | arithmetic->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); |
25 | arithmetic->matchers.push_back(x: make_uniq<ExpressionMatcher>()); |
26 | arithmetic->policy = SetMatcher::Policy::SOME; |
27 | op->matchers.push_back(x: std::move(arithmetic)); |
28 | root = std::move(op); |
29 | } |
30 | |
31 | unique_ptr<Expression> MoveConstantsRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, |
32 | bool &changes_made, bool is_root) { |
33 | auto &comparison = bindings[0].get().Cast<BoundComparisonExpression>(); |
34 | auto &outer_constant = bindings[1].get().Cast<BoundConstantExpression>(); |
35 | auto &arithmetic = bindings[2].get().Cast<BoundFunctionExpression>(); |
36 | auto &inner_constant = bindings[3].get().Cast<BoundConstantExpression>(); |
37 | if (!TypeIsIntegral(type: arithmetic.return_type.InternalType())) { |
38 | return nullptr; |
39 | } |
40 | if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { |
41 | return make_uniq<BoundConstantExpression>(args: Value(comparison.return_type)); |
42 | } |
43 | auto &constant_type = outer_constant.return_type; |
44 | hugeint_t outer_value = IntegralValue::Get(value: outer_constant.value); |
45 | hugeint_t inner_value = IntegralValue::Get(value: inner_constant.value); |
46 | |
47 | idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; |
48 | auto &op_type = arithmetic.function.name; |
49 | if (op_type == "+" ) { |
50 | // [x + 1 COMP 10] OR [1 + x COMP 10] |
51 | // order does not matter in addition: |
52 | // simply change right side to 10-1 (outer_constant - inner_constant) |
53 | if (!Hugeint::SubtractInPlace(lhs&: outer_value, rhs: inner_value)) { |
54 | return nullptr; |
55 | } |
56 | auto result_value = Value::HUGEINT(value: outer_value); |
57 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
58 | if (comparison.type != ExpressionType::COMPARE_EQUAL) { |
59 | return nullptr; |
60 | } |
61 | // if the cast is not possible then the comparison is not possible |
62 | // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 |
63 | // since this is not possible we can remove the entire branch here |
64 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
65 | value: Value::BOOLEAN(value: false)); |
66 | } |
67 | outer_constant.value = std::move(result_value); |
68 | } else if (op_type == "-" ) { |
69 | // [x - 1 COMP 10] O R [1 - x COMP 10] |
70 | // order matters in subtraction: |
71 | if (arithmetic_child_index == 0) { |
72 | // [x - 1 COMP 10] |
73 | // change right side to 10+1 (outer_constant + inner_constant) |
74 | if (!Hugeint::AddInPlace(lhs&: outer_value, rhs: inner_value)) { |
75 | return nullptr; |
76 | } |
77 | auto result_value = Value::HUGEINT(value: outer_value); |
78 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
79 | // if the cast is not possible then an equality comparison is not possible |
80 | if (comparison.type != ExpressionType::COMPARE_EQUAL) { |
81 | return nullptr; |
82 | } |
83 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
84 | value: Value::BOOLEAN(value: false)); |
85 | } |
86 | outer_constant.value = std::move(result_value); |
87 | } else { |
88 | // [1 - x COMP 10] |
89 | // change right side to 1-10=-9 |
90 | if (!Hugeint::SubtractInPlace(lhs&: inner_value, rhs: outer_value)) { |
91 | return nullptr; |
92 | } |
93 | auto result_value = Value::HUGEINT(value: inner_value); |
94 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
95 | // if the cast is not possible then an equality comparison is not possible |
96 | if (comparison.type != ExpressionType::COMPARE_EQUAL) { |
97 | return nullptr; |
98 | } |
99 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
100 | value: Value::BOOLEAN(value: false)); |
101 | } |
102 | outer_constant.value = std::move(result_value); |
103 | // in this case, we should also flip the comparison |
104 | // e.g. if we have [4 - x < 2] then we should have [x > 2] |
105 | comparison.type = FlipComparisonExpression(type: comparison.type); |
106 | } |
107 | } else { |
108 | D_ASSERT(op_type == "*" ); |
109 | // [x * 2 COMP 10] OR [2 * x COMP 10] |
110 | // order does not matter in multiplication: |
111 | // change right side to 10/2 (outer_constant / inner_constant) |
112 | // but ONLY if outer_constant is cleanly divisible by the inner_constant |
113 | if (inner_value == 0) { |
114 | // x * 0, the result is either 0 or NULL |
115 | // we let the arithmetic_simplification rule take care of simplifying this first |
116 | return nullptr; |
117 | } |
118 | if (outer_value % inner_value != 0) { |
119 | // not cleanly divisible |
120 | bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; |
121 | bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; |
122 | if (is_equality || is_inequality) { |
123 | // we know the values are not equal |
124 | // the result will be either FALSE or NULL (if COMPARE_EQUAL) |
125 | // or TRUE or NULL (if COMPARE_NOTEQUAL) |
126 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
127 | value: Value::BOOLEAN(value: is_inequality)); |
128 | } else { |
129 | // not cleanly divisible and we are doing > >= < <=, skip the simplification for now |
130 | return nullptr; |
131 | } |
132 | } |
133 | if (inner_value < 0) { |
134 | // multiply by negative value, need to flip expression |
135 | comparison.type = FlipComparisonExpression(type: comparison.type); |
136 | } |
137 | // else divide the RHS by the LHS |
138 | // we need to do a range check on the cast even though we do a division |
139 | // because e.g. -128 / -1 = 128, which is out of range |
140 | auto result_value = Value::HUGEINT(value: outer_value / inner_value); |
141 | if (!result_value.DefaultTryCastAs(target_type: constant_type)) { |
142 | return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), |
143 | value: Value::BOOLEAN(value: false)); |
144 | } |
145 | outer_constant.value = std::move(result_value); |
146 | } |
147 | // replace left side with x |
148 | // first extract x from the arithmetic expression |
149 | auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); |
150 | // then place in the comparison |
151 | if (comparison.left.get() == &outer_constant) { |
152 | comparison.right = std::move(arithmetic_child); |
153 | } else { |
154 | comparison.left = std::move(arithmetic_child); |
155 | } |
156 | changes_made = true; |
157 | return nullptr; |
158 | } |
159 | |
160 | } // namespace duckdb |
161 | |