1 | #include "duckdb/function/cast/default_casts.hpp" |
2 | |
3 | #include "duckdb/common/likely.hpp" |
4 | #include "duckdb/common/limits.hpp" |
5 | #include "duckdb/common/operator/cast_operators.hpp" |
6 | #include "duckdb/common/string_util.hpp" |
7 | #include "duckdb/common/types/cast_helpers.hpp" |
8 | #include "duckdb/common/types/chunk_collection.hpp" |
9 | #include "duckdb/common/types/null_value.hpp" |
10 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
11 | #include "duckdb/function/cast/vector_cast_helpers.hpp" |
12 | |
13 | namespace duckdb { |
14 | |
15 | BindCastInfo::~BindCastInfo() { |
16 | } |
17 | |
18 | BoundCastData::~BoundCastData() { |
19 | } |
20 | |
21 | BoundCastInfo::BoundCastInfo(cast_function_t function_p, unique_ptr<BoundCastData> cast_data_p, |
22 | init_cast_local_state_t init_local_state_p) |
23 | : function(function_p), init_local_state(init_local_state_p), cast_data(std::move(cast_data_p)) { |
24 | } |
25 | |
26 | BoundCastInfo BoundCastInfo::Copy() const { |
27 | return BoundCastInfo(function, cast_data ? cast_data->Copy() : nullptr, init_local_state); |
28 | } |
29 | |
30 | bool DefaultCasts::NopCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
31 | result.Reference(other&: source); |
32 | return true; |
33 | } |
34 | |
35 | static string UnimplementedCastMessage(const LogicalType &source_type, const LogicalType &target_type) { |
36 | return StringUtil::Format(fmt_str: "Unimplemented type for cast (%s -> %s)" , params: source_type.ToString(), params: target_type.ToString()); |
37 | } |
38 | |
39 | // NULL cast only works if all values in source are NULL, otherwise an unimplemented cast exception is thrown |
40 | bool DefaultCasts::TryVectorNullCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
41 | bool success = true; |
42 | if (VectorOperations::HasNotNull(input&: source, count)) { |
43 | HandleCastError::AssignError(error_message: UnimplementedCastMessage(source_type: source.GetType(), target_type: result.GetType()), |
44 | error_message_ptr: parameters.error_message); |
45 | success = false; |
46 | } |
47 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
48 | ConstantVector::SetNull(vector&: result, is_null: true); |
49 | return success; |
50 | } |
51 | |
52 | bool DefaultCasts::ReinterpretCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
53 | result.Reinterpret(other&: source); |
54 | return true; |
55 | } |
56 | |
57 | static bool AggregateStateToBlobCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
58 | if (result.GetType().id() != LogicalTypeId::BLOB) { |
59 | throw TypeMismatchException(source.GetType(), result.GetType(), |
60 | "Cannot cast AGGREGATE_STATE to anything but BLOB" ); |
61 | } |
62 | result.Reinterpret(other&: source); |
63 | return true; |
64 | } |
65 | |
66 | static bool NullTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
67 | // cast a NULL to another type, just copy the properties and change the type |
68 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
69 | ConstantVector::SetNull(vector&: result, is_null: true); |
70 | return true; |
71 | } |
72 | |
73 | BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const LogicalType &source, |
74 | const LogicalType &target) { |
75 | D_ASSERT(source != target); |
76 | |
77 | // first check if were casting to a union |
78 | if (source.id() != LogicalTypeId::UNION && source.id() != LogicalTypeId::SQLNULL && |
79 | target.id() == LogicalTypeId::UNION) { |
80 | return ImplicitToUnionCast(input, source, target); |
81 | } |
82 | |
83 | // else, switch on source type |
84 | switch (source.id()) { |
85 | case LogicalTypeId::BOOLEAN: |
86 | case LogicalTypeId::TINYINT: |
87 | case LogicalTypeId::SMALLINT: |
88 | case LogicalTypeId::INTEGER: |
89 | case LogicalTypeId::BIGINT: |
90 | case LogicalTypeId::UTINYINT: |
91 | case LogicalTypeId::USMALLINT: |
92 | case LogicalTypeId::UINTEGER: |
93 | case LogicalTypeId::UBIGINT: |
94 | case LogicalTypeId::HUGEINT: |
95 | case LogicalTypeId::FLOAT: |
96 | case LogicalTypeId::DOUBLE: |
97 | return NumericCastSwitch(input, source, target); |
98 | case LogicalTypeId::POINTER: |
99 | return PointerCastSwitch(input, source, target); |
100 | case LogicalTypeId::UUID: |
101 | return UUIDCastSwitch(input, source, target); |
102 | case LogicalTypeId::DECIMAL: |
103 | return DecimalCastSwitch(input, source, target); |
104 | case LogicalTypeId::DATE: |
105 | return DateCastSwitch(input, source, target); |
106 | case LogicalTypeId::TIME: |
107 | return TimeCastSwitch(input, source, target); |
108 | case LogicalTypeId::TIME_TZ: |
109 | return TimeTzCastSwitch(input, source, target); |
110 | case LogicalTypeId::TIMESTAMP: |
111 | return TimestampCastSwitch(input, source, target); |
112 | case LogicalTypeId::TIMESTAMP_TZ: |
113 | return TimestampTzCastSwitch(input, source, target); |
114 | case LogicalTypeId::TIMESTAMP_NS: |
115 | return TimestampNsCastSwitch(input, source, target); |
116 | case LogicalTypeId::TIMESTAMP_MS: |
117 | return TimestampMsCastSwitch(input, source, target); |
118 | case LogicalTypeId::TIMESTAMP_SEC: |
119 | return TimestampSecCastSwitch(input, source, target); |
120 | case LogicalTypeId::INTERVAL: |
121 | return IntervalCastSwitch(input, source, target); |
122 | case LogicalTypeId::VARCHAR: |
123 | return StringCastSwitch(input, source, target); |
124 | case LogicalTypeId::BLOB: |
125 | return BlobCastSwitch(input, source, target); |
126 | case LogicalTypeId::BIT: |
127 | return BitCastSwitch(input, source, target); |
128 | case LogicalTypeId::SQLNULL: |
129 | return NullTypeCast; |
130 | case LogicalTypeId::MAP: |
131 | return MapCastSwitch(input, source, target); |
132 | case LogicalTypeId::STRUCT: |
133 | return StructCastSwitch(input, source, target); |
134 | case LogicalTypeId::LIST: |
135 | return ListCastSwitch(input, source, target); |
136 | case LogicalTypeId::UNION: |
137 | return UnionCastSwitch(input, source, target); |
138 | case LogicalTypeId::ENUM: |
139 | return EnumCastSwitch(input, source, target); |
140 | case LogicalTypeId::AGGREGATE_STATE: |
141 | return AggregateStateToBlobCast; |
142 | default: |
143 | return nullptr; |
144 | } |
145 | } |
146 | |
147 | } // namespace duckdb |
148 | |