| 1 | #include "duckdb/common/string_util.hpp" |
| 2 | #include "duckdb/execution/expression_executor.hpp" |
| 3 | #include "duckdb/function/scalar/nested_functions.hpp" |
| 4 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 5 | #include "duckdb/planner/expression/bound_parameter_expression.hpp" |
| 6 | #include "duckdb/storage/statistics/struct_stats.hpp" |
| 7 | |
| 8 | namespace duckdb { |
| 9 | |
| 10 | struct : public FunctionData { |
| 11 | (string key, idx_t index, LogicalType type) |
| 12 | : key(std::move(key)), index(index), type(std::move(type)) { |
| 13 | } |
| 14 | |
| 15 | string ; |
| 16 | idx_t ; |
| 17 | LogicalType ; |
| 18 | |
| 19 | public: |
| 20 | unique_ptr<FunctionData> () const override { |
| 21 | return make_uniq<StructExtractBindData>(args: key, args: index, args: type); |
| 22 | } |
| 23 | bool (const FunctionData &other_p) const override { |
| 24 | auto &other = other_p.Cast<StructExtractBindData>(); |
| 25 | return key == other.key && index == other.index && type == other.type; |
| 26 | } |
| 27 | }; |
| 28 | |
| 29 | static void (DataChunk &args, ExpressionState &state, Vector &result) { |
| 30 | auto &func_expr = state.expr.Cast<BoundFunctionExpression>(); |
| 31 | auto &info = func_expr.bind_info->Cast<StructExtractBindData>(); |
| 32 | |
| 33 | // this should be guaranteed by the binder |
| 34 | auto &vec = args.data[0]; |
| 35 | |
| 36 | vec.Verify(count: args.size()); |
| 37 | auto &children = StructVector::GetEntries(vector&: vec); |
| 38 | D_ASSERT(info.index < children.size()); |
| 39 | auto &struct_child = children[info.index]; |
| 40 | result.Reference(other&: *struct_child); |
| 41 | result.Verify(count: args.size()); |
| 42 | } |
| 43 | |
| 44 | static unique_ptr<FunctionData> (ClientContext &context, ScalarFunction &bound_function, |
| 45 | vector<unique_ptr<Expression>> &arguments) { |
| 46 | D_ASSERT(bound_function.arguments.size() == 2); |
| 47 | if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { |
| 48 | throw ParameterNotResolvedException(); |
| 49 | } |
| 50 | D_ASSERT(LogicalTypeId::STRUCT == arguments[0]->return_type.id()); |
| 51 | auto &struct_children = StructType::GetChildTypes(type: arguments[0]->return_type); |
| 52 | if (struct_children.empty()) { |
| 53 | throw InternalException("Can't extract something from an empty struct" ); |
| 54 | } |
| 55 | bound_function.arguments[0] = arguments[0]->return_type; |
| 56 | |
| 57 | auto &key_child = arguments[1]; |
| 58 | if (key_child->HasParameter()) { |
| 59 | throw ParameterNotResolvedException(); |
| 60 | } |
| 61 | |
| 62 | if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { |
| 63 | throw BinderException("Key name for struct_extract needs to be a constant string" ); |
| 64 | } |
| 65 | Value key_val = ExpressionExecutor::EvaluateScalar(context, expr: *key_child); |
| 66 | D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); |
| 67 | auto &key_str = StringValue::Get(value: key_val); |
| 68 | if (key_val.IsNull() || key_str.empty()) { |
| 69 | throw BinderException("Key name for struct_extract needs to be neither NULL nor empty" ); |
| 70 | } |
| 71 | string key = StringUtil::Lower(str: key_str); |
| 72 | |
| 73 | LogicalType return_type; |
| 74 | idx_t key_index = 0; |
| 75 | bool found_key = false; |
| 76 | |
| 77 | for (size_t i = 0; i < struct_children.size(); i++) { |
| 78 | auto &child = struct_children[i]; |
| 79 | if (StringUtil::Lower(str: child.first) == key) { |
| 80 | found_key = true; |
| 81 | key_index = i; |
| 82 | return_type = child.second; |
| 83 | break; |
| 84 | } |
| 85 | } |
| 86 | |
| 87 | if (!found_key) { |
| 88 | vector<string> candidates; |
| 89 | candidates.reserve(n: struct_children.size()); |
| 90 | for (auto &struct_child : struct_children) { |
| 91 | candidates.push_back(x: struct_child.first); |
| 92 | } |
| 93 | auto closest_settings = StringUtil::TopNLevenshtein(strings: candidates, target: key); |
| 94 | auto message = StringUtil::CandidatesMessage(candidates: closest_settings, candidate: "Candidate Entries" ); |
| 95 | throw BinderException("Could not find key \"%s\" in struct\n%s" , key, message); |
| 96 | } |
| 97 | |
| 98 | bound_function.return_type = return_type; |
| 99 | return make_uniq<StructExtractBindData>(args: std::move(key), args&: key_index, args: std::move(return_type)); |
| 100 | } |
| 101 | |
| 102 | static unique_ptr<BaseStatistics> (ClientContext &context, FunctionStatisticsInput &input) { |
| 103 | auto &child_stats = input.child_stats; |
| 104 | auto &bind_data = input.bind_data; |
| 105 | |
| 106 | auto &info = bind_data->Cast<StructExtractBindData>(); |
| 107 | auto struct_child_stats = StructStats::GetChildStats(stats: child_stats[0]); |
| 108 | return struct_child_stats[info.index].ToUnique(); |
| 109 | } |
| 110 | |
| 111 | ScalarFunction StructExtractFun::() { |
| 112 | return ScalarFunction("struct_extract" , {LogicalTypeId::STRUCT, LogicalType::VARCHAR}, LogicalType::ANY, |
| 113 | StructExtractFunction, StructExtractBind, nullptr, PropagateStructExtractStats); |
| 114 | } |
| 115 | |
| 116 | void StructExtractFun::(BuiltinFunctions &set) { |
| 117 | // the arguments and return types are actually set in the binder function |
| 118 | auto fun = GetFunction(); |
| 119 | set.AddFunction(function: fun); |
| 120 | } |
| 121 | |
| 122 | } // namespace duckdb |
| 123 | |