1#include "duckdb/common/vector_operations/vector_operations.hpp"
2#include "duckdb/execution/expression_executor.hpp"
3#include "duckdb/planner/expression/bound_case_expression.hpp"
4#include "duckdb/common/types/chunk_collection.hpp"
5
6using namespace duckdb;
7using namespace std;
8
9void Case(Vector &res_true, Vector &res_false, Vector &result, SelectionVector &tside, idx_t tcount,
10 SelectionVector &fside, idx_t fcount);
11
12unique_ptr<ExpressionState> ExpressionExecutor::InitializeState(BoundCaseExpression &expr,
13 ExpressionExecutorState &root) {
14 auto result = make_unique<ExpressionState>(expr, root);
15 result->AddChild(expr.check.get());
16 result->AddChild(expr.result_if_true.get());
17 result->AddChild(expr.result_if_false.get());
18 return result;
19}
20
21void ExpressionExecutor::Execute(BoundCaseExpression &expr, ExpressionState *state, const SelectionVector *sel,
22 idx_t count, Vector &result) {
23 Vector res_true(expr.result_if_true->return_type), res_false(expr.result_if_false->return_type);
24
25 auto check_state = state->child_states[0].get();
26 auto res_true_state = state->child_states[1].get();
27 auto res_false_state = state->child_states[2].get();
28
29 // first execute the check expression
30 SelectionVector true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE);
31 idx_t tcount = Select(*expr.check, check_state, sel, count, &true_sel, &false_sel);
32 idx_t fcount = count - tcount;
33 if (fcount == 0) {
34 // everything is true, only execute TRUE side
35 Execute(*expr.result_if_true, res_true_state, sel, count, result);
36 } else if (tcount == 0) {
37 // everything is false, only execute FALSE side
38 Execute(*expr.result_if_false, res_false_state, sel, count, result);
39 } else {
40 // have to execute both and mix and match
41 Execute(*expr.result_if_true, res_true_state, &true_sel, tcount, res_true);
42 Execute(*expr.result_if_false, res_false_state, &false_sel, fcount, res_false);
43
44 Case(res_true, res_false, result, true_sel, tcount, false_sel, fcount);
45 if (sel) {
46 result.Slice(*sel, count);
47 }
48 }
49}
50
51template <class T> void fill_loop(Vector &vector, Vector &result, SelectionVector &sel, sel_t count) {
52 auto res = FlatVector::GetData<T>(result);
53 auto &result_nullmask = FlatVector::Nullmask(result);
54 if (vector.vector_type == VectorType::CONSTANT_VECTOR) {
55 auto data = ConstantVector::GetData<T>(vector);
56 if (ConstantVector::IsNull(vector)) {
57 for (idx_t i = 0; i < count; i++) {
58 result_nullmask[sel.get_index(i)] = true;
59 }
60 } else {
61 for (idx_t i = 0; i < count; i++) {
62 res[sel.get_index(i)] = *data;
63 }
64 }
65 } else {
66 VectorData vdata;
67 vector.Orrify(count, vdata);
68 auto data = (T *)vdata.data;
69 for (idx_t i = 0; i < count; i++) {
70 auto source_idx = vdata.sel->get_index(i);
71 auto res_idx = sel.get_index(i);
72
73 res[res_idx] = data[source_idx];
74 result_nullmask[res_idx] = (*vdata.nullmask)[source_idx];
75 }
76 }
77}
78
79template <class T>
80void case_loop(Vector &res_true, Vector &res_false, Vector &result, SelectionVector &tside, idx_t tcount,
81 SelectionVector &fside, idx_t fcount) {
82 fill_loop<T>(res_true, result, tside, tcount);
83 fill_loop<T>(res_false, result, fside, fcount);
84}
85
86void Case(Vector &res_true, Vector &res_false, Vector &result, SelectionVector &tside, idx_t tcount,
87 SelectionVector &fside, idx_t fcount) {
88 assert(res_true.type == res_false.type && res_true.type == result.type);
89
90 switch (result.type) {
91 case TypeId::BOOL:
92 case TypeId::INT8:
93 case_loop<int8_t>(res_true, res_false, result, tside, tcount, fside, fcount);
94 break;
95 case TypeId::INT16:
96 case_loop<int16_t>(res_true, res_false, result, tside, tcount, fside, fcount);
97 break;
98 case TypeId::INT32:
99 case_loop<int32_t>(res_true, res_false, result, tside, tcount, fside, fcount);
100 break;
101 case TypeId::INT64:
102 case_loop<int64_t>(res_true, res_false, result, tside, tcount, fside, fcount);
103 break;
104 case TypeId::FLOAT:
105 case_loop<float>(res_true, res_false, result, tside, tcount, fside, fcount);
106 break;
107 case TypeId::DOUBLE:
108 case_loop<double>(res_true, res_false, result, tside, tcount, fside, fcount);
109 break;
110 case TypeId::VARCHAR:
111 case_loop<string_t>(res_true, res_false, result, tside, tcount, fside, fcount);
112 StringVector::AddHeapReference(result, res_true);
113 StringVector::AddHeapReference(result, res_false);
114 break;
115 case TypeId::LIST: {
116 auto result_cc = make_unique<ChunkCollection>();
117 ListVector::SetEntry(result, move(result_cc));
118
119 auto &result_child = ListVector::GetEntry(result);
120 idx_t offset = 0;
121 if (ListVector::HasEntry(res_true)) {
122 auto &true_child = ListVector::GetEntry(res_true);
123 assert(true_child.types.size() == 1);
124 offset += true_child.count;
125 result_child.Append(true_child);
126 }
127 if (ListVector::HasEntry(res_false)) {
128 auto &false_child = ListVector::GetEntry(res_false);
129 assert(false_child.types.size() == 1);
130 result_child.Append(false_child);
131 }
132
133 // all the false offsets need to be incremented by true_child.count
134 fill_loop<list_entry_t>(res_true, result, tside, tcount);
135
136 // FIXME the nullmask here is likely borked
137 // TODO uuugly
138 VectorData fdata;
139 res_false.Orrify(fcount, fdata);
140
141 auto data = (list_entry_t *)fdata.data;
142 auto res = FlatVector::GetData<list_entry_t>(result);
143 auto &mask = FlatVector::Nullmask(result);
144
145 for (idx_t i = 0; i < fcount; i++) {
146 auto fidx = fdata.sel->get_index(i);
147 auto res_idx = fside.get_index(i);
148 auto list_entry = data[fidx];
149 list_entry.offset += offset;
150 res[res_idx] = list_entry;
151 mask[res_idx] = (*fdata.nullmask)[fidx];
152 }
153
154 result.Verify(tcount + fcount);
155 break;
156 }
157 default:
158 throw NotImplementedException("Unimplemented type for case expression: %s",
159 TypeIdToString(result.type).c_str());
160 }
161}
162