1 | #include "duckdb/function/cast/default_casts.hpp" |
2 | #include "duckdb/function/cast/vector_cast_helpers.hpp" |
3 | #include "duckdb/function/cast/cast_function_set.hpp" |
4 | |
5 | namespace duckdb { |
6 | |
7 | template <class SRC_TYPE, class RES_TYPE> |
8 | bool (Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
9 | bool all_converted = true; |
10 | result.SetVectorType(VectorType::FLAT_VECTOR); |
11 | |
12 | auto &str_vec = EnumType::GetValuesInsertOrder(type: source.GetType()); |
13 | auto str_vec_ptr = FlatVector::GetData<string_t>(vector: str_vec); |
14 | |
15 | auto res_enum_type = result.GetType(); |
16 | |
17 | UnifiedVectorFormat vdata; |
18 | source.ToUnifiedFormat(count, data&: vdata); |
19 | |
20 | auto source_data = UnifiedVectorFormat::GetData<SRC_TYPE>(vdata); |
21 | auto source_sel = vdata.sel; |
22 | auto source_mask = vdata.validity; |
23 | |
24 | auto result_data = FlatVector::GetData<RES_TYPE>(result); |
25 | auto &result_mask = FlatVector::Validity(vector&: result); |
26 | |
27 | for (idx_t i = 0; i < count; i++) { |
28 | auto src_idx = source_sel->get_index(idx: i); |
29 | if (!source_mask.RowIsValid(row_idx: src_idx)) { |
30 | result_mask.SetInvalid(i); |
31 | continue; |
32 | } |
33 | auto key = EnumType::GetPos(type: res_enum_type, key: str_vec_ptr[source_data[src_idx]]); |
34 | if (key == -1) { |
35 | // key doesn't exist on result enum |
36 | if (!parameters.error_message) { |
37 | result_data[i] = HandleVectorCastError::Operation<RES_TYPE>( |
38 | CastExceptionText<SRC_TYPE, RES_TYPE>(source_data[src_idx]), result_mask, i, |
39 | parameters.error_message, all_converted); |
40 | } else { |
41 | result_mask.SetInvalid(i); |
42 | } |
43 | continue; |
44 | } |
45 | result_data[i] = key; |
46 | } |
47 | return all_converted; |
48 | } |
49 | |
50 | template <class SRC_TYPE> |
51 | BoundCastInfo (BindCastInput &input, const LogicalType &source, const LogicalType &target) { |
52 | switch (target.InternalType()) { |
53 | case PhysicalType::UINT8: |
54 | return EnumEnumCast<SRC_TYPE, uint8_t>; |
55 | case PhysicalType::UINT16: |
56 | return EnumEnumCast<SRC_TYPE, uint16_t>; |
57 | case PhysicalType::UINT32: |
58 | return EnumEnumCast<SRC_TYPE, uint32_t>; |
59 | default: |
60 | throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types" ); |
61 | } |
62 | } |
63 | |
64 | template <class SRC> |
65 | static bool EnumToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
66 | auto &enum_dictionary = EnumType::GetValuesInsertOrder(type: source.GetType()); |
67 | auto dictionary_data = FlatVector::GetData<string_t>(vector: enum_dictionary); |
68 | auto result_data = FlatVector::GetData<string_t>(vector&: result); |
69 | auto &result_mask = FlatVector::Validity(vector&: result); |
70 | |
71 | UnifiedVectorFormat vdata; |
72 | source.ToUnifiedFormat(count, data&: vdata); |
73 | |
74 | auto source_data = UnifiedVectorFormat::GetData<SRC>(vdata); |
75 | for (idx_t i = 0; i < count; i++) { |
76 | auto source_idx = vdata.sel->get_index(idx: i); |
77 | if (!vdata.validity.RowIsValid(row_idx: source_idx)) { |
78 | result_mask.SetInvalid(i); |
79 | continue; |
80 | } |
81 | auto enum_idx = source_data[source_idx]; |
82 | result_data[i] = dictionary_data[enum_idx]; |
83 | } |
84 | if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
85 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
86 | } else { |
87 | result.SetVectorType(VectorType::FLAT_VECTOR); |
88 | } |
89 | return true; |
90 | } |
91 | |
92 | struct EnumBoundCastData : public BoundCastData { |
93 | EnumBoundCastData(BoundCastInfo to_varchar_cast, BoundCastInfo from_varchar_cast) |
94 | : to_varchar_cast(std::move(to_varchar_cast)), from_varchar_cast(std::move(from_varchar_cast)) { |
95 | } |
96 | |
97 | BoundCastInfo to_varchar_cast; |
98 | BoundCastInfo from_varchar_cast; |
99 | |
100 | public: |
101 | unique_ptr<BoundCastData> Copy() const override { |
102 | return make_uniq<EnumBoundCastData>(args: to_varchar_cast.Copy(), args: from_varchar_cast.Copy()); |
103 | } |
104 | }; |
105 | |
106 | unique_ptr<BoundCastData> BindEnumCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { |
107 | auto to_varchar_cast = input.GetCastFunction(source, target: LogicalType::VARCHAR); |
108 | auto from_varchar_cast = input.GetCastFunction(source: LogicalType::VARCHAR, target); |
109 | return make_uniq<EnumBoundCastData>(args: std::move(to_varchar_cast), args: std::move(from_varchar_cast)); |
110 | } |
111 | |
112 | struct EnumCastLocalState : public FunctionLocalState { |
113 | public: |
114 | unique_ptr<FunctionLocalState> to_varchar_local; |
115 | unique_ptr<FunctionLocalState> from_varchar_local; |
116 | }; |
117 | |
118 | static unique_ptr<FunctionLocalState> InitEnumCastLocalState(CastLocalStateParameters ¶meters) { |
119 | auto &cast_data = parameters.cast_data->Cast<EnumBoundCastData>(); |
120 | auto result = make_uniq<EnumCastLocalState>(); |
121 | |
122 | if (cast_data.from_varchar_cast.init_local_state) { |
123 | CastLocalStateParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data); |
124 | result->from_varchar_local = cast_data.from_varchar_cast.init_local_state(from_varchar_params); |
125 | } |
126 | if (cast_data.to_varchar_cast.init_local_state) { |
127 | CastLocalStateParameters from_varchar_params(parameters, cast_data.to_varchar_cast.cast_data); |
128 | result->from_varchar_local = cast_data.to_varchar_cast.init_local_state(from_varchar_params); |
129 | } |
130 | return std::move(result); |
131 | } |
132 | |
133 | static bool EnumToAnyCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
134 | auto &cast_data = parameters.cast_data->Cast<EnumBoundCastData>(); |
135 | auto &lstate = parameters.local_state->Cast<EnumCastLocalState>(); |
136 | |
137 | Vector varchar_cast(LogicalType::VARCHAR, count); |
138 | |
139 | // cast to varchar |
140 | CastParameters to_varchar_params(parameters, cast_data.to_varchar_cast.cast_data, lstate.to_varchar_local); |
141 | cast_data.to_varchar_cast.function(source, varchar_cast, count, to_varchar_params); |
142 | |
143 | // cast from varchar to the target |
144 | CastParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data, lstate.from_varchar_local); |
145 | cast_data.from_varchar_cast.function(varchar_cast, result, count, from_varchar_params); |
146 | return true; |
147 | } |
148 | |
149 | BoundCastInfo DefaultCasts::EnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { |
150 | auto enum_physical_type = source.InternalType(); |
151 | switch (target.id()) { |
152 | case LogicalTypeId::ENUM: { |
153 | // This means they are both ENUMs, but of different types. |
154 | switch (enum_physical_type) { |
155 | case PhysicalType::UINT8: |
156 | return EnumEnumCastSwitch<uint8_t>(input, source, target); |
157 | case PhysicalType::UINT16: |
158 | return EnumEnumCastSwitch<uint16_t>(input, source, target); |
159 | case PhysicalType::UINT32: |
160 | return EnumEnumCastSwitch<uint32_t>(input, source, target); |
161 | default: |
162 | throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types" ); |
163 | } |
164 | } |
165 | case LogicalTypeId::VARCHAR: |
166 | switch (enum_physical_type) { |
167 | case PhysicalType::UINT8: |
168 | return EnumToVarcharCast<uint8_t>; |
169 | case PhysicalType::UINT16: |
170 | return EnumToVarcharCast<uint16_t>; |
171 | case PhysicalType::UINT32: |
172 | return EnumToVarcharCast<uint32_t>; |
173 | default: |
174 | throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types" ); |
175 | } |
176 | default: { |
177 | return BoundCastInfo(EnumToAnyCast, BindEnumCast(input, source, target), InitEnumCastLocalState); |
178 | } |
179 | } |
180 | } |
181 | |
182 | } // namespace duckdb |
183 | |