1#include "duckdb/function/cast/cast_function_set.hpp"
2
3#include "duckdb/common/pair.hpp"
4#include "duckdb/common/types/type_map.hpp"
5#include "duckdb/function/cast_rules.hpp"
6#include "duckdb/main/config.hpp"
7
8namespace duckdb {
9
10BindCastInput::BindCastInput(CastFunctionSet &function_set, optional_ptr<BindCastInfo> info,
11 optional_ptr<ClientContext> context)
12 : function_set(function_set), info(info), context(context) {
13}
14
15BoundCastInfo BindCastInput::GetCastFunction(const LogicalType &source, const LogicalType &target) {
16 GetCastFunctionInput input(context);
17 return function_set.GetCastFunction(source, target, input);
18}
19
20BindCastFunction::BindCastFunction(bind_cast_function_t function_p, unique_ptr<BindCastInfo> info_p)
21 : function(function_p), info(std::move(info_p)) {
22}
23
24CastFunctionSet::CastFunctionSet() : map_info(nullptr) {
25 bind_functions.emplace_back(args&: DefaultCasts::GetDefaultCastFunction);
26}
27
28CastFunctionSet &CastFunctionSet::Get(ClientContext &context) {
29 return DBConfig::GetConfig(context).GetCastFunctions();
30}
31
32CastFunctionSet &CastFunctionSet::Get(DatabaseInstance &db) {
33 return DBConfig::GetConfig(db).GetCastFunctions();
34}
35
36BoundCastInfo CastFunctionSet::GetCastFunction(const LogicalType &source, const LogicalType &target,
37 GetCastFunctionInput &get_input) {
38 if (source == target) {
39 return DefaultCasts::NopCast;
40 }
41 // the first function is the default
42 // we iterate the set of bind functions backwards
43 for (idx_t i = bind_functions.size(); i > 0; i--) {
44 auto &bind_function = bind_functions[i - 1];
45 BindCastInput input(*this, bind_function.info.get(), get_input.context);
46 auto result = bind_function.function(input, source, target);
47 if (result.function) {
48 // found a cast function! return it
49 return result;
50 }
51 }
52 // no cast found: return the default null cast
53 return DefaultCasts::TryVectorNullCast;
54}
55
56struct MapCastNode {
57 MapCastNode(BoundCastInfo info, int64_t implicit_cast_cost)
58 : cast_info(std::move(info)), bind_function(nullptr), implicit_cast_cost(implicit_cast_cost) {
59 }
60 MapCastNode(bind_cast_function_t func, int64_t implicit_cast_cost)
61 : cast_info(nullptr), bind_function(func), implicit_cast_cost(implicit_cast_cost) {
62 }
63
64 BoundCastInfo cast_info;
65 bind_cast_function_t bind_function;
66 int64_t implicit_cast_cost;
67};
68
69template <class MAP_VALUE_TYPE>
70static auto RelaxedTypeMatch(type_map_t<MAP_VALUE_TYPE> &map, const LogicalType &type) -> decltype(map.find(type)) {
71 D_ASSERT(map.find(type) == map.end()); // we shouldn't be here
72 switch (type.id()) {
73 case LogicalTypeId::LIST:
74 return map.find(LogicalType::LIST(child: LogicalType::ANY));
75 case LogicalTypeId::STRUCT:
76 return map.find(LogicalType::STRUCT(children: {{"any", LogicalType::ANY}}));
77 case LogicalTypeId::MAP:
78 for (auto it = map.begin(); it != map.end(); it++) {
79 const auto &entry_type = it->first;
80 if (entry_type.id() != LogicalTypeId::MAP) {
81 continue;
82 }
83 auto &entry_key_type = MapType::KeyType(type: entry_type);
84 auto &entry_val_type = MapType::ValueType(type: entry_type);
85 if ((entry_key_type == LogicalType::ANY || entry_key_type == MapType::KeyType(type)) &&
86 (entry_val_type == LogicalType::ANY || entry_val_type == MapType::ValueType(type))) {
87 return it;
88 }
89 }
90 return map.end();
91 case LogicalTypeId::UNION:
92 return map.find(LogicalType::UNION(members: {{"any", LogicalType::ANY}}));
93 default:
94 return map.find(LogicalType::ANY);
95 }
96}
97
98struct MapCastInfo : public BindCastInfo {
99public:
100 const optional_ptr<MapCastNode> GetEntry(const LogicalType &source, const LogicalType &target) {
101 auto source_type_id_entry = casts.find(x: source.id());
102 if (source_type_id_entry == casts.end()) {
103 source_type_id_entry = casts.find(x: LogicalTypeId::ANY);
104 if (source_type_id_entry == casts.end()) {
105 return nullptr;
106 }
107 }
108
109 auto &source_type_entries = source_type_id_entry->second;
110 auto source_type_entry = source_type_entries.find(x: source);
111 if (source_type_entry == source_type_entries.end()) {
112 source_type_entry = RelaxedTypeMatch(map&: source_type_entries, type: source);
113 if (source_type_entry == source_type_entries.end()) {
114 return nullptr;
115 }
116 }
117
118 auto &target_type_id_entries = source_type_entry->second;
119 auto target_type_id_entry = target_type_id_entries.find(x: target.id());
120 if (target_type_id_entry == target_type_id_entries.end()) {
121 target_type_id_entry = target_type_id_entries.find(x: LogicalTypeId::ANY);
122 if (target_type_id_entry == target_type_id_entries.end()) {
123 return nullptr;
124 }
125 }
126
127 auto &target_type_entries = target_type_id_entry->second;
128 auto target_type_entry = target_type_entries.find(x: target);
129 if (target_type_entry == target_type_entries.end()) {
130 target_type_entry = RelaxedTypeMatch(map&: target_type_entries, type: target);
131 if (target_type_entry == target_type_entries.end()) {
132 return nullptr;
133 }
134 }
135
136 return &target_type_entry->second;
137 }
138
139 void AddEntry(const LogicalType &source, const LogicalType &target, MapCastNode node) {
140 casts[source.id()][source][target.id()].insert(x: make_pair(x: target, y: std::move(node)));
141 }
142
143private:
144 type_id_map_t<type_map_t<type_id_map_t<type_map_t<MapCastNode>>>> casts;
145};
146
147int64_t CastFunctionSet::ImplicitCastCost(const LogicalType &source, const LogicalType &target) {
148 // check if a cast has been registered
149 if (map_info) {
150 auto entry = map_info->GetEntry(source, target);
151 if (entry) {
152 return entry->implicit_cast_cost;
153 }
154 }
155 // if not, fallback to the default implicit cast rules
156 return CastRules::ImplicitCast(from: source, to: target);
157}
158
159BoundCastInfo MapCastFunction(BindCastInput &input, const LogicalType &source, const LogicalType &target) {
160 D_ASSERT(input.info);
161 auto &map_info = input.info->Cast<MapCastInfo>();
162 auto entry = map_info.GetEntry(source, target);
163 if (entry) {
164 if (entry->bind_function) {
165 return entry->bind_function(input, source, target);
166 }
167 return entry->cast_info.Copy();
168 }
169 return nullptr;
170}
171
172void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function,
173 int64_t implicit_cast_cost) {
174 RegisterCastFunction(source, target, node: MapCastNode(std::move(function), implicit_cast_cost));
175}
176
177void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target,
178 bind_cast_function_t bind_function, int64_t implicit_cast_cost) {
179 RegisterCastFunction(source, target, node: MapCastNode(bind_function, implicit_cast_cost));
180}
181
182void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, MapCastNode node) {
183 if (!map_info) {
184 // create the cast map and the cast map function
185 auto info = make_uniq<MapCastInfo>();
186 map_info = info.get();
187 bind_functions.emplace_back(args&: MapCastFunction, args: std::move(info));
188 }
189 map_info->AddEntry(source, target, node: std::move(node));
190}
191
192} // namespace duckdb
193