1#include "duckdb/function/cast/cast_function_set.hpp"
2#include "duckdb/function/cast/default_casts.hpp"
3#include "duckdb/function/cast/bound_cast_data.hpp"
4
5#include <algorithm> // for std::sort
6
7namespace duckdb {
8
9//--------------------------------------------------------------------------------------------------
10// ??? -> UNION
11//--------------------------------------------------------------------------------------------------
12// if the source can be implicitly cast to a member of the target union, the cast is valid
13
14struct ToUnionBoundCastData : public BoundCastData {
15 ToUnionBoundCastData(union_tag_t member_idx, string name, LogicalType type, int64_t cost,
16 BoundCastInfo member_cast_info)
17 : tag(member_idx), name(std::move(name)), type(std::move(type)), cost(cost),
18 member_cast_info(std::move(member_cast_info)) {
19 }
20
21 union_tag_t tag;
22 string name;
23 LogicalType type;
24 int64_t cost;
25 BoundCastInfo member_cast_info;
26
27public:
28 unique_ptr<BoundCastData> Copy() const override {
29 return make_uniq<ToUnionBoundCastData>(args: tag, args: name, args: type, args: cost, args: member_cast_info.Copy());
30 }
31
32 static bool SortByCostAscending(const ToUnionBoundCastData &left, const ToUnionBoundCastData &right) {
33 return left.cost < right.cost;
34 }
35};
36
37unique_ptr<BoundCastData> BindToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) {
38 D_ASSERT(target.id() == LogicalTypeId::UNION);
39
40 vector<ToUnionBoundCastData> candidates;
41
42 for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(type: target); member_idx++) {
43 auto member_type = UnionType::GetMemberType(type: target, index: member_idx);
44 auto member_name = UnionType::GetMemberName(type: target, index: member_idx);
45 auto member_cast_cost = input.function_set.ImplicitCastCost(source, target: member_type);
46 if (member_cast_cost != -1) {
47 auto member_cast_info = input.GetCastFunction(source, target: member_type);
48 candidates.emplace_back(args&: member_idx, args&: member_name, args&: member_type, args&: member_cast_cost,
49 args: std::move(member_cast_info));
50 }
51 };
52
53 // no possible casts found!
54 if (candidates.empty()) {
55 auto message = StringUtil::Format(
56 fmt_str: "Type %s can't be cast as %s. %s can't be implicitly cast to any of the union member types: ",
57 params: source.ToString(), params: target.ToString(), params: source.ToString());
58
59 auto member_count = UnionType::GetMemberCount(type: target);
60 for (idx_t member_idx = 0; member_idx < member_count; member_idx++) {
61 auto member_type = UnionType::GetMemberType(type: target, index: member_idx);
62 message += member_type.ToString();
63 if (member_idx < member_count - 1) {
64 message += ", ";
65 }
66 }
67 throw CastException(message);
68 }
69
70 // sort the candidate casts by cost
71 std::sort(first: candidates.begin(), last: candidates.end(), comp: ToUnionBoundCastData::SortByCostAscending);
72
73 // select the lowest possible cost cast
74 auto &selected_cast = candidates[0];
75 auto selected_cost = candidates[0].cost;
76
77 // check if the cast is ambiguous (2 or more casts have the same cost)
78 if (candidates.size() > 1 && candidates[1].cost == selected_cost) {
79
80 // collect all the ambiguous types
81 auto message = StringUtil::Format(
82 fmt_str: "Type %s can't be cast as %s. The cast is ambiguous, multiple possible members in target: ", params: source,
83 params: target);
84 for (size_t i = 0; i < candidates.size(); i++) {
85 if (candidates[i].cost == selected_cost) {
86 message += StringUtil::Format(fmt_str: "'%s (%s)'", params: candidates[i].name, params: candidates[i].type.ToString());
87 if (i < candidates.size() - 1) {
88 message += ", ";
89 }
90 }
91 }
92 message += ". Disambiguate the target type by using the 'union_value(<tag> := <arg>)' function to promote the "
93 "source value to a single member union before casting.";
94 throw CastException(message);
95 }
96
97 // otherwise, return the selected cast
98 return make_uniq<ToUnionBoundCastData>(args: std::move(selected_cast));
99}
100
101unique_ptr<FunctionLocalState> InitToUnionLocalState(CastLocalStateParameters &parameters) {
102 auto &cast_data = parameters.cast_data->Cast<ToUnionBoundCastData>();
103 if (!cast_data.member_cast_info.init_local_state) {
104 return nullptr;
105 }
106 CastLocalStateParameters child_parameters(parameters, cast_data.member_cast_info.cast_data);
107 return cast_data.member_cast_info.init_local_state(child_parameters);
108}
109
110static bool ToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
111 D_ASSERT(result.GetType().id() == LogicalTypeId::UNION);
112 auto &cast_data = parameters.cast_data->Cast<ToUnionBoundCastData>();
113 auto &selected_member_vector = UnionVector::GetMember(vector&: result, member_index: cast_data.tag);
114
115 CastParameters child_parameters(parameters, cast_data.member_cast_info.cast_data, parameters.local_state);
116 if (!cast_data.member_cast_info.function(source, selected_member_vector, count, child_parameters)) {
117 return false;
118 }
119
120 // cast succeeded, create union vector
121 UnionVector::SetToMember(vector&: result, tag: cast_data.tag, member_vector&: selected_member_vector, count, keep_tags_for_null: true);
122
123 result.Verify(count);
124
125 return true;
126}
127
128BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const LogicalType &source,
129 const LogicalType &target) {
130 return BoundCastInfo(&ToUnionCast, BindToUnionCast(input, source, target), InitToUnionLocalState);
131}
132
133//--------------------------------------------------------------------------------------------------
134// UNION -> UNION
135//--------------------------------------------------------------------------------------------------
136// if the source member tags is a subset of the target member tags, and all the source members can be
137// implicitly cast to the corresponding target members, the cast is valid.
138//
139// VALID: UNION(A, B) -> UNION(A, B, C)
140// VALID: UNION(A, B) -> UNION(A, C) if B can be implicitly cast to C
141//
142// INVALID: UNION(A, B, C) -> UNION(A, B)
143// INVALID: UNION(A, B) -> UNION(A, C) if B can't be implicitly cast to C
144// INVALID: UNION(A, B, D) -> UNION(A, B, C)
145
146struct UnionToUnionBoundCastData : public BoundCastData {
147
148 // mapping from source member index to target member index
149 // these are always the same size as the source member count
150 // (since all source members must be present in the target)
151 vector<idx_t> tag_map;
152 vector<BoundCastInfo> member_casts;
153
154 LogicalType target_type;
155
156 UnionToUnionBoundCastData(vector<idx_t> tag_map, vector<BoundCastInfo> member_casts, LogicalType target_type)
157 : tag_map(std::move(tag_map)), member_casts(std::move(member_casts)), target_type(std::move(target_type)) {
158 }
159
160public:
161 unique_ptr<BoundCastData> Copy() const override {
162 vector<BoundCastInfo> member_casts_copy;
163 for (auto &member_cast : member_casts) {
164 member_casts_copy.push_back(x: member_cast.Copy());
165 }
166 return make_uniq<UnionToUnionBoundCastData>(args: tag_map, args: std::move(member_casts_copy), args: target_type);
167 }
168};
169
170unique_ptr<BoundCastData> BindUnionToUnionCast(BindCastInput &input, const LogicalType &source,
171 const LogicalType &target) {
172 D_ASSERT(source.id() == LogicalTypeId::UNION);
173 D_ASSERT(target.id() == LogicalTypeId::UNION);
174
175 auto source_member_count = UnionType::GetMemberCount(type: source);
176
177 auto tag_map = vector<idx_t>(source_member_count);
178 auto member_casts = vector<BoundCastInfo>();
179
180 for (idx_t source_idx = 0; source_idx < source_member_count; source_idx++) {
181 auto &source_member_type = UnionType::GetMemberType(type: source, index: source_idx);
182 auto &source_member_name = UnionType::GetMemberName(type: source, index: source_idx);
183
184 bool found = false;
185 for (idx_t target_idx = 0; target_idx < UnionType::GetMemberCount(type: target); target_idx++) {
186 auto &target_member_name = UnionType::GetMemberName(type: target, index: target_idx);
187
188 // found a matching member
189 if (source_member_name == target_member_name) {
190 auto &target_member_type = UnionType::GetMemberType(type: target, index: target_idx);
191 tag_map[source_idx] = target_idx;
192 member_casts.push_back(x: input.GetCastFunction(source: source_member_type, target: target_member_type));
193 found = true;
194 break;
195 }
196 }
197 if (!found) {
198 // no matching member tag found in the target set
199 auto message =
200 StringUtil::Format(fmt_str: "Type %s can't be cast as %s. The member '%s' is not present in target union",
201 params: source.ToString(), params: target.ToString(), params: source_member_name);
202 throw CastException(message);
203 }
204 }
205
206 return make_uniq<UnionToUnionBoundCastData>(args&: tag_map, args: std::move(member_casts), args: target);
207}
208
209unique_ptr<FunctionLocalState> InitUnionToUnionLocalState(CastLocalStateParameters &parameters) {
210 auto &cast_data = parameters.cast_data->Cast<UnionToUnionBoundCastData>();
211 auto result = make_uniq<StructCastLocalState>();
212
213 for (auto &entry : cast_data.member_casts) {
214 unique_ptr<FunctionLocalState> child_state;
215 if (entry.init_local_state) {
216 CastLocalStateParameters child_params(parameters, entry.cast_data);
217 child_state = entry.init_local_state(child_params);
218 }
219 result->local_states.push_back(x: std::move(child_state));
220 }
221 return std::move(result);
222}
223
224static bool UnionToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
225 auto &cast_data = parameters.cast_data->Cast<UnionToUnionBoundCastData>();
226 auto &lstate = parameters.local_state->Cast<StructCastLocalState>();
227
228 auto source_member_count = UnionType::GetMemberCount(type: source.GetType());
229 auto target_member_count = UnionType::GetMemberCount(type: result.GetType());
230
231 auto target_member_is_mapped = vector<bool>(target_member_count);
232
233 // Perform the casts from source to target members
234 for (idx_t member_idx = 0; member_idx < source_member_count; member_idx++) {
235 auto target_member_idx = cast_data.tag_map[member_idx];
236
237 auto &source_member_vector = UnionVector::GetMember(vector&: source, member_index: member_idx);
238 auto &target_member_vector = UnionVector::GetMember(vector&: result, member_index: target_member_idx);
239 auto &member_cast = cast_data.member_casts[member_idx];
240
241 CastParameters child_parameters(parameters, member_cast.cast_data, lstate.local_states[member_idx]);
242 if (!member_cast.function(source_member_vector, target_member_vector, count, child_parameters)) {
243 return false;
244 }
245
246 target_member_is_mapped[target_member_idx] = true;
247 }
248
249 // All member casts succeeded!
250
251 // Set the unmapped target members to constant NULL.
252 // If we cast UNION(A, B) -> UNION(A, B, C) we need to invalidate C so that
253 // the invariants of the result union hold. (only member columns "selected"
254 // by the rowwise corresponding tag in the tag vector should be valid)
255 for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) {
256 if (!target_member_is_mapped[target_member_idx]) {
257 auto &target_member_vector = UnionVector::GetMember(vector&: result, member_index: target_member_idx);
258 target_member_vector.SetVectorType(VectorType::CONSTANT_VECTOR);
259 ConstantVector::SetNull(vector&: target_member_vector, is_null: true);
260 }
261 }
262
263 // Update the tags in the result vector
264 auto &source_tag_vector = UnionVector::GetTags(v&: source);
265 auto &result_tag_vector = UnionVector::GetTags(v&: result);
266
267 if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) {
268 // Constant vector case optimization
269 result.SetVectorType(VectorType::CONSTANT_VECTOR);
270 if (ConstantVector::IsNull(vector: source)) {
271 ConstantVector::SetNull(vector&: result, is_null: true);
272 } else {
273 // map the tag
274 auto source_tag = ConstantVector::GetData<union_tag_t>(vector&: source_tag_vector)[0];
275 auto mapped_tag = cast_data.tag_map[source_tag];
276 ConstantVector::GetData<union_tag_t>(vector&: result_tag_vector)[0] = mapped_tag;
277 }
278 } else {
279 // Otherwise, use the unified vector format to access the source vector.
280
281 // Ensure that all the result members are flat vectors
282 // This is not always the case, e.g. when a member is cast using the default TryNullCast function
283 // the resulting member vector will be a constant null vector.
284 for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) {
285 UnionVector::GetMember(vector&: result, member_index: target_member_idx).Flatten(count);
286 }
287
288 // We assume that a union tag vector validity matches the union vector validity.
289 UnifiedVectorFormat source_tag_format;
290 source_tag_vector.ToUnifiedFormat(count, data&: source_tag_format);
291
292 for (idx_t row_idx = 0; row_idx < count; row_idx++) {
293 auto source_row_idx = source_tag_format.sel->get_index(idx: row_idx);
294 if (source_tag_format.validity.RowIsValid(row_idx: source_row_idx)) {
295 // map the tag
296 auto source_tag = (UnifiedVectorFormat::GetData<union_tag_t>(format: source_tag_format))[source_row_idx];
297 auto target_tag = cast_data.tag_map[source_tag];
298 FlatVector::GetData<union_tag_t>(vector&: result_tag_vector)[row_idx] = target_tag;
299 } else {
300
301 // Issue: The members of the result is not always flatvectors
302 // In the case of TryNullCast, the result member is constant.
303 FlatVector::SetNull(vector&: result, idx: row_idx, is_null: true);
304 }
305 }
306 }
307
308 result.Verify(count);
309
310 return true;
311}
312
313static bool UnionToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
314 auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR;
315 // first cast all union members to varchar
316 auto &cast_data = parameters.cast_data->Cast<UnionToUnionBoundCastData>();
317 Vector varchar_union(cast_data.target_type, count);
318
319 UnionToUnionCast(source, result&: varchar_union, count, parameters);
320
321 // now construct the actual varchar vector
322 varchar_union.Flatten(count);
323 auto &tag_vector = UnionVector::GetTags(v&: source);
324 auto tags = FlatVector::GetData<union_tag_t>(vector&: tag_vector);
325
326 auto &validity = FlatVector::Validity(vector&: varchar_union);
327 auto result_data = FlatVector::GetData<string_t>(vector&: result);
328
329 for (idx_t i = 0; i < count; i++) {
330 if (!validity.RowIsValid(row_idx: i)) {
331 FlatVector::SetNull(vector&: result, idx: i, is_null: true);
332 continue;
333 }
334
335 auto &member = UnionVector::GetMember(vector&: varchar_union, member_index: tags[i]);
336 UnifiedVectorFormat member_vdata;
337 member.ToUnifiedFormat(count, data&: member_vdata);
338
339 auto mapped_idx = member_vdata.sel->get_index(idx: i);
340 auto member_valid = member_vdata.validity.RowIsValid(row_idx: mapped_idx);
341 if (member_valid) {
342 auto member_str = (UnifiedVectorFormat::GetData<string_t>(format: member_vdata))[mapped_idx];
343 result_data[i] = StringVector::AddString(vector&: result, data: member_str);
344 } else {
345 result_data[i] = StringVector::AddString(vector&: result, data: "NULL");
346 }
347 }
348
349 if (constant) {
350 result.SetVectorType(VectorType::CONSTANT_VECTOR);
351 }
352
353 result.Verify(count);
354 return true;
355}
356
357BoundCastInfo DefaultCasts::UnionCastSwitch(BindCastInput &input, const LogicalType &source,
358 const LogicalType &target) {
359 switch (target.id()) {
360 case LogicalTypeId::VARCHAR: {
361 // bind a cast in which we convert all members to VARCHAR first
362 child_list_t<LogicalType> varchar_members;
363 for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(type: source); member_idx++) {
364 varchar_members.push_back(x: make_pair(x: UnionType::GetMemberName(type: source, index: member_idx), y: LogicalType::VARCHAR));
365 }
366 auto varchar_type = LogicalType::UNION(members: std::move(varchar_members));
367 return BoundCastInfo(UnionToVarcharCast, BindUnionToUnionCast(input, source, target: varchar_type),
368 InitUnionToUnionLocalState);
369 }
370 case LogicalTypeId::UNION:
371 return BoundCastInfo(UnionToUnionCast, BindUnionToUnionCast(input, source, target), InitUnionToUnionLocalState);
372 default:
373 return TryVectorNullCast;
374 }
375}
376
377} // namespace duckdb
378