1 | #include "duckdb/optimizer/rule/arithmetic_simplification.hpp" |
2 | |
3 | #include "duckdb/common/exception.hpp" |
4 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
6 | #include "duckdb/optimizer/expression_rewriter.hpp" |
7 | |
8 | namespace duckdb { |
9 | |
10 | ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
11 | // match on an OperatorExpression that has a ConstantExpression as child |
12 | auto op = make_uniq<FunctionExpressionMatcher>(); |
13 | op->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>()); |
14 | op->matchers.push_back(x: make_uniq<ExpressionMatcher>()); |
15 | op->policy = SetMatcher::Policy::SOME; |
16 | // we only match on simple arithmetic expressions (+, -, *, /) |
17 | op->function = make_uniq<ManyFunctionMatcher>(args: unordered_set<string> {"+" , "-" , "*" , "//" }); |
18 | // and only with numeric results |
19 | op->type = make_uniq<IntegerTypeMatcher>(); |
20 | op->matchers[0]->type = make_uniq<IntegerTypeMatcher>(); |
21 | op->matchers[1]->type = make_uniq<IntegerTypeMatcher>(); |
22 | root = std::move(op); |
23 | } |
24 | |
25 | unique_ptr<Expression> ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, |
26 | bool &changes_made, bool is_root) { |
27 | auto &root = bindings[0].get().Cast<BoundFunctionExpression>(); |
28 | auto &constant = bindings[1].get().Cast<BoundConstantExpression>(); |
29 | int constant_child = root.children[0].get() == &constant ? 0 : 1; |
30 | D_ASSERT(root.children.size() == 2); |
31 | (void)root; |
32 | // any arithmetic operator involving NULL is always NULL |
33 | if (constant.value.IsNull()) { |
34 | return make_uniq<BoundConstantExpression>(args: Value(root.return_type)); |
35 | } |
36 | auto &func_name = root.function.name; |
37 | if (func_name == "+" ) { |
38 | if (constant.value == 0) { |
39 | // addition with 0 |
40 | // we can remove the entire operator and replace it with the non-constant child |
41 | return std::move(root.children[1 - constant_child]); |
42 | } |
43 | } else if (func_name == "-" ) { |
44 | if (constant_child == 1 && constant.value == 0) { |
45 | // subtraction by 0 |
46 | // we can remove the entire operator and replace it with the non-constant child |
47 | return std::move(root.children[1 - constant_child]); |
48 | } |
49 | } else if (func_name == "*" ) { |
50 | if (constant.value == 1) { |
51 | // multiply with 1, replace with non-constant child |
52 | return std::move(root.children[1 - constant_child]); |
53 | } else if (constant.value == 0) { |
54 | // multiply by zero: replace with constant or null |
55 | return ExpressionRewriter::ConstantOrNull(child: std::move(root.children[1 - constant_child]), |
56 | value: Value::Numeric(type: root.return_type, value: 0)); |
57 | } |
58 | } else if (func_name == "//" ) { |
59 | if (constant_child == 1) { |
60 | if (constant.value == 1) { |
61 | // divide by 1, replace with non-constant child |
62 | return std::move(root.children[1 - constant_child]); |
63 | } else if (constant.value == 0) { |
64 | // divide by 0, replace with NULL |
65 | return make_uniq<BoundConstantExpression>(args: Value(root.return_type)); |
66 | } |
67 | } |
68 | } else { |
69 | throw InternalException("Unrecognized function name in ArithmeticSimplificationRule" ); |
70 | } |
71 | return nullptr; |
72 | } |
73 | } // namespace duckdb |
74 | |