1#include "duckdb/function/scalar/nested_functions.hpp"
2#include "duckdb/execution/expression_executor.hpp"
3#include "duckdb/planner/expression/bound_function_expression.hpp"
4#include "duckdb/common/string_util.hpp"
5
6using namespace std;
7
8namespace duckdb {
9
10static void struct_extract_fun(DataChunk &args, ExpressionState &state, Vector &result) {
11 auto &func_expr = (BoundFunctionExpression &)state.expr;
12 auto &info = (StructExtractBindData &)*func_expr.bind_info;
13
14 // this should be guaranteed by the binder
15 assert(args.column_count() == 1);
16 auto &vec = args.data[0];
17
18 vec.Verify(args.size());
19 if (vec.vector_type == VectorType::DICTIONARY_VECTOR) {
20 auto &child = DictionaryVector::Child(vec);
21 auto &dict_sel = DictionaryVector::SelVector(vec);
22 auto &children = StructVector::GetEntries(child);
23 if (info.index >= children.size()) {
24 throw Exception("Not enough struct entries for struct_extract");
25 }
26 auto &struct_child = children[info.index];
27 if (struct_child.first != info.key || struct_child.second->type != info.type) {
28 throw Exception("Struct key or type mismatch");
29 }
30 result.Slice(*struct_child.second, dict_sel, args.size());
31 } else {
32 auto &children = StructVector::GetEntries(vec);
33 if (info.index >= children.size()) {
34 throw Exception("Not enough struct entries for struct_extract");
35 }
36 auto &struct_child = children[info.index];
37 if (struct_child.first != info.key || struct_child.second->type != info.type) {
38 throw Exception("Struct key or type mismatch");
39 }
40 result.Reference(*struct_child.second);
41 }
42 result.Verify(args.size());
43}
44
45static unique_ptr<FunctionData> struct_extract_bind(BoundFunctionExpression &expr, ClientContext &context) {
46 // the binder should fix this for us.
47 assert(expr.children.size() == 2);
48 assert(expr.arguments.size() == expr.children.size());
49 assert(expr.arguments[0].id == SQLTypeId::STRUCT);
50 assert(expr.children[0]->return_type == TypeId::STRUCT);
51 if (expr.arguments[0].child_type.size() < 1) {
52 throw Exception("Can't extract something from an empty struct");
53 }
54
55 auto &key_child = expr.children[1];
56
57 if (expr.arguments[1].id != SQLTypeId::VARCHAR || key_child->return_type != TypeId::VARCHAR ||
58 !key_child->IsScalar()) {
59 throw Exception("Key name for struct_extract needs to be a constant string");
60 }
61 Value key_val = ExpressionExecutor::EvaluateScalar(*key_child.get());
62 assert(key_val.type == TypeId::VARCHAR);
63 if (key_val.is_null || key_val.str_value.length() < 1) {
64 throw Exception("Key name for struct_extract needs to be neither NULL nor empty");
65 }
66 string key = StringUtil::Lower(key_val.str_value);
67
68 SQLType return_type;
69 idx_t key_index = 0;
70 bool found_key = false;
71
72 for (size_t i = 0; i < expr.arguments[0].child_type.size(); i++) {
73 auto &child = expr.arguments[0].child_type[i];
74 if (child.first == key) {
75 found_key = true;
76 key_index = i;
77 return_type = child.second;
78 break;
79 }
80 }
81 if (!found_key) {
82 throw Exception("Could not find key in struct");
83 }
84
85 expr.return_type = GetInternalType(return_type);
86 expr.sql_return_type = return_type;
87 expr.children.pop_back();
88 return make_unique<StructExtractBindData>(key, key_index, GetInternalType(return_type));
89}
90
91void StructExtractFun::RegisterFunction(BuiltinFunctions &set) {
92 // the arguments and return types are actually set in the binder function
93 ScalarFunction fun("struct_extract", {SQLType::STRUCT, SQLType::VARCHAR}, SQLType::ANY, struct_extract_fun, false,
94 struct_extract_bind);
95 set.AddFunction(fun);
96}
97
98} // namespace duckdb
99