| 1 | #include "duckdb/function/cast/default_casts.hpp" |
| 2 | |
| 3 | #include "duckdb/common/likely.hpp" |
| 4 | #include "duckdb/common/limits.hpp" |
| 5 | #include "duckdb/common/operator/cast_operators.hpp" |
| 6 | #include "duckdb/common/string_util.hpp" |
| 7 | #include "duckdb/common/types/cast_helpers.hpp" |
| 8 | #include "duckdb/common/types/chunk_collection.hpp" |
| 9 | #include "duckdb/common/types/null_value.hpp" |
| 10 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 11 | #include "duckdb/function/cast/vector_cast_helpers.hpp" |
| 12 | |
| 13 | namespace duckdb { |
| 14 | |
| 15 | BindCastInfo::~BindCastInfo() { |
| 16 | } |
| 17 | |
| 18 | BoundCastData::~BoundCastData() { |
| 19 | } |
| 20 | |
| 21 | BoundCastInfo::BoundCastInfo(cast_function_t function_p, unique_ptr<BoundCastData> cast_data_p, |
| 22 | init_cast_local_state_t init_local_state_p) |
| 23 | : function(function_p), init_local_state(init_local_state_p), cast_data(std::move(cast_data_p)) { |
| 24 | } |
| 25 | |
| 26 | BoundCastInfo BoundCastInfo::Copy() const { |
| 27 | return BoundCastInfo(function, cast_data ? cast_data->Copy() : nullptr, init_local_state); |
| 28 | } |
| 29 | |
| 30 | bool DefaultCasts::NopCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 31 | result.Reference(other&: source); |
| 32 | return true; |
| 33 | } |
| 34 | |
| 35 | static string UnimplementedCastMessage(const LogicalType &source_type, const LogicalType &target_type) { |
| 36 | return StringUtil::Format(fmt_str: "Unimplemented type for cast (%s -> %s)" , params: source_type.ToString(), params: target_type.ToString()); |
| 37 | } |
| 38 | |
| 39 | // NULL cast only works if all values in source are NULL, otherwise an unimplemented cast exception is thrown |
| 40 | bool DefaultCasts::TryVectorNullCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 41 | bool success = true; |
| 42 | if (VectorOperations::HasNotNull(input&: source, count)) { |
| 43 | HandleCastError::AssignError(error_message: UnimplementedCastMessage(source_type: source.GetType(), target_type: result.GetType()), |
| 44 | error_message_ptr: parameters.error_message); |
| 45 | success = false; |
| 46 | } |
| 47 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 48 | ConstantVector::SetNull(vector&: result, is_null: true); |
| 49 | return success; |
| 50 | } |
| 51 | |
| 52 | bool DefaultCasts::ReinterpretCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 53 | result.Reinterpret(other&: source); |
| 54 | return true; |
| 55 | } |
| 56 | |
| 57 | static bool AggregateStateToBlobCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 58 | if (result.GetType().id() != LogicalTypeId::BLOB) { |
| 59 | throw TypeMismatchException(source.GetType(), result.GetType(), |
| 60 | "Cannot cast AGGREGATE_STATE to anything but BLOB" ); |
| 61 | } |
| 62 | result.Reinterpret(other&: source); |
| 63 | return true; |
| 64 | } |
| 65 | |
| 66 | static bool NullTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 67 | // cast a NULL to another type, just copy the properties and change the type |
| 68 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 69 | ConstantVector::SetNull(vector&: result, is_null: true); |
| 70 | return true; |
| 71 | } |
| 72 | |
| 73 | BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const LogicalType &source, |
| 74 | const LogicalType &target) { |
| 75 | D_ASSERT(source != target); |
| 76 | |
| 77 | // first check if were casting to a union |
| 78 | if (source.id() != LogicalTypeId::UNION && source.id() != LogicalTypeId::SQLNULL && |
| 79 | target.id() == LogicalTypeId::UNION) { |
| 80 | return ImplicitToUnionCast(input, source, target); |
| 81 | } |
| 82 | |
| 83 | // else, switch on source type |
| 84 | switch (source.id()) { |
| 85 | case LogicalTypeId::BOOLEAN: |
| 86 | case LogicalTypeId::TINYINT: |
| 87 | case LogicalTypeId::SMALLINT: |
| 88 | case LogicalTypeId::INTEGER: |
| 89 | case LogicalTypeId::BIGINT: |
| 90 | case LogicalTypeId::UTINYINT: |
| 91 | case LogicalTypeId::USMALLINT: |
| 92 | case LogicalTypeId::UINTEGER: |
| 93 | case LogicalTypeId::UBIGINT: |
| 94 | case LogicalTypeId::HUGEINT: |
| 95 | case LogicalTypeId::FLOAT: |
| 96 | case LogicalTypeId::DOUBLE: |
| 97 | return NumericCastSwitch(input, source, target); |
| 98 | case LogicalTypeId::POINTER: |
| 99 | return PointerCastSwitch(input, source, target); |
| 100 | case LogicalTypeId::UUID: |
| 101 | return UUIDCastSwitch(input, source, target); |
| 102 | case LogicalTypeId::DECIMAL: |
| 103 | return DecimalCastSwitch(input, source, target); |
| 104 | case LogicalTypeId::DATE: |
| 105 | return DateCastSwitch(input, source, target); |
| 106 | case LogicalTypeId::TIME: |
| 107 | return TimeCastSwitch(input, source, target); |
| 108 | case LogicalTypeId::TIME_TZ: |
| 109 | return TimeTzCastSwitch(input, source, target); |
| 110 | case LogicalTypeId::TIMESTAMP: |
| 111 | return TimestampCastSwitch(input, source, target); |
| 112 | case LogicalTypeId::TIMESTAMP_TZ: |
| 113 | return TimestampTzCastSwitch(input, source, target); |
| 114 | case LogicalTypeId::TIMESTAMP_NS: |
| 115 | return TimestampNsCastSwitch(input, source, target); |
| 116 | case LogicalTypeId::TIMESTAMP_MS: |
| 117 | return TimestampMsCastSwitch(input, source, target); |
| 118 | case LogicalTypeId::TIMESTAMP_SEC: |
| 119 | return TimestampSecCastSwitch(input, source, target); |
| 120 | case LogicalTypeId::INTERVAL: |
| 121 | return IntervalCastSwitch(input, source, target); |
| 122 | case LogicalTypeId::VARCHAR: |
| 123 | return StringCastSwitch(input, source, target); |
| 124 | case LogicalTypeId::BLOB: |
| 125 | return BlobCastSwitch(input, source, target); |
| 126 | case LogicalTypeId::BIT: |
| 127 | return BitCastSwitch(input, source, target); |
| 128 | case LogicalTypeId::SQLNULL: |
| 129 | return NullTypeCast; |
| 130 | case LogicalTypeId::MAP: |
| 131 | return MapCastSwitch(input, source, target); |
| 132 | case LogicalTypeId::STRUCT: |
| 133 | return StructCastSwitch(input, source, target); |
| 134 | case LogicalTypeId::LIST: |
| 135 | return ListCastSwitch(input, source, target); |
| 136 | case LogicalTypeId::UNION: |
| 137 | return UnionCastSwitch(input, source, target); |
| 138 | case LogicalTypeId::ENUM: |
| 139 | return EnumCastSwitch(input, source, target); |
| 140 | case LogicalTypeId::AGGREGATE_STATE: |
| 141 | return AggregateStateToBlobCast; |
| 142 | default: |
| 143 | return nullptr; |
| 144 | } |
| 145 | } |
| 146 | |
| 147 | } // namespace duckdb |
| 148 | |