| 1 | #include "duckdb/common/vector_operations/vector_operations.hpp" | 
| 2 | #include "duckdb/execution/expression_executor.hpp" | 
| 3 | #include "duckdb/planner/expression/bound_case_expression.hpp" | 
| 4 |  | 
| 5 | namespace duckdb { | 
| 6 |  | 
| 7 | struct CaseExpressionState : public ExpressionState { | 
| 8 | 	CaseExpressionState(const Expression &expr, ExpressionExecutorState &root) | 
| 9 | 	    : ExpressionState(expr, root), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE) { | 
| 10 | 	} | 
| 11 |  | 
| 12 | 	SelectionVector true_sel; | 
| 13 | 	SelectionVector false_sel; | 
| 14 | }; | 
| 15 |  | 
| 16 | unique_ptr<ExpressionState> ExpressionExecutor::InitializeState(const BoundCaseExpression &expr, | 
| 17 |                                                                 ExpressionExecutorState &root) { | 
| 18 | 	auto result = make_uniq<CaseExpressionState>(args: expr, args&: root); | 
| 19 | 	for (auto &case_check : expr.case_checks) { | 
| 20 | 		result->AddChild(expr: case_check.when_expr.get()); | 
| 21 | 		result->AddChild(expr: case_check.then_expr.get()); | 
| 22 | 	} | 
| 23 | 	result->AddChild(expr: expr.else_expr.get()); | 
| 24 | 	result->Finalize(); | 
| 25 | 	return std::move(result); | 
| 26 | } | 
| 27 |  | 
| 28 | void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionState *state_p, const SelectionVector *sel, | 
| 29 |                                  idx_t count, Vector &result) { | 
| 30 | 	auto &state = state_p->Cast<CaseExpressionState>(); | 
| 31 |  | 
| 32 | 	state.intermediate_chunk.Reset(); | 
| 33 |  | 
| 34 | 	// first execute the check expression | 
| 35 | 	auto current_true_sel = &state.true_sel; | 
| 36 | 	auto current_false_sel = &state.false_sel; | 
| 37 | 	auto current_sel = sel; | 
| 38 | 	idx_t current_count = count; | 
| 39 | 	for (idx_t i = 0; i < expr.case_checks.size(); i++) { | 
| 40 | 		auto &case_check = expr.case_checks[i]; | 
| 41 | 		auto &intermediate_result = state.intermediate_chunk.data[i * 2 + 1]; | 
| 42 | 		auto check_state = state.child_states[i * 2].get(); | 
| 43 | 		auto then_state = state.child_states[i * 2 + 1].get(); | 
| 44 |  | 
| 45 | 		idx_t tcount = | 
| 46 | 		    Select(expr: *case_check.when_expr, state: check_state, sel: current_sel, count: current_count, true_sel: current_true_sel, false_sel: current_false_sel); | 
| 47 | 		if (tcount == 0) { | 
| 48 | 			// everything is false: do nothing | 
| 49 | 			continue; | 
| 50 | 		} | 
| 51 | 		idx_t fcount = current_count - tcount; | 
| 52 | 		if (fcount == 0 && current_count == count) { | 
| 53 | 			// everything is true in the first CHECK statement | 
| 54 | 			// we can skip the entire case and only execute the TRUE side | 
| 55 | 			Execute(expr: *case_check.then_expr, state: then_state, sel, count, result); | 
| 56 | 			return; | 
| 57 | 		} else { | 
| 58 | 			// we need to execute and then fill in the desired tuples in the result | 
| 59 | 			Execute(expr: *case_check.then_expr, state: then_state, sel: current_true_sel, count: tcount, result&: intermediate_result); | 
| 60 | 			FillSwitch(vector&: intermediate_result, result, sel: *current_true_sel, count: tcount); | 
| 61 | 		} | 
| 62 | 		// continue with the false tuples | 
| 63 | 		current_sel = current_false_sel; | 
| 64 | 		current_count = fcount; | 
| 65 | 		if (fcount == 0) { | 
| 66 | 			// everything is true: we are done | 
| 67 | 			break; | 
| 68 | 		} | 
| 69 | 	} | 
| 70 | 	if (current_count > 0) { | 
| 71 | 		auto else_state = state.child_states.back().get(); | 
| 72 | 		if (current_count == count) { | 
| 73 | 			// everything was false, we can just evaluate the else expression directly | 
| 74 | 			Execute(expr: *expr.else_expr, state: else_state, sel, count, result); | 
| 75 | 			return; | 
| 76 | 		} else { | 
| 77 | 			auto &intermediate_result = state.intermediate_chunk.data[expr.case_checks.size() * 2]; | 
| 78 |  | 
| 79 | 			D_ASSERT(current_sel); | 
| 80 | 			Execute(expr: *expr.else_expr, state: else_state, sel: current_sel, count: current_count, result&: intermediate_result); | 
| 81 | 			FillSwitch(vector&: intermediate_result, result, sel: *current_sel, count: current_count); | 
| 82 | 		} | 
| 83 | 	} | 
| 84 | 	if (sel) { | 
| 85 | 		result.Slice(sel: *sel, count); | 
| 86 | 	} | 
| 87 | } | 
| 88 |  | 
| 89 | template <class T> | 
| 90 | void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { | 
| 91 | 	result.SetVectorType(VectorType::FLAT_VECTOR); | 
| 92 | 	auto res = FlatVector::GetData<T>(result); | 
| 93 | 	auto &result_mask = FlatVector::Validity(vector&: result); | 
| 94 | 	if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { | 
| 95 | 		auto data = ConstantVector::GetData<T>(vector); | 
| 96 | 		if (ConstantVector::IsNull(vector)) { | 
| 97 | 			for (idx_t i = 0; i < count; i++) { | 
| 98 | 				result_mask.SetInvalid(sel.get_index(idx: i)); | 
| 99 | 			} | 
| 100 | 		} else { | 
| 101 | 			for (idx_t i = 0; i < count; i++) { | 
| 102 | 				res[sel.get_index(idx: i)] = *data; | 
| 103 | 			} | 
| 104 | 		} | 
| 105 | 	} else { | 
| 106 | 		UnifiedVectorFormat vdata; | 
| 107 | 		vector.ToUnifiedFormat(count, data&: vdata); | 
| 108 | 		auto data = UnifiedVectorFormat::GetData<T>(vdata); | 
| 109 | 		for (idx_t i = 0; i < count; i++) { | 
| 110 | 			auto source_idx = vdata.sel->get_index(idx: i); | 
| 111 | 			auto res_idx = sel.get_index(idx: i); | 
| 112 |  | 
| 113 | 			res[res_idx] = data[source_idx]; | 
| 114 | 			result_mask.Set(row_idx: res_idx, valid: vdata.validity.RowIsValid(row_idx: source_idx)); | 
| 115 | 		} | 
| 116 | 	} | 
| 117 | } | 
| 118 |  | 
| 119 | void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { | 
| 120 | 	result.SetVectorType(VectorType::FLAT_VECTOR); | 
| 121 | 	auto &result_mask = FlatVector::Validity(vector&: result); | 
| 122 | 	if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { | 
| 123 | 		if (ConstantVector::IsNull(vector)) { | 
| 124 | 			for (idx_t i = 0; i < count; i++) { | 
| 125 | 				result_mask.SetInvalid(sel.get_index(idx: i)); | 
| 126 | 			} | 
| 127 | 		} | 
| 128 | 	} else { | 
| 129 | 		UnifiedVectorFormat vdata; | 
| 130 | 		vector.ToUnifiedFormat(count, data&: vdata); | 
| 131 | 		if (vdata.validity.AllValid()) { | 
| 132 | 			return; | 
| 133 | 		} | 
| 134 | 		for (idx_t i = 0; i < count; i++) { | 
| 135 | 			auto source_idx = vdata.sel->get_index(idx: i); | 
| 136 | 			if (!vdata.validity.RowIsValid(row_idx: source_idx)) { | 
| 137 | 				result_mask.SetInvalid(sel.get_index(idx: i)); | 
| 138 | 			} | 
| 139 | 		} | 
| 140 | 	} | 
| 141 | } | 
| 142 |  | 
| 143 | void ExpressionExecutor::FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { | 
| 144 | 	switch (result.GetType().InternalType()) { | 
| 145 | 	case PhysicalType::BOOL: | 
| 146 | 	case PhysicalType::INT8: | 
| 147 | 		TemplatedFillLoop<int8_t>(vector, result, sel, count); | 
| 148 | 		break; | 
| 149 | 	case PhysicalType::INT16: | 
| 150 | 		TemplatedFillLoop<int16_t>(vector, result, sel, count); | 
| 151 | 		break; | 
| 152 | 	case PhysicalType::INT32: | 
| 153 | 		TemplatedFillLoop<int32_t>(vector, result, sel, count); | 
| 154 | 		break; | 
| 155 | 	case PhysicalType::INT64: | 
| 156 | 		TemplatedFillLoop<int64_t>(vector, result, sel, count); | 
| 157 | 		break; | 
| 158 | 	case PhysicalType::UINT8: | 
| 159 | 		TemplatedFillLoop<uint8_t>(vector, result, sel, count); | 
| 160 | 		break; | 
| 161 | 	case PhysicalType::UINT16: | 
| 162 | 		TemplatedFillLoop<uint16_t>(vector, result, sel, count); | 
| 163 | 		break; | 
| 164 | 	case PhysicalType::UINT32: | 
| 165 | 		TemplatedFillLoop<uint32_t>(vector, result, sel, count); | 
| 166 | 		break; | 
| 167 | 	case PhysicalType::UINT64: | 
| 168 | 		TemplatedFillLoop<uint64_t>(vector, result, sel, count); | 
| 169 | 		break; | 
| 170 | 	case PhysicalType::INT128: | 
| 171 | 		TemplatedFillLoop<hugeint_t>(vector, result, sel, count); | 
| 172 | 		break; | 
| 173 | 	case PhysicalType::FLOAT: | 
| 174 | 		TemplatedFillLoop<float>(vector, result, sel, count); | 
| 175 | 		break; | 
| 176 | 	case PhysicalType::DOUBLE: | 
| 177 | 		TemplatedFillLoop<double>(vector, result, sel, count); | 
| 178 | 		break; | 
| 179 | 	case PhysicalType::INTERVAL: | 
| 180 | 		TemplatedFillLoop<interval_t>(vector, result, sel, count); | 
| 181 | 		break; | 
| 182 | 	case PhysicalType::VARCHAR: | 
| 183 | 		TemplatedFillLoop<string_t>(vector, result, sel, count); | 
| 184 | 		StringVector::AddHeapReference(vector&: result, other&: vector); | 
| 185 | 		break; | 
| 186 | 	case PhysicalType::STRUCT: { | 
| 187 | 		auto &vector_entries = StructVector::GetEntries(vector); | 
| 188 | 		auto &result_entries = StructVector::GetEntries(vector&: result); | 
| 189 | 		ValidityFillLoop(vector, result, sel, count); | 
| 190 | 		D_ASSERT(vector_entries.size() == result_entries.size()); | 
| 191 | 		for (idx_t i = 0; i < vector_entries.size(); i++) { | 
| 192 | 			FillSwitch(vector&: *vector_entries[i], result&: *result_entries[i], sel, count); | 
| 193 | 		} | 
| 194 | 		break; | 
| 195 | 	} | 
| 196 | 	case PhysicalType::LIST: { | 
| 197 | 		idx_t offset = ListVector::GetListSize(vector: result); | 
| 198 | 		auto &list_child = ListVector::GetEntry(vector); | 
| 199 | 		ListVector::Append(target&: result, source: list_child, source_size: ListVector::GetListSize(vector)); | 
| 200 |  | 
| 201 | 		// all the false offsets need to be incremented by true_child.count | 
| 202 | 		TemplatedFillLoop<list_entry_t>(vector, result, sel, count); | 
| 203 | 		if (offset == 0) { | 
| 204 | 			break; | 
| 205 | 		} | 
| 206 |  | 
| 207 | 		auto result_data = FlatVector::GetData<list_entry_t>(vector&: result); | 
| 208 | 		for (idx_t i = 0; i < count; i++) { | 
| 209 | 			auto result_idx = sel.get_index(idx: i); | 
| 210 | 			result_data[result_idx].offset += offset; | 
| 211 | 		} | 
| 212 |  | 
| 213 | 		Vector::Verify(vector&: result, sel, count); | 
| 214 | 		break; | 
| 215 | 	} | 
| 216 | 	default: | 
| 217 | 		throw NotImplementedException("Unimplemented type for case expression: %s" , result.GetType().ToString()); | 
| 218 | 	} | 
| 219 | } | 
| 220 |  | 
| 221 | } // namespace duckdb | 
| 222 |  |