| 1 | #include "duckdb/parser/expression/case_expression.hpp" |
|---|---|
| 2 | #include "duckdb/parser/expression/comparison_expression.hpp" |
| 3 | #include "duckdb/parser/expression/constant_expression.hpp" |
| 4 | #include "duckdb/parser/transformer.hpp" |
| 5 | |
| 6 | using namespace duckdb; |
| 7 | using namespace std; |
| 8 | |
| 9 | unique_ptr<ParsedExpression> Transformer::TransformCase(PGCaseExpr *root) { |
| 10 | if (!root) { |
| 11 | return nullptr; |
| 12 | } |
| 13 | // CASE expression WHEN value THEN result [WHEN ...] ELSE result uses this, |
| 14 | // but we rewrite to CASE WHEN expression = value THEN result ... to only |
| 15 | // have to handle one case downstream. |
| 16 | |
| 17 | unique_ptr<ParsedExpression> def_res; |
| 18 | if (root->defresult) { |
| 19 | def_res = TransformExpression(reinterpret_cast<PGNode *>(root->defresult)); |
| 20 | } else { |
| 21 | def_res = make_unique<ConstantExpression>(SQLType::SQLNULL, Value()); |
| 22 | } |
| 23 | // def_res will be the else part of the innermost case expression |
| 24 | |
| 25 | // CASE WHEN e1 THEN r1 WHEN w2 THEN r2 ELSE r3 is rewritten to |
| 26 | // CASE WHEN e1 THEN r1 ELSE CASE WHEN e2 THEN r2 ELSE r3 |
| 27 | |
| 28 | auto exp_root = make_unique<CaseExpression>(); |
| 29 | auto cur_root = exp_root.get(); |
| 30 | for (auto cell = root->args->head; cell != nullptr; cell = cell->next) { |
| 31 | auto w = reinterpret_cast<PGCaseWhen *>(cell->data.ptr_value); |
| 32 | auto test_raw = TransformExpression(reinterpret_cast<PGNode *>(w->expr)); |
| 33 | unique_ptr<ParsedExpression> test; |
| 34 | auto arg = TransformExpression(reinterpret_cast<PGNode *>(root->arg)); |
| 35 | if (arg) { |
| 36 | test = make_unique<ComparisonExpression>(ExpressionType::COMPARE_EQUAL, move(arg), move(test_raw)); |
| 37 | } else { |
| 38 | test = move(test_raw); |
| 39 | } |
| 40 | |
| 41 | cur_root->check = move(test); |
| 42 | cur_root->result_if_true = TransformExpression(reinterpret_cast<PGNode *>(w->result)); |
| 43 | if (cell->next == nullptr) { |
| 44 | // finished all cases |
| 45 | // res_false is the default result |
| 46 | cur_root->result_if_false = move(def_res); |
| 47 | } else { |
| 48 | // more cases remain, create a case statement within the FALSE branch |
| 49 | auto next_case = make_unique<CaseExpression>(); |
| 50 | auto case_ptr = next_case.get(); |
| 51 | cur_root->result_if_false = move(next_case); |
| 52 | cur_root = case_ptr; |
| 53 | } |
| 54 | } |
| 55 | |
| 56 | return move(exp_root); |
| 57 | } |
| 58 |