1#include "duckdb/function/cast/default_casts.hpp"
2#include "duckdb/function/cast/vector_cast_helpers.hpp"
3#include "duckdb/function/cast/cast_function_set.hpp"
4
5namespace duckdb {
6
7template <class SRC_TYPE, class RES_TYPE>
8bool EnumEnumCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
9 bool all_converted = true;
10 result.SetVectorType(VectorType::FLAT_VECTOR);
11
12 auto &str_vec = EnumType::GetValuesInsertOrder(type: source.GetType());
13 auto str_vec_ptr = FlatVector::GetData<string_t>(vector: str_vec);
14
15 auto res_enum_type = result.GetType();
16
17 UnifiedVectorFormat vdata;
18 source.ToUnifiedFormat(count, data&: vdata);
19
20 auto source_data = UnifiedVectorFormat::GetData<SRC_TYPE>(vdata);
21 auto source_sel = vdata.sel;
22 auto source_mask = vdata.validity;
23
24 auto result_data = FlatVector::GetData<RES_TYPE>(result);
25 auto &result_mask = FlatVector::Validity(vector&: result);
26
27 for (idx_t i = 0; i < count; i++) {
28 auto src_idx = source_sel->get_index(idx: i);
29 if (!source_mask.RowIsValid(row_idx: src_idx)) {
30 result_mask.SetInvalid(i);
31 continue;
32 }
33 auto key = EnumType::GetPos(type: res_enum_type, key: str_vec_ptr[source_data[src_idx]]);
34 if (key == -1) {
35 // key doesn't exist on result enum
36 if (!parameters.error_message) {
37 result_data[i] = HandleVectorCastError::Operation<RES_TYPE>(
38 CastExceptionText<SRC_TYPE, RES_TYPE>(source_data[src_idx]), result_mask, i,
39 parameters.error_message, all_converted);
40 } else {
41 result_mask.SetInvalid(i);
42 }
43 continue;
44 }
45 result_data[i] = key;
46 }
47 return all_converted;
48}
49
50template <class SRC_TYPE>
51BoundCastInfo EnumEnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) {
52 switch (target.InternalType()) {
53 case PhysicalType::UINT8:
54 return EnumEnumCast<SRC_TYPE, uint8_t>;
55 case PhysicalType::UINT16:
56 return EnumEnumCast<SRC_TYPE, uint16_t>;
57 case PhysicalType::UINT32:
58 return EnumEnumCast<SRC_TYPE, uint32_t>;
59 default:
60 throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types");
61 }
62}
63
64template <class SRC>
65static bool EnumToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
66 auto &enum_dictionary = EnumType::GetValuesInsertOrder(type: source.GetType());
67 auto dictionary_data = FlatVector::GetData<string_t>(vector: enum_dictionary);
68 auto result_data = FlatVector::GetData<string_t>(vector&: result);
69 auto &result_mask = FlatVector::Validity(vector&: result);
70
71 UnifiedVectorFormat vdata;
72 source.ToUnifiedFormat(count, data&: vdata);
73
74 auto source_data = UnifiedVectorFormat::GetData<SRC>(vdata);
75 for (idx_t i = 0; i < count; i++) {
76 auto source_idx = vdata.sel->get_index(idx: i);
77 if (!vdata.validity.RowIsValid(row_idx: source_idx)) {
78 result_mask.SetInvalid(i);
79 continue;
80 }
81 auto enum_idx = source_data[source_idx];
82 result_data[i] = dictionary_data[enum_idx];
83 }
84 if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) {
85 result.SetVectorType(VectorType::CONSTANT_VECTOR);
86 } else {
87 result.SetVectorType(VectorType::FLAT_VECTOR);
88 }
89 return true;
90}
91
92struct EnumBoundCastData : public BoundCastData {
93 EnumBoundCastData(BoundCastInfo to_varchar_cast, BoundCastInfo from_varchar_cast)
94 : to_varchar_cast(std::move(to_varchar_cast)), from_varchar_cast(std::move(from_varchar_cast)) {
95 }
96
97 BoundCastInfo to_varchar_cast;
98 BoundCastInfo from_varchar_cast;
99
100public:
101 unique_ptr<BoundCastData> Copy() const override {
102 return make_uniq<EnumBoundCastData>(args: to_varchar_cast.Copy(), args: from_varchar_cast.Copy());
103 }
104};
105
106unique_ptr<BoundCastData> BindEnumCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) {
107 auto to_varchar_cast = input.GetCastFunction(source, target: LogicalType::VARCHAR);
108 auto from_varchar_cast = input.GetCastFunction(source: LogicalType::VARCHAR, target);
109 return make_uniq<EnumBoundCastData>(args: std::move(to_varchar_cast), args: std::move(from_varchar_cast));
110}
111
112struct EnumCastLocalState : public FunctionLocalState {
113public:
114 unique_ptr<FunctionLocalState> to_varchar_local;
115 unique_ptr<FunctionLocalState> from_varchar_local;
116};
117
118static unique_ptr<FunctionLocalState> InitEnumCastLocalState(CastLocalStateParameters &parameters) {
119 auto &cast_data = parameters.cast_data->Cast<EnumBoundCastData>();
120 auto result = make_uniq<EnumCastLocalState>();
121
122 if (cast_data.from_varchar_cast.init_local_state) {
123 CastLocalStateParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data);
124 result->from_varchar_local = cast_data.from_varchar_cast.init_local_state(from_varchar_params);
125 }
126 if (cast_data.to_varchar_cast.init_local_state) {
127 CastLocalStateParameters from_varchar_params(parameters, cast_data.to_varchar_cast.cast_data);
128 result->from_varchar_local = cast_data.to_varchar_cast.init_local_state(from_varchar_params);
129 }
130 return std::move(result);
131}
132
133static bool EnumToAnyCast(Vector &source, Vector &result, idx_t count, CastParameters &parameters) {
134 auto &cast_data = parameters.cast_data->Cast<EnumBoundCastData>();
135 auto &lstate = parameters.local_state->Cast<EnumCastLocalState>();
136
137 Vector varchar_cast(LogicalType::VARCHAR, count);
138
139 // cast to varchar
140 CastParameters to_varchar_params(parameters, cast_data.to_varchar_cast.cast_data, lstate.to_varchar_local);
141 cast_data.to_varchar_cast.function(source, varchar_cast, count, to_varchar_params);
142
143 // cast from varchar to the target
144 CastParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data, lstate.from_varchar_local);
145 cast_data.from_varchar_cast.function(varchar_cast, result, count, from_varchar_params);
146 return true;
147}
148
149BoundCastInfo DefaultCasts::EnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) {
150 auto enum_physical_type = source.InternalType();
151 switch (target.id()) {
152 case LogicalTypeId::ENUM: {
153 // This means they are both ENUMs, but of different types.
154 switch (enum_physical_type) {
155 case PhysicalType::UINT8:
156 return EnumEnumCastSwitch<uint8_t>(input, source, target);
157 case PhysicalType::UINT16:
158 return EnumEnumCastSwitch<uint16_t>(input, source, target);
159 case PhysicalType::UINT32:
160 return EnumEnumCastSwitch<uint32_t>(input, source, target);
161 default:
162 throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types");
163 }
164 }
165 case LogicalTypeId::VARCHAR:
166 switch (enum_physical_type) {
167 case PhysicalType::UINT8:
168 return EnumToVarcharCast<uint8_t>;
169 case PhysicalType::UINT16:
170 return EnumToVarcharCast<uint16_t>;
171 case PhysicalType::UINT32:
172 return EnumToVarcharCast<uint32_t>;
173 default:
174 throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types");
175 }
176 default: {
177 return BoundCastInfo(EnumToAnyCast, BindEnumCast(input, source, target), InitEnumCastLocalState);
178 }
179 }
180}
181
182} // namespace duckdb
183