1#include "duckdb/common/field_writer.hpp"
2#include "duckdb/planner/expression/bound_cast_expression.hpp"
3#include "duckdb/planner/expression/bound_default_expression.hpp"
4#include "duckdb/planner/expression/bound_parameter_expression.hpp"
5#include "duckdb/function/cast_rules.hpp"
6#include "duckdb/function/cast/cast_function_set.hpp"
7#include "duckdb/main/config.hpp"
8
9namespace duckdb {
10
11BoundCastExpression::BoundCastExpression(unique_ptr<Expression> child_p, LogicalType target_type_p,
12 BoundCastInfo bound_cast_p, bool try_cast_p)
13 : Expression(ExpressionType::OPERATOR_CAST, ExpressionClass::BOUND_CAST, std::move(target_type_p)),
14 child(std::move(child_p)), try_cast(try_cast_p), bound_cast(std::move(bound_cast_p)) {
15}
16
17unique_ptr<Expression> AddCastExpressionInternal(unique_ptr<Expression> expr, const LogicalType &target_type,
18 BoundCastInfo bound_cast, bool try_cast) {
19 if (expr->return_type == target_type) {
20 return expr;
21 }
22 auto &expr_type = expr->return_type;
23 if (target_type.id() == LogicalTypeId::LIST && expr_type.id() == LogicalTypeId::LIST) {
24 auto &target_list = ListType::GetChildType(type: target_type);
25 auto &expr_list = ListType::GetChildType(type: expr_type);
26 if (target_list.id() == LogicalTypeId::ANY || expr_list == target_list) {
27 return expr;
28 }
29 }
30 return make_uniq<BoundCastExpression>(args: std::move(expr), args: target_type, args: std::move(bound_cast), args&: try_cast);
31}
32
33static BoundCastInfo BindCastFunction(ClientContext &context, const LogicalType &source, const LogicalType &target) {
34 auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions();
35 GetCastFunctionInput input(context);
36 return cast_functions.GetCastFunction(source, target, input);
37}
38
39unique_ptr<Expression> AddCastToTypeInternal(unique_ptr<Expression> expr, const LogicalType &target_type,
40 CastFunctionSet &cast_functions, GetCastFunctionInput &get_input,
41 bool try_cast) {
42 D_ASSERT(expr);
43 if (expr->expression_class == ExpressionClass::BOUND_PARAMETER) {
44 auto &parameter = expr->Cast<BoundParameterExpression>();
45 if (!target_type.IsValid()) {
46 // invalidate the parameter
47 parameter.parameter_data->return_type = LogicalType::INVALID;
48 parameter.return_type = target_type;
49 return expr;
50 }
51 if (parameter.parameter_data->return_type.id() == LogicalTypeId::INVALID) {
52 // we don't know the type of this parameter
53 parameter.return_type = target_type;
54 return expr;
55 }
56 if (parameter.parameter_data->return_type.id() == LogicalTypeId::UNKNOWN) {
57 // prepared statement parameter cast - but there is no type, convert the type
58 parameter.parameter_data->return_type = target_type;
59 parameter.return_type = target_type;
60 return expr;
61 }
62 // prepared statement parameter already has a type
63 if (parameter.parameter_data->return_type == target_type) {
64 // this type! we are done
65 parameter.return_type = parameter.parameter_data->return_type;
66 return expr;
67 }
68 // invalidate the type
69 parameter.parameter_data->return_type = LogicalType::INVALID;
70 parameter.return_type = target_type;
71 return expr;
72 } else if (expr->expression_class == ExpressionClass::BOUND_DEFAULT) {
73 D_ASSERT(target_type.IsValid());
74 auto &def = expr->Cast<BoundDefaultExpression>();
75 def.return_type = target_type;
76 }
77 if (!target_type.IsValid()) {
78 return expr;
79 }
80
81 auto cast_function = cast_functions.GetCastFunction(source: expr->return_type, target: target_type, input&: get_input);
82 return AddCastExpressionInternal(expr: std::move(expr), target_type, bound_cast: std::move(cast_function), try_cast);
83}
84
85unique_ptr<Expression> BoundCastExpression::AddDefaultCastToType(unique_ptr<Expression> expr,
86 const LogicalType &target_type, bool try_cast) {
87 CastFunctionSet default_set;
88 GetCastFunctionInput get_input;
89 return AddCastToTypeInternal(expr: std::move(expr), target_type, cast_functions&: default_set, get_input, try_cast);
90}
91
92unique_ptr<Expression> BoundCastExpression::AddCastToType(ClientContext &context, unique_ptr<Expression> expr,
93 const LogicalType &target_type, bool try_cast) {
94 auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions();
95 GetCastFunctionInput get_input(context);
96 return AddCastToTypeInternal(expr: std::move(expr), target_type, cast_functions, get_input, try_cast);
97}
98
99bool BoundCastExpression::CastIsInvertible(const LogicalType &source_type, const LogicalType &target_type) {
100 D_ASSERT(source_type.IsValid() && target_type.IsValid());
101 if (source_type.id() == LogicalTypeId::BOOLEAN || target_type.id() == LogicalTypeId::BOOLEAN) {
102 return false;
103 }
104 if (source_type.id() == LogicalTypeId::FLOAT || target_type.id() == LogicalTypeId::FLOAT) {
105 return false;
106 }
107 if (source_type.id() == LogicalTypeId::DOUBLE || target_type.id() == LogicalTypeId::DOUBLE) {
108 return false;
109 }
110 if (source_type.id() == LogicalTypeId::DECIMAL || target_type.id() == LogicalTypeId::DECIMAL) {
111 uint8_t source_width, target_width;
112 uint8_t source_scale, target_scale;
113 // cast to or from decimal
114 // cast is only invertible if the cast is strictly widening
115 if (!source_type.GetDecimalProperties(width&: source_width, scale&: source_scale)) {
116 return false;
117 }
118 if (!target_type.GetDecimalProperties(width&: target_width, scale&: target_scale)) {
119 return false;
120 }
121 if (target_scale < source_scale) {
122 return false;
123 }
124 return true;
125 }
126 if (source_type.id() == LogicalTypeId::TIMESTAMP || source_type.id() == LogicalTypeId::TIMESTAMP_TZ) {
127 switch (target_type.id()) {
128 case LogicalTypeId::DATE:
129 case LogicalTypeId::TIME:
130 case LogicalTypeId::TIME_TZ:
131 return false;
132 default:
133 break;
134 }
135 }
136 if (source_type.id() == LogicalTypeId::VARCHAR) {
137 switch (target_type.id()) {
138 case LogicalTypeId::TIME:
139 case LogicalTypeId::TIMESTAMP:
140 case LogicalTypeId::TIMESTAMP_NS:
141 case LogicalTypeId::TIMESTAMP_MS:
142 case LogicalTypeId::TIMESTAMP_SEC:
143 case LogicalTypeId::TIME_TZ:
144 case LogicalTypeId::TIMESTAMP_TZ:
145 return true;
146 default:
147 return false;
148 }
149 }
150 if (target_type.id() == LogicalTypeId::VARCHAR) {
151 switch (source_type.id()) {
152 case LogicalTypeId::DATE:
153 case LogicalTypeId::TIME:
154 case LogicalTypeId::TIMESTAMP:
155 case LogicalTypeId::TIMESTAMP_NS:
156 case LogicalTypeId::TIMESTAMP_MS:
157 case LogicalTypeId::TIMESTAMP_SEC:
158 case LogicalTypeId::TIME_TZ:
159 case LogicalTypeId::TIMESTAMP_TZ:
160 return true;
161 default:
162 return false;
163 }
164 }
165 return true;
166}
167
168string BoundCastExpression::ToString() const {
169 return (try_cast ? "TRY_CAST(" : "CAST(") + child->GetName() + " AS " + return_type.ToString() + ")";
170}
171
172bool BoundCastExpression::Equals(const BaseExpression &other_p) const {
173 if (!Expression::Equals(other: other_p)) {
174 return false;
175 }
176 auto &other = other_p.Cast<BoundCastExpression>();
177 if (!Expression::Equals(left: *child, right: *other.child)) {
178 return false;
179 }
180 if (try_cast != other.try_cast) {
181 return false;
182 }
183 return true;
184}
185
186unique_ptr<Expression> BoundCastExpression::Copy() {
187 auto copy = make_uniq<BoundCastExpression>(args: child->Copy(), args&: return_type, args: bound_cast.Copy(), args&: try_cast);
188 copy->CopyProperties(other&: *this);
189 return std::move(copy);
190}
191
192void BoundCastExpression::Serialize(FieldWriter &writer) const {
193 writer.WriteSerializable(element: *child);
194 writer.WriteSerializable(element: return_type);
195 writer.WriteField(element: try_cast);
196}
197
198unique_ptr<Expression> BoundCastExpression::Deserialize(ExpressionDeserializationState &state, FieldReader &reader) {
199 auto child = reader.ReadRequiredSerializable<Expression>(args&: state.gstate);
200 auto target_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>();
201 auto try_cast = reader.ReadRequired<bool>();
202 auto cast_function = BindCastFunction(context&: state.gstate.context, source: child->return_type, target: target_type);
203 return make_uniq<BoundCastExpression>(args: std::move(child), args: std::move(target_type), args: std::move(cast_function), args&: try_cast);
204}
205
206} // namespace duckdb
207