1 | #include "duckdb/function/cast/default_casts.hpp" |
2 | #include "duckdb/function/cast/cast_function_set.hpp" |
3 | #include "duckdb/function/cast/bound_cast_data.hpp" |
4 | |
5 | namespace duckdb { |
6 | |
7 | unique_ptr<BoundCastData> StructBoundCastData::BindStructToStructCast(BindCastInput &input, const LogicalType &source, |
8 | const LogicalType &target) { |
9 | vector<BoundCastInfo> child_cast_info; |
10 | auto &source_child_types = StructType::GetChildTypes(type: source); |
11 | auto &result_child_types = StructType::GetChildTypes(type: target); |
12 | if (source_child_types.size() != result_child_types.size()) { |
13 | throw TypeMismatchException(source, target, "Cannot cast STRUCTs of different size" ); |
14 | } |
15 | for (idx_t i = 0; i < source_child_types.size(); i++) { |
16 | auto child_cast = input.GetCastFunction(source: source_child_types[i].second, target: result_child_types[i].second); |
17 | child_cast_info.push_back(x: std::move(child_cast)); |
18 | } |
19 | return make_uniq<StructBoundCastData>(args: std::move(child_cast_info), args: target); |
20 | } |
21 | |
22 | unique_ptr<FunctionLocalState> StructBoundCastData::InitStructCastLocalState(CastLocalStateParameters ¶meters) { |
23 | auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>(); |
24 | auto result = make_uniq<StructCastLocalState>(); |
25 | |
26 | for (auto &entry : cast_data.child_cast_info) { |
27 | unique_ptr<FunctionLocalState> child_state; |
28 | if (entry.init_local_state) { |
29 | CastLocalStateParameters child_params(parameters, entry.cast_data); |
30 | child_state = entry.init_local_state(child_params); |
31 | } |
32 | result->local_states.push_back(x: std::move(child_state)); |
33 | } |
34 | return std::move(result); |
35 | } |
36 | |
37 | static bool StructToStructCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
38 | auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>(); |
39 | auto &lstate = parameters.local_state->Cast<StructCastLocalState>(); |
40 | auto &source_child_types = StructType::GetChildTypes(type: source.GetType()); |
41 | auto &source_children = StructVector::GetEntries(vector&: source); |
42 | D_ASSERT(source_children.size() == StructType::GetChildTypes(result.GetType()).size()); |
43 | |
44 | auto &result_children = StructVector::GetEntries(vector&: result); |
45 | bool all_converted = true; |
46 | for (idx_t c_idx = 0; c_idx < source_child_types.size(); c_idx++) { |
47 | auto &result_child_vector = *result_children[c_idx]; |
48 | auto &source_child_vector = *source_children[c_idx]; |
49 | CastParameters child_parameters(parameters, cast_data.child_cast_info[c_idx].cast_data, |
50 | lstate.local_states[c_idx]); |
51 | if (!cast_data.child_cast_info[c_idx].function(source_child_vector, result_child_vector, count, |
52 | child_parameters)) { |
53 | all_converted = false; |
54 | } |
55 | } |
56 | if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
57 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
58 | ConstantVector::SetNull(vector&: result, is_null: ConstantVector::IsNull(vector: source)); |
59 | } else { |
60 | source.Flatten(count); |
61 | FlatVector::Validity(vector&: result) = FlatVector::Validity(vector&: source); |
62 | } |
63 | return all_converted; |
64 | } |
65 | |
66 | static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
67 | auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; |
68 | // first cast all child elements to varchar |
69 | auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>(); |
70 | Vector varchar_struct(cast_data.target, count); |
71 | StructToStructCast(source, result&: varchar_struct, count, parameters); |
72 | |
73 | // now construct the actual varchar vector |
74 | varchar_struct.Flatten(count); |
75 | auto &child_types = StructType::GetChildTypes(type: source.GetType()); |
76 | auto &children = StructVector::GetEntries(vector&: varchar_struct); |
77 | auto &validity = FlatVector::Validity(vector&: varchar_struct); |
78 | auto result_data = FlatVector::GetData<string_t>(vector&: result); |
79 | static constexpr const idx_t SEP_LENGTH = 2; |
80 | static constexpr const idx_t NAME_SEP_LENGTH = 4; |
81 | static constexpr const idx_t NULL_LENGTH = 4; |
82 | for (idx_t i = 0; i < count; i++) { |
83 | if (!validity.RowIsValid(row_idx: i)) { |
84 | FlatVector::SetNull(vector&: result, idx: i, is_null: true); |
85 | continue; |
86 | } |
87 | idx_t string_length = 2; // {} |
88 | for (idx_t c = 0; c < children.size(); c++) { |
89 | if (c > 0) { |
90 | string_length += SEP_LENGTH; |
91 | } |
92 | children[c]->Flatten(count); |
93 | auto &child_validity = FlatVector::Validity(vector&: *children[c]); |
94 | auto data = FlatVector::GetData<string_t>(vector&: *children[c]); |
95 | auto &name = child_types[c].first; |
96 | string_length += name.size() + NAME_SEP_LENGTH; // "'{name}': " |
97 | string_length += child_validity.RowIsValid(row_idx: i) ? data[i].GetSize() : NULL_LENGTH; |
98 | } |
99 | result_data[i] = StringVector::EmptyString(vector&: result, len: string_length); |
100 | auto dataptr = result_data[i].GetDataWriteable(); |
101 | idx_t offset = 0; |
102 | dataptr[offset++] = '{'; |
103 | for (idx_t c = 0; c < children.size(); c++) { |
104 | if (c > 0) { |
105 | memcpy(dest: dataptr + offset, src: ", " , n: SEP_LENGTH); |
106 | offset += SEP_LENGTH; |
107 | } |
108 | auto &child_validity = FlatVector::Validity(vector&: *children[c]); |
109 | auto data = FlatVector::GetData<string_t>(vector&: *children[c]); |
110 | auto &name = child_types[c].first; |
111 | // "'{name}': " |
112 | dataptr[offset++] = '\''; |
113 | memcpy(dest: dataptr + offset, src: name.c_str(), n: name.size()); |
114 | offset += name.size(); |
115 | dataptr[offset++] = '\''; |
116 | dataptr[offset++] = ':'; |
117 | dataptr[offset++] = ' '; |
118 | // value |
119 | if (child_validity.RowIsValid(row_idx: i)) { |
120 | auto len = data[i].GetSize(); |
121 | memcpy(dest: dataptr + offset, src: data[i].GetData(), n: len); |
122 | offset += len; |
123 | } else { |
124 | memcpy(dest: dataptr + offset, src: "NULL" , n: NULL_LENGTH); |
125 | offset += NULL_LENGTH; |
126 | } |
127 | } |
128 | dataptr[offset++] = '}'; |
129 | result_data[i].Finalize(); |
130 | } |
131 | |
132 | if (constant) { |
133 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
134 | } |
135 | return true; |
136 | } |
137 | |
138 | BoundCastInfo DefaultCasts::StructCastSwitch(BindCastInput &input, const LogicalType &source, |
139 | const LogicalType &target) { |
140 | switch (target.id()) { |
141 | case LogicalTypeId::STRUCT: |
142 | return BoundCastInfo(StructToStructCast, StructBoundCastData::BindStructToStructCast(input, source, target), |
143 | StructBoundCastData::InitStructCastLocalState); |
144 | case LogicalTypeId::VARCHAR: { |
145 | // bind a cast in which we convert all child entries to VARCHAR entries |
146 | auto &struct_children = StructType::GetChildTypes(type: source); |
147 | child_list_t<LogicalType> varchar_children; |
148 | for (auto &child_entry : struct_children) { |
149 | varchar_children.push_back(x: make_pair(x: child_entry.first, y: LogicalType::VARCHAR)); |
150 | } |
151 | auto varchar_type = LogicalType::STRUCT(children: varchar_children); |
152 | return BoundCastInfo(StructToVarcharCast, |
153 | StructBoundCastData::BindStructToStructCast(input, source, target: varchar_type), |
154 | StructBoundCastData::InitStructCastLocalState); |
155 | } |
156 | default: |
157 | return TryVectorNullCast; |
158 | } |
159 | } |
160 | |
161 | } // namespace duckdb |
162 | |