1#include "duckdb/function/cast/default_casts.hpp"
2#include "duckdb/function/cast/cast_function_set.hpp"
3#include "duckdb/function/cast/bound_cast_data.hpp"
4
5namespace duckdb {
6
7unique_ptr<BoundCastData> StructBoundCastData::BindStructToStructCast(BindCastInput &input, const LogicalType &source,
8 const LogicalType &target) {
9 vector<BoundCastInfo> child_cast_info;
10 auto &source_child_types = StructType::GetChildTypes(type: source);
11 auto &result_child_types = StructType::GetChildTypes(type: target);
12 if (source_child_types.size() != result_child_types.size()) {
13 throw TypeMismatchException(source, target, "Cannot cast STRUCTs of different size");
14 }
15 for (idx_t i = 0; i < source_child_types.size(); i++) {
16 auto child_cast = input.GetCastFunction(source: source_child_types[i].second, target: result_child_types[i].second);
17 child_cast_info.push_back(x: std::move(child_cast));
18 }
19 return make_uniq<StructBoundCastData>(args: std::move(child_cast_info), args: target);
20}
21
22unique_ptr<FunctionLocalState> StructBoundCastData::InitStructCastLocalState(CastLocalStateParameters &parameters) {
23 auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>();
24 auto result = make_uniq<StructCastLocalState>();
25
26 for (auto &entry : cast_data.child_cast_info) {
27 unique_ptr<FunctionLocalState> child_state;
28 if (entry.init_local_state) {
29 CastLocalStateParameters child_params(parameters, entry.cast_data);
30 child_state = entry.init_local_state(child_params);
31 }
32 result->local_states.push_back(x: std::move(child_state));
33 }
34 return std::move(result);
35}
36
37static bool StructToStructCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
38 auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>();
39 auto &lstate = parameters.local_state->Cast<StructCastLocalState>();
40 auto &source_child_types = StructType::GetChildTypes(type: source.GetType());
41 auto &source_children = StructVector::GetEntries(vector&: source);
42 D_ASSERT(source_children.size() == StructType::GetChildTypes(result.GetType()).size());
43
44 auto &result_children = StructVector::GetEntries(vector&: result);
45 bool all_converted = true;
46 for (idx_t c_idx = 0; c_idx < source_child_types.size(); c_idx++) {
47 auto &result_child_vector = *result_children[c_idx];
48 auto &source_child_vector = *source_children[c_idx];
49 CastParameters child_parameters(parameters, cast_data.child_cast_info[c_idx].cast_data,
50 lstate.local_states[c_idx]);
51 if (!cast_data.child_cast_info[c_idx].function(source_child_vector, result_child_vector, count,
52 child_parameters)) {
53 all_converted = false;
54 }
55 }
56 if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) {
57 result.SetVectorType(VectorType::CONSTANT_VECTOR);
58 ConstantVector::SetNull(vector&: result, is_null: ConstantVector::IsNull(vector: source));
59 } else {
60 source.Flatten(count);
61 FlatVector::Validity(vector&: result) = FlatVector::Validity(vector&: source);
62 }
63 return all_converted;
64}
65
66static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
67 auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR;
68 // first cast all child elements to varchar
69 auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>();
70 Vector varchar_struct(cast_data.target, count);
71 StructToStructCast(source, result&: varchar_struct, count, parameters);
72
73 // now construct the actual varchar vector
74 varchar_struct.Flatten(count);
75 auto &child_types = StructType::GetChildTypes(type: source.GetType());
76 auto &children = StructVector::GetEntries(vector&: varchar_struct);
77 auto &validity = FlatVector::Validity(vector&: varchar_struct);
78 auto result_data = FlatVector::GetData<string_t>(vector&: result);
79 static constexpr const idx_t SEP_LENGTH = 2;
80 static constexpr const idx_t NAME_SEP_LENGTH = 4;
81 static constexpr const idx_t NULL_LENGTH = 4;
82 for (idx_t i = 0; i < count; i++) {
83 if (!validity.RowIsValid(row_idx: i)) {
84 FlatVector::SetNull(vector&: result, idx: i, is_null: true);
85 continue;
86 }
87 idx_t string_length = 2; // {}
88 for (idx_t c = 0; c < children.size(); c++) {
89 if (c > 0) {
90 string_length += SEP_LENGTH;
91 }
92 children[c]->Flatten(count);
93 auto &child_validity = FlatVector::Validity(vector&: *children[c]);
94 auto data = FlatVector::GetData<string_t>(vector&: *children[c]);
95 auto &name = child_types[c].first;
96 string_length += name.size() + NAME_SEP_LENGTH; // "'{name}': "
97 string_length += child_validity.RowIsValid(row_idx: i) ? data[i].GetSize() : NULL_LENGTH;
98 }
99 result_data[i] = StringVector::EmptyString(vector&: result, len: string_length);
100 auto dataptr = result_data[i].GetDataWriteable();
101 idx_t offset = 0;
102 dataptr[offset++] = '{';
103 for (idx_t c = 0; c < children.size(); c++) {
104 if (c > 0) {
105 memcpy(dest: dataptr + offset, src: ", ", n: SEP_LENGTH);
106 offset += SEP_LENGTH;
107 }
108 auto &child_validity = FlatVector::Validity(vector&: *children[c]);
109 auto data = FlatVector::GetData<string_t>(vector&: *children[c]);
110 auto &name = child_types[c].first;
111 // "'{name}': "
112 dataptr[offset++] = '\'';
113 memcpy(dest: dataptr + offset, src: name.c_str(), n: name.size());
114 offset += name.size();
115 dataptr[offset++] = '\'';
116 dataptr[offset++] = ':';
117 dataptr[offset++] = ' ';
118 // value
119 if (child_validity.RowIsValid(row_idx: i)) {
120 auto len = data[i].GetSize();
121 memcpy(dest: dataptr + offset, src: data[i].GetData(), n: len);
122 offset += len;
123 } else {
124 memcpy(dest: dataptr + offset, src: "NULL", n: NULL_LENGTH);
125 offset += NULL_LENGTH;
126 }
127 }
128 dataptr[offset++] = '}';
129 result_data[i].Finalize();
130 }
131
132 if (constant) {
133 result.SetVectorType(VectorType::CONSTANT_VECTOR);
134 }
135 return true;
136}
137
138BoundCastInfo DefaultCasts::StructCastSwitch(BindCastInput &input, const LogicalType &source,
139 const LogicalType &target) {
140 switch (target.id()) {
141 case LogicalTypeId::STRUCT:
142 return BoundCastInfo(StructToStructCast, StructBoundCastData::BindStructToStructCast(input, source, target),
143 StructBoundCastData::InitStructCastLocalState);
144 case LogicalTypeId::VARCHAR: {
145 // bind a cast in which we convert all child entries to VARCHAR entries
146 auto &struct_children = StructType::GetChildTypes(type: source);
147 child_list_t<LogicalType> varchar_children;
148 for (auto &child_entry : struct_children) {
149 varchar_children.push_back(x: make_pair(x: child_entry.first, y: LogicalType::VARCHAR));
150 }
151 auto varchar_type = LogicalType::STRUCT(children: varchar_children);
152 return BoundCastInfo(StructToVarcharCast,
153 StructBoundCastData::BindStructToStructCast(input, source, target: varchar_type),
154 StructBoundCastData::InitStructCastLocalState);
155 }
156 default:
157 return TryVectorNullCast;
158 }
159}
160
161} // namespace duckdb
162