1 | #include "duckdb/function/function_binder.hpp" |
2 | #include "duckdb/common/limits.hpp" |
3 | |
4 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
8 | #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" |
9 | |
10 | #include "duckdb/planner/expression_binder.hpp" |
11 | #include "duckdb/function/aggregate_function.hpp" |
12 | #include "duckdb/function/cast_rules.hpp" |
13 | #include "duckdb/catalog/catalog.hpp" |
14 | |
15 | namespace duckdb { |
16 | |
17 | FunctionBinder::FunctionBinder(ClientContext &context) : context(context) { |
18 | } |
19 | |
20 | int64_t FunctionBinder::BindVarArgsFunctionCost(const SimpleFunction &func, const vector<LogicalType> &arguments) { |
21 | if (arguments.size() < func.arguments.size()) { |
22 | // not enough arguments to fulfill the non-vararg part of the function |
23 | return -1; |
24 | } |
25 | int64_t cost = 0; |
26 | for (idx_t i = 0; i < arguments.size(); i++) { |
27 | LogicalType arg_type = i < func.arguments.size() ? func.arguments[i] : func.varargs; |
28 | if (arguments[i] == arg_type) { |
29 | // arguments match: do nothing |
30 | continue; |
31 | } |
32 | int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(source: arguments[i], target: arg_type); |
33 | if (cast_cost >= 0) { |
34 | // we can implicitly cast, add the cost to the total cost |
35 | cost += cast_cost; |
36 | } else { |
37 | // we can't implicitly cast: throw an error |
38 | return -1; |
39 | } |
40 | } |
41 | return cost; |
42 | } |
43 | |
44 | int64_t FunctionBinder::BindFunctionCost(const SimpleFunction &func, const vector<LogicalType> &arguments) { |
45 | if (func.HasVarArgs()) { |
46 | // special case varargs function |
47 | return BindVarArgsFunctionCost(func, arguments); |
48 | } |
49 | if (func.arguments.size() != arguments.size()) { |
50 | // invalid argument count: check the next function |
51 | return -1; |
52 | } |
53 | int64_t cost = 0; |
54 | for (idx_t i = 0; i < arguments.size(); i++) { |
55 | int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(source: arguments[i], target: func.arguments[i]); |
56 | if (cast_cost >= 0) { |
57 | // we can implicitly cast, add the cost to the total cost |
58 | cost += cast_cost; |
59 | } else { |
60 | // we can't implicitly cast: throw an error |
61 | return -1; |
62 | } |
63 | } |
64 | return cost; |
65 | } |
66 | |
67 | template <class T> |
68 | vector<idx_t> FunctionBinder::BindFunctionsFromArguments(const string &name, FunctionSet<T> &functions, |
69 | const vector<LogicalType> &arguments, string &error) { |
70 | idx_t best_function = DConstants::INVALID_INDEX; |
71 | int64_t lowest_cost = NumericLimits<int64_t>::Maximum(); |
72 | vector<idx_t> candidate_functions; |
73 | for (idx_t f_idx = 0; f_idx < functions.functions.size(); f_idx++) { |
74 | auto &func = functions.functions[f_idx]; |
75 | // check the arguments of the function |
76 | int64_t cost = BindFunctionCost(func, arguments); |
77 | if (cost < 0) { |
78 | // auto casting was not possible |
79 | continue; |
80 | } |
81 | if (cost == lowest_cost) { |
82 | candidate_functions.push_back(x: f_idx); |
83 | continue; |
84 | } |
85 | if (cost > lowest_cost) { |
86 | continue; |
87 | } |
88 | candidate_functions.clear(); |
89 | lowest_cost = cost; |
90 | best_function = f_idx; |
91 | } |
92 | if (best_function == DConstants::INVALID_INDEX) { |
93 | // no matching function was found, throw an error |
94 | string call_str = Function::CallToString(name, arguments); |
95 | string candidate_str = "" ; |
96 | for (auto &f : functions.functions) { |
97 | candidate_str += "\t" + f.ToString() + "\n" ; |
98 | } |
99 | error = StringUtil::Format(fmt_str: "No function matches the given name and argument types '%s'. You might need to add " |
100 | "explicit type casts.\n\tCandidate functions:\n%s" , |
101 | params: call_str, params: candidate_str); |
102 | return candidate_functions; |
103 | } |
104 | candidate_functions.push_back(x: best_function); |
105 | return candidate_functions; |
106 | } |
107 | |
108 | template <class T> |
109 | idx_t FunctionBinder::MultipleCandidateException(const string &name, FunctionSet<T> &functions, |
110 | vector<idx_t> &candidate_functions, |
111 | const vector<LogicalType> &arguments, string &error) { |
112 | D_ASSERT(functions.functions.size() > 1); |
113 | // there are multiple possible function definitions |
114 | // throw an exception explaining which overloads are there |
115 | string call_str = Function::CallToString(name, arguments); |
116 | string candidate_str = "" ; |
117 | for (auto &conf : candidate_functions) { |
118 | T f = functions.GetFunctionByOffset(conf); |
119 | candidate_str += "\t" + f.ToString() + "\n" ; |
120 | } |
121 | error = StringUtil::Format(fmt_str: "Could not choose a best candidate function for the function call \"%s\". In order to " |
122 | "select one, please add explicit type casts.\n\tCandidate functions:\n%s" , |
123 | params: call_str, params: candidate_str); |
124 | return DConstants::INVALID_INDEX; |
125 | } |
126 | |
127 | template <class T> |
128 | idx_t FunctionBinder::BindFunctionFromArguments(const string &name, FunctionSet<T> &functions, |
129 | const vector<LogicalType> &arguments, string &error) { |
130 | auto candidate_functions = BindFunctionsFromArguments<T>(name, functions, arguments, error); |
131 | if (candidate_functions.empty()) { |
132 | // no candidates |
133 | return DConstants::INVALID_INDEX; |
134 | } |
135 | if (candidate_functions.size() > 1) { |
136 | // multiple candidates, check if there are any unknown arguments |
137 | bool has_parameters = false; |
138 | for (auto &arg_type : arguments) { |
139 | if (arg_type.id() == LogicalTypeId::UNKNOWN) { |
140 | //! there are! we could not resolve parameters in this case |
141 | throw ParameterNotResolvedException(); |
142 | } |
143 | } |
144 | if (!has_parameters) { |
145 | return MultipleCandidateException(name, functions, candidate_functions, arguments, error); |
146 | } |
147 | } |
148 | return candidate_functions[0]; |
149 | } |
150 | |
151 | idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, |
152 | const vector<LogicalType> &arguments, string &error) { |
153 | return BindFunctionFromArguments(name, functions, arguments, error); |
154 | } |
155 | |
156 | idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, |
157 | const vector<LogicalType> &arguments, string &error) { |
158 | return BindFunctionFromArguments(name, functions, arguments, error); |
159 | } |
160 | |
161 | idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, |
162 | const vector<LogicalType> &arguments, string &error) { |
163 | return BindFunctionFromArguments(name, functions, arguments, error); |
164 | } |
165 | |
166 | idx_t FunctionBinder::BindFunction(const string &name, PragmaFunctionSet &functions, PragmaInfo &info, string &error) { |
167 | vector<LogicalType> types; |
168 | for (auto &value : info.parameters) { |
169 | types.push_back(x: value.type()); |
170 | } |
171 | idx_t entry = BindFunctionFromArguments(name, functions, arguments: types, error); |
172 | if (entry == DConstants::INVALID_INDEX) { |
173 | throw BinderException(error); |
174 | } |
175 | auto candidate_function = functions.GetFunctionByOffset(offset: entry); |
176 | // cast the input parameters |
177 | for (idx_t i = 0; i < info.parameters.size(); i++) { |
178 | auto target_type = |
179 | i < candidate_function.arguments.size() ? candidate_function.arguments[i] : candidate_function.varargs; |
180 | info.parameters[i] = info.parameters[i].CastAs(context, target_type); |
181 | } |
182 | return entry; |
183 | } |
184 | |
185 | vector<LogicalType> FunctionBinder::GetLogicalTypesFromExpressions(vector<unique_ptr<Expression>> &arguments) { |
186 | vector<LogicalType> types; |
187 | types.reserve(n: arguments.size()); |
188 | for (auto &argument : arguments) { |
189 | types.push_back(x: argument->return_type); |
190 | } |
191 | return types; |
192 | } |
193 | |
194 | idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, |
195 | vector<unique_ptr<Expression>> &arguments, string &error) { |
196 | auto types = GetLogicalTypesFromExpressions(arguments); |
197 | return BindFunction(name, functions, arguments: types, error); |
198 | } |
199 | |
200 | idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, |
201 | vector<unique_ptr<Expression>> &arguments, string &error) { |
202 | auto types = GetLogicalTypesFromExpressions(arguments); |
203 | return BindFunction(name, functions, arguments: types, error); |
204 | } |
205 | |
206 | idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, |
207 | vector<unique_ptr<Expression>> &arguments, string &error) { |
208 | auto types = GetLogicalTypesFromExpressions(arguments); |
209 | return BindFunction(name, functions, arguments: types, error); |
210 | } |
211 | |
212 | enum class LogicalTypeComparisonResult { IDENTICAL_TYPE, TARGET_IS_ANY, DIFFERENT_TYPES }; |
213 | |
214 | LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const LogicalType &target_type) { |
215 | if (target_type.id() == LogicalTypeId::ANY) { |
216 | return LogicalTypeComparisonResult::TARGET_IS_ANY; |
217 | } |
218 | if (source_type == target_type) { |
219 | return LogicalTypeComparisonResult::IDENTICAL_TYPE; |
220 | } |
221 | if (source_type.id() == LogicalTypeId::LIST && target_type.id() == LogicalTypeId::LIST) { |
222 | return RequiresCast(source_type: ListType::GetChildType(type: source_type), target_type: ListType::GetChildType(type: target_type)); |
223 | } |
224 | return LogicalTypeComparisonResult::DIFFERENT_TYPES; |
225 | } |
226 | |
227 | void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector<unique_ptr<Expression>> &children) { |
228 | for (idx_t i = 0; i < children.size(); i++) { |
229 | auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs; |
230 | target_type.Verify(); |
231 | // don't cast lambda children, they get removed anyways |
232 | if (children[i]->return_type.id() == LogicalTypeId::LAMBDA) { |
233 | continue; |
234 | } |
235 | // check if the type of child matches the type of function argument |
236 | // if not we need to add a cast |
237 | auto cast_result = RequiresCast(source_type: children[i]->return_type, target_type); |
238 | // except for one special case: if the function accepts ANY argument |
239 | // in that case we don't add a cast |
240 | if (cast_result == LogicalTypeComparisonResult::DIFFERENT_TYPES) { |
241 | children[i] = BoundCastExpression::AddCastToType(context, expr: std::move(children[i]), target_type); |
242 | } |
243 | } |
244 | } |
245 | |
246 | unique_ptr<Expression> FunctionBinder::BindScalarFunction(const string &schema, const string &name, |
247 | vector<unique_ptr<Expression>> children, string &error, |
248 | bool is_operator, Binder *binder) { |
249 | // bind the function |
250 | auto &function = |
251 | Catalog::GetSystemCatalog(context).GetEntry(context, type: CatalogType::SCALAR_FUNCTION_ENTRY, schema, name); |
252 | D_ASSERT(function.type == CatalogType::SCALAR_FUNCTION_ENTRY); |
253 | return BindScalarFunction(function&: function.Cast<ScalarFunctionCatalogEntry>(), children: std::move(children), error, is_operator, |
254 | binder); |
255 | } |
256 | |
257 | unique_ptr<Expression> FunctionBinder::BindScalarFunction(ScalarFunctionCatalogEntry &func, |
258 | vector<unique_ptr<Expression>> children, string &error, |
259 | bool is_operator, Binder *binder) { |
260 | // bind the function |
261 | idx_t best_function = BindFunction(name: func.name, functions&: func.functions, arguments&: children, error); |
262 | if (best_function == DConstants::INVALID_INDEX) { |
263 | return nullptr; |
264 | } |
265 | |
266 | // found a matching function! |
267 | auto bound_function = func.functions.GetFunctionByOffset(offset: best_function); |
268 | |
269 | if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { |
270 | for (auto &child : children) { |
271 | if (child->return_type == LogicalTypeId::SQLNULL) { |
272 | return make_uniq<BoundConstantExpression>(args: Value(LogicalType::SQLNULL)); |
273 | } |
274 | } |
275 | } |
276 | return BindScalarFunction(bound_function, children: std::move(children), is_operator); |
277 | } |
278 | |
279 | unique_ptr<BoundFunctionExpression> FunctionBinder::BindScalarFunction(ScalarFunction bound_function, |
280 | vector<unique_ptr<Expression>> children, |
281 | bool is_operator) { |
282 | unique_ptr<FunctionData> bind_info; |
283 | if (bound_function.bind) { |
284 | bind_info = bound_function.bind(context, bound_function, children); |
285 | } |
286 | // check if we need to add casts to the children |
287 | CastToFunctionArguments(function&: bound_function, children); |
288 | |
289 | // now create the function |
290 | auto return_type = bound_function.return_type; |
291 | return make_uniq<BoundFunctionExpression>(args: std::move(return_type), args: std::move(bound_function), args: std::move(children), |
292 | args: std::move(bind_info), args&: is_operator); |
293 | } |
294 | |
295 | unique_ptr<BoundAggregateExpression> FunctionBinder::BindAggregateFunction(AggregateFunction bound_function, |
296 | vector<unique_ptr<Expression>> children, |
297 | unique_ptr<Expression> filter, |
298 | AggregateType aggr_type) { |
299 | unique_ptr<FunctionData> bind_info; |
300 | if (bound_function.bind) { |
301 | bind_info = bound_function.bind(context, bound_function, children); |
302 | // we may have lost some arguments in the bind |
303 | children.resize(new_size: MinValue(a: bound_function.arguments.size(), b: children.size())); |
304 | } |
305 | |
306 | // check if we need to add casts to the children |
307 | CastToFunctionArguments(function&: bound_function, children); |
308 | |
309 | return make_uniq<BoundAggregateExpression>(args: std::move(bound_function), args: std::move(children), args: std::move(filter), |
310 | args: std::move(bind_info), args&: aggr_type); |
311 | } |
312 | |
313 | } // namespace duckdb |
314 | |