| 1 | #include "duckdb/planner/expression/list.hpp" | 
| 2 | #include "duckdb/optimizer/rule/comparison_simplification.hpp" | 
| 3 |  | 
| 4 | #include "duckdb/execution/expression_executor.hpp" | 
| 5 | #include "duckdb/planner/expression/bound_constant_expression.hpp" | 
| 6 |  | 
| 7 | namespace duckdb { | 
| 8 |  | 
| 9 | ComparisonSimplificationRule::ComparisonSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { | 
| 10 | 	// match on a ComparisonExpression that has a ConstantExpression as a check | 
| 11 | 	auto op = make_uniq<ComparisonExpressionMatcher>(); | 
| 12 | 	op->matchers.push_back(x: make_uniq<FoldableConstantMatcher>()); | 
| 13 | 	op->policy = SetMatcher::Policy::SOME; | 
| 14 | 	root = std::move(op); | 
| 15 | } | 
| 16 |  | 
| 17 | unique_ptr<Expression> ComparisonSimplificationRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, | 
| 18 |                                                            bool &changes_made, bool is_root) { | 
| 19 | 	auto &expr = bindings[0].get().Cast<BoundComparisonExpression>(); | 
| 20 | 	auto &constant_expr = bindings[1].get(); | 
| 21 | 	bool column_ref_left = expr.left.get() != &constant_expr; | 
| 22 | 	auto column_ref_expr = !column_ref_left ? expr.right.get() : expr.left.get(); | 
| 23 | 	// the constant_expr is a scalar expression that we have to fold | 
| 24 | 	// use an ExpressionExecutor to execute the expression | 
| 25 | 	D_ASSERT(constant_expr.IsFoldable()); | 
| 26 | 	Value constant_value; | 
| 27 | 	if (!ExpressionExecutor::TryEvaluateScalar(context&: GetContext(), expr: constant_expr, result&: constant_value)) { | 
| 28 | 		return nullptr; | 
| 29 | 	} | 
| 30 | 	if (constant_value.IsNull() && !(expr.type == ExpressionType::COMPARE_NOT_DISTINCT_FROM || | 
| 31 | 	                                 expr.type == ExpressionType::COMPARE_DISTINCT_FROM)) { | 
| 32 | 		// comparison with constant NULL, return NULL | 
| 33 | 		return make_uniq<BoundConstantExpression>(args: Value(LogicalType::BOOLEAN)); | 
| 34 | 	} | 
| 35 | 	if (column_ref_expr->expression_class == ExpressionClass::BOUND_CAST) { | 
| 36 | 		//! Here we check if we can apply the expression on the constant side | 
| 37 | 		//! We can do this if the cast itself is invertible and casting the constant is | 
| 38 | 		//! invertible in practice. | 
| 39 | 		auto &cast_expression = column_ref_expr->Cast<BoundCastExpression>(); | 
| 40 | 		auto target_type = cast_expression.source_type(); | 
| 41 | 		if (!BoundCastExpression::CastIsInvertible(source_type: target_type, target_type: cast_expression.return_type)) { | 
| 42 | 			return nullptr; | 
| 43 | 		} | 
| 44 |  | 
| 45 | 		// Can we cast the constant at all? | 
| 46 | 		string error_message; | 
| 47 | 		Value cast_constant; | 
| 48 | 		auto new_constant = constant_value.DefaultTryCastAs(target_type, new_value&: cast_constant, error_message: &error_message, strict: true); | 
| 49 | 		if (!new_constant) { | 
| 50 | 			return nullptr; | 
| 51 | 		} | 
| 52 |  | 
| 53 | 		// Is the constant cast invertible? | 
| 54 | 		if (!cast_constant.IsNull() && | 
| 55 | 		    !BoundCastExpression::CastIsInvertible(source_type: cast_expression.return_type, target_type)) { | 
| 56 | 			// Is it actually invertible? | 
| 57 | 			Value uncast_constant; | 
| 58 | 			if (!cast_constant.DefaultTryCastAs(target_type: constant_value.type(), new_value&: uncast_constant, error_message: &error_message, strict: true) || | 
| 59 | 			    uncast_constant != constant_value) { | 
| 60 | 				return nullptr; | 
| 61 | 			} | 
| 62 | 		} | 
| 63 |  | 
| 64 | 		//! We can cast, now we change our column_ref_expression from an operator cast to a column reference | 
| 65 | 		auto child_expression = std::move(cast_expression.child); | 
| 66 | 		auto new_constant_expr = make_uniq<BoundConstantExpression>(args&: cast_constant); | 
| 67 | 		if (column_ref_left) { | 
| 68 | 			expr.left = std::move(child_expression); | 
| 69 | 			expr.right = std::move(new_constant_expr); | 
| 70 | 		} else { | 
| 71 | 			expr.left = std::move(new_constant_expr); | 
| 72 | 			expr.right = std::move(child_expression); | 
| 73 | 		} | 
| 74 | 	} | 
| 75 | 	return nullptr; | 
| 76 | } | 
| 77 |  | 
| 78 | } // namespace duckdb | 
| 79 |  |