1 | #include "duckdb/function/cast_rules.hpp" |
2 | |
3 | namespace duckdb { |
4 | |
5 | //! The target type determines the preferred implicit casts |
6 | static int64_t TargetTypeCost(const LogicalType &type) { |
7 | switch (type.id()) { |
8 | case LogicalTypeId::INTEGER: |
9 | return 103; |
10 | case LogicalTypeId::BIGINT: |
11 | return 101; |
12 | case LogicalTypeId::DOUBLE: |
13 | return 102; |
14 | case LogicalTypeId::HUGEINT: |
15 | return 120; |
16 | case LogicalTypeId::TIMESTAMP: |
17 | return 120; |
18 | case LogicalTypeId::VARCHAR: |
19 | return 149; |
20 | case LogicalTypeId::DECIMAL: |
21 | return 104; |
22 | case LogicalTypeId::STRUCT: |
23 | case LogicalTypeId::MAP: |
24 | case LogicalTypeId::LIST: |
25 | case LogicalTypeId::UNION: |
26 | return 160; |
27 | default: |
28 | return 110; |
29 | } |
30 | } |
31 | |
32 | static int64_t ImplicitCastTinyint(const LogicalType &to) { |
33 | switch (to.id()) { |
34 | case LogicalTypeId::SMALLINT: |
35 | case LogicalTypeId::INTEGER: |
36 | case LogicalTypeId::BIGINT: |
37 | case LogicalTypeId::HUGEINT: |
38 | case LogicalTypeId::FLOAT: |
39 | case LogicalTypeId::DOUBLE: |
40 | case LogicalTypeId::DECIMAL: |
41 | return TargetTypeCost(type: to); |
42 | default: |
43 | return -1; |
44 | } |
45 | } |
46 | |
47 | static int64_t ImplicitCastSmallint(const LogicalType &to) { |
48 | switch (to.id()) { |
49 | case LogicalTypeId::INTEGER: |
50 | case LogicalTypeId::BIGINT: |
51 | case LogicalTypeId::HUGEINT: |
52 | case LogicalTypeId::FLOAT: |
53 | case LogicalTypeId::DOUBLE: |
54 | case LogicalTypeId::DECIMAL: |
55 | return TargetTypeCost(type: to); |
56 | default: |
57 | return -1; |
58 | } |
59 | } |
60 | |
61 | static int64_t ImplicitCastInteger(const LogicalType &to) { |
62 | switch (to.id()) { |
63 | case LogicalTypeId::BIGINT: |
64 | case LogicalTypeId::HUGEINT: |
65 | case LogicalTypeId::FLOAT: |
66 | case LogicalTypeId::DOUBLE: |
67 | case LogicalTypeId::DECIMAL: |
68 | return TargetTypeCost(type: to); |
69 | default: |
70 | return -1; |
71 | } |
72 | } |
73 | |
74 | static int64_t ImplicitCastBigint(const LogicalType &to) { |
75 | switch (to.id()) { |
76 | case LogicalTypeId::FLOAT: |
77 | case LogicalTypeId::DOUBLE: |
78 | case LogicalTypeId::HUGEINT: |
79 | case LogicalTypeId::DECIMAL: |
80 | return TargetTypeCost(type: to); |
81 | default: |
82 | return -1; |
83 | } |
84 | } |
85 | |
86 | static int64_t ImplicitCastUTinyint(const LogicalType &to) { |
87 | switch (to.id()) { |
88 | case LogicalTypeId::USMALLINT: |
89 | case LogicalTypeId::UINTEGER: |
90 | case LogicalTypeId::UBIGINT: |
91 | case LogicalTypeId::SMALLINT: |
92 | case LogicalTypeId::INTEGER: |
93 | case LogicalTypeId::BIGINT: |
94 | case LogicalTypeId::HUGEINT: |
95 | case LogicalTypeId::FLOAT: |
96 | case LogicalTypeId::DOUBLE: |
97 | case LogicalTypeId::DECIMAL: |
98 | return TargetTypeCost(type: to); |
99 | default: |
100 | return -1; |
101 | } |
102 | } |
103 | |
104 | static int64_t ImplicitCastUSmallint(const LogicalType &to) { |
105 | switch (to.id()) { |
106 | case LogicalTypeId::UINTEGER: |
107 | case LogicalTypeId::UBIGINT: |
108 | case LogicalTypeId::INTEGER: |
109 | case LogicalTypeId::BIGINT: |
110 | case LogicalTypeId::HUGEINT: |
111 | case LogicalTypeId::FLOAT: |
112 | case LogicalTypeId::DOUBLE: |
113 | case LogicalTypeId::DECIMAL: |
114 | return TargetTypeCost(type: to); |
115 | default: |
116 | return -1; |
117 | } |
118 | } |
119 | |
120 | static int64_t ImplicitCastUInteger(const LogicalType &to) { |
121 | switch (to.id()) { |
122 | |
123 | case LogicalTypeId::UBIGINT: |
124 | case LogicalTypeId::BIGINT: |
125 | case LogicalTypeId::HUGEINT: |
126 | case LogicalTypeId::FLOAT: |
127 | case LogicalTypeId::DOUBLE: |
128 | case LogicalTypeId::DECIMAL: |
129 | return TargetTypeCost(type: to); |
130 | default: |
131 | return -1; |
132 | } |
133 | } |
134 | |
135 | static int64_t ImplicitCastUBigint(const LogicalType &to) { |
136 | switch (to.id()) { |
137 | case LogicalTypeId::FLOAT: |
138 | case LogicalTypeId::DOUBLE: |
139 | case LogicalTypeId::HUGEINT: |
140 | case LogicalTypeId::DECIMAL: |
141 | return TargetTypeCost(type: to); |
142 | default: |
143 | return -1; |
144 | } |
145 | } |
146 | |
147 | static int64_t ImplicitCastFloat(const LogicalType &to) { |
148 | switch (to.id()) { |
149 | case LogicalTypeId::DOUBLE: |
150 | return TargetTypeCost(type: to); |
151 | default: |
152 | return -1; |
153 | } |
154 | } |
155 | |
156 | static int64_t ImplicitCastDouble(const LogicalType &to) { |
157 | switch (to.id()) { |
158 | default: |
159 | return -1; |
160 | } |
161 | } |
162 | |
163 | static int64_t ImplicitCastDecimal(const LogicalType &to) { |
164 | switch (to.id()) { |
165 | case LogicalTypeId::FLOAT: |
166 | case LogicalTypeId::DOUBLE: |
167 | return TargetTypeCost(type: to); |
168 | default: |
169 | return -1; |
170 | } |
171 | } |
172 | |
173 | static int64_t ImplicitCastHugeint(const LogicalType &to) { |
174 | switch (to.id()) { |
175 | case LogicalTypeId::FLOAT: |
176 | case LogicalTypeId::DOUBLE: |
177 | case LogicalTypeId::DECIMAL: |
178 | return TargetTypeCost(type: to); |
179 | default: |
180 | return -1; |
181 | } |
182 | } |
183 | |
184 | static int64_t ImplicitCastDate(const LogicalType &to) { |
185 | switch (to.id()) { |
186 | case LogicalTypeId::TIMESTAMP: |
187 | return TargetTypeCost(type: to); |
188 | default: |
189 | return -1; |
190 | } |
191 | } |
192 | |
193 | int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) { |
194 | if (from.id() == LogicalTypeId::SQLNULL) { |
195 | // NULL expression can be cast to anything |
196 | return TargetTypeCost(type: to); |
197 | } |
198 | if (from.id() == LogicalTypeId::UNKNOWN) { |
199 | // parameter expression can be cast to anything for no cost |
200 | return 0; |
201 | } |
202 | if (to.id() == LogicalTypeId::ANY) { |
203 | // anything can be cast to ANY type for (almost no) cost |
204 | return 1; |
205 | } |
206 | if (from.GetAlias() != to.GetAlias()) { |
207 | // if aliases are different, an implicit cast is not possible |
208 | return -1; |
209 | } |
210 | if (from.id() == LogicalTypeId::LIST && to.id() == LogicalTypeId::LIST) { |
211 | // Lists can be cast if their child types can be cast |
212 | auto child_cost = ImplicitCast(from: ListType::GetChildType(type: from), to: ListType::GetChildType(type: to)); |
213 | if (child_cost >= 100) { |
214 | // subtract one from the cost because we prefer LIST[X] -> LIST[VARCHAR] over LIST[X] -> VARCHAR |
215 | child_cost--; |
216 | } |
217 | return child_cost; |
218 | } |
219 | if (from.id() == to.id()) { |
220 | // arguments match: do nothing |
221 | return 0; |
222 | } |
223 | if (from.id() == LogicalTypeId::BLOB && to.id() == LogicalTypeId::VARCHAR) { |
224 | // Implicit cast not allowed from BLOB to VARCHAR |
225 | return -1; |
226 | } |
227 | if (to.id() == LogicalTypeId::VARCHAR) { |
228 | // everything can be cast to VARCHAR, but this cast has a high cost |
229 | return TargetTypeCost(type: to); |
230 | } |
231 | |
232 | if (from.id() == LogicalTypeId::UNION && to.id() == LogicalTypeId::UNION) { |
233 | // Unions can be cast if the source tags are a subset of the target tags |
234 | // in which case the most expensive cost is used |
235 | int cost = -1; |
236 | for (idx_t from_member_idx = 0; from_member_idx < UnionType::GetMemberCount(type: from); from_member_idx++) { |
237 | auto &from_member_name = UnionType::GetMemberName(type: from, index: from_member_idx); |
238 | |
239 | bool found = false; |
240 | for (idx_t to_member_idx = 0; to_member_idx < UnionType::GetMemberCount(type: to); to_member_idx++) { |
241 | auto &to_member_name = UnionType::GetMemberName(type: to, index: to_member_idx); |
242 | |
243 | if (from_member_name == to_member_name) { |
244 | auto &from_member_type = UnionType::GetMemberType(type: from, index: from_member_idx); |
245 | auto &to_member_type = UnionType::GetMemberType(type: to, index: to_member_idx); |
246 | |
247 | int child_cost = ImplicitCast(from: from_member_type, to: to_member_type); |
248 | if (child_cost > cost) { |
249 | cost = child_cost; |
250 | } |
251 | found = true; |
252 | break; |
253 | } |
254 | } |
255 | if (!found) { |
256 | return -1; |
257 | } |
258 | } |
259 | return cost; |
260 | } |
261 | |
262 | if (to.id() == LogicalTypeId::UNION) { |
263 | // check that the union type is fully resolved. |
264 | if (to.AuxInfo() == nullptr) { |
265 | return -1; |
266 | } |
267 | // every type can be implicitly be cast to a union if the source type is a member of the union |
268 | for (idx_t i = 0; i < UnionType::GetMemberCount(type: to); i++) { |
269 | auto member = UnionType::GetMemberType(type: to, index: i); |
270 | if (from == member) { |
271 | return 0; |
272 | } |
273 | } |
274 | } |
275 | |
276 | if ((from.id() == LogicalTypeId::TIMESTAMP_SEC || from.id() == LogicalTypeId::TIMESTAMP_MS || |
277 | from.id() == LogicalTypeId::TIMESTAMP_NS) && |
278 | to.id() == LogicalTypeId::TIMESTAMP) { |
279 | //! Any timestamp type can be converted to the default (us) type at low cost |
280 | return 101; |
281 | } |
282 | if ((to.id() == LogicalTypeId::TIMESTAMP_SEC || to.id() == LogicalTypeId::TIMESTAMP_MS || |
283 | to.id() == LogicalTypeId::TIMESTAMP_NS) && |
284 | from.id() == LogicalTypeId::TIMESTAMP) { |
285 | //! Any timestamp type can be converted to the default (us) type at low cost |
286 | return 100; |
287 | } |
288 | switch (from.id()) { |
289 | case LogicalTypeId::TINYINT: |
290 | return ImplicitCastTinyint(to); |
291 | case LogicalTypeId::SMALLINT: |
292 | return ImplicitCastSmallint(to); |
293 | case LogicalTypeId::INTEGER: |
294 | return ImplicitCastInteger(to); |
295 | case LogicalTypeId::BIGINT: |
296 | return ImplicitCastBigint(to); |
297 | case LogicalTypeId::UTINYINT: |
298 | return ImplicitCastUTinyint(to); |
299 | case LogicalTypeId::USMALLINT: |
300 | return ImplicitCastUSmallint(to); |
301 | case LogicalTypeId::UINTEGER: |
302 | return ImplicitCastUInteger(to); |
303 | case LogicalTypeId::UBIGINT: |
304 | return ImplicitCastUBigint(to); |
305 | case LogicalTypeId::HUGEINT: |
306 | return ImplicitCastHugeint(to); |
307 | case LogicalTypeId::FLOAT: |
308 | return ImplicitCastFloat(to); |
309 | case LogicalTypeId::DOUBLE: |
310 | return ImplicitCastDouble(to); |
311 | case LogicalTypeId::DATE: |
312 | return ImplicitCastDate(to); |
313 | case LogicalTypeId::DECIMAL: |
314 | return ImplicitCastDecimal(to); |
315 | default: |
316 | return -1; |
317 | } |
318 | } |
319 | |
320 | } // namespace duckdb |
321 | |