| 1 | #include "duckdb/optimizer/rule/arithmetic_simplification.hpp" |
| 2 | |
| 3 | #include "duckdb/common/exception.hpp" |
| 4 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 6 | |
| 7 | using namespace duckdb; |
| 8 | using namespace std; |
| 9 | |
| 10 | ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
| 11 | // match on an OperatorExpression that has a ConstantExpression as child |
| 12 | auto op = make_unique<FunctionExpressionMatcher>(); |
| 13 | op->matchers.push_back(make_unique<ConstantExpressionMatcher>()); |
| 14 | op->matchers.push_back(make_unique<ExpressionMatcher>()); |
| 15 | op->policy = SetMatcher::Policy::SOME; |
| 16 | // we only match on simple arithmetic expressions (+, -, *, /) |
| 17 | op->function = make_unique<ManyFunctionMatcher>(unordered_set<string>{"+" , "-" , "*" , "/" }); |
| 18 | // and only with numeric results |
| 19 | op->type = make_unique<IntegerTypeMatcher>(); |
| 20 | root = move(op); |
| 21 | } |
| 22 | |
| 23 | unique_ptr<Expression> ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector<Expression *> &bindings, |
| 24 | bool &changes_made) { |
| 25 | auto root = (BoundFunctionExpression *)bindings[0]; |
| 26 | auto constant = (BoundConstantExpression *)bindings[1]; |
| 27 | int constant_child = root->children[0].get() == constant ? 0 : 1; |
| 28 | assert(root->children.size() == 2); |
| 29 | // any arithmetic operator involving NULL is always NULL |
| 30 | if (constant->value.is_null) { |
| 31 | return make_unique<BoundConstantExpression>(Value(root->return_type)); |
| 32 | } |
| 33 | auto &func_name = root->function.name; |
| 34 | if (func_name == "+" ) { |
| 35 | if (constant->value == 0) { |
| 36 | // addition with 0 |
| 37 | // we can remove the entire operator and replace it with the non-constant child |
| 38 | return move(root->children[1 - constant_child]); |
| 39 | } |
| 40 | } else if (func_name == "-" ) { |
| 41 | if (constant_child == 1 && constant->value == 0) { |
| 42 | // subtraction by 0 |
| 43 | // we can remove the entire operator and replace it with the non-constant child |
| 44 | return move(root->children[1 - constant_child]); |
| 45 | } |
| 46 | } else if (func_name == "*" ) { |
| 47 | if (constant->value == 1) { |
| 48 | // multiply with 1, replace with non-constant child |
| 49 | return move(root->children[1 - constant_child]); |
| 50 | } |
| 51 | } else { |
| 52 | assert(func_name == "/" ); |
| 53 | if (constant_child == 1) { |
| 54 | if (constant->value == 1) { |
| 55 | // divide by 1, replace with non-constant child |
| 56 | return move(root->children[1 - constant_child]); |
| 57 | } else if (constant->value == 0) { |
| 58 | // divide by 0, replace with NULL |
| 59 | return make_unique<BoundConstantExpression>(Value(root->return_type)); |
| 60 | } |
| 61 | } |
| 62 | } |
| 63 | return nullptr; |
| 64 | } |
| 65 | |