1#include "duckdb/function/function_binder.hpp"
2#include "duckdb/common/limits.hpp"
3
4#include "duckdb/planner/expression/bound_cast_expression.hpp"
5#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
6#include "duckdb/planner/expression/bound_function_expression.hpp"
7#include "duckdb/planner/expression/bound_constant_expression.hpp"
8#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
9
10#include "duckdb/planner/expression_binder.hpp"
11#include "duckdb/function/aggregate_function.hpp"
12#include "duckdb/function/cast_rules.hpp"
13#include "duckdb/catalog/catalog.hpp"
14
15namespace duckdb {
16
17FunctionBinder::FunctionBinder(ClientContext &context) : context(context) {
18}
19
20int64_t FunctionBinder::BindVarArgsFunctionCost(const SimpleFunction &func, const vector<LogicalType> &arguments) {
21 if (arguments.size() < func.arguments.size()) {
22 // not enough arguments to fulfill the non-vararg part of the function
23 return -1;
24 }
25 int64_t cost = 0;
26 for (idx_t i = 0; i < arguments.size(); i++) {
27 LogicalType arg_type = i < func.arguments.size() ? func.arguments[i] : func.varargs;
28 if (arguments[i] == arg_type) {
29 // arguments match: do nothing
30 continue;
31 }
32 int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(source: arguments[i], target: arg_type);
33 if (cast_cost >= 0) {
34 // we can implicitly cast, add the cost to the total cost
35 cost += cast_cost;
36 } else {
37 // we can't implicitly cast: throw an error
38 return -1;
39 }
40 }
41 return cost;
42}
43
44int64_t FunctionBinder::BindFunctionCost(const SimpleFunction &func, const vector<LogicalType> &arguments) {
45 if (func.HasVarArgs()) {
46 // special case varargs function
47 return BindVarArgsFunctionCost(func, arguments);
48 }
49 if (func.arguments.size() != arguments.size()) {
50 // invalid argument count: check the next function
51 return -1;
52 }
53 int64_t cost = 0;
54 for (idx_t i = 0; i < arguments.size(); i++) {
55 int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(source: arguments[i], target: func.arguments[i]);
56 if (cast_cost >= 0) {
57 // we can implicitly cast, add the cost to the total cost
58 cost += cast_cost;
59 } else {
60 // we can't implicitly cast: throw an error
61 return -1;
62 }
63 }
64 return cost;
65}
66
67template <class T>
68vector<idx_t> FunctionBinder::BindFunctionsFromArguments(const string &name, FunctionSet<T> &functions,
69 const vector<LogicalType> &arguments, string &error) {
70 idx_t best_function = DConstants::INVALID_INDEX;
71 int64_t lowest_cost = NumericLimits<int64_t>::Maximum();
72 vector<idx_t> candidate_functions;
73 for (idx_t f_idx = 0; f_idx < functions.functions.size(); f_idx++) {
74 auto &func = functions.functions[f_idx];
75 // check the arguments of the function
76 int64_t cost = BindFunctionCost(func, arguments);
77 if (cost < 0) {
78 // auto casting was not possible
79 continue;
80 }
81 if (cost == lowest_cost) {
82 candidate_functions.push_back(x: f_idx);
83 continue;
84 }
85 if (cost > lowest_cost) {
86 continue;
87 }
88 candidate_functions.clear();
89 lowest_cost = cost;
90 best_function = f_idx;
91 }
92 if (best_function == DConstants::INVALID_INDEX) {
93 // no matching function was found, throw an error
94 string call_str = Function::CallToString(name, arguments);
95 string candidate_str = "";
96 for (auto &f : functions.functions) {
97 candidate_str += "\t" + f.ToString() + "\n";
98 }
99 error = StringUtil::Format(fmt_str: "No function matches the given name and argument types '%s'. You might need to add "
100 "explicit type casts.\n\tCandidate functions:\n%s",
101 params: call_str, params: candidate_str);
102 return candidate_functions;
103 }
104 candidate_functions.push_back(x: best_function);
105 return candidate_functions;
106}
107
108template <class T>
109idx_t FunctionBinder::MultipleCandidateException(const string &name, FunctionSet<T> &functions,
110 vector<idx_t> &candidate_functions,
111 const vector<LogicalType> &arguments, string &error) {
112 D_ASSERT(functions.functions.size() > 1);
113 // there are multiple possible function definitions
114 // throw an exception explaining which overloads are there
115 string call_str = Function::CallToString(name, arguments);
116 string candidate_str = "";
117 for (auto &conf : candidate_functions) {
118 T f = functions.GetFunctionByOffset(conf);
119 candidate_str += "\t" + f.ToString() + "\n";
120 }
121 error = StringUtil::Format(fmt_str: "Could not choose a best candidate function for the function call \"%s\". In order to "
122 "select one, please add explicit type casts.\n\tCandidate functions:\n%s",
123 params: call_str, params: candidate_str);
124 return DConstants::INVALID_INDEX;
125}
126
127template <class T>
128idx_t FunctionBinder::BindFunctionFromArguments(const string &name, FunctionSet<T> &functions,
129 const vector<LogicalType> &arguments, string &error) {
130 auto candidate_functions = BindFunctionsFromArguments<T>(name, functions, arguments, error);
131 if (candidate_functions.empty()) {
132 // no candidates
133 return DConstants::INVALID_INDEX;
134 }
135 if (candidate_functions.size() > 1) {
136 // multiple candidates, check if there are any unknown arguments
137 bool has_parameters = false;
138 for (auto &arg_type : arguments) {
139 if (arg_type.id() == LogicalTypeId::UNKNOWN) {
140 //! there are! we could not resolve parameters in this case
141 throw ParameterNotResolvedException();
142 }
143 }
144 if (!has_parameters) {
145 return MultipleCandidateException(name, functions, candidate_functions, arguments, error);
146 }
147 }
148 return candidate_functions[0];
149}
150
151idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions,
152 const vector<LogicalType> &arguments, string &error) {
153 return BindFunctionFromArguments(name, functions, arguments, error);
154}
155
156idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions,
157 const vector<LogicalType> &arguments, string &error) {
158 return BindFunctionFromArguments(name, functions, arguments, error);
159}
160
161idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions,
162 const vector<LogicalType> &arguments, string &error) {
163 return BindFunctionFromArguments(name, functions, arguments, error);
164}
165
166idx_t FunctionBinder::BindFunction(const string &name, PragmaFunctionSet &functions, PragmaInfo &info, string &error) {
167 vector<LogicalType> types;
168 for (auto &value : info.parameters) {
169 types.push_back(x: value.type());
170 }
171 idx_t entry = BindFunctionFromArguments(name, functions, arguments: types, error);
172 if (entry == DConstants::INVALID_INDEX) {
173 throw BinderException(error);
174 }
175 auto candidate_function = functions.GetFunctionByOffset(offset: entry);
176 // cast the input parameters
177 for (idx_t i = 0; i < info.parameters.size(); i++) {
178 auto target_type =
179 i < candidate_function.arguments.size() ? candidate_function.arguments[i] : candidate_function.varargs;
180 info.parameters[i] = info.parameters[i].CastAs(context, target_type);
181 }
182 return entry;
183}
184
185vector<LogicalType> FunctionBinder::GetLogicalTypesFromExpressions(vector<unique_ptr<Expression>> &arguments) {
186 vector<LogicalType> types;
187 types.reserve(n: arguments.size());
188 for (auto &argument : arguments) {
189 types.push_back(x: argument->return_type);
190 }
191 return types;
192}
193
194idx_t FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions,
195 vector<unique_ptr<Expression>> &arguments, string &error) {
196 auto types = GetLogicalTypesFromExpressions(arguments);
197 return BindFunction(name, functions, arguments: types, error);
198}
199
200idx_t FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions,
201 vector<unique_ptr<Expression>> &arguments, string &error) {
202 auto types = GetLogicalTypesFromExpressions(arguments);
203 return BindFunction(name, functions, arguments: types, error);
204}
205
206idx_t FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions,
207 vector<unique_ptr<Expression>> &arguments, string &error) {
208 auto types = GetLogicalTypesFromExpressions(arguments);
209 return BindFunction(name, functions, arguments: types, error);
210}
211
212enum class LogicalTypeComparisonResult { IDENTICAL_TYPE, TARGET_IS_ANY, DIFFERENT_TYPES };
213
214LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const LogicalType &target_type) {
215 if (target_type.id() == LogicalTypeId::ANY) {
216 return LogicalTypeComparisonResult::TARGET_IS_ANY;
217 }
218 if (source_type == target_type) {
219 return LogicalTypeComparisonResult::IDENTICAL_TYPE;
220 }
221 if (source_type.id() == LogicalTypeId::LIST && target_type.id() == LogicalTypeId::LIST) {
222 return RequiresCast(source_type: ListType::GetChildType(type: source_type), target_type: ListType::GetChildType(type: target_type));
223 }
224 return LogicalTypeComparisonResult::DIFFERENT_TYPES;
225}
226
227void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector<unique_ptr<Expression>> &children) {
228 for (idx_t i = 0; i < children.size(); i++) {
229 auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs;
230 target_type.Verify();
231 // don't cast lambda children, they get removed anyways
232 if (children[i]->return_type.id() == LogicalTypeId::LAMBDA) {
233 continue;
234 }
235 // check if the type of child matches the type of function argument
236 // if not we need to add a cast
237 auto cast_result = RequiresCast(source_type: children[i]->return_type, target_type);
238 // except for one special case: if the function accepts ANY argument
239 // in that case we don't add a cast
240 if (cast_result == LogicalTypeComparisonResult::DIFFERENT_TYPES) {
241 children[i] = BoundCastExpression::AddCastToType(context, expr: std::move(children[i]), target_type);
242 }
243 }
244}
245
246unique_ptr<Expression> FunctionBinder::BindScalarFunction(const string &schema, const string &name,
247 vector<unique_ptr<Expression>> children, string &error,
248 bool is_operator, Binder *binder) {
249 // bind the function
250 auto &function =
251 Catalog::GetSystemCatalog(context).GetEntry(context, type: CatalogType::SCALAR_FUNCTION_ENTRY, schema, name);
252 D_ASSERT(function.type == CatalogType::SCALAR_FUNCTION_ENTRY);
253 return BindScalarFunction(function&: function.Cast<ScalarFunctionCatalogEntry>(), children: std::move(children), error, is_operator,
254 binder);
255}
256
257unique_ptr<Expression> FunctionBinder::BindScalarFunction(ScalarFunctionCatalogEntry &func,
258 vector<unique_ptr<Expression>> children, string &error,
259 bool is_operator, Binder *binder) {
260 // bind the function
261 idx_t best_function = BindFunction(name: func.name, functions&: func.functions, arguments&: children, error);
262 if (best_function == DConstants::INVALID_INDEX) {
263 return nullptr;
264 }
265
266 // found a matching function!
267 auto bound_function = func.functions.GetFunctionByOffset(offset: best_function);
268
269 if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) {
270 for (auto &child : children) {
271 if (child->return_type == LogicalTypeId::SQLNULL) {
272 return make_uniq<BoundConstantExpression>(args: Value(LogicalType::SQLNULL));
273 }
274 }
275 }
276 return BindScalarFunction(bound_function, children: std::move(children), is_operator);
277}
278
279unique_ptr<BoundFunctionExpression> FunctionBinder::BindScalarFunction(ScalarFunction bound_function,
280 vector<unique_ptr<Expression>> children,
281 bool is_operator) {
282 unique_ptr<FunctionData> bind_info;
283 if (bound_function.bind) {
284 bind_info = bound_function.bind(context, bound_function, children);
285 }
286 // check if we need to add casts to the children
287 CastToFunctionArguments(function&: bound_function, children);
288
289 // now create the function
290 auto return_type = bound_function.return_type;
291 return make_uniq<BoundFunctionExpression>(args: std::move(return_type), args: std::move(bound_function), args: std::move(children),
292 args: std::move(bind_info), args&: is_operator);
293}
294
295unique_ptr<BoundAggregateExpression> FunctionBinder::BindAggregateFunction(AggregateFunction bound_function,
296 vector<unique_ptr<Expression>> children,
297 unique_ptr<Expression> filter,
298 AggregateType aggr_type) {
299 unique_ptr<FunctionData> bind_info;
300 if (bound_function.bind) {
301 bind_info = bound_function.bind(context, bound_function, children);
302 // we may have lost some arguments in the bind
303 children.resize(new_size: MinValue(a: bound_function.arguments.size(), b: children.size()));
304 }
305
306 // check if we need to add casts to the children
307 CastToFunctionArguments(function&: bound_function, children);
308
309 return make_uniq<BoundAggregateExpression>(args: std::move(bound_function), args: std::move(children), args: std::move(filter),
310 args: std::move(bind_info), args&: aggr_type);
311}
312
313} // namespace duckdb
314