1#include "duckdb/planner/expression/list.hpp"
2#include "duckdb/optimizer/rule/comparison_simplification.hpp"
3
4#include "duckdb/execution/expression_executor.hpp"
5#include "duckdb/planner/expression/bound_constant_expression.hpp"
6
7using namespace duckdb;
8using namespace std;
9
10ComparisonSimplificationRule::ComparisonSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
11 // match on a ComparisonExpression that has a ConstantExpression as a check
12 auto op = make_unique<ComparisonExpressionMatcher>();
13 op->matchers.push_back(make_unique<FoldableConstantMatcher>());
14 op->policy = SetMatcher::Policy::SOME;
15 root = move(op);
16}
17
18unique_ptr<Expression> ComparisonSimplificationRule::Apply(LogicalOperator &op, vector<Expression *> &bindings,
19 bool &changes_made) {
20 assert(bindings[0]->expression_class == ExpressionClass::BOUND_COMPARISON);
21 auto expr = (BoundComparisonExpression *)bindings[0];
22 auto constant_expr = bindings[1];
23 bool column_ref_left = expr->left.get() != constant_expr;
24 auto column_ref_expr = !column_ref_left ? expr->right.get() : expr->left.get();
25 // the constant_expr is a scalar expression that we have to fold
26 // use an ExpressionExecutor to execute the expression
27 assert(constant_expr->IsFoldable());
28 auto constant_value = ExpressionExecutor::EvaluateScalar(*constant_expr);
29 if (constant_value.is_null) {
30 // comparison with constant NULL, return NULL
31 return make_unique<BoundConstantExpression>(Value(TypeId::BOOL));
32 }
33 if (column_ref_expr->expression_class == ExpressionClass::BOUND_CAST &&
34 constant_expr->expression_class == ExpressionClass::BOUND_CONSTANT) {
35 //! Here we check if we can apply the expression on the constant side
36 auto cast_expression = (BoundCastExpression *)column_ref_expr;
37 if (!BoundCastExpression::CastIsInvertible(cast_expression->source_type, cast_expression->target_type)) {
38 return nullptr;
39 }
40 auto bound_const_expr = (BoundConstantExpression *)constant_expr;
41 auto new_constant =
42 bound_const_expr->value.TryCastAs(cast_expression->target_type.id, cast_expression->source_type.id);
43 if (new_constant) {
44 auto child_expression = move(cast_expression->child);
45 constant_expr->return_type = bound_const_expr->value.type;
46 //! We can cast, now we change our column_ref_expression from an operator cast to a column reference
47 if (column_ref_left) {
48 expr->left = move(child_expression);
49 } else {
50 expr->right = move(child_expression);
51 }
52 }
53 }
54 return nullptr;
55}
56