1#include "duckdb/optimizer/rule/arithmetic_simplification.hpp"
2
3#include "duckdb/common/exception.hpp"
4#include "duckdb/planner/expression/bound_constant_expression.hpp"
5#include "duckdb/planner/expression/bound_function_expression.hpp"
6
7using namespace duckdb;
8using namespace std;
9
10ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
11 // match on an OperatorExpression that has a ConstantExpression as child
12 auto op = make_unique<FunctionExpressionMatcher>();
13 op->matchers.push_back(make_unique<ConstantExpressionMatcher>());
14 op->matchers.push_back(make_unique<ExpressionMatcher>());
15 op->policy = SetMatcher::Policy::SOME;
16 // we only match on simple arithmetic expressions (+, -, *, /)
17 op->function = make_unique<ManyFunctionMatcher>(unordered_set<string>{"+", "-", "*", "/"});
18 // and only with numeric results
19 op->type = make_unique<IntegerTypeMatcher>();
20 root = move(op);
21}
22
23unique_ptr<Expression> ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector<Expression *> &bindings,
24 bool &changes_made) {
25 auto root = (BoundFunctionExpression *)bindings[0];
26 auto constant = (BoundConstantExpression *)bindings[1];
27 int constant_child = root->children[0].get() == constant ? 0 : 1;
28 assert(root->children.size() == 2);
29 // any arithmetic operator involving NULL is always NULL
30 if (constant->value.is_null) {
31 return make_unique<BoundConstantExpression>(Value(root->return_type));
32 }
33 auto &func_name = root->function.name;
34 if (func_name == "+") {
35 if (constant->value == 0) {
36 // addition with 0
37 // we can remove the entire operator and replace it with the non-constant child
38 return move(root->children[1 - constant_child]);
39 }
40 } else if (func_name == "-") {
41 if (constant_child == 1 && constant->value == 0) {
42 // subtraction by 0
43 // we can remove the entire operator and replace it with the non-constant child
44 return move(root->children[1 - constant_child]);
45 }
46 } else if (func_name == "*") {
47 if (constant->value == 1) {
48 // multiply with 1, replace with non-constant child
49 return move(root->children[1 - constant_child]);
50 }
51 } else {
52 assert(func_name == "/");
53 if (constant_child == 1) {
54 if (constant->value == 1) {
55 // divide by 1, replace with non-constant child
56 return move(root->children[1 - constant_child]);
57 } else if (constant->value == 0) {
58 // divide by 0, replace with NULL
59 return make_unique<BoundConstantExpression>(Value(root->return_type));
60 }
61 }
62 }
63 return nullptr;
64}
65