1 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
2 | #include "duckdb/execution/expression_executor.hpp" |
3 | #include "duckdb/planner/expression/bound_case_expression.hpp" |
4 | |
5 | namespace duckdb { |
6 | |
7 | struct CaseExpressionState : public ExpressionState { |
8 | CaseExpressionState(const Expression &expr, ExpressionExecutorState &root) |
9 | : ExpressionState(expr, root), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE) { |
10 | } |
11 | |
12 | SelectionVector true_sel; |
13 | SelectionVector false_sel; |
14 | }; |
15 | |
16 | unique_ptr<ExpressionState> ExpressionExecutor::InitializeState(const BoundCaseExpression &expr, |
17 | ExpressionExecutorState &root) { |
18 | auto result = make_uniq<CaseExpressionState>(args: expr, args&: root); |
19 | for (auto &case_check : expr.case_checks) { |
20 | result->AddChild(expr: case_check.when_expr.get()); |
21 | result->AddChild(expr: case_check.then_expr.get()); |
22 | } |
23 | result->AddChild(expr: expr.else_expr.get()); |
24 | result->Finalize(); |
25 | return std::move(result); |
26 | } |
27 | |
28 | void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionState *state_p, const SelectionVector *sel, |
29 | idx_t count, Vector &result) { |
30 | auto &state = state_p->Cast<CaseExpressionState>(); |
31 | |
32 | state.intermediate_chunk.Reset(); |
33 | |
34 | // first execute the check expression |
35 | auto current_true_sel = &state.true_sel; |
36 | auto current_false_sel = &state.false_sel; |
37 | auto current_sel = sel; |
38 | idx_t current_count = count; |
39 | for (idx_t i = 0; i < expr.case_checks.size(); i++) { |
40 | auto &case_check = expr.case_checks[i]; |
41 | auto &intermediate_result = state.intermediate_chunk.data[i * 2 + 1]; |
42 | auto check_state = state.child_states[i * 2].get(); |
43 | auto then_state = state.child_states[i * 2 + 1].get(); |
44 | |
45 | idx_t tcount = |
46 | Select(expr: *case_check.when_expr, state: check_state, sel: current_sel, count: current_count, true_sel: current_true_sel, false_sel: current_false_sel); |
47 | if (tcount == 0) { |
48 | // everything is false: do nothing |
49 | continue; |
50 | } |
51 | idx_t fcount = current_count - tcount; |
52 | if (fcount == 0 && current_count == count) { |
53 | // everything is true in the first CHECK statement |
54 | // we can skip the entire case and only execute the TRUE side |
55 | Execute(expr: *case_check.then_expr, state: then_state, sel, count, result); |
56 | return; |
57 | } else { |
58 | // we need to execute and then fill in the desired tuples in the result |
59 | Execute(expr: *case_check.then_expr, state: then_state, sel: current_true_sel, count: tcount, result&: intermediate_result); |
60 | FillSwitch(vector&: intermediate_result, result, sel: *current_true_sel, count: tcount); |
61 | } |
62 | // continue with the false tuples |
63 | current_sel = current_false_sel; |
64 | current_count = fcount; |
65 | if (fcount == 0) { |
66 | // everything is true: we are done |
67 | break; |
68 | } |
69 | } |
70 | if (current_count > 0) { |
71 | auto else_state = state.child_states.back().get(); |
72 | if (current_count == count) { |
73 | // everything was false, we can just evaluate the else expression directly |
74 | Execute(expr: *expr.else_expr, state: else_state, sel, count, result); |
75 | return; |
76 | } else { |
77 | auto &intermediate_result = state.intermediate_chunk.data[expr.case_checks.size() * 2]; |
78 | |
79 | D_ASSERT(current_sel); |
80 | Execute(expr: *expr.else_expr, state: else_state, sel: current_sel, count: current_count, result&: intermediate_result); |
81 | FillSwitch(vector&: intermediate_result, result, sel: *current_sel, count: current_count); |
82 | } |
83 | } |
84 | if (sel) { |
85 | result.Slice(sel: *sel, count); |
86 | } |
87 | } |
88 | |
89 | template <class T> |
90 | void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { |
91 | result.SetVectorType(VectorType::FLAT_VECTOR); |
92 | auto res = FlatVector::GetData<T>(result); |
93 | auto &result_mask = FlatVector::Validity(vector&: result); |
94 | if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
95 | auto data = ConstantVector::GetData<T>(vector); |
96 | if (ConstantVector::IsNull(vector)) { |
97 | for (idx_t i = 0; i < count; i++) { |
98 | result_mask.SetInvalid(sel.get_index(idx: i)); |
99 | } |
100 | } else { |
101 | for (idx_t i = 0; i < count; i++) { |
102 | res[sel.get_index(idx: i)] = *data; |
103 | } |
104 | } |
105 | } else { |
106 | UnifiedVectorFormat vdata; |
107 | vector.ToUnifiedFormat(count, data&: vdata); |
108 | auto data = UnifiedVectorFormat::GetData<T>(vdata); |
109 | for (idx_t i = 0; i < count; i++) { |
110 | auto source_idx = vdata.sel->get_index(idx: i); |
111 | auto res_idx = sel.get_index(idx: i); |
112 | |
113 | res[res_idx] = data[source_idx]; |
114 | result_mask.Set(row_idx: res_idx, valid: vdata.validity.RowIsValid(row_idx: source_idx)); |
115 | } |
116 | } |
117 | } |
118 | |
119 | void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { |
120 | result.SetVectorType(VectorType::FLAT_VECTOR); |
121 | auto &result_mask = FlatVector::Validity(vector&: result); |
122 | if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
123 | if (ConstantVector::IsNull(vector)) { |
124 | for (idx_t i = 0; i < count; i++) { |
125 | result_mask.SetInvalid(sel.get_index(idx: i)); |
126 | } |
127 | } |
128 | } else { |
129 | UnifiedVectorFormat vdata; |
130 | vector.ToUnifiedFormat(count, data&: vdata); |
131 | if (vdata.validity.AllValid()) { |
132 | return; |
133 | } |
134 | for (idx_t i = 0; i < count; i++) { |
135 | auto source_idx = vdata.sel->get_index(idx: i); |
136 | if (!vdata.validity.RowIsValid(row_idx: source_idx)) { |
137 | result_mask.SetInvalid(sel.get_index(idx: i)); |
138 | } |
139 | } |
140 | } |
141 | } |
142 | |
143 | void ExpressionExecutor::FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { |
144 | switch (result.GetType().InternalType()) { |
145 | case PhysicalType::BOOL: |
146 | case PhysicalType::INT8: |
147 | TemplatedFillLoop<int8_t>(vector, result, sel, count); |
148 | break; |
149 | case PhysicalType::INT16: |
150 | TemplatedFillLoop<int16_t>(vector, result, sel, count); |
151 | break; |
152 | case PhysicalType::INT32: |
153 | TemplatedFillLoop<int32_t>(vector, result, sel, count); |
154 | break; |
155 | case PhysicalType::INT64: |
156 | TemplatedFillLoop<int64_t>(vector, result, sel, count); |
157 | break; |
158 | case PhysicalType::UINT8: |
159 | TemplatedFillLoop<uint8_t>(vector, result, sel, count); |
160 | break; |
161 | case PhysicalType::UINT16: |
162 | TemplatedFillLoop<uint16_t>(vector, result, sel, count); |
163 | break; |
164 | case PhysicalType::UINT32: |
165 | TemplatedFillLoop<uint32_t>(vector, result, sel, count); |
166 | break; |
167 | case PhysicalType::UINT64: |
168 | TemplatedFillLoop<uint64_t>(vector, result, sel, count); |
169 | break; |
170 | case PhysicalType::INT128: |
171 | TemplatedFillLoop<hugeint_t>(vector, result, sel, count); |
172 | break; |
173 | case PhysicalType::FLOAT: |
174 | TemplatedFillLoop<float>(vector, result, sel, count); |
175 | break; |
176 | case PhysicalType::DOUBLE: |
177 | TemplatedFillLoop<double>(vector, result, sel, count); |
178 | break; |
179 | case PhysicalType::INTERVAL: |
180 | TemplatedFillLoop<interval_t>(vector, result, sel, count); |
181 | break; |
182 | case PhysicalType::VARCHAR: |
183 | TemplatedFillLoop<string_t>(vector, result, sel, count); |
184 | StringVector::AddHeapReference(vector&: result, other&: vector); |
185 | break; |
186 | case PhysicalType::STRUCT: { |
187 | auto &vector_entries = StructVector::GetEntries(vector); |
188 | auto &result_entries = StructVector::GetEntries(vector&: result); |
189 | ValidityFillLoop(vector, result, sel, count); |
190 | D_ASSERT(vector_entries.size() == result_entries.size()); |
191 | for (idx_t i = 0; i < vector_entries.size(); i++) { |
192 | FillSwitch(vector&: *vector_entries[i], result&: *result_entries[i], sel, count); |
193 | } |
194 | break; |
195 | } |
196 | case PhysicalType::LIST: { |
197 | idx_t offset = ListVector::GetListSize(vector: result); |
198 | auto &list_child = ListVector::GetEntry(vector); |
199 | ListVector::Append(target&: result, source: list_child, source_size: ListVector::GetListSize(vector)); |
200 | |
201 | // all the false offsets need to be incremented by true_child.count |
202 | TemplatedFillLoop<list_entry_t>(vector, result, sel, count); |
203 | if (offset == 0) { |
204 | break; |
205 | } |
206 | |
207 | auto result_data = FlatVector::GetData<list_entry_t>(vector&: result); |
208 | for (idx_t i = 0; i < count; i++) { |
209 | auto result_idx = sel.get_index(idx: i); |
210 | result_data[result_idx].offset += offset; |
211 | } |
212 | |
213 | Vector::Verify(vector&: result, sel, count); |
214 | break; |
215 | } |
216 | default: |
217 | throw NotImplementedException("Unimplemented type for case expression: %s" , result.GetType().ToString()); |
218 | } |
219 | } |
220 | |
221 | } // namespace duckdb |
222 | |