1 | #include "duckdb/optimizer/rule/arithmetic_simplification.hpp" |
2 | |
3 | #include "duckdb/common/exception.hpp" |
4 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
6 | |
7 | using namespace duckdb; |
8 | using namespace std; |
9 | |
10 | ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
11 | // match on an OperatorExpression that has a ConstantExpression as child |
12 | auto op = make_unique<FunctionExpressionMatcher>(); |
13 | op->matchers.push_back(make_unique<ConstantExpressionMatcher>()); |
14 | op->matchers.push_back(make_unique<ExpressionMatcher>()); |
15 | op->policy = SetMatcher::Policy::SOME; |
16 | // we only match on simple arithmetic expressions (+, -, *, /) |
17 | op->function = make_unique<ManyFunctionMatcher>(unordered_set<string>{"+" , "-" , "*" , "/" }); |
18 | // and only with numeric results |
19 | op->type = make_unique<IntegerTypeMatcher>(); |
20 | root = move(op); |
21 | } |
22 | |
23 | unique_ptr<Expression> ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector<Expression *> &bindings, |
24 | bool &changes_made) { |
25 | auto root = (BoundFunctionExpression *)bindings[0]; |
26 | auto constant = (BoundConstantExpression *)bindings[1]; |
27 | int constant_child = root->children[0].get() == constant ? 0 : 1; |
28 | assert(root->children.size() == 2); |
29 | // any arithmetic operator involving NULL is always NULL |
30 | if (constant->value.is_null) { |
31 | return make_unique<BoundConstantExpression>(Value(root->return_type)); |
32 | } |
33 | auto &func_name = root->function.name; |
34 | if (func_name == "+" ) { |
35 | if (constant->value == 0) { |
36 | // addition with 0 |
37 | // we can remove the entire operator and replace it with the non-constant child |
38 | return move(root->children[1 - constant_child]); |
39 | } |
40 | } else if (func_name == "-" ) { |
41 | if (constant_child == 1 && constant->value == 0) { |
42 | // subtraction by 0 |
43 | // we can remove the entire operator and replace it with the non-constant child |
44 | return move(root->children[1 - constant_child]); |
45 | } |
46 | } else if (func_name == "*" ) { |
47 | if (constant->value == 1) { |
48 | // multiply with 1, replace with non-constant child |
49 | return move(root->children[1 - constant_child]); |
50 | } |
51 | } else { |
52 | assert(func_name == "/" ); |
53 | if (constant_child == 1) { |
54 | if (constant->value == 1) { |
55 | // divide by 1, replace with non-constant child |
56 | return move(root->children[1 - constant_child]); |
57 | } else if (constant->value == 0) { |
58 | // divide by 0, replace with NULL |
59 | return make_unique<BoundConstantExpression>(Value(root->return_type)); |
60 | } |
61 | } |
62 | } |
63 | return nullptr; |
64 | } |
65 | |