1#include "duckdb/common/vector_operations/vector_operations.hpp"
2#include "duckdb/execution/expression_executor.hpp"
3#include "duckdb/planner/expression/bound_case_expression.hpp"
4
5namespace duckdb {
6
7struct CaseExpressionState : public ExpressionState {
8 CaseExpressionState(const Expression &expr, ExpressionExecutorState &root)
9 : ExpressionState(expr, root), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE) {
10 }
11
12 SelectionVector true_sel;
13 SelectionVector false_sel;
14};
15
16unique_ptr<ExpressionState> ExpressionExecutor::InitializeState(const BoundCaseExpression &expr,
17 ExpressionExecutorState &root) {
18 auto result = make_uniq<CaseExpressionState>(args: expr, args&: root);
19 for (auto &case_check : expr.case_checks) {
20 result->AddChild(expr: case_check.when_expr.get());
21 result->AddChild(expr: case_check.then_expr.get());
22 }
23 result->AddChild(expr: expr.else_expr.get());
24 result->Finalize();
25 return std::move(result);
26}
27
28void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionState *state_p, const SelectionVector *sel,
29 idx_t count, Vector &result) {
30 auto &state = state_p->Cast<CaseExpressionState>();
31
32 state.intermediate_chunk.Reset();
33
34 // first execute the check expression
35 auto current_true_sel = &state.true_sel;
36 auto current_false_sel = &state.false_sel;
37 auto current_sel = sel;
38 idx_t current_count = count;
39 for (idx_t i = 0; i < expr.case_checks.size(); i++) {
40 auto &case_check = expr.case_checks[i];
41 auto &intermediate_result = state.intermediate_chunk.data[i * 2 + 1];
42 auto check_state = state.child_states[i * 2].get();
43 auto then_state = state.child_states[i * 2 + 1].get();
44
45 idx_t tcount =
46 Select(expr: *case_check.when_expr, state: check_state, sel: current_sel, count: current_count, true_sel: current_true_sel, false_sel: current_false_sel);
47 if (tcount == 0) {
48 // everything is false: do nothing
49 continue;
50 }
51 idx_t fcount = current_count - tcount;
52 if (fcount == 0 && current_count == count) {
53 // everything is true in the first CHECK statement
54 // we can skip the entire case and only execute the TRUE side
55 Execute(expr: *case_check.then_expr, state: then_state, sel, count, result);
56 return;
57 } else {
58 // we need to execute and then fill in the desired tuples in the result
59 Execute(expr: *case_check.then_expr, state: then_state, sel: current_true_sel, count: tcount, result&: intermediate_result);
60 FillSwitch(vector&: intermediate_result, result, sel: *current_true_sel, count: tcount);
61 }
62 // continue with the false tuples
63 current_sel = current_false_sel;
64 current_count = fcount;
65 if (fcount == 0) {
66 // everything is true: we are done
67 break;
68 }
69 }
70 if (current_count > 0) {
71 auto else_state = state.child_states.back().get();
72 if (current_count == count) {
73 // everything was false, we can just evaluate the else expression directly
74 Execute(expr: *expr.else_expr, state: else_state, sel, count, result);
75 return;
76 } else {
77 auto &intermediate_result = state.intermediate_chunk.data[expr.case_checks.size() * 2];
78
79 D_ASSERT(current_sel);
80 Execute(expr: *expr.else_expr, state: else_state, sel: current_sel, count: current_count, result&: intermediate_result);
81 FillSwitch(vector&: intermediate_result, result, sel: *current_sel, count: current_count);
82 }
83 }
84 if (sel) {
85 result.Slice(sel: *sel, count);
86 }
87}
88
89template <class T>
90void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) {
91 result.SetVectorType(VectorType::FLAT_VECTOR);
92 auto res = FlatVector::GetData<T>(result);
93 auto &result_mask = FlatVector::Validity(vector&: result);
94 if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) {
95 auto data = ConstantVector::GetData<T>(vector);
96 if (ConstantVector::IsNull(vector)) {
97 for (idx_t i = 0; i < count; i++) {
98 result_mask.SetInvalid(sel.get_index(idx: i));
99 }
100 } else {
101 for (idx_t i = 0; i < count; i++) {
102 res[sel.get_index(idx: i)] = *data;
103 }
104 }
105 } else {
106 UnifiedVectorFormat vdata;
107 vector.ToUnifiedFormat(count, data&: vdata);
108 auto data = UnifiedVectorFormat::GetData<T>(vdata);
109 for (idx_t i = 0; i < count; i++) {
110 auto source_idx = vdata.sel->get_index(idx: i);
111 auto res_idx = sel.get_index(idx: i);
112
113 res[res_idx] = data[source_idx];
114 result_mask.Set(row_idx: res_idx, valid: vdata.validity.RowIsValid(row_idx: source_idx));
115 }
116 }
117}
118
119void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) {
120 result.SetVectorType(VectorType::FLAT_VECTOR);
121 auto &result_mask = FlatVector::Validity(vector&: result);
122 if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) {
123 if (ConstantVector::IsNull(vector)) {
124 for (idx_t i = 0; i < count; i++) {
125 result_mask.SetInvalid(sel.get_index(idx: i));
126 }
127 }
128 } else {
129 UnifiedVectorFormat vdata;
130 vector.ToUnifiedFormat(count, data&: vdata);
131 if (vdata.validity.AllValid()) {
132 return;
133 }
134 for (idx_t i = 0; i < count; i++) {
135 auto source_idx = vdata.sel->get_index(idx: i);
136 if (!vdata.validity.RowIsValid(row_idx: source_idx)) {
137 result_mask.SetInvalid(sel.get_index(idx: i));
138 }
139 }
140 }
141}
142
143void ExpressionExecutor::FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) {
144 switch (result.GetType().InternalType()) {
145 case PhysicalType::BOOL:
146 case PhysicalType::INT8:
147 TemplatedFillLoop<int8_t>(vector, result, sel, count);
148 break;
149 case PhysicalType::INT16:
150 TemplatedFillLoop<int16_t>(vector, result, sel, count);
151 break;
152 case PhysicalType::INT32:
153 TemplatedFillLoop<int32_t>(vector, result, sel, count);
154 break;
155 case PhysicalType::INT64:
156 TemplatedFillLoop<int64_t>(vector, result, sel, count);
157 break;
158 case PhysicalType::UINT8:
159 TemplatedFillLoop<uint8_t>(vector, result, sel, count);
160 break;
161 case PhysicalType::UINT16:
162 TemplatedFillLoop<uint16_t>(vector, result, sel, count);
163 break;
164 case PhysicalType::UINT32:
165 TemplatedFillLoop<uint32_t>(vector, result, sel, count);
166 break;
167 case PhysicalType::UINT64:
168 TemplatedFillLoop<uint64_t>(vector, result, sel, count);
169 break;
170 case PhysicalType::INT128:
171 TemplatedFillLoop<hugeint_t>(vector, result, sel, count);
172 break;
173 case PhysicalType::FLOAT:
174 TemplatedFillLoop<float>(vector, result, sel, count);
175 break;
176 case PhysicalType::DOUBLE:
177 TemplatedFillLoop<double>(vector, result, sel, count);
178 break;
179 case PhysicalType::INTERVAL:
180 TemplatedFillLoop<interval_t>(vector, result, sel, count);
181 break;
182 case PhysicalType::VARCHAR:
183 TemplatedFillLoop<string_t>(vector, result, sel, count);
184 StringVector::AddHeapReference(vector&: result, other&: vector);
185 break;
186 case PhysicalType::STRUCT: {
187 auto &vector_entries = StructVector::GetEntries(vector);
188 auto &result_entries = StructVector::GetEntries(vector&: result);
189 ValidityFillLoop(vector, result, sel, count);
190 D_ASSERT(vector_entries.size() == result_entries.size());
191 for (idx_t i = 0; i < vector_entries.size(); i++) {
192 FillSwitch(vector&: *vector_entries[i], result&: *result_entries[i], sel, count);
193 }
194 break;
195 }
196 case PhysicalType::LIST: {
197 idx_t offset = ListVector::GetListSize(vector: result);
198 auto &list_child = ListVector::GetEntry(vector);
199 ListVector::Append(target&: result, source: list_child, source_size: ListVector::GetListSize(vector));
200
201 // all the false offsets need to be incremented by true_child.count
202 TemplatedFillLoop<list_entry_t>(vector, result, sel, count);
203 if (offset == 0) {
204 break;
205 }
206
207 auto result_data = FlatVector::GetData<list_entry_t>(vector&: result);
208 for (idx_t i = 0; i < count; i++) {
209 auto result_idx = sel.get_index(idx: i);
210 result_data[result_idx].offset += offset;
211 }
212
213 Vector::Verify(vector&: result, sel, count);
214 break;
215 }
216 default:
217 throw NotImplementedException("Unimplemented type for case expression: %s", result.GetType().ToString());
218 }
219}
220
221} // namespace duckdb
222