1 | #include "duckdb/common/string_util.hpp" |
2 | #include "duckdb/execution/expression_executor.hpp" |
3 | #include "duckdb/function/scalar/nested_functions.hpp" |
4 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_parameter_expression.hpp" |
6 | #include "duckdb/storage/statistics/struct_stats.hpp" |
7 | |
8 | namespace duckdb { |
9 | |
10 | struct : public FunctionData { |
11 | (string key, idx_t index, LogicalType type) |
12 | : key(std::move(key)), index(index), type(std::move(type)) { |
13 | } |
14 | |
15 | string ; |
16 | idx_t ; |
17 | LogicalType ; |
18 | |
19 | public: |
20 | unique_ptr<FunctionData> () const override { |
21 | return make_uniq<StructExtractBindData>(args: key, args: index, args: type); |
22 | } |
23 | bool (const FunctionData &other_p) const override { |
24 | auto &other = other_p.Cast<StructExtractBindData>(); |
25 | return key == other.key && index == other.index && type == other.type; |
26 | } |
27 | }; |
28 | |
29 | static void (DataChunk &args, ExpressionState &state, Vector &result) { |
30 | auto &func_expr = state.expr.Cast<BoundFunctionExpression>(); |
31 | auto &info = func_expr.bind_info->Cast<StructExtractBindData>(); |
32 | |
33 | // this should be guaranteed by the binder |
34 | auto &vec = args.data[0]; |
35 | |
36 | vec.Verify(count: args.size()); |
37 | auto &children = StructVector::GetEntries(vector&: vec); |
38 | D_ASSERT(info.index < children.size()); |
39 | auto &struct_child = children[info.index]; |
40 | result.Reference(other&: *struct_child); |
41 | result.Verify(count: args.size()); |
42 | } |
43 | |
44 | static unique_ptr<FunctionData> (ClientContext &context, ScalarFunction &bound_function, |
45 | vector<unique_ptr<Expression>> &arguments) { |
46 | D_ASSERT(bound_function.arguments.size() == 2); |
47 | if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { |
48 | throw ParameterNotResolvedException(); |
49 | } |
50 | D_ASSERT(LogicalTypeId::STRUCT == arguments[0]->return_type.id()); |
51 | auto &struct_children = StructType::GetChildTypes(type: arguments[0]->return_type); |
52 | if (struct_children.empty()) { |
53 | throw InternalException("Can't extract something from an empty struct" ); |
54 | } |
55 | bound_function.arguments[0] = arguments[0]->return_type; |
56 | |
57 | auto &key_child = arguments[1]; |
58 | if (key_child->HasParameter()) { |
59 | throw ParameterNotResolvedException(); |
60 | } |
61 | |
62 | if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { |
63 | throw BinderException("Key name for struct_extract needs to be a constant string" ); |
64 | } |
65 | Value key_val = ExpressionExecutor::EvaluateScalar(context, expr: *key_child); |
66 | D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); |
67 | auto &key_str = StringValue::Get(value: key_val); |
68 | if (key_val.IsNull() || key_str.empty()) { |
69 | throw BinderException("Key name for struct_extract needs to be neither NULL nor empty" ); |
70 | } |
71 | string key = StringUtil::Lower(str: key_str); |
72 | |
73 | LogicalType return_type; |
74 | idx_t key_index = 0; |
75 | bool found_key = false; |
76 | |
77 | for (size_t i = 0; i < struct_children.size(); i++) { |
78 | auto &child = struct_children[i]; |
79 | if (StringUtil::Lower(str: child.first) == key) { |
80 | found_key = true; |
81 | key_index = i; |
82 | return_type = child.second; |
83 | break; |
84 | } |
85 | } |
86 | |
87 | if (!found_key) { |
88 | vector<string> candidates; |
89 | candidates.reserve(n: struct_children.size()); |
90 | for (auto &struct_child : struct_children) { |
91 | candidates.push_back(x: struct_child.first); |
92 | } |
93 | auto closest_settings = StringUtil::TopNLevenshtein(strings: candidates, target: key); |
94 | auto message = StringUtil::CandidatesMessage(candidates: closest_settings, candidate: "Candidate Entries" ); |
95 | throw BinderException("Could not find key \"%s\" in struct\n%s" , key, message); |
96 | } |
97 | |
98 | bound_function.return_type = return_type; |
99 | return make_uniq<StructExtractBindData>(args: std::move(key), args&: key_index, args: std::move(return_type)); |
100 | } |
101 | |
102 | static unique_ptr<BaseStatistics> (ClientContext &context, FunctionStatisticsInput &input) { |
103 | auto &child_stats = input.child_stats; |
104 | auto &bind_data = input.bind_data; |
105 | |
106 | auto &info = bind_data->Cast<StructExtractBindData>(); |
107 | auto struct_child_stats = StructStats::GetChildStats(stats: child_stats[0]); |
108 | return struct_child_stats[info.index].ToUnique(); |
109 | } |
110 | |
111 | ScalarFunction StructExtractFun::() { |
112 | return ScalarFunction("struct_extract" , {LogicalTypeId::STRUCT, LogicalType::VARCHAR}, LogicalType::ANY, |
113 | StructExtractFunction, StructExtractBind, nullptr, PropagateStructExtractStats); |
114 | } |
115 | |
116 | void StructExtractFun::(BuiltinFunctions &set) { |
117 | // the arguments and return types are actually set in the binder function |
118 | auto fun = GetFunction(); |
119 | set.AddFunction(function: fun); |
120 | } |
121 | |
122 | } // namespace duckdb |
123 | |