1#include "duckdb/optimizer/rule/move_constants.hpp"
2
3#include "duckdb/common/exception.hpp"
4#include "duckdb/common/value_operations/value_operations.hpp"
5#include "duckdb/planner/expression/bound_comparison_expression.hpp"
6#include "duckdb/planner/expression/bound_constant_expression.hpp"
7#include "duckdb/planner/expression/bound_function_expression.hpp"
8#include "duckdb/optimizer/expression_rewriter.hpp"
9
10namespace duckdb {
11
12MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
13 auto op = make_uniq<ComparisonExpressionMatcher>();
14 op->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>());
15 op->policy = SetMatcher::Policy::UNORDERED;
16
17 auto arithmetic = make_uniq<FunctionExpressionMatcher>();
18 // we handle multiplication, addition and subtraction because those are "easy"
19 // integer division makes the division case difficult
20 // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules
21 arithmetic->function = make_uniq<ManyFunctionMatcher>(args: unordered_set<string> {"+", "-", "*"});
22 // we match only on integral numeric types
23 arithmetic->type = make_uniq<IntegerTypeMatcher>();
24 arithmetic->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>());
25 arithmetic->matchers.push_back(x: make_uniq<ExpressionMatcher>());
26 arithmetic->policy = SetMatcher::Policy::SOME;
27 op->matchers.push_back(x: std::move(arithmetic));
28 root = std::move(op);
29}
30
31unique_ptr<Expression> MoveConstantsRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
32 bool &changes_made, bool is_root) {
33 auto &comparison = bindings[0].get().Cast<BoundComparisonExpression>();
34 auto &outer_constant = bindings[1].get().Cast<BoundConstantExpression>();
35 auto &arithmetic = bindings[2].get().Cast<BoundFunctionExpression>();
36 auto &inner_constant = bindings[3].get().Cast<BoundConstantExpression>();
37 if (!TypeIsIntegral(type: arithmetic.return_type.InternalType())) {
38 return nullptr;
39 }
40 if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) {
41 return make_uniq<BoundConstantExpression>(args: Value(comparison.return_type));
42 }
43 auto &constant_type = outer_constant.return_type;
44 hugeint_t outer_value = IntegralValue::Get(value: outer_constant.value);
45 hugeint_t inner_value = IntegralValue::Get(value: inner_constant.value);
46
47 idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0;
48 auto &op_type = arithmetic.function.name;
49 if (op_type == "+") {
50 // [x + 1 COMP 10] OR [1 + x COMP 10]
51 // order does not matter in addition:
52 // simply change right side to 10-1 (outer_constant - inner_constant)
53 if (!Hugeint::SubtractInPlace(lhs&: outer_value, rhs: inner_value)) {
54 return nullptr;
55 }
56 auto result_value = Value::HUGEINT(value: outer_value);
57 if (!result_value.DefaultTryCastAs(target_type: constant_type)) {
58 if (comparison.type != ExpressionType::COMPARE_EQUAL) {
59 return nullptr;
60 }
61 // if the cast is not possible then the comparison is not possible
62 // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2
63 // since this is not possible we can remove the entire branch here
64 return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]),
65 value: Value::BOOLEAN(value: false));
66 }
67 outer_constant.value = std::move(result_value);
68 } else if (op_type == "-") {
69 // [x - 1 COMP 10] O R [1 - x COMP 10]
70 // order matters in subtraction:
71 if (arithmetic_child_index == 0) {
72 // [x - 1 COMP 10]
73 // change right side to 10+1 (outer_constant + inner_constant)
74 if (!Hugeint::AddInPlace(lhs&: outer_value, rhs: inner_value)) {
75 return nullptr;
76 }
77 auto result_value = Value::HUGEINT(value: outer_value);
78 if (!result_value.DefaultTryCastAs(target_type: constant_type)) {
79 // if the cast is not possible then an equality comparison is not possible
80 if (comparison.type != ExpressionType::COMPARE_EQUAL) {
81 return nullptr;
82 }
83 return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]),
84 value: Value::BOOLEAN(value: false));
85 }
86 outer_constant.value = std::move(result_value);
87 } else {
88 // [1 - x COMP 10]
89 // change right side to 1-10=-9
90 if (!Hugeint::SubtractInPlace(lhs&: inner_value, rhs: outer_value)) {
91 return nullptr;
92 }
93 auto result_value = Value::HUGEINT(value: inner_value);
94 if (!result_value.DefaultTryCastAs(target_type: constant_type)) {
95 // if the cast is not possible then an equality comparison is not possible
96 if (comparison.type != ExpressionType::COMPARE_EQUAL) {
97 return nullptr;
98 }
99 return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]),
100 value: Value::BOOLEAN(value: false));
101 }
102 outer_constant.value = std::move(result_value);
103 // in this case, we should also flip the comparison
104 // e.g. if we have [4 - x < 2] then we should have [x > 2]
105 comparison.type = FlipComparisonExpression(type: comparison.type);
106 }
107 } else {
108 D_ASSERT(op_type == "*");
109 // [x * 2 COMP 10] OR [2 * x COMP 10]
110 // order does not matter in multiplication:
111 // change right side to 10/2 (outer_constant / inner_constant)
112 // but ONLY if outer_constant is cleanly divisible by the inner_constant
113 if (inner_value == 0) {
114 // x * 0, the result is either 0 or NULL
115 // we let the arithmetic_simplification rule take care of simplifying this first
116 return nullptr;
117 }
118 if (outer_value % inner_value != 0) {
119 // not cleanly divisible
120 bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL;
121 bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL;
122 if (is_equality || is_inequality) {
123 // we know the values are not equal
124 // the result will be either FALSE or NULL (if COMPARE_EQUAL)
125 // or TRUE or NULL (if COMPARE_NOTEQUAL)
126 return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]),
127 value: Value::BOOLEAN(value: is_inequality));
128 } else {
129 // not cleanly divisible and we are doing > >= < <=, skip the simplification for now
130 return nullptr;
131 }
132 }
133 if (inner_value < 0) {
134 // multiply by negative value, need to flip expression
135 comparison.type = FlipComparisonExpression(type: comparison.type);
136 }
137 // else divide the RHS by the LHS
138 // we need to do a range check on the cast even though we do a division
139 // because e.g. -128 / -1 = 128, which is out of range
140 auto result_value = Value::HUGEINT(value: outer_value / inner_value);
141 if (!result_value.DefaultTryCastAs(target_type: constant_type)) {
142 return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]),
143 value: Value::BOOLEAN(value: false));
144 }
145 outer_constant.value = std::move(result_value);
146 }
147 // replace left side with x
148 // first extract x from the arithmetic expression
149 auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]);
150 // then place in the comparison
151 if (comparison.left.get() == &outer_constant) {
152 comparison.right = std::move(arithmetic_child);
153 } else {
154 comparison.left = std::move(arithmetic_child);
155 }
156 changes_made = true;
157 return nullptr;
158}
159
160} // namespace duckdb
161