1#include "duckdb/parser/expression/case_expression.hpp"
2#include "duckdb/parser/expression/comparison_expression.hpp"
3#include "duckdb/parser/expression/constant_expression.hpp"
4#include "duckdb/parser/transformer.hpp"
5
6using namespace duckdb;
7using namespace std;
8
9unique_ptr<ParsedExpression> Transformer::TransformCase(PGCaseExpr *root) {
10 if (!root) {
11 return nullptr;
12 }
13 // CASE expression WHEN value THEN result [WHEN ...] ELSE result uses this,
14 // but we rewrite to CASE WHEN expression = value THEN result ... to only
15 // have to handle one case downstream.
16
17 unique_ptr<ParsedExpression> def_res;
18 if (root->defresult) {
19 def_res = TransformExpression(reinterpret_cast<PGNode *>(root->defresult));
20 } else {
21 def_res = make_unique<ConstantExpression>(SQLType::SQLNULL, Value());
22 }
23 // def_res will be the else part of the innermost case expression
24
25 // CASE WHEN e1 THEN r1 WHEN w2 THEN r2 ELSE r3 is rewritten to
26 // CASE WHEN e1 THEN r1 ELSE CASE WHEN e2 THEN r2 ELSE r3
27
28 auto exp_root = make_unique<CaseExpression>();
29 auto cur_root = exp_root.get();
30 for (auto cell = root->args->head; cell != nullptr; cell = cell->next) {
31 auto w = reinterpret_cast<PGCaseWhen *>(cell->data.ptr_value);
32 auto test_raw = TransformExpression(reinterpret_cast<PGNode *>(w->expr));
33 unique_ptr<ParsedExpression> test;
34 auto arg = TransformExpression(reinterpret_cast<PGNode *>(root->arg));
35 if (arg) {
36 test = make_unique<ComparisonExpression>(ExpressionType::COMPARE_EQUAL, move(arg), move(test_raw));
37 } else {
38 test = move(test_raw);
39 }
40
41 cur_root->check = move(test);
42 cur_root->result_if_true = TransformExpression(reinterpret_cast<PGNode *>(w->result));
43 if (cell->next == nullptr) {
44 // finished all cases
45 // res_false is the default result
46 cur_root->result_if_false = move(def_res);
47 } else {
48 // more cases remain, create a case statement within the FALSE branch
49 auto next_case = make_unique<CaseExpression>();
50 auto case_ptr = next_case.get();
51 cur_root->result_if_false = move(next_case);
52 cur_root = case_ptr;
53 }
54 }
55
56 return move(exp_root);
57}
58