1#include "duckdb/function/function.hpp"
2
3#include "duckdb/catalog/catalog.hpp"
4#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
5#include "duckdb/common/string_util.hpp"
6#include "duckdb/function/aggregate_function.hpp"
7#include "duckdb/function/cast_rules.hpp"
8#include "duckdb/function/scalar/string_functions.hpp"
9#include "duckdb/function/scalar_function.hpp"
10#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp"
11#include "duckdb/parser/parsed_data/create_collation_info.hpp"
12#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp"
13#include "duckdb/parser/parsed_data/create_table_function_info.hpp"
14#include "duckdb/planner/expression/bound_cast_expression.hpp"
15#include "duckdb/planner/expression/bound_function_expression.hpp"
16
17using namespace duckdb;
18using namespace std;
19
20// add your initializer for new functions here
21void BuiltinFunctions::Initialize() {
22 RegisterSQLiteFunctions();
23 RegisterReadFunctions();
24
25 RegisterAlgebraicAggregates();
26 RegisterDistributiveAggregates();
27 RegisterNestedAggregates();
28
29 RegisterDateFunctions();
30 RegisterMathFunctions();
31 RegisterOperators();
32 RegisterSequenceFunctions();
33 RegisterStringFunctions();
34 RegisterNestedFunctions();
35 RegisterTrigonometricsFunctions();
36
37 // initialize collations
38 AddCollation("nocase", LowerFun::GetFunction(), true);
39 AddCollation("noaccent", StripAccentsFun::GetFunction());
40}
41
42BuiltinFunctions::BuiltinFunctions(ClientContext &context, Catalog &catalog) : context(context), catalog(catalog) {
43}
44
45void BuiltinFunctions::AddCollation(string name, ScalarFunction function, bool combinable,
46 bool not_required_for_equality) {
47 CreateCollationInfo info(move(name), move(function), combinable, not_required_for_equality);
48 catalog.CreateCollation(context, &info);
49}
50
51void BuiltinFunctions::AddFunction(AggregateFunctionSet set) {
52 CreateAggregateFunctionInfo info(set);
53 catalog.CreateFunction(context, &info);
54}
55
56void BuiltinFunctions::AddFunction(AggregateFunction function) {
57 CreateAggregateFunctionInfo info(function);
58 catalog.CreateFunction(context, &info);
59}
60
61void BuiltinFunctions::AddFunction(ScalarFunction function) {
62 CreateScalarFunctionInfo info(function);
63 catalog.CreateFunction(context, &info);
64}
65
66void BuiltinFunctions::AddFunction(vector<string> names, ScalarFunction function) {
67 for (auto &name : names) {
68 function.name = name;
69 AddFunction(function);
70 }
71}
72
73void BuiltinFunctions::AddFunction(ScalarFunctionSet set) {
74 CreateScalarFunctionInfo info(set);
75 catalog.CreateFunction(context, &info);
76}
77
78void BuiltinFunctions::AddFunction(TableFunction function) {
79 CreateTableFunctionInfo info(function);
80 catalog.CreateTableFunction(context, &info);
81}
82
83string Function::CallToString(string name, vector<SQLType> arguments) {
84 string result = name + "(";
85 result += StringUtil::Join(arguments, arguments.size(), ", ",
86 [](const SQLType &argument) { return SQLTypeToString(argument); });
87 return result + ")";
88}
89
90string Function::CallToString(string name, vector<SQLType> arguments, SQLType return_type) {
91 string result = CallToString(name, arguments);
92 result += " -> " + SQLTypeToString(return_type);
93 return result;
94}
95
96static int64_t BindVarArgsFunctionCost(SimpleFunction &func, vector<SQLType> &arguments) {
97 if (arguments.size() < func.arguments.size()) {
98 // not enough arguments to fulfill the non-vararg part of the function
99 return -1;
100 }
101 int64_t cost = 0;
102 for (idx_t i = 0; i < arguments.size(); i++) {
103 SQLType arg_type = i < func.arguments.size() ? func.arguments[i] : func.varargs;
104 if (arguments[i] == arg_type) {
105 // arguments match: do nothing
106 continue;
107 }
108 int64_t cast_cost = CastRules::ImplicitCast(arguments[i], arg_type);
109 if (cast_cost >= 0) {
110 // we can implicitly cast, add the cost to the total cost
111 cost += cast_cost;
112 } else {
113 // we can't implicitly cast: throw an error
114 return -1;
115 }
116 }
117 return cost;
118}
119
120static int64_t BindFunctionCost(SimpleFunction &func, vector<SQLType> &arguments) {
121 if (func.HasVarArgs()) {
122 // special case varargs function
123 return BindVarArgsFunctionCost(func, arguments);
124 }
125 if (func.arguments.size() != arguments.size()) {
126 // invalid argument count: check the next function
127 return -1;
128 }
129 int64_t cost = 0;
130 for (idx_t i = 0; i < arguments.size(); i++) {
131 if (arguments[i] == func.arguments[i]) {
132 // arguments match: do nothing
133 continue;
134 }
135 int64_t cast_cost = CastRules::ImplicitCast(arguments[i], func.arguments[i]);
136 if (cast_cost >= 0) {
137 // we can implicitly cast, add the cost to the total cost
138 cost += cast_cost;
139 } else {
140 // we can't implicitly cast: throw an error
141 return -1;
142 }
143 }
144 return cost;
145}
146
147template <class T>
148static idx_t BindFunctionFromArguments(string name, vector<T> &functions, vector<SQLType> &arguments) {
149 idx_t best_function = INVALID_INDEX;
150 int64_t lowest_cost = numeric_limits<int64_t>::max();
151 vector<idx_t> conflicting_functions;
152 for (idx_t f_idx = 0; f_idx < functions.size(); f_idx++) {
153 auto &func = functions[f_idx];
154 // check the arguments of the function
155 int64_t cost = BindFunctionCost(func, arguments);
156 if (cost < 0) {
157 // auto casting was not possible
158 continue;
159 }
160 if (cost == lowest_cost) {
161 conflicting_functions.push_back(f_idx);
162 continue;
163 }
164 if (cost > lowest_cost) {
165 continue;
166 }
167 conflicting_functions.clear();
168 lowest_cost = cost;
169 best_function = f_idx;
170 }
171 if (conflicting_functions.size() > 0) {
172 // there are multiple possible function definitions
173 // throw an exception explaining which overloads are there
174 conflicting_functions.push_back(best_function);
175 string call_str = Function::CallToString(name, arguments);
176 string candidate_str = "";
177 for (auto &conf : conflicting_functions) {
178 auto &f = functions[conf];
179 candidate_str += "\t" + f.ToString() + "\n";
180 }
181 throw BinderException("Could not choose a best candidate function for the function call \"%s\". In order to "
182 "select one, please add explicit type casts.\n\tCandidate functions:\n%s",
183 call_str.c_str(), candidate_str.c_str());
184 }
185 if (best_function == INVALID_INDEX) {
186 // no matching function was found, throw an error
187 string call_str = Function::CallToString(name, arguments);
188 string candidate_str = "";
189 for (auto &f : functions) {
190 candidate_str += "\t" + f.ToString() + "\n";
191 }
192 throw BinderException("No function matches the given name and argument types '%s'. You might need to add "
193 "explicit type casts.\n\tCandidate functions:\n%s",
194 call_str.c_str(), candidate_str.c_str());
195 }
196 return best_function;
197}
198
199idx_t Function::BindFunction(string name, vector<ScalarFunction> &functions, vector<SQLType> &arguments) {
200 return BindFunctionFromArguments(name, functions, arguments);
201}
202
203idx_t Function::BindFunction(string name, vector<AggregateFunction> &functions, vector<SQLType> &arguments) {
204 return BindFunctionFromArguments(name, functions, arguments);
205}
206
207void SimpleFunction::CastToFunctionArguments(vector<unique_ptr<Expression>> &children, vector<SQLType> &types) {
208 for (idx_t i = 0; i < types.size(); i++) {
209 auto target_type = i < this->arguments.size() ? this->arguments[i] : this->varargs;
210 if (target_type.id != SQLTypeId::ANY && types[i] != target_type) {
211 // type of child does not match type of function argument: add a cast
212 children[i] = BoundCastExpression::AddCastToType(move(children[i]), types[i], target_type);
213 }
214 }
215}
216
217unique_ptr<BoundFunctionExpression> ScalarFunction::BindScalarFunction(ClientContext &context, string schema,
218 string name, vector<SQLType> &arguments,
219 vector<unique_ptr<Expression>> children,
220 bool is_operator) {
221 // bind the function
222 auto function = Catalog::GetCatalog(context).GetEntry(context, CatalogType::SCALAR_FUNCTION, schema, name);
223 assert(function && function->type == CatalogType::SCALAR_FUNCTION);
224 return ScalarFunction::BindScalarFunction(context, (ScalarFunctionCatalogEntry &)*function, arguments,
225 move(children), is_operator);
226}
227
228unique_ptr<BoundFunctionExpression>
229ScalarFunction::BindScalarFunction(ClientContext &context, ScalarFunctionCatalogEntry &func, vector<SQLType> &arguments,
230 vector<unique_ptr<Expression>> children, bool is_operator) {
231 // bind the function
232 idx_t best_function = Function::BindFunction(func.name, func.functions, arguments);
233 // found a matching function!
234 auto &bound_function = func.functions[best_function];
235 // check if we need to add casts to the children
236 bound_function.CastToFunctionArguments(children, arguments);
237
238 // now create the function
239 auto result =
240 make_unique<BoundFunctionExpression>(GetInternalType(bound_function.return_type), bound_function, is_operator);
241 result->children = move(children);
242 result->arguments = arguments;
243 result->sql_return_type = bound_function.return_type;
244 if (bound_function.bind) {
245 result->bind_info = bound_function.bind(*result, context);
246 }
247 return result;
248}
249