1 | #include "duckdb/function/cast/default_casts.hpp" |
2 | #include "duckdb/function/cast/cast_function_set.hpp" |
3 | #include "duckdb/function/cast/bound_cast_data.hpp" |
4 | |
5 | namespace duckdb { |
6 | |
7 | unique_ptr<BoundCastData> ListBoundCastData::BindListToListCast(BindCastInput &input, const LogicalType &source, |
8 | const LogicalType &target) { |
9 | vector<BoundCastInfo> child_cast_info; |
10 | auto &source_child_type = ListType::GetChildType(type: source); |
11 | auto &result_child_type = ListType::GetChildType(type: target); |
12 | auto child_cast = input.GetCastFunction(source: source_child_type, target: result_child_type); |
13 | return make_uniq<ListBoundCastData>(args: std::move(child_cast)); |
14 | } |
15 | |
16 | unique_ptr<FunctionLocalState> ListBoundCastData::InitListLocalState(CastLocalStateParameters ¶meters) { |
17 | auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>(); |
18 | if (!cast_data.child_cast_info.init_local_state) { |
19 | return nullptr; |
20 | } |
21 | CastLocalStateParameters child_parameters(parameters, cast_data.child_cast_info.cast_data); |
22 | return cast_data.child_cast_info.init_local_state(child_parameters); |
23 | } |
24 | |
25 | bool ListCast::ListToListCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
26 | auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>(); |
27 | |
28 | // only handle constant and flat vectors here for now |
29 | if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
30 | result.SetVectorType(source.GetVectorType()); |
31 | ConstantVector::SetNull(vector&: result, is_null: ConstantVector::IsNull(vector: source)); |
32 | |
33 | auto ldata = ConstantVector::GetData<list_entry_t>(vector&: source); |
34 | auto tdata = ConstantVector::GetData<list_entry_t>(vector&: result); |
35 | *tdata = *ldata; |
36 | } else { |
37 | source.Flatten(count); |
38 | result.SetVectorType(VectorType::FLAT_VECTOR); |
39 | FlatVector::SetValidity(vector&: result, new_validity&: FlatVector::Validity(vector&: source)); |
40 | |
41 | auto ldata = FlatVector::GetData<list_entry_t>(vector&: source); |
42 | auto tdata = FlatVector::GetData<list_entry_t>(vector&: result); |
43 | for (idx_t i = 0; i < count; i++) { |
44 | tdata[i] = ldata[i]; |
45 | } |
46 | } |
47 | auto &source_cc = ListVector::GetEntry(vector&: source); |
48 | auto source_size = ListVector::GetListSize(vector: source); |
49 | |
50 | ListVector::Reserve(vec&: result, required_capacity: source_size); |
51 | auto &append_vector = ListVector::GetEntry(vector&: result); |
52 | |
53 | CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); |
54 | bool all_succeeded = cast_data.child_cast_info.function(source_cc, append_vector, source_size, child_parameters); |
55 | ListVector::SetListSize(vec&: result, size: source_size); |
56 | D_ASSERT(ListVector::GetListSize(result) == source_size); |
57 | return all_succeeded; |
58 | } |
59 | |
60 | static bool ListToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
61 | auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; |
62 | // first cast the child vector to varchar |
63 | Vector varchar_list(LogicalType::LIST(child: LogicalType::VARCHAR), count); |
64 | ListCast::ListToListCast(source, result&: varchar_list, count, parameters); |
65 | |
66 | // now construct the actual varchar vector |
67 | varchar_list.Flatten(count); |
68 | auto &child = ListVector::GetEntry(vector&: varchar_list); |
69 | auto list_data = FlatVector::GetData<list_entry_t>(vector&: varchar_list); |
70 | auto &validity = FlatVector::Validity(vector&: varchar_list); |
71 | |
72 | child.Flatten(count); |
73 | auto child_data = FlatVector::GetData<string_t>(vector&: child); |
74 | auto &child_validity = FlatVector::Validity(vector&: child); |
75 | |
76 | auto result_data = FlatVector::GetData<string_t>(vector&: result); |
77 | static constexpr const idx_t SEP_LENGTH = 2; |
78 | static constexpr const idx_t NULL_LENGTH = 4; |
79 | for (idx_t i = 0; i < count; i++) { |
80 | if (!validity.RowIsValid(row_idx: i)) { |
81 | FlatVector::SetNull(vector&: result, idx: i, is_null: true); |
82 | continue; |
83 | } |
84 | auto list = list_data[i]; |
85 | // figure out how long the result needs to be |
86 | idx_t list_length = 2; // "[" and "]" |
87 | for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { |
88 | auto idx = list.offset + list_idx; |
89 | if (list_idx > 0) { |
90 | list_length += SEP_LENGTH; // ", " |
91 | } |
92 | // string length, or "NULL" |
93 | list_length += child_validity.RowIsValid(row_idx: idx) ? child_data[idx].GetSize() : NULL_LENGTH; |
94 | } |
95 | result_data[i] = StringVector::EmptyString(vector&: result, len: list_length); |
96 | auto dataptr = result_data[i].GetDataWriteable(); |
97 | auto offset = 0; |
98 | dataptr[offset++] = '['; |
99 | for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { |
100 | auto idx = list.offset + list_idx; |
101 | if (list_idx > 0) { |
102 | memcpy(dest: dataptr + offset, src: ", " , n: SEP_LENGTH); |
103 | offset += SEP_LENGTH; |
104 | } |
105 | if (child_validity.RowIsValid(row_idx: idx)) { |
106 | auto len = child_data[idx].GetSize(); |
107 | memcpy(dest: dataptr + offset, src: child_data[idx].GetData(), n: len); |
108 | offset += len; |
109 | } else { |
110 | memcpy(dest: dataptr + offset, src: "NULL" , n: NULL_LENGTH); |
111 | offset += NULL_LENGTH; |
112 | } |
113 | } |
114 | dataptr[offset] = ']'; |
115 | result_data[i].Finalize(); |
116 | } |
117 | |
118 | if (constant) { |
119 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
120 | } |
121 | return true; |
122 | } |
123 | |
124 | BoundCastInfo DefaultCasts::ListCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { |
125 | switch (target.id()) { |
126 | case LogicalTypeId::LIST: |
127 | return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), |
128 | ListBoundCastData::InitListLocalState); |
129 | case LogicalTypeId::VARCHAR: |
130 | return BoundCastInfo( |
131 | ListToVarcharCast, |
132 | ListBoundCastData::BindListToListCast(input, source, target: LogicalType::LIST(child: LogicalType::VARCHAR)), |
133 | ListBoundCastData::InitListLocalState); |
134 | default: |
135 | return DefaultCasts::TryVectorNullCast; |
136 | } |
137 | } |
138 | |
139 | } // namespace duckdb |
140 | |