1 | #include "duckdb/optimizer/matcher/expression_matcher.hpp" |
2 | |
3 | #include "duckdb/planner/expression/list.hpp" |
4 | |
5 | using namespace duckdb; |
6 | using namespace std; |
7 | |
8 | bool ExpressionMatcher::Match(Expression *expr, vector<Expression *> &bindings) { |
9 | if (type && !type->Match(expr->return_type)) { |
10 | return false; |
11 | } |
12 | if (expr_type && !expr_type->Match(expr->type)) { |
13 | return false; |
14 | } |
15 | if (expr_class != ExpressionClass::INVALID && expr_class != expr->GetExpressionClass()) { |
16 | return false; |
17 | } |
18 | bindings.push_back(expr); |
19 | return true; |
20 | } |
21 | |
22 | bool ExpressionEqualityMatcher::Match(Expression *expr, vector<Expression *> &bindings) { |
23 | if (!Expression::Equals(expression, expr)) { |
24 | return false; |
25 | } |
26 | bindings.push_back(expr); |
27 | return true; |
28 | } |
29 | |
30 | bool CaseExpressionMatcher::Match(Expression *expr_, vector<Expression *> &bindings) { |
31 | if (!ExpressionMatcher::Match(expr_, bindings)) { |
32 | return false; |
33 | } |
34 | auto expr = (BoundCaseExpression *)expr_; |
35 | if (check && !check->Match(expr->check.get(), bindings)) { |
36 | return false; |
37 | } |
38 | if (result_if_true && !result_if_true->Match(expr->result_if_true.get(), bindings)) { |
39 | return false; |
40 | } |
41 | if (result_if_false && !result_if_false->Match(expr->result_if_false.get(), bindings)) { |
42 | return false; |
43 | } |
44 | return true; |
45 | } |
46 | |
47 | bool CastExpressionMatcher::Match(Expression *expr_, vector<Expression *> &bindings) { |
48 | if (!ExpressionMatcher::Match(expr_, bindings)) { |
49 | return false; |
50 | } |
51 | auto expr = (BoundCastExpression *)expr_; |
52 | if (child && !child->Match(expr->child.get(), bindings)) { |
53 | return false; |
54 | } |
55 | return true; |
56 | } |
57 | |
58 | bool ComparisonExpressionMatcher::Match(Expression *expr_, vector<Expression *> &bindings) { |
59 | if (!ExpressionMatcher::Match(expr_, bindings)) { |
60 | return false; |
61 | } |
62 | auto expr = (BoundComparisonExpression *)expr_; |
63 | vector<Expression *> expressions = {expr->left.get(), expr->right.get()}; |
64 | return SetMatcher::Match(matchers, expressions, bindings, policy); |
65 | } |
66 | |
67 | bool ConjunctionExpressionMatcher::Match(Expression *expr_, vector<Expression *> &bindings) { |
68 | if (!ExpressionMatcher::Match(expr_, bindings)) { |
69 | return false; |
70 | } |
71 | auto expr = (BoundConjunctionExpression *)expr_; |
72 | if (!SetMatcher::Match(matchers, expr->children, bindings, policy)) { |
73 | return false; |
74 | } |
75 | return true; |
76 | } |
77 | |
78 | bool OperatorExpressionMatcher::Match(Expression *expr_, vector<Expression *> &bindings) { |
79 | if (!ExpressionMatcher::Match(expr_, bindings)) { |
80 | return false; |
81 | } |
82 | auto expr = (BoundOperatorExpression *)expr_; |
83 | return SetMatcher::Match(matchers, expr->children, bindings, policy); |
84 | } |
85 | |
86 | bool FunctionExpressionMatcher::Match(Expression *expr_, vector<Expression *> &bindings) { |
87 | if (!ExpressionMatcher::Match(expr_, bindings)) { |
88 | return false; |
89 | } |
90 | auto expr = (BoundFunctionExpression *)expr_; |
91 | if (!FunctionMatcher::Match(function, expr->function.name)) { |
92 | return false; |
93 | } |
94 | if (!SetMatcher::Match(matchers, expr->children, bindings, policy)) { |
95 | return false; |
96 | } |
97 | return true; |
98 | } |
99 | |
100 | bool FoldableConstantMatcher::Match(Expression *expr, vector<Expression *> &bindings) { |
101 | // we match on ANY expression that is a scalar expression |
102 | if (!expr->IsFoldable()) { |
103 | return false; |
104 | } |
105 | bindings.push_back(expr); |
106 | return true; |
107 | } |
108 | |