| 1 | #include "duckdb/optimizer/rule/move_constants.hpp" | 
| 2 |  | 
| 3 | #include "duckdb/common/exception.hpp" | 
| 4 | #include "duckdb/common/value_operations/value_operations.hpp" | 
| 5 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" | 
| 6 | #include "duckdb/planner/expression/bound_constant_expression.hpp" | 
| 7 | #include "duckdb/planner/expression/bound_function_expression.hpp" | 
| 8 | #include "duckdb/optimizer/expression_rewriter.hpp" | 
| 9 |  | 
| 10 | namespace duckdb { | 
| 11 |  | 
| 12 | MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { | 
| 13 | 	auto op = make_uniq<ComparisonExpressionMatcher>(); | 
| 14 | 	op->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); | 
| 15 | 	op->policy = SetMatcher::Policy::UNORDERED; | 
| 16 |  | 
| 17 | 	auto arithmetic = make_uniq<FunctionExpressionMatcher>(); | 
| 18 | 	// we handle multiplication, addition and subtraction because those are "easy" | 
| 19 | 	// integer division makes the division case difficult | 
| 20 | 	// e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules | 
| 21 | 	arithmetic->function = make_uniq<ManyFunctionMatcher>(args: unordered_set<string> {"+" , "-" , "*" }); | 
| 22 | 	// we match only on integral numeric types | 
| 23 | 	arithmetic->type = make_uniq<IntegerTypeMatcher>(); | 
| 24 | 	arithmetic->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); | 
| 25 | 	arithmetic->matchers.push_back(x: make_uniq<ExpressionMatcher>()); | 
| 26 | 	arithmetic->policy = SetMatcher::Policy::SOME; | 
| 27 | 	op->matchers.push_back(x: std::move(arithmetic)); | 
| 28 | 	root = std::move(op); | 
| 29 | } | 
| 30 |  | 
| 31 | unique_ptr<Expression> MoveConstantsRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, | 
| 32 |                                                 bool &changes_made, bool is_root) { | 
| 33 | 	auto &comparison = bindings[0].get().Cast<BoundComparisonExpression>(); | 
| 34 | 	auto &outer_constant = bindings[1].get().Cast<BoundConstantExpression>(); | 
| 35 | 	auto &arithmetic = bindings[2].get().Cast<BoundFunctionExpression>(); | 
| 36 | 	auto &inner_constant = bindings[3].get().Cast<BoundConstantExpression>(); | 
| 37 | 	if (!TypeIsIntegral(type: arithmetic.return_type.InternalType())) { | 
| 38 | 		return nullptr; | 
| 39 | 	} | 
| 40 | 	if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { | 
| 41 | 		return make_uniq<BoundConstantExpression>(args: Value(comparison.return_type)); | 
| 42 | 	} | 
| 43 | 	auto &constant_type = outer_constant.return_type; | 
| 44 | 	hugeint_t outer_value = IntegralValue::Get(value: outer_constant.value); | 
| 45 | 	hugeint_t inner_value = IntegralValue::Get(value: inner_constant.value); | 
| 46 |  | 
| 47 | 	idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; | 
| 48 | 	auto &op_type = arithmetic.function.name; | 
| 49 | 	if (op_type == "+" ) { | 
| 50 | 		// [x + 1 COMP 10] OR [1 + x COMP 10] | 
| 51 | 		// order does not matter in addition: | 
| 52 | 		// simply change right side to 10-1 (outer_constant - inner_constant) | 
| 53 | 		if (!Hugeint::SubtractInPlace(lhs&: outer_value, rhs: inner_value)) { | 
| 54 | 			return nullptr; | 
| 55 | 		} | 
| 56 | 		auto result_value = Value::HUGEINT(value: outer_value); | 
| 57 | 		if (!result_value.DefaultTryCastAs(target_type: constant_type)) { | 
| 58 | 			if (comparison.type != ExpressionType::COMPARE_EQUAL) { | 
| 59 | 				return nullptr; | 
| 60 | 			} | 
| 61 | 			// if the cast is not possible then the comparison is not possible | 
| 62 | 			// for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 | 
| 63 | 			// since this is not possible we can remove the entire branch here | 
| 64 | 			return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), | 
| 65 | 			                                          value: Value::BOOLEAN(value: false)); | 
| 66 | 		} | 
| 67 | 		outer_constant.value = std::move(result_value); | 
| 68 | 	} else if (op_type == "-" ) { | 
| 69 | 		// [x - 1 COMP 10] O R [1 - x COMP 10] | 
| 70 | 		// order matters in subtraction: | 
| 71 | 		if (arithmetic_child_index == 0) { | 
| 72 | 			// [x - 1 COMP 10] | 
| 73 | 			// change right side to 10+1 (outer_constant + inner_constant) | 
| 74 | 			if (!Hugeint::AddInPlace(lhs&: outer_value, rhs: inner_value)) { | 
| 75 | 				return nullptr; | 
| 76 | 			} | 
| 77 | 			auto result_value = Value::HUGEINT(value: outer_value); | 
| 78 | 			if (!result_value.DefaultTryCastAs(target_type: constant_type)) { | 
| 79 | 				// if the cast is not possible then an equality comparison is not possible | 
| 80 | 				if (comparison.type != ExpressionType::COMPARE_EQUAL) { | 
| 81 | 					return nullptr; | 
| 82 | 				} | 
| 83 | 				return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), | 
| 84 | 				                                          value: Value::BOOLEAN(value: false)); | 
| 85 | 			} | 
| 86 | 			outer_constant.value = std::move(result_value); | 
| 87 | 		} else { | 
| 88 | 			// [1 - x COMP 10] | 
| 89 | 			// change right side to 1-10=-9 | 
| 90 | 			if (!Hugeint::SubtractInPlace(lhs&: inner_value, rhs: outer_value)) { | 
| 91 | 				return nullptr; | 
| 92 | 			} | 
| 93 | 			auto result_value = Value::HUGEINT(value: inner_value); | 
| 94 | 			if (!result_value.DefaultTryCastAs(target_type: constant_type)) { | 
| 95 | 				// if the cast is not possible then an equality comparison is not possible | 
| 96 | 				if (comparison.type != ExpressionType::COMPARE_EQUAL) { | 
| 97 | 					return nullptr; | 
| 98 | 				} | 
| 99 | 				return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), | 
| 100 | 				                                          value: Value::BOOLEAN(value: false)); | 
| 101 | 			} | 
| 102 | 			outer_constant.value = std::move(result_value); | 
| 103 | 			// in this case, we should also flip the comparison | 
| 104 | 			// e.g. if we have [4 - x < 2] then we should have [x > 2] | 
| 105 | 			comparison.type = FlipComparisonExpression(type: comparison.type); | 
| 106 | 		} | 
| 107 | 	} else { | 
| 108 | 		D_ASSERT(op_type == "*" ); | 
| 109 | 		// [x * 2 COMP 10] OR [2 * x COMP 10] | 
| 110 | 		// order does not matter in multiplication: | 
| 111 | 		// change right side to 10/2 (outer_constant / inner_constant) | 
| 112 | 		// but ONLY if outer_constant is cleanly divisible by the inner_constant | 
| 113 | 		if (inner_value == 0) { | 
| 114 | 			// x * 0, the result is either 0 or NULL | 
| 115 | 			// we let the arithmetic_simplification rule take care of simplifying this first | 
| 116 | 			return nullptr; | 
| 117 | 		} | 
| 118 | 		if (outer_value % inner_value != 0) { | 
| 119 | 			// not cleanly divisible | 
| 120 | 			bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; | 
| 121 | 			bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; | 
| 122 | 			if (is_equality || is_inequality) { | 
| 123 | 				// we know the values are not equal | 
| 124 | 				// the result will be either FALSE or NULL (if COMPARE_EQUAL) | 
| 125 | 				// or TRUE or NULL (if COMPARE_NOTEQUAL) | 
| 126 | 				return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), | 
| 127 | 				                                          value: Value::BOOLEAN(value: is_inequality)); | 
| 128 | 			} else { | 
| 129 | 				// not cleanly divisible and we are doing > >= < <=, skip the simplification for now | 
| 130 | 				return nullptr; | 
| 131 | 			} | 
| 132 | 		} | 
| 133 | 		if (inner_value < 0) { | 
| 134 | 			// multiply by negative value, need to flip expression | 
| 135 | 			comparison.type = FlipComparisonExpression(type: comparison.type); | 
| 136 | 		} | 
| 137 | 		// else divide the RHS by the LHS | 
| 138 | 		// we need to do a range check on the cast even though we do a division | 
| 139 | 		// because e.g. -128 / -1 = 128, which is out of range | 
| 140 | 		auto result_value = Value::HUGEINT(value: outer_value / inner_value); | 
| 141 | 		if (!result_value.DefaultTryCastAs(target_type: constant_type)) { | 
| 142 | 			return ExpressionRewriter::ConstantOrNull(child: std::move(arithmetic.children[arithmetic_child_index]), | 
| 143 | 			                                          value: Value::BOOLEAN(value: false)); | 
| 144 | 		} | 
| 145 | 		outer_constant.value = std::move(result_value); | 
| 146 | 	} | 
| 147 | 	// replace left side with x | 
| 148 | 	// first extract x from the arithmetic expression | 
| 149 | 	auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); | 
| 150 | 	// then place in the comparison | 
| 151 | 	if (comparison.left.get() == &outer_constant) { | 
| 152 | 		comparison.right = std::move(arithmetic_child); | 
| 153 | 	} else { | 
| 154 | 		comparison.left = std::move(arithmetic_child); | 
| 155 | 	} | 
| 156 | 	changes_made = true; | 
| 157 | 	return nullptr; | 
| 158 | } | 
| 159 |  | 
| 160 | } // namespace duckdb | 
| 161 |  |