1#include "duckdb/function/cast/default_casts.hpp"
2#include "duckdb/function/cast/cast_function_set.hpp"
3#include "duckdb/function/cast/bound_cast_data.hpp"
4
5namespace duckdb {
6
7unique_ptr<BoundCastData> ListBoundCastData::BindListToListCast(BindCastInput &input, const LogicalType &source,
8 const LogicalType &target) {
9 vector<BoundCastInfo> child_cast_info;
10 auto &source_child_type = ListType::GetChildType(type: source);
11 auto &result_child_type = ListType::GetChildType(type: target);
12 auto child_cast = input.GetCastFunction(source: source_child_type, target: result_child_type);
13 return make_uniq<ListBoundCastData>(args: std::move(child_cast));
14}
15
16unique_ptr<FunctionLocalState> ListBoundCastData::InitListLocalState(CastLocalStateParameters &parameters) {
17 auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>();
18 if (!cast_data.child_cast_info.init_local_state) {
19 return nullptr;
20 }
21 CastLocalStateParameters child_parameters(parameters, cast_data.child_cast_info.cast_data);
22 return cast_data.child_cast_info.init_local_state(child_parameters);
23}
24
25bool ListCast::ListToListCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
26 auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>();
27
28 // only handle constant and flat vectors here for now
29 if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) {
30 result.SetVectorType(source.GetVectorType());
31 ConstantVector::SetNull(vector&: result, is_null: ConstantVector::IsNull(vector: source));
32
33 auto ldata = ConstantVector::GetData<list_entry_t>(vector&: source);
34 auto tdata = ConstantVector::GetData<list_entry_t>(vector&: result);
35 *tdata = *ldata;
36 } else {
37 source.Flatten(count);
38 result.SetVectorType(VectorType::FLAT_VECTOR);
39 FlatVector::SetValidity(vector&: result, new_validity&: FlatVector::Validity(vector&: source));
40
41 auto ldata = FlatVector::GetData<list_entry_t>(vector&: source);
42 auto tdata = FlatVector::GetData<list_entry_t>(vector&: result);
43 for (idx_t i = 0; i < count; i++) {
44 tdata[i] = ldata[i];
45 }
46 }
47 auto &source_cc = ListVector::GetEntry(vector&: source);
48 auto source_size = ListVector::GetListSize(vector: source);
49
50 ListVector::Reserve(vec&: result, required_capacity: source_size);
51 auto &append_vector = ListVector::GetEntry(vector&: result);
52
53 CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state);
54 bool all_succeeded = cast_data.child_cast_info.function(source_cc, append_vector, source_size, child_parameters);
55 ListVector::SetListSize(vec&: result, size: source_size);
56 D_ASSERT(ListVector::GetListSize(result) == source_size);
57 return all_succeeded;
58}
59
60static bool ListToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
61 auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR;
62 // first cast the child vector to varchar
63 Vector varchar_list(LogicalType::LIST(child: LogicalType::VARCHAR), count);
64 ListCast::ListToListCast(source, result&: varchar_list, count, parameters);
65
66 // now construct the actual varchar vector
67 varchar_list.Flatten(count);
68 auto &child = ListVector::GetEntry(vector&: varchar_list);
69 auto list_data = FlatVector::GetData<list_entry_t>(vector&: varchar_list);
70 auto &validity = FlatVector::Validity(vector&: varchar_list);
71
72 child.Flatten(count);
73 auto child_data = FlatVector::GetData<string_t>(vector&: child);
74 auto &child_validity = FlatVector::Validity(vector&: child);
75
76 auto result_data = FlatVector::GetData<string_t>(vector&: result);
77 static constexpr const idx_t SEP_LENGTH = 2;
78 static constexpr const idx_t NULL_LENGTH = 4;
79 for (idx_t i = 0; i < count; i++) {
80 if (!validity.RowIsValid(row_idx: i)) {
81 FlatVector::SetNull(vector&: result, idx: i, is_null: true);
82 continue;
83 }
84 auto list = list_data[i];
85 // figure out how long the result needs to be
86 idx_t list_length = 2; // "[" and "]"
87 for (idx_t list_idx = 0; list_idx < list.length; list_idx++) {
88 auto idx = list.offset + list_idx;
89 if (list_idx > 0) {
90 list_length += SEP_LENGTH; // ", "
91 }
92 // string length, or "NULL"
93 list_length += child_validity.RowIsValid(row_idx: idx) ? child_data[idx].GetSize() : NULL_LENGTH;
94 }
95 result_data[i] = StringVector::EmptyString(vector&: result, len: list_length);
96 auto dataptr = result_data[i].GetDataWriteable();
97 auto offset = 0;
98 dataptr[offset++] = '[';
99 for (idx_t list_idx = 0; list_idx < list.length; list_idx++) {
100 auto idx = list.offset + list_idx;
101 if (list_idx > 0) {
102 memcpy(dest: dataptr + offset, src: ", ", n: SEP_LENGTH);
103 offset += SEP_LENGTH;
104 }
105 if (child_validity.RowIsValid(row_idx: idx)) {
106 auto len = child_data[idx].GetSize();
107 memcpy(dest: dataptr + offset, src: child_data[idx].GetData(), n: len);
108 offset += len;
109 } else {
110 memcpy(dest: dataptr + offset, src: "NULL", n: NULL_LENGTH);
111 offset += NULL_LENGTH;
112 }
113 }
114 dataptr[offset] = ']';
115 result_data[i].Finalize();
116 }
117
118 if (constant) {
119 result.SetVectorType(VectorType::CONSTANT_VECTOR);
120 }
121 return true;
122}
123
124BoundCastInfo DefaultCasts::ListCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) {
125 switch (target.id()) {
126 case LogicalTypeId::LIST:
127 return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target),
128 ListBoundCastData::InitListLocalState);
129 case LogicalTypeId::VARCHAR:
130 return BoundCastInfo(
131 ListToVarcharCast,
132 ListBoundCastData::BindListToListCast(input, source, target: LogicalType::LIST(child: LogicalType::VARCHAR)),
133 ListBoundCastData::InitListLocalState);
134 default:
135 return DefaultCasts::TryVectorNullCast;
136 }
137}
138
139} // namespace duckdb
140