1#include "duckdb/common/string_util.hpp"
2#include "duckdb/execution/expression_executor.hpp"
3#include "duckdb/function/scalar/nested_functions.hpp"
4#include "duckdb/planner/expression/bound_function_expression.hpp"
5#include "duckdb/planner/expression/bound_parameter_expression.hpp"
6#include "duckdb/storage/statistics/struct_stats.hpp"
7
8namespace duckdb {
9
10struct StructExtractBindData : public FunctionData {
11 StructExtractBindData(string key, idx_t index, LogicalType type)
12 : key(std::move(key)), index(index), type(std::move(type)) {
13 }
14
15 string key;
16 idx_t index;
17 LogicalType type;
18
19public:
20 unique_ptr<FunctionData> Copy() const override {
21 return make_uniq<StructExtractBindData>(args: key, args: index, args: type);
22 }
23 bool Equals(const FunctionData &other_p) const override {
24 auto &other = other_p.Cast<StructExtractBindData>();
25 return key == other.key && index == other.index && type == other.type;
26 }
27};
28
29static void StructExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) {
30 auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
31 auto &info = func_expr.bind_info->Cast<StructExtractBindData>();
32
33 // this should be guaranteed by the binder
34 auto &vec = args.data[0];
35
36 vec.Verify(count: args.size());
37 auto &children = StructVector::GetEntries(vector&: vec);
38 D_ASSERT(info.index < children.size());
39 auto &struct_child = children[info.index];
40 result.Reference(other&: *struct_child);
41 result.Verify(count: args.size());
42}
43
44static unique_ptr<FunctionData> StructExtractBind(ClientContext &context, ScalarFunction &bound_function,
45 vector<unique_ptr<Expression>> &arguments) {
46 D_ASSERT(bound_function.arguments.size() == 2);
47 if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) {
48 throw ParameterNotResolvedException();
49 }
50 D_ASSERT(LogicalTypeId::STRUCT == arguments[0]->return_type.id());
51 auto &struct_children = StructType::GetChildTypes(type: arguments[0]->return_type);
52 if (struct_children.empty()) {
53 throw InternalException("Can't extract something from an empty struct");
54 }
55 bound_function.arguments[0] = arguments[0]->return_type;
56
57 auto &key_child = arguments[1];
58 if (key_child->HasParameter()) {
59 throw ParameterNotResolvedException();
60 }
61
62 if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) {
63 throw BinderException("Key name for struct_extract needs to be a constant string");
64 }
65 Value key_val = ExpressionExecutor::EvaluateScalar(context, expr: *key_child);
66 D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR);
67 auto &key_str = StringValue::Get(value: key_val);
68 if (key_val.IsNull() || key_str.empty()) {
69 throw BinderException("Key name for struct_extract needs to be neither NULL nor empty");
70 }
71 string key = StringUtil::Lower(str: key_str);
72
73 LogicalType return_type;
74 idx_t key_index = 0;
75 bool found_key = false;
76
77 for (size_t i = 0; i < struct_children.size(); i++) {
78 auto &child = struct_children[i];
79 if (StringUtil::Lower(str: child.first) == key) {
80 found_key = true;
81 key_index = i;
82 return_type = child.second;
83 break;
84 }
85 }
86
87 if (!found_key) {
88 vector<string> candidates;
89 candidates.reserve(n: struct_children.size());
90 for (auto &struct_child : struct_children) {
91 candidates.push_back(x: struct_child.first);
92 }
93 auto closest_settings = StringUtil::TopNLevenshtein(strings: candidates, target: key);
94 auto message = StringUtil::CandidatesMessage(candidates: closest_settings, candidate: "Candidate Entries");
95 throw BinderException("Could not find key \"%s\" in struct\n%s", key, message);
96 }
97
98 bound_function.return_type = return_type;
99 return make_uniq<StructExtractBindData>(args: std::move(key), args&: key_index, args: std::move(return_type));
100}
101
102static unique_ptr<BaseStatistics> PropagateStructExtractStats(ClientContext &context, FunctionStatisticsInput &input) {
103 auto &child_stats = input.child_stats;
104 auto &bind_data = input.bind_data;
105
106 auto &info = bind_data->Cast<StructExtractBindData>();
107 auto struct_child_stats = StructStats::GetChildStats(stats: child_stats[0]);
108 return struct_child_stats[info.index].ToUnique();
109}
110
111ScalarFunction StructExtractFun::GetFunction() {
112 return ScalarFunction("struct_extract", {LogicalTypeId::STRUCT, LogicalType::VARCHAR}, LogicalType::ANY,
113 StructExtractFunction, StructExtractBind, nullptr, PropagateStructExtractStats);
114}
115
116void StructExtractFun::RegisterFunction(BuiltinFunctions &set) {
117 // the arguments and return types are actually set in the binder function
118 auto fun = GetFunction();
119 set.AddFunction(function: fun);
120}
121
122} // namespace duckdb
123