| 1 | #include "duckdb/optimizer/rule/arithmetic_simplification.hpp" |
| 2 | |
| 3 | #include "duckdb/common/exception.hpp" |
| 4 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 6 | #include "duckdb/optimizer/expression_rewriter.hpp" |
| 7 | |
| 8 | namespace duckdb { |
| 9 | |
| 10 | ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
| 11 | // match on an OperatorExpression that has a ConstantExpression as child |
| 12 | auto op = make_uniq<FunctionExpressionMatcher>(); |
| 13 | op->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); |
| 14 | op->matchers.push_back(x: make_uniq<ExpressionMatcher>()); |
| 15 | op->policy = SetMatcher::Policy::SOME; |
| 16 | // we only match on simple arithmetic expressions (+, -, *, /) |
| 17 | op->function = make_uniq<ManyFunctionMatcher>(args: unordered_set<string> {"+" , "-" , "*" , "//" }); |
| 18 | // and only with numeric results |
| 19 | op->type = make_uniq<IntegerTypeMatcher>(); |
| 20 | op->matchers[0]->type = make_uniq<IntegerTypeMatcher>(); |
| 21 | op->matchers[1]->type = make_uniq<IntegerTypeMatcher>(); |
| 22 | root = std::move(op); |
| 23 | } |
| 24 | |
| 25 | unique_ptr<Expression> ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, |
| 26 | bool &changes_made, bool is_root) { |
| 27 | auto &root = bindings[0].get().Cast<BoundFunctionExpression>(); |
| 28 | auto &constant = bindings[1].get().Cast<BoundConstantExpression>(); |
| 29 | int constant_child = root.children[0].get() == &constant ? 0 : 1; |
| 30 | D_ASSERT(root.children.size() == 2); |
| 31 | (void)root; |
| 32 | // any arithmetic operator involving NULL is always NULL |
| 33 | if (constant.value.IsNull()) { |
| 34 | return make_uniq<BoundConstantExpression>(args: Value(root.return_type)); |
| 35 | } |
| 36 | auto &func_name = root.function.name; |
| 37 | if (func_name == "+" ) { |
| 38 | if (constant.value == 0) { |
| 39 | // addition with 0 |
| 40 | // we can remove the entire operator and replace it with the non-constant child |
| 41 | return std::move(root.children[1 - constant_child]); |
| 42 | } |
| 43 | } else if (func_name == "-" ) { |
| 44 | if (constant_child == 1 && constant.value == 0) { |
| 45 | // subtraction by 0 |
| 46 | // we can remove the entire operator and replace it with the non-constant child |
| 47 | return std::move(root.children[1 - constant_child]); |
| 48 | } |
| 49 | } else if (func_name == "*" ) { |
| 50 | if (constant.value == 1) { |
| 51 | // multiply with 1, replace with non-constant child |
| 52 | return std::move(root.children[1 - constant_child]); |
| 53 | } else if (constant.value == 0) { |
| 54 | // multiply by zero: replace with constant or null |
| 55 | return ExpressionRewriter::ConstantOrNull(child: std::move(root.children[1 - constant_child]), |
| 56 | value: Value::Numeric(type: root.return_type, value: 0)); |
| 57 | } |
| 58 | } else if (func_name == "//" ) { |
| 59 | if (constant_child == 1) { |
| 60 | if (constant.value == 1) { |
| 61 | // divide by 1, replace with non-constant child |
| 62 | return std::move(root.children[1 - constant_child]); |
| 63 | } else if (constant.value == 0) { |
| 64 | // divide by 0, replace with NULL |
| 65 | return make_uniq<BoundConstantExpression>(args: Value(root.return_type)); |
| 66 | } |
| 67 | } |
| 68 | } else { |
| 69 | throw InternalException("Unrecognized function name in ArithmeticSimplificationRule" ); |
| 70 | } |
| 71 | return nullptr; |
| 72 | } |
| 73 | } // namespace duckdb |
| 74 | |