1 | #include "duckdb/optimizer/rule/case_simplification.hpp" |
---|---|
2 | |
3 | #include "duckdb/execution/expression_executor.hpp" |
4 | #include "duckdb/planner/expression/bound_case_expression.hpp" |
5 | |
6 | namespace duckdb { |
7 | |
8 | CaseSimplificationRule::CaseSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { |
9 | // match on a CaseExpression that has a ConstantExpression as a check |
10 | auto op = make_uniq<CaseExpressionMatcher>(); |
11 | root = std::move(op); |
12 | } |
13 | |
14 | unique_ptr<Expression> CaseSimplificationRule::Apply(LogicalOperator &op, vector<reference<Expression>> &bindings, |
15 | bool &changes_made, bool is_root) { |
16 | auto &root = bindings[0].get().Cast<BoundCaseExpression>(); |
17 | for (idx_t i = 0; i < root.case_checks.size(); i++) { |
18 | auto &case_check = root.case_checks[i]; |
19 | if (case_check.when_expr->IsFoldable()) { |
20 | // the WHEN check is a foldable expression |
21 | // use an ExpressionExecutor to execute the expression |
22 | auto constant_value = ExpressionExecutor::EvaluateScalar(context&: GetContext(), expr: *case_check.when_expr); |
23 | |
24 | // fold based on the constant condition |
25 | auto condition = constant_value.DefaultCastAs(target_type: LogicalType::BOOLEAN); |
26 | if (condition.IsNull() || !BooleanValue::Get(value: condition)) { |
27 | // the condition is always false: remove this case check |
28 | root.case_checks.erase(position: root.case_checks.begin() + i); |
29 | i--; |
30 | } else { |
31 | // the condition is always true |
32 | // move the THEN clause to the ELSE of the case |
33 | root.else_expr = std::move(case_check.then_expr); |
34 | // remove this case check and any case checks after this one |
35 | root.case_checks.erase(first: root.case_checks.begin() + i, last: root.case_checks.end()); |
36 | break; |
37 | } |
38 | } |
39 | } |
40 | if (root.case_checks.empty()) { |
41 | // no case checks left: return the ELSE expression |
42 | return std::move(root.else_expr); |
43 | } |
44 | return nullptr; |
45 | } |
46 | |
47 | } // namespace duckdb |
48 |