1 | #include "duckdb/function/cast/default_casts.hpp" |
2 | #include "duckdb/function/cast/vector_cast_helpers.hpp" |
3 | #include "duckdb/common/pair.hpp" |
4 | #include "duckdb/common/vector.hpp" |
5 | #include "duckdb/function/scalar/nested_functions.hpp" |
6 | #include "duckdb/function/cast/bound_cast_data.hpp" |
7 | |
8 | namespace duckdb { |
9 | |
10 | template <class T> |
11 | bool StringEnumCastLoop(const string_t *source_data, ValidityMask &source_mask, const LogicalType &source_type, |
12 | T *result_data, ValidityMask &result_mask, const LogicalType &result_type, idx_t count, |
13 | string *error_message, const SelectionVector *sel) { |
14 | bool all_converted = true; |
15 | for (idx_t i = 0; i < count; i++) { |
16 | idx_t source_idx = i; |
17 | if (sel) { |
18 | source_idx = sel->get_index(idx: i); |
19 | } |
20 | if (source_mask.RowIsValid(row_idx: source_idx)) { |
21 | auto pos = EnumType::GetPos(type: result_type, key: source_data[source_idx]); |
22 | if (pos == -1) { |
23 | result_data[i] = |
24 | HandleVectorCastError::Operation<T>(CastExceptionText<string_t, T>(source_data[source_idx]), |
25 | result_mask, i, error_message, all_converted); |
26 | } else { |
27 | result_data[i] = pos; |
28 | } |
29 | } else { |
30 | result_mask.SetInvalid(i); |
31 | } |
32 | } |
33 | return all_converted; |
34 | } |
35 | |
36 | template <class T> |
37 | bool StringEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
38 | D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); |
39 | auto enum_name = EnumType::GetTypeName(type: result.GetType()); |
40 | switch (source.GetVectorType()) { |
41 | case VectorType::CONSTANT_VECTOR: { |
42 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
43 | |
44 | auto source_data = ConstantVector::GetData<string_t>(vector&: source); |
45 | auto source_mask = ConstantVector::Validity(vector&: source); |
46 | auto result_data = ConstantVector::GetData<T>(result); |
47 | auto &result_mask = ConstantVector::Validity(vector&: result); |
48 | |
49 | return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, |
50 | result.GetType(), 1, parameters.error_message, nullptr); |
51 | } |
52 | default: { |
53 | UnifiedVectorFormat vdata; |
54 | source.ToUnifiedFormat(count, data&: vdata); |
55 | |
56 | result.SetVectorType(VectorType::FLAT_VECTOR); |
57 | |
58 | auto source_data = UnifiedVectorFormat::GetData<string_t>(format: vdata); |
59 | auto source_sel = vdata.sel; |
60 | auto source_mask = vdata.validity; |
61 | auto result_data = FlatVector::GetData<T>(result); |
62 | auto &result_mask = FlatVector::Validity(vector&: result); |
63 | |
64 | return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, |
65 | result.GetType(), count, parameters.error_message, source_sel); |
66 | } |
67 | } |
68 | } |
69 | |
70 | static BoundCastInfo VectorStringCastNumericSwitch(BindCastInput &input, const LogicalType &source, |
71 | const LogicalType &target) { |
72 | // now switch on the result type |
73 | switch (target.id()) { |
74 | case LogicalTypeId::ENUM: { |
75 | switch (target.InternalType()) { |
76 | case PhysicalType::UINT8: |
77 | return StringEnumCast<uint8_t>; |
78 | case PhysicalType::UINT16: |
79 | return StringEnumCast<uint16_t>; |
80 | case PhysicalType::UINT32: |
81 | return StringEnumCast<uint32_t>; |
82 | default: |
83 | throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types" ); |
84 | } |
85 | } |
86 | case LogicalTypeId::BOOLEAN: |
87 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, bool, duckdb::TryCast>); |
88 | case LogicalTypeId::TINYINT: |
89 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int8_t, duckdb::TryCast>); |
90 | case LogicalTypeId::SMALLINT: |
91 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int16_t, duckdb::TryCast>); |
92 | case LogicalTypeId::INTEGER: |
93 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int32_t, duckdb::TryCast>); |
94 | case LogicalTypeId::BIGINT: |
95 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int64_t, duckdb::TryCast>); |
96 | case LogicalTypeId::UTINYINT: |
97 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint8_t, duckdb::TryCast>); |
98 | case LogicalTypeId::USMALLINT: |
99 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint16_t, duckdb::TryCast>); |
100 | case LogicalTypeId::UINTEGER: |
101 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint32_t, duckdb::TryCast>); |
102 | case LogicalTypeId::UBIGINT: |
103 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint64_t, duckdb::TryCast>); |
104 | case LogicalTypeId::HUGEINT: |
105 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, hugeint_t, duckdb::TryCast>); |
106 | case LogicalTypeId::FLOAT: |
107 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, float, duckdb::TryCast>); |
108 | case LogicalTypeId::DOUBLE: |
109 | return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, double, duckdb::TryCast>); |
110 | case LogicalTypeId::INTERVAL: |
111 | return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, interval_t, duckdb::TryCastErrorMessage>); |
112 | case LogicalTypeId::DECIMAL: |
113 | return BoundCastInfo(&VectorCastHelpers::ToDecimalCast<string_t>); |
114 | default: |
115 | return DefaultCasts::TryVectorNullCast; |
116 | } |
117 | } |
118 | |
119 | //===--------------------------------------------------------------------===// |
120 | // string -> list casting |
121 | //===--------------------------------------------------------------------===// |
122 | bool VectorStringToList::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, |
123 | Vector &result, ValidityMask &result_mask, idx_t count, |
124 | CastParameters ¶meters, const SelectionVector *sel) { |
125 | idx_t total_list_size = 0; |
126 | for (idx_t i = 0; i < count; i++) { |
127 | idx_t idx = i; |
128 | if (sel) { |
129 | idx = sel->get_index(idx: i); |
130 | } |
131 | if (!source_mask.RowIsValid(row_idx: idx)) { |
132 | continue; |
133 | } |
134 | total_list_size += VectorStringToList::CountPartsList(input: source_data[idx]); |
135 | } |
136 | |
137 | Vector varchar_vector(LogicalType::VARCHAR, total_list_size); |
138 | |
139 | ListVector::Reserve(vec&: result, required_capacity: total_list_size); |
140 | ListVector::SetListSize(vec&: result, size: total_list_size); |
141 | |
142 | auto list_data = ListVector::GetData(v&: result); |
143 | auto child_data = FlatVector::GetData<string_t>(vector&: varchar_vector); |
144 | |
145 | bool all_converted = true; |
146 | idx_t total = 0; |
147 | for (idx_t i = 0; i < count; i++) { |
148 | idx_t idx = i; |
149 | if (sel) { |
150 | idx = sel->get_index(idx: i); |
151 | } |
152 | if (!source_mask.RowIsValid(row_idx: idx)) { |
153 | result_mask.SetInvalid(i); |
154 | continue; |
155 | } |
156 | |
157 | list_data[i].offset = total; |
158 | if (!VectorStringToList::SplitStringList(input: source_data[idx], child_data, child_start&: total, child&: varchar_vector)) { |
159 | string text = "Type VARCHAR with value '" + source_data[idx].GetString() + |
160 | "' can't be cast to the destination type LIST" ; |
161 | HandleVectorCastError::Operation<string_t>(error_message: text, mask&: result_mask, idx, error_message_ptr: parameters.error_message, all_converted); |
162 | } |
163 | list_data[i].length = total - list_data[i].offset; // length is the amount of parts coming from this string |
164 | } |
165 | D_ASSERT(total_list_size == total); |
166 | |
167 | auto &result_child = ListVector::GetEntry(vector&: result); |
168 | auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>(); |
169 | CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); |
170 | return cast_data.child_cast_info.function(varchar_vector, result_child, total_list_size, child_parameters) && |
171 | all_converted; |
172 | } |
173 | |
174 | static LogicalType InitVarcharStructType(const LogicalType &target) { |
175 | child_list_t<LogicalType> child_types; |
176 | for (auto &child : StructType::GetChildTypes(type: target)) { |
177 | child_types.push_back(x: make_pair(x: child.first, y: LogicalType::VARCHAR)); |
178 | } |
179 | |
180 | return LogicalType::STRUCT(children: child_types); |
181 | } |
182 | |
183 | //===--------------------------------------------------------------------===// |
184 | // string -> struct casting |
185 | //===--------------------------------------------------------------------===// |
186 | bool VectorStringToStruct::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, |
187 | Vector &result, ValidityMask &result_mask, idx_t count, |
188 | CastParameters ¶meters, const SelectionVector *sel) { |
189 | auto varchar_struct_type = InitVarcharStructType(target: result.GetType()); |
190 | Vector varchar_vector(varchar_struct_type, count); |
191 | auto &child_vectors = StructVector::GetEntries(vector&: varchar_vector); |
192 | auto &result_children = StructVector::GetEntries(vector&: result); |
193 | |
194 | string_map_t<idx_t> child_names; |
195 | vector<ValidityMask *> child_masks; |
196 | for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { |
197 | child_names.insert(x: {StructType::GetChildName(type: result.GetType(), index: child_idx), child_idx}); |
198 | child_masks.emplace_back(args: &FlatVector::Validity(vector&: *child_vectors[child_idx])); |
199 | child_masks[child_idx]->SetAllInvalid(count); |
200 | } |
201 | |
202 | bool all_converted = true; |
203 | for (idx_t i = 0; i < count; i++) { |
204 | idx_t idx = i; |
205 | if (sel) { |
206 | idx = sel->get_index(idx: i); |
207 | } |
208 | if (!source_mask.RowIsValid(row_idx: idx)) { |
209 | result_mask.SetInvalid(i); |
210 | continue; |
211 | } |
212 | if (!VectorStringToStruct::SplitStruct(input: source_data[idx], varchar_vectors&: child_vectors, row_idx&: i, child_names, child_masks)) { |
213 | string text = "Type VARCHAR with value '" + source_data[idx].GetString() + |
214 | "' can't be cast to the destination type STRUCT" ; |
215 | for (auto &child_mask : child_masks) { |
216 | child_mask->SetInvalid(idx); // some values may have already been found and set valid |
217 | } |
218 | HandleVectorCastError::Operation<string_t>(error_message: text, mask&: result_mask, idx, error_message_ptr: parameters.error_message, all_converted); |
219 | } |
220 | } |
221 | |
222 | auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>(); |
223 | auto &lstate = parameters.local_state->Cast<StructCastLocalState>(); |
224 | D_ASSERT(cast_data.child_cast_info.size() == result_children.size()); |
225 | |
226 | for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { |
227 | auto &child_varchar_vector = *child_vectors[child_idx]; |
228 | auto &result_child_vector = *result_children[child_idx]; |
229 | auto &child_cast_info = cast_data.child_cast_info[child_idx]; |
230 | CastParameters child_parameters(parameters, child_cast_info.cast_data, lstate.local_states[child_idx]); |
231 | if (!child_cast_info.function(child_varchar_vector, result_child_vector, count, child_parameters)) { |
232 | all_converted = false; |
233 | } |
234 | } |
235 | return all_converted; |
236 | } |
237 | |
238 | //===--------------------------------------------------------------------===// |
239 | // string -> map casting |
240 | //===--------------------------------------------------------------------===// |
241 | unique_ptr<FunctionLocalState> InitMapCastLocalState(CastLocalStateParameters ¶meters) { |
242 | auto &cast_data = parameters.cast_data->Cast<MapBoundCastData>(); |
243 | auto result = make_uniq<MapCastLocalState>(); |
244 | |
245 | if (cast_data.key_cast.init_local_state) { |
246 | CastLocalStateParameters child_params(parameters, cast_data.key_cast.cast_data); |
247 | result->key_state = cast_data.key_cast.init_local_state(child_params); |
248 | } |
249 | if (cast_data.value_cast.init_local_state) { |
250 | CastLocalStateParameters child_params(parameters, cast_data.value_cast.cast_data); |
251 | result->value_state = cast_data.value_cast.init_local_state(child_params); |
252 | } |
253 | return std::move(result); |
254 | } |
255 | |
256 | bool VectorStringToMap::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, |
257 | Vector &result, ValidityMask &result_mask, idx_t count, |
258 | CastParameters ¶meters, const SelectionVector *sel) { |
259 | idx_t total_elements = 0; |
260 | for (idx_t i = 0; i < count; i++) { |
261 | idx_t idx = i; |
262 | if (sel) { |
263 | idx = sel->get_index(idx: i); |
264 | } |
265 | if (!source_mask.RowIsValid(row_idx: idx)) { |
266 | continue; |
267 | } |
268 | total_elements += (VectorStringToMap::CountPartsMap(input: source_data[idx]) + 1) / 2; |
269 | } |
270 | |
271 | Vector varchar_key_vector(LogicalType::VARCHAR, total_elements); |
272 | Vector varchar_val_vector(LogicalType::VARCHAR, total_elements); |
273 | auto child_key_data = FlatVector::GetData<string_t>(vector&: varchar_key_vector); |
274 | auto child_val_data = FlatVector::GetData<string_t>(vector&: varchar_val_vector); |
275 | |
276 | ListVector::Reserve(vec&: result, required_capacity: total_elements); |
277 | ListVector::SetListSize(vec&: result, size: total_elements); |
278 | auto list_data = ListVector::GetData(v&: result); |
279 | |
280 | bool all_converted = true; |
281 | idx_t total = 0; |
282 | for (idx_t i = 0; i < count; i++) { |
283 | idx_t idx = i; |
284 | if (sel) { |
285 | idx = sel->get_index(idx: i); |
286 | } |
287 | if (!source_mask.RowIsValid(row_idx: idx)) { |
288 | result_mask.SetInvalid(idx); |
289 | continue; |
290 | } |
291 | |
292 | list_data[i].offset = total; |
293 | if (!VectorStringToMap::SplitStringMap(input: source_data[idx], child_key_data, child_val_data, child_start&: total, |
294 | varchar_key&: varchar_key_vector, varchar_val&: varchar_val_vector)) { |
295 | string text = "Type VARCHAR with value '" + source_data[idx].GetString() + |
296 | "' can't be cast to the destination type MAP" ; |
297 | FlatVector::SetNull(vector&: result, idx, is_null: true); |
298 | HandleVectorCastError::Operation<string_t>(error_message: text, mask&: result_mask, idx, error_message_ptr: parameters.error_message, all_converted); |
299 | } |
300 | list_data[i].length = total - list_data[i].offset; |
301 | } |
302 | D_ASSERT(total_elements == total); |
303 | |
304 | auto &result_key_child = MapVector::GetKeys(vector&: result); |
305 | auto &result_val_child = MapVector::GetValues(vector&: result); |
306 | auto &cast_data = parameters.cast_data->Cast<MapBoundCastData>(); |
307 | auto &lstate = parameters.local_state->Cast<MapCastLocalState>(); |
308 | |
309 | CastParameters key_params(parameters, cast_data.key_cast.cast_data, lstate.key_state); |
310 | if (!cast_data.key_cast.function(varchar_key_vector, result_key_child, total_elements, key_params)) { |
311 | all_converted = false; |
312 | } |
313 | CastParameters val_params(parameters, cast_data.value_cast.cast_data, lstate.value_state); |
314 | if (!cast_data.value_cast.function(varchar_val_vector, result_val_child, total_elements, val_params)) { |
315 | all_converted = false; |
316 | } |
317 | |
318 | auto &key_validity = FlatVector::Validity(vector&: result_key_child); |
319 | if (!all_converted) { |
320 | for (idx_t row_idx = 0; row_idx < count; row_idx++) { |
321 | if (!result_mask.RowIsValid(row_idx)) { |
322 | continue; |
323 | } |
324 | auto list = list_data[row_idx]; |
325 | for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { |
326 | auto idx = list.offset + list_idx; |
327 | if (!key_validity.RowIsValid(row_idx: idx)) { |
328 | result_mask.SetInvalid(row_idx); |
329 | } |
330 | } |
331 | } |
332 | } |
333 | MapVector::MapConversionVerify(vector&: result, count); |
334 | return all_converted; |
335 | } |
336 | |
337 | template <class T> |
338 | bool StringToNestedTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
339 | D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); |
340 | |
341 | switch (source.GetVectorType()) { |
342 | case VectorType::CONSTANT_VECTOR: { |
343 | auto source_data = ConstantVector::GetData<string_t>(vector&: source); |
344 | auto &source_mask = ConstantVector::Validity(vector&: source); |
345 | auto &result_mask = FlatVector::Validity(vector&: result); |
346 | auto ret = T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, 1, parameters, nullptr); |
347 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
348 | return ret; |
349 | } |
350 | default: { |
351 | UnifiedVectorFormat unified_source; |
352 | |
353 | source.ToUnifiedFormat(count, data&: unified_source); |
354 | auto source_sel = unified_source.sel; |
355 | auto source_data = UnifiedVectorFormat::GetData<string_t>(format: unified_source); |
356 | auto &source_mask = unified_source.validity; |
357 | auto &result_mask = FlatVector::Validity(vector&: result); |
358 | |
359 | return T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, count, parameters, |
360 | source_sel); |
361 | } |
362 | } |
363 | } |
364 | |
365 | BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const LogicalType &source, |
366 | const LogicalType &target) { |
367 | switch (target.id()) { |
368 | case LogicalTypeId::DATE: |
369 | return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, date_t, duckdb::TryCastErrorMessage>); |
370 | case LogicalTypeId::TIME: |
371 | case LogicalTypeId::TIME_TZ: |
372 | return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, dtime_t, duckdb::TryCastErrorMessage>); |
373 | case LogicalTypeId::TIMESTAMP: |
374 | case LogicalTypeId::TIMESTAMP_TZ: |
375 | return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, timestamp_t, duckdb::TryCastErrorMessage>); |
376 | case LogicalTypeId::TIMESTAMP_NS: |
377 | return BoundCastInfo( |
378 | &VectorCastHelpers::TryCastStrictLoop<string_t, timestamp_t, duckdb::TryCastToTimestampNS>); |
379 | case LogicalTypeId::TIMESTAMP_SEC: |
380 | return BoundCastInfo( |
381 | &VectorCastHelpers::TryCastStrictLoop<string_t, timestamp_t, duckdb::TryCastToTimestampSec>); |
382 | case LogicalTypeId::TIMESTAMP_MS: |
383 | return BoundCastInfo( |
384 | &VectorCastHelpers::TryCastStrictLoop<string_t, timestamp_t, duckdb::TryCastToTimestampMS>); |
385 | case LogicalTypeId::BLOB: |
386 | return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop<string_t, string_t, duckdb::TryCastToBlob>); |
387 | case LogicalTypeId::BIT: |
388 | return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop<string_t, string_t, duckdb::TryCastToBit>); |
389 | case LogicalTypeId::UUID: |
390 | return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop<string_t, hugeint_t, duckdb::TryCastToUUID>); |
391 | case LogicalTypeId::SQLNULL: |
392 | return &DefaultCasts::TryVectorNullCast; |
393 | case LogicalTypeId::VARCHAR: |
394 | return &DefaultCasts::ReinterpretCast; |
395 | case LogicalTypeId::LIST: |
396 | // the second argument allows for a secondary casting function to be passed in the CastParameters |
397 | return BoundCastInfo( |
398 | &StringToNestedTypeCast<VectorStringToList>, |
399 | ListBoundCastData::BindListToListCast(input, source: LogicalType::LIST(child: LogicalType::VARCHAR), target), |
400 | ListBoundCastData::InitListLocalState); |
401 | case LogicalTypeId::STRUCT: |
402 | return BoundCastInfo(&StringToNestedTypeCast<VectorStringToStruct>, |
403 | StructBoundCastData::BindStructToStructCast(input, source: InitVarcharStructType(target), target), |
404 | StructBoundCastData::InitStructCastLocalState); |
405 | case LogicalTypeId::MAP: |
406 | return BoundCastInfo(&StringToNestedTypeCast<VectorStringToMap>, |
407 | MapBoundCastData::BindMapToMapCast( |
408 | input, source: LogicalType::MAP(key: LogicalType::VARCHAR, value: LogicalType::VARCHAR), target), |
409 | InitMapCastLocalState); |
410 | default: |
411 | return VectorStringCastNumericSwitch(input, source, target); |
412 | } |
413 | } |
414 | |
415 | } // namespace duckdb |
416 | |