1#include "duckdb/optimizer/rule/arithmetic_simplification.hpp"
2
3#include "duckdb/common/exception.hpp"
4#include "duckdb/planner/expression/bound_constant_expression.hpp"
5#include "duckdb/planner/expression/bound_function_expression.hpp"
6#include "duckdb/optimizer/expression_rewriter.hpp"
7
8namespace duckdb {
9
10ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
11 // match on an OperatorExpression that has a ConstantExpression as child
12 auto op = make_uniq<FunctionExpressionMatcher>();
13 op->matchers.push_back(x: make_uniq<ConstantExpressionMatcher>());
14 op->matchers.push_back(x: make_uniq<ExpressionMatcher>());
15 op->policy = SetMatcher::Policy::SOME;
16 // we only match on simple arithmetic expressions (+, -, *, /)
17 op->function = make_uniq<ManyFunctionMatcher>(args: unordered_set<string> {"+", "-", "*", "//"});
18 // and only with numeric results
19 op->type = make_uniq<IntegerTypeMatcher>();
20 op->matchers[0]->type = make_uniq<IntegerTypeMatcher>();
21 op->matchers[1]->type = make_uniq<IntegerTypeMatcher>();
22 root = std::move(op);
23}
24
25unique_ptr<Expression> ArithmeticSimplificationRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
26 bool &changes_made, bool is_root) {
27 auto &root = bindings[0].get().Cast<BoundFunctionExpression>();
28 auto &constant = bindings[1].get().Cast<BoundConstantExpression>();
29 int constant_child = root.children[0].get() == &constant ? 0 : 1;
30 D_ASSERT(root.children.size() == 2);
31 (void)root;
32 // any arithmetic operator involving NULL is always NULL
33 if (constant.value.IsNull()) {
34 return make_uniq<BoundConstantExpression>(args: Value(root.return_type));
35 }
36 auto &func_name = root.function.name;
37 if (func_name == "+") {
38 if (constant.value == 0) {
39 // addition with 0
40 // we can remove the entire operator and replace it with the non-constant child
41 return std::move(root.children[1 - constant_child]);
42 }
43 } else if (func_name == "-") {
44 if (constant_child == 1 && constant.value == 0) {
45 // subtraction by 0
46 // we can remove the entire operator and replace it with the non-constant child
47 return std::move(root.children[1 - constant_child]);
48 }
49 } else if (func_name == "*") {
50 if (constant.value == 1) {
51 // multiply with 1, replace with non-constant child
52 return std::move(root.children[1 - constant_child]);
53 } else if (constant.value == 0) {
54 // multiply by zero: replace with constant or null
55 return ExpressionRewriter::ConstantOrNull(child: std::move(root.children[1 - constant_child]),
56 value: Value::Numeric(type: root.return_type, value: 0));
57 }
58 } else if (func_name == "//") {
59 if (constant_child == 1) {
60 if (constant.value == 1) {
61 // divide by 1, replace with non-constant child
62 return std::move(root.children[1 - constant_child]);
63 } else if (constant.value == 0) {
64 // divide by 0, replace with NULL
65 return make_uniq<BoundConstantExpression>(args: Value(root.return_type));
66 }
67 }
68 } else {
69 throw InternalException("Unrecognized function name in ArithmeticSimplificationRule");
70 }
71 return nullptr;
72}
73} // namespace duckdb
74