1#include "duckdb/optimizer/rule/enum_comparison.hpp"
2
3#include "duckdb/execution/expression_executor.hpp"
4#include "duckdb/planner/expression/bound_comparison_expression.hpp"
5#include "duckdb/planner/expression/bound_cast_expression.hpp"
6#include "duckdb/optimizer/matcher/type_matcher_id.hpp"
7#include "duckdb/optimizer/expression_rewriter.hpp"
8#include "duckdb/common/types.hpp"
9
10namespace duckdb {
11
12EnumComparisonRule::EnumComparisonRule(ExpressionRewriter &rewriter) : Rule(rewriter) {
13 // match on a ComparisonExpression that is an Equality and has a VARCHAR and ENUM as its children
14 auto op = make_uniq<ComparisonExpressionMatcher>();
15 // Enum requires expression to be root
16 op->expr_type = make_uniq<SpecificExpressionTypeMatcher>(args: ExpressionType::COMPARE_EQUAL);
17 for (idx_t i = 0; i < 2; i++) {
18 auto child = make_uniq<CastExpressionMatcher>();
19 child->type = make_uniq<TypeMatcherId>(args: LogicalTypeId::VARCHAR);
20 child->matcher = make_uniq<ExpressionMatcher>();
21 child->matcher->type = make_uniq<TypeMatcherId>(args: LogicalTypeId::ENUM);
22 op->matchers.push_back(x: std::move(child));
23 }
24 root = std::move(op);
25}
26
27bool AreMatchesPossible(LogicalType &left, LogicalType &right) {
28 LogicalType *small_enum, *big_enum;
29 if (EnumType::GetSize(type: left) < EnumType::GetSize(type: right)) {
30 small_enum = &left;
31 big_enum = &right;
32 } else {
33 small_enum = &right;
34 big_enum = &left;
35 }
36 auto &string_vec = EnumType::GetValuesInsertOrder(type: *small_enum);
37 auto string_vec_ptr = FlatVector::GetData<string_t>(vector: string_vec);
38 auto size = EnumType::GetSize(type: *small_enum);
39 for (idx_t i = 0; i < size; i++) {
40 auto key = string_vec_ptr[i].GetString();
41 if (EnumType::GetPos(type: *big_enum, key) != -1) {
42 return true;
43 }
44 }
45 return false;
46}
47unique_ptr<Expression> EnumComparisonRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings,
48 bool &changes_made, bool is_root) {
49
50 auto &root = bindings[0].get().Cast<BoundComparisonExpression>();
51 auto &left_child = bindings[1].get().Cast<BoundCastExpression>();
52 auto &right_child = bindings[3].get().Cast<BoundCastExpression>();
53
54 if (!AreMatchesPossible(left&: left_child.child->return_type, right&: right_child.child->return_type)) {
55 vector<unique_ptr<Expression>> children;
56 children.push_back(x: std::move(root.left));
57 children.push_back(x: std::move(root.right));
58 return ExpressionRewriter::ConstantOrNull(children: std::move(children), value: Value::BOOLEAN(value: false));
59 }
60
61 if (!is_root || op.type != LogicalOperatorType::LOGICAL_FILTER) {
62 return nullptr;
63 }
64
65 auto cast_left_to_right =
66 BoundCastExpression::AddDefaultCastToType(expr: std::move(left_child.child), target_type: right_child.child->return_type, try_cast: true);
67 return make_uniq<BoundComparisonExpression>(args&: root.type, args: std::move(cast_left_to_right), args: std::move(right_child.child));
68}
69
70} // namespace duckdb
71