1 | #include "duckdb/common/field_writer.hpp" |
2 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
3 | #include "duckdb/planner/expression/bound_default_expression.hpp" |
4 | #include "duckdb/planner/expression/bound_parameter_expression.hpp" |
5 | #include "duckdb/function/cast_rules.hpp" |
6 | #include "duckdb/function/cast/cast_function_set.hpp" |
7 | #include "duckdb/main/config.hpp" |
8 | |
9 | namespace duckdb { |
10 | |
11 | BoundCastExpression::BoundCastExpression(unique_ptr<Expression> child_p, LogicalType target_type_p, |
12 | BoundCastInfo bound_cast_p, bool try_cast_p) |
13 | : Expression(ExpressionType::OPERATOR_CAST, ExpressionClass::BOUND_CAST, std::move(target_type_p)), |
14 | child(std::move(child_p)), try_cast(try_cast_p), bound_cast(std::move(bound_cast_p)) { |
15 | } |
16 | |
17 | unique_ptr<Expression> AddCastExpressionInternal(unique_ptr<Expression> expr, const LogicalType &target_type, |
18 | BoundCastInfo bound_cast, bool try_cast) { |
19 | if (expr->return_type == target_type) { |
20 | return expr; |
21 | } |
22 | auto &expr_type = expr->return_type; |
23 | if (target_type.id() == LogicalTypeId::LIST && expr_type.id() == LogicalTypeId::LIST) { |
24 | auto &target_list = ListType::GetChildType(type: target_type); |
25 | auto &expr_list = ListType::GetChildType(type: expr_type); |
26 | if (target_list.id() == LogicalTypeId::ANY || expr_list == target_list) { |
27 | return expr; |
28 | } |
29 | } |
30 | return make_uniq<BoundCastExpression>(args: std::move(expr), args: target_type, args: std::move(bound_cast), args&: try_cast); |
31 | } |
32 | |
33 | static BoundCastInfo BindCastFunction(ClientContext &context, const LogicalType &source, const LogicalType &target) { |
34 | auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); |
35 | GetCastFunctionInput input(context); |
36 | return cast_functions.GetCastFunction(source, target, input); |
37 | } |
38 | |
39 | unique_ptr<Expression> AddCastToTypeInternal(unique_ptr<Expression> expr, const LogicalType &target_type, |
40 | CastFunctionSet &cast_functions, GetCastFunctionInput &get_input, |
41 | bool try_cast) { |
42 | D_ASSERT(expr); |
43 | if (expr->expression_class == ExpressionClass::BOUND_PARAMETER) { |
44 | auto ¶meter = expr->Cast<BoundParameterExpression>(); |
45 | if (!target_type.IsValid()) { |
46 | // invalidate the parameter |
47 | parameter.parameter_data->return_type = LogicalType::INVALID; |
48 | parameter.return_type = target_type; |
49 | return expr; |
50 | } |
51 | if (parameter.parameter_data->return_type.id() == LogicalTypeId::INVALID) { |
52 | // we don't know the type of this parameter |
53 | parameter.return_type = target_type; |
54 | return expr; |
55 | } |
56 | if (parameter.parameter_data->return_type.id() == LogicalTypeId::UNKNOWN) { |
57 | // prepared statement parameter cast - but there is no type, convert the type |
58 | parameter.parameter_data->return_type = target_type; |
59 | parameter.return_type = target_type; |
60 | return expr; |
61 | } |
62 | // prepared statement parameter already has a type |
63 | if (parameter.parameter_data->return_type == target_type) { |
64 | // this type! we are done |
65 | parameter.return_type = parameter.parameter_data->return_type; |
66 | return expr; |
67 | } |
68 | // invalidate the type |
69 | parameter.parameter_data->return_type = LogicalType::INVALID; |
70 | parameter.return_type = target_type; |
71 | return expr; |
72 | } else if (expr->expression_class == ExpressionClass::BOUND_DEFAULT) { |
73 | D_ASSERT(target_type.IsValid()); |
74 | auto &def = expr->Cast<BoundDefaultExpression>(); |
75 | def.return_type = target_type; |
76 | } |
77 | if (!target_type.IsValid()) { |
78 | return expr; |
79 | } |
80 | |
81 | auto cast_function = cast_functions.GetCastFunction(source: expr->return_type, target: target_type, input&: get_input); |
82 | return AddCastExpressionInternal(expr: std::move(expr), target_type, bound_cast: std::move(cast_function), try_cast); |
83 | } |
84 | |
85 | unique_ptr<Expression> BoundCastExpression::AddDefaultCastToType(unique_ptr<Expression> expr, |
86 | const LogicalType &target_type, bool try_cast) { |
87 | CastFunctionSet default_set; |
88 | GetCastFunctionInput get_input; |
89 | return AddCastToTypeInternal(expr: std::move(expr), target_type, cast_functions&: default_set, get_input, try_cast); |
90 | } |
91 | |
92 | unique_ptr<Expression> BoundCastExpression::AddCastToType(ClientContext &context, unique_ptr<Expression> expr, |
93 | const LogicalType &target_type, bool try_cast) { |
94 | auto &cast_functions = DBConfig::GetConfig(context).GetCastFunctions(); |
95 | GetCastFunctionInput get_input(context); |
96 | return AddCastToTypeInternal(expr: std::move(expr), target_type, cast_functions, get_input, try_cast); |
97 | } |
98 | |
99 | bool BoundCastExpression::CastIsInvertible(const LogicalType &source_type, const LogicalType &target_type) { |
100 | D_ASSERT(source_type.IsValid() && target_type.IsValid()); |
101 | if (source_type.id() == LogicalTypeId::BOOLEAN || target_type.id() == LogicalTypeId::BOOLEAN) { |
102 | return false; |
103 | } |
104 | if (source_type.id() == LogicalTypeId::FLOAT || target_type.id() == LogicalTypeId::FLOAT) { |
105 | return false; |
106 | } |
107 | if (source_type.id() == LogicalTypeId::DOUBLE || target_type.id() == LogicalTypeId::DOUBLE) { |
108 | return false; |
109 | } |
110 | if (source_type.id() == LogicalTypeId::DECIMAL || target_type.id() == LogicalTypeId::DECIMAL) { |
111 | uint8_t source_width, target_width; |
112 | uint8_t source_scale, target_scale; |
113 | // cast to or from decimal |
114 | // cast is only invertible if the cast is strictly widening |
115 | if (!source_type.GetDecimalProperties(width&: source_width, scale&: source_scale)) { |
116 | return false; |
117 | } |
118 | if (!target_type.GetDecimalProperties(width&: target_width, scale&: target_scale)) { |
119 | return false; |
120 | } |
121 | if (target_scale < source_scale) { |
122 | return false; |
123 | } |
124 | return true; |
125 | } |
126 | if (source_type.id() == LogicalTypeId::TIMESTAMP || source_type.id() == LogicalTypeId::TIMESTAMP_TZ) { |
127 | switch (target_type.id()) { |
128 | case LogicalTypeId::DATE: |
129 | case LogicalTypeId::TIME: |
130 | case LogicalTypeId::TIME_TZ: |
131 | return false; |
132 | default: |
133 | break; |
134 | } |
135 | } |
136 | if (source_type.id() == LogicalTypeId::VARCHAR) { |
137 | switch (target_type.id()) { |
138 | case LogicalTypeId::TIME: |
139 | case LogicalTypeId::TIMESTAMP: |
140 | case LogicalTypeId::TIMESTAMP_NS: |
141 | case LogicalTypeId::TIMESTAMP_MS: |
142 | case LogicalTypeId::TIMESTAMP_SEC: |
143 | case LogicalTypeId::TIME_TZ: |
144 | case LogicalTypeId::TIMESTAMP_TZ: |
145 | return true; |
146 | default: |
147 | return false; |
148 | } |
149 | } |
150 | if (target_type.id() == LogicalTypeId::VARCHAR) { |
151 | switch (source_type.id()) { |
152 | case LogicalTypeId::DATE: |
153 | case LogicalTypeId::TIME: |
154 | case LogicalTypeId::TIMESTAMP: |
155 | case LogicalTypeId::TIMESTAMP_NS: |
156 | case LogicalTypeId::TIMESTAMP_MS: |
157 | case LogicalTypeId::TIMESTAMP_SEC: |
158 | case LogicalTypeId::TIME_TZ: |
159 | case LogicalTypeId::TIMESTAMP_TZ: |
160 | return true; |
161 | default: |
162 | return false; |
163 | } |
164 | } |
165 | return true; |
166 | } |
167 | |
168 | string BoundCastExpression::ToString() const { |
169 | return (try_cast ? "TRY_CAST(" : "CAST(" ) + child->GetName() + " AS " + return_type.ToString() + ")" ; |
170 | } |
171 | |
172 | bool BoundCastExpression::Equals(const BaseExpression &other_p) const { |
173 | if (!Expression::Equals(other: other_p)) { |
174 | return false; |
175 | } |
176 | auto &other = other_p.Cast<BoundCastExpression>(); |
177 | if (!Expression::Equals(left: *child, right: *other.child)) { |
178 | return false; |
179 | } |
180 | if (try_cast != other.try_cast) { |
181 | return false; |
182 | } |
183 | return true; |
184 | } |
185 | |
186 | unique_ptr<Expression> BoundCastExpression::Copy() { |
187 | auto copy = make_uniq<BoundCastExpression>(args: child->Copy(), args&: return_type, args: bound_cast.Copy(), args&: try_cast); |
188 | copy->CopyProperties(other&: *this); |
189 | return std::move(copy); |
190 | } |
191 | |
192 | void BoundCastExpression::Serialize(FieldWriter &writer) const { |
193 | writer.WriteSerializable(element: *child); |
194 | writer.WriteSerializable(element: return_type); |
195 | writer.WriteField(element: try_cast); |
196 | } |
197 | |
198 | unique_ptr<Expression> BoundCastExpression::Deserialize(ExpressionDeserializationState &state, FieldReader &reader) { |
199 | auto child = reader.ReadRequiredSerializable<Expression>(args&: state.gstate); |
200 | auto target_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>(); |
201 | auto try_cast = reader.ReadRequired<bool>(); |
202 | auto cast_function = BindCastFunction(context&: state.gstate.context, source: child->return_type, target: target_type); |
203 | return make_uniq<BoundCastExpression>(args: std::move(child), args: std::move(target_type), args: std::move(cast_function), args&: try_cast); |
204 | } |
205 | |
206 | } // namespace duckdb |
207 | |