1#include "duckdb/function/cast/default_casts.hpp"
2#include "duckdb/function/cast/vector_cast_helpers.hpp"
3#include "duckdb/common/pair.hpp"
4#include "duckdb/common/vector.hpp"
5#include "duckdb/function/scalar/nested_functions.hpp"
6#include "duckdb/function/cast/bound_cast_data.hpp"
7
8namespace duckdb {
9
10template <class T>
11bool StringEnumCastLoop(const string_t *source_data, ValidityMask &source_mask, const LogicalType &source_type,
12 T *result_data, ValidityMask &result_mask, const LogicalType &result_type, idx_t count,
13 string *error_message, const SelectionVector *sel) {
14 bool all_converted = true;
15 for (idx_t i = 0; i < count; i++) {
16 idx_t source_idx = i;
17 if (sel) {
18 source_idx = sel->get_index(idx: i);
19 }
20 if (source_mask.RowIsValid(row_idx: source_idx)) {
21 auto pos = EnumType::GetPos(type: result_type, key: source_data[source_idx]);
22 if (pos == -1) {
23 result_data[i] =
24 HandleVectorCastError::Operation<T>(CastExceptionText<string_t, T>(source_data[source_idx]),
25 result_mask, i, error_message, all_converted);
26 } else {
27 result_data[i] = pos;
28 }
29 } else {
30 result_mask.SetInvalid(i);
31 }
32 }
33 return all_converted;
34}
35
36template <class T>
37bool StringEnumCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
38 D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR);
39 auto enum_name = EnumType::GetTypeName(type: result.GetType());
40 switch (source.GetVectorType()) {
41 case VectorType::CONSTANT_VECTOR: {
42 result.SetVectorType(VectorType::CONSTANT_VECTOR);
43
44 auto source_data = ConstantVector::GetData<string_t>(vector&: source);
45 auto source_mask = ConstantVector::Validity(vector&: source);
46 auto result_data = ConstantVector::GetData<T>(result);
47 auto &result_mask = ConstantVector::Validity(vector&: result);
48
49 return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask,
50 result.GetType(), 1, parameters.error_message, nullptr);
51 }
52 default: {
53 UnifiedVectorFormat vdata;
54 source.ToUnifiedFormat(count, data&: vdata);
55
56 result.SetVectorType(VectorType::FLAT_VECTOR);
57
58 auto source_data = UnifiedVectorFormat::GetData<string_t>(format: vdata);
59 auto source_sel = vdata.sel;
60 auto source_mask = vdata.validity;
61 auto result_data = FlatVector::GetData<T>(result);
62 auto &result_mask = FlatVector::Validity(vector&: result);
63
64 return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask,
65 result.GetType(), count, parameters.error_message, source_sel);
66 }
67 }
68}
69
70static BoundCastInfo VectorStringCastNumericSwitch(BindCastInput &input, const LogicalType &source,
71 const LogicalType &target) {
72 // now switch on the result type
73 switch (target.id()) {
74 case LogicalTypeId::ENUM: {
75 switch (target.InternalType()) {
76 case PhysicalType::UINT8:
77 return StringEnumCast<uint8_t>;
78 case PhysicalType::UINT16:
79 return StringEnumCast<uint16_t>;
80 case PhysicalType::UINT32:
81 return StringEnumCast<uint32_t>;
82 default:
83 throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types");
84 }
85 }
86 case LogicalTypeId::BOOLEAN:
87 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, bool, duckdb::TryCast>);
88 case LogicalTypeId::TINYINT:
89 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int8_t, duckdb::TryCast>);
90 case LogicalTypeId::SMALLINT:
91 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int16_t, duckdb::TryCast>);
92 case LogicalTypeId::INTEGER:
93 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int32_t, duckdb::TryCast>);
94 case LogicalTypeId::BIGINT:
95 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, int64_t, duckdb::TryCast>);
96 case LogicalTypeId::UTINYINT:
97 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint8_t, duckdb::TryCast>);
98 case LogicalTypeId::USMALLINT:
99 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint16_t, duckdb::TryCast>);
100 case LogicalTypeId::UINTEGER:
101 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint32_t, duckdb::TryCast>);
102 case LogicalTypeId::UBIGINT:
103 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, uint64_t, duckdb::TryCast>);
104 case LogicalTypeId::HUGEINT:
105 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, hugeint_t, duckdb::TryCast>);
106 case LogicalTypeId::FLOAT:
107 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, float, duckdb::TryCast>);
108 case LogicalTypeId::DOUBLE:
109 return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop<string_t, double, duckdb::TryCast>);
110 case LogicalTypeId::INTERVAL:
111 return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, interval_t, duckdb::TryCastErrorMessage>);
112 case LogicalTypeId::DECIMAL:
113 return BoundCastInfo(&VectorCastHelpers::ToDecimalCast<string_t>);
114 default:
115 return DefaultCasts::TryVectorNullCast;
116 }
117}
118
119//===--------------------------------------------------------------------===//
120// string -> list casting
121//===--------------------------------------------------------------------===//
122bool VectorStringToList::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask,
123 Vector &result, ValidityMask &result_mask, idx_t count,
124 CastParameters &parameters, const SelectionVector *sel) {
125 idx_t total_list_size = 0;
126 for (idx_t i = 0; i < count; i++) {
127 idx_t idx = i;
128 if (sel) {
129 idx = sel->get_index(idx: i);
130 }
131 if (!source_mask.RowIsValid(row_idx: idx)) {
132 continue;
133 }
134 total_list_size += VectorStringToList::CountPartsList(input: source_data[idx]);
135 }
136
137 Vector varchar_vector(LogicalType::VARCHAR, total_list_size);
138
139 ListVector::Reserve(vec&: result, required_capacity: total_list_size);
140 ListVector::SetListSize(vec&: result, size: total_list_size);
141
142 auto list_data = ListVector::GetData(v&: result);
143 auto child_data = FlatVector::GetData<string_t>(vector&: varchar_vector);
144
145 bool all_converted = true;
146 idx_t total = 0;
147 for (idx_t i = 0; i < count; i++) {
148 idx_t idx = i;
149 if (sel) {
150 idx = sel->get_index(idx: i);
151 }
152 if (!source_mask.RowIsValid(row_idx: idx)) {
153 result_mask.SetInvalid(i);
154 continue;
155 }
156
157 list_data[i].offset = total;
158 if (!VectorStringToList::SplitStringList(input: source_data[idx], child_data, child_start&: total, child&: varchar_vector)) {
159 string text = "Type VARCHAR with value '" + source_data[idx].GetString() +
160 "' can't be cast to the destination type LIST";
161 HandleVectorCastError::Operation<string_t>(error_message: text, mask&: result_mask, idx, error_message_ptr: parameters.error_message, all_converted);
162 }
163 list_data[i].length = total - list_data[i].offset; // length is the amount of parts coming from this string
164 }
165 D_ASSERT(total_list_size == total);
166
167 auto &result_child = ListVector::GetEntry(vector&: result);
168 auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>();
169 CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state);
170 return cast_data.child_cast_info.function(varchar_vector, result_child, total_list_size, child_parameters) &&
171 all_converted;
172}
173
174static LogicalType InitVarcharStructType(const LogicalType &target) {
175 child_list_t<LogicalType> child_types;
176 for (auto &child : StructType::GetChildTypes(type: target)) {
177 child_types.push_back(x: make_pair(x: child.first, y: LogicalType::VARCHAR));
178 }
179
180 return LogicalType::STRUCT(children: child_types);
181}
182
183//===--------------------------------------------------------------------===//
184// string -> struct casting
185//===--------------------------------------------------------------------===//
186bool VectorStringToStruct::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask,
187 Vector &result, ValidityMask &result_mask, idx_t count,
188 CastParameters &parameters, const SelectionVector *sel) {
189 auto varchar_struct_type = InitVarcharStructType(target: result.GetType());
190 Vector varchar_vector(varchar_struct_type, count);
191 auto &child_vectors = StructVector::GetEntries(vector&: varchar_vector);
192 auto &result_children = StructVector::GetEntries(vector&: result);
193
194 string_map_t<idx_t> child_names;
195 vector<ValidityMask *> child_masks;
196 for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) {
197 child_names.insert(x: {StructType::GetChildName(type: result.GetType(), index: child_idx), child_idx});
198 child_masks.emplace_back(args: &FlatVector::Validity(vector&: *child_vectors[child_idx]));
199 child_masks[child_idx]->SetAllInvalid(count);
200 }
201
202 bool all_converted = true;
203 for (idx_t i = 0; i < count; i++) {
204 idx_t idx = i;
205 if (sel) {
206 idx = sel->get_index(idx: i);
207 }
208 if (!source_mask.RowIsValid(row_idx: idx)) {
209 result_mask.SetInvalid(i);
210 continue;
211 }
212 if (!VectorStringToStruct::SplitStruct(input: source_data[idx], varchar_vectors&: child_vectors, row_idx&: i, child_names, child_masks)) {
213 string text = "Type VARCHAR with value '" + source_data[idx].GetString() +
214 "' can't be cast to the destination type STRUCT";
215 for (auto &child_mask : child_masks) {
216 child_mask->SetInvalid(idx); // some values may have already been found and set valid
217 }
218 HandleVectorCastError::Operation<string_t>(error_message: text, mask&: result_mask, idx, error_message_ptr: parameters.error_message, all_converted);
219 }
220 }
221
222 auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>();
223 auto &lstate = parameters.local_state->Cast<StructCastLocalState>();
224 D_ASSERT(cast_data.child_cast_info.size() == result_children.size());
225
226 for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) {
227 auto &child_varchar_vector = *child_vectors[child_idx];
228 auto &result_child_vector = *result_children[child_idx];
229 auto &child_cast_info = cast_data.child_cast_info[child_idx];
230 CastParameters child_parameters(parameters, child_cast_info.cast_data, lstate.local_states[child_idx]);
231 if (!child_cast_info.function(child_varchar_vector, result_child_vector, count, child_parameters)) {
232 all_converted = false;
233 }
234 }
235 return all_converted;
236}
237
238//===--------------------------------------------------------------------===//
239// string -> map casting
240//===--------------------------------------------------------------------===//
241unique_ptr<FunctionLocalState> InitMapCastLocalState(CastLocalStateParameters &parameters) {
242 auto &cast_data = parameters.cast_data->Cast<MapBoundCastData>();
243 auto result = make_uniq<MapCastLocalState>();
244
245 if (cast_data.key_cast.init_local_state) {
246 CastLocalStateParameters child_params(parameters, cast_data.key_cast.cast_data);
247 result->key_state = cast_data.key_cast.init_local_state(child_params);
248 }
249 if (cast_data.value_cast.init_local_state) {
250 CastLocalStateParameters child_params(parameters, cast_data.value_cast.cast_data);
251 result->value_state = cast_data.value_cast.init_local_state(child_params);
252 }
253 return std::move(result);
254}
255
256bool VectorStringToMap::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask,
257 Vector &result, ValidityMask &result_mask, idx_t count,
258 CastParameters &parameters, const SelectionVector *sel) {
259 idx_t total_elements = 0;
260 for (idx_t i = 0; i < count; i++) {
261 idx_t idx = i;
262 if (sel) {
263 idx = sel->get_index(idx: i);
264 }
265 if (!source_mask.RowIsValid(row_idx: idx)) {
266 continue;
267 }
268 total_elements += (VectorStringToMap::CountPartsMap(input: source_data[idx]) + 1) / 2;
269 }
270
271 Vector varchar_key_vector(LogicalType::VARCHAR, total_elements);
272 Vector varchar_val_vector(LogicalType::VARCHAR, total_elements);
273 auto child_key_data = FlatVector::GetData<string_t>(vector&: varchar_key_vector);
274 auto child_val_data = FlatVector::GetData<string_t>(vector&: varchar_val_vector);
275
276 ListVector::Reserve(vec&: result, required_capacity: total_elements);
277 ListVector::SetListSize(vec&: result, size: total_elements);
278 auto list_data = ListVector::GetData(v&: result);
279
280 bool all_converted = true;
281 idx_t total = 0;
282 for (idx_t i = 0; i < count; i++) {
283 idx_t idx = i;
284 if (sel) {
285 idx = sel->get_index(idx: i);
286 }
287 if (!source_mask.RowIsValid(row_idx: idx)) {
288 result_mask.SetInvalid(idx);
289 continue;
290 }
291
292 list_data[i].offset = total;
293 if (!VectorStringToMap::SplitStringMap(input: source_data[idx], child_key_data, child_val_data, child_start&: total,
294 varchar_key&: varchar_key_vector, varchar_val&: varchar_val_vector)) {
295 string text = "Type VARCHAR with value '" + source_data[idx].GetString() +
296 "' can't be cast to the destination type MAP";
297 FlatVector::SetNull(vector&: result, idx, is_null: true);
298 HandleVectorCastError::Operation<string_t>(error_message: text, mask&: result_mask, idx, error_message_ptr: parameters.error_message, all_converted);
299 }
300 list_data[i].length = total - list_data[i].offset;
301 }
302 D_ASSERT(total_elements == total);
303
304 auto &result_key_child = MapVector::GetKeys(vector&: result);
305 auto &result_val_child = MapVector::GetValues(vector&: result);
306 auto &cast_data = parameters.cast_data->Cast<MapBoundCastData>();
307 auto &lstate = parameters.local_state->Cast<MapCastLocalState>();
308
309 CastParameters key_params(parameters, cast_data.key_cast.cast_data, lstate.key_state);
310 if (!cast_data.key_cast.function(varchar_key_vector, result_key_child, total_elements, key_params)) {
311 all_converted = false;
312 }
313 CastParameters val_params(parameters, cast_data.value_cast.cast_data, lstate.value_state);
314 if (!cast_data.value_cast.function(varchar_val_vector, result_val_child, total_elements, val_params)) {
315 all_converted = false;
316 }
317
318 auto &key_validity = FlatVector::Validity(vector&: result_key_child);
319 if (!all_converted) {
320 for (idx_t row_idx = 0; row_idx < count; row_idx++) {
321 if (!result_mask.RowIsValid(row_idx)) {
322 continue;
323 }
324 auto list = list_data[row_idx];
325 for (idx_t list_idx = 0; list_idx < list.length; list_idx++) {
326 auto idx = list.offset + list_idx;
327 if (!key_validity.RowIsValid(row_idx: idx)) {
328 result_mask.SetInvalid(row_idx);
329 }
330 }
331 }
332 }
333 MapVector::MapConversionVerify(vector&: result, count);
334 return all_converted;
335}
336
337template <class T>
338bool StringToNestedTypeCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
339 D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR);
340
341 switch (source.GetVectorType()) {
342 case VectorType::CONSTANT_VECTOR: {
343 auto source_data = ConstantVector::GetData<string_t>(vector&: source);
344 auto &source_mask = ConstantVector::Validity(vector&: source);
345 auto &result_mask = FlatVector::Validity(vector&: result);
346 auto ret = T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, 1, parameters, nullptr);
347 result.SetVectorType(VectorType::CONSTANT_VECTOR);
348 return ret;
349 }
350 default: {
351 UnifiedVectorFormat unified_source;
352
353 source.ToUnifiedFormat(count, data&: unified_source);
354 auto source_sel = unified_source.sel;
355 auto source_data = UnifiedVectorFormat::GetData<string_t>(format: unified_source);
356 auto &source_mask = unified_source.validity;
357 auto &result_mask = FlatVector::Validity(vector&: result);
358
359 return T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, count, parameters,
360 source_sel);
361 }
362 }
363}
364
365BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const LogicalType &source,
366 const LogicalType &target) {
367 switch (target.id()) {
368 case LogicalTypeId::DATE:
369 return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, date_t, duckdb::TryCastErrorMessage>);
370 case LogicalTypeId::TIME:
371 case LogicalTypeId::TIME_TZ:
372 return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, dtime_t, duckdb::TryCastErrorMessage>);
373 case LogicalTypeId::TIMESTAMP:
374 case LogicalTypeId::TIMESTAMP_TZ:
375 return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop<string_t, timestamp_t, duckdb::TryCastErrorMessage>);
376 case LogicalTypeId::TIMESTAMP_NS:
377 return BoundCastInfo(
378 &VectorCastHelpers::TryCastStrictLoop<string_t, timestamp_t, duckdb::TryCastToTimestampNS>);
379 case LogicalTypeId::TIMESTAMP_SEC:
380 return BoundCastInfo(
381 &VectorCastHelpers::TryCastStrictLoop<string_t, timestamp_t, duckdb::TryCastToTimestampSec>);
382 case LogicalTypeId::TIMESTAMP_MS:
383 return BoundCastInfo(
384 &VectorCastHelpers::TryCastStrictLoop<string_t, timestamp_t, duckdb::TryCastToTimestampMS>);
385 case LogicalTypeId::BLOB:
386 return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop<string_t, string_t, duckdb::TryCastToBlob>);
387 case LogicalTypeId::BIT:
388 return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop<string_t, string_t, duckdb::TryCastToBit>);
389 case LogicalTypeId::UUID:
390 return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop<string_t, hugeint_t, duckdb::TryCastToUUID>);
391 case LogicalTypeId::SQLNULL:
392 return &DefaultCasts::TryVectorNullCast;
393 case LogicalTypeId::VARCHAR:
394 return &DefaultCasts::ReinterpretCast;
395 case LogicalTypeId::LIST:
396 // the second argument allows for a secondary casting function to be passed in the CastParameters
397 return BoundCastInfo(
398 &StringToNestedTypeCast<VectorStringToList>,
399 ListBoundCastData::BindListToListCast(input, source: LogicalType::LIST(child: LogicalType::VARCHAR), target),
400 ListBoundCastData::InitListLocalState);
401 case LogicalTypeId::STRUCT:
402 return BoundCastInfo(&StringToNestedTypeCast<VectorStringToStruct>,
403 StructBoundCastData::BindStructToStructCast(input, source: InitVarcharStructType(target), target),
404 StructBoundCastData::InitStructCastLocalState);
405 case LogicalTypeId::MAP:
406 return BoundCastInfo(&StringToNestedTypeCast<VectorStringToMap>,
407 MapBoundCastData::BindMapToMapCast(
408 input, source: LogicalType::MAP(key: LogicalType::VARCHAR, value: LogicalType::VARCHAR), target),
409 InitMapCastLocalState);
410 default:
411 return VectorStringCastNumericSwitch(input, source, target);
412 }
413}
414
415} // namespace duckdb
416