| 1 | #include "duckdb/function/scalar/nested_functions.hpp" | 
|---|
| 2 | #include "duckdb/execution/expression_executor.hpp" | 
|---|
| 3 | #include "duckdb/planner/expression/bound_function_expression.hpp" | 
|---|
| 4 | #include "duckdb/common/string_util.hpp" | 
|---|
| 5 |  | 
|---|
| 6 | using namespace std; | 
|---|
| 7 |  | 
|---|
| 8 | namespace duckdb { | 
|---|
| 9 |  | 
|---|
| 10 | static void (DataChunk &args, ExpressionState &state, Vector &result) { | 
|---|
| 11 | auto &func_expr = (BoundFunctionExpression &)state.expr; | 
|---|
| 12 | auto &info = (StructExtractBindData &)*func_expr.bind_info; | 
|---|
| 13 |  | 
|---|
| 14 | // this should be guaranteed by the binder | 
|---|
| 15 | assert(args.column_count() == 1); | 
|---|
| 16 | auto &vec = args.data[0]; | 
|---|
| 17 |  | 
|---|
| 18 | vec.Verify(args.size()); | 
|---|
| 19 | if (vec.vector_type == VectorType::DICTIONARY_VECTOR) { | 
|---|
| 20 | auto &child = DictionaryVector::Child(vec); | 
|---|
| 21 | auto &dict_sel = DictionaryVector::SelVector(vec); | 
|---|
| 22 | auto &children = StructVector::GetEntries(child); | 
|---|
| 23 | if (info.index >= children.size()) { | 
|---|
| 24 | throw Exception( "Not enough struct entries for struct_extract"); | 
|---|
| 25 | } | 
|---|
| 26 | auto &struct_child = children[info.index]; | 
|---|
| 27 | if (struct_child.first != info.key || struct_child.second->type != info.type) { | 
|---|
| 28 | throw Exception( "Struct key or type mismatch"); | 
|---|
| 29 | } | 
|---|
| 30 | result.Slice(*struct_child.second, dict_sel, args.size()); | 
|---|
| 31 | } else { | 
|---|
| 32 | auto &children = StructVector::GetEntries(vec); | 
|---|
| 33 | if (info.index >= children.size()) { | 
|---|
| 34 | throw Exception( "Not enough struct entries for struct_extract"); | 
|---|
| 35 | } | 
|---|
| 36 | auto &struct_child = children[info.index]; | 
|---|
| 37 | if (struct_child.first != info.key || struct_child.second->type != info.type) { | 
|---|
| 38 | throw Exception( "Struct key or type mismatch"); | 
|---|
| 39 | } | 
|---|
| 40 | result.Reference(*struct_child.second); | 
|---|
| 41 | } | 
|---|
| 42 | result.Verify(args.size()); | 
|---|
| 43 | } | 
|---|
| 44 |  | 
|---|
| 45 | static unique_ptr<FunctionData> (BoundFunctionExpression &expr, ClientContext &context) { | 
|---|
| 46 | // the binder should fix this for us. | 
|---|
| 47 | assert(expr.children.size() == 2); | 
|---|
| 48 | assert(expr.arguments.size() == expr.children.size()); | 
|---|
| 49 | assert(expr.arguments[0].id == SQLTypeId::STRUCT); | 
|---|
| 50 | assert(expr.children[0]->return_type == TypeId::STRUCT); | 
|---|
| 51 | if (expr.arguments[0].child_type.size() < 1) { | 
|---|
| 52 | throw Exception( "Can't extract something from an empty struct"); | 
|---|
| 53 | } | 
|---|
| 54 |  | 
|---|
| 55 | auto &key_child = expr.children[1]; | 
|---|
| 56 |  | 
|---|
| 57 | if (expr.arguments[1].id != SQLTypeId::VARCHAR || key_child->return_type != TypeId::VARCHAR || | 
|---|
| 58 | !key_child->IsScalar()) { | 
|---|
| 59 | throw Exception( "Key name for struct_extract needs to be a constant string"); | 
|---|
| 60 | } | 
|---|
| 61 | Value key_val = ExpressionExecutor::EvaluateScalar(*key_child.get()); | 
|---|
| 62 | assert(key_val.type == TypeId::VARCHAR); | 
|---|
| 63 | if (key_val.is_null || key_val.str_value.length() < 1) { | 
|---|
| 64 | throw Exception( "Key name for struct_extract needs to be neither NULL nor empty"); | 
|---|
| 65 | } | 
|---|
| 66 | string key = StringUtil::Lower(key_val.str_value); | 
|---|
| 67 |  | 
|---|
| 68 | SQLType return_type; | 
|---|
| 69 | idx_t key_index = 0; | 
|---|
| 70 | bool found_key = false; | 
|---|
| 71 |  | 
|---|
| 72 | for (size_t i = 0; i < expr.arguments[0].child_type.size(); i++) { | 
|---|
| 73 | auto &child = expr.arguments[0].child_type[i]; | 
|---|
| 74 | if (child.first == key) { | 
|---|
| 75 | found_key = true; | 
|---|
| 76 | key_index = i; | 
|---|
| 77 | return_type = child.second; | 
|---|
| 78 | break; | 
|---|
| 79 | } | 
|---|
| 80 | } | 
|---|
| 81 | if (!found_key) { | 
|---|
| 82 | throw Exception( "Could not find key in struct"); | 
|---|
| 83 | } | 
|---|
| 84 |  | 
|---|
| 85 | expr.return_type = GetInternalType(return_type); | 
|---|
| 86 | expr.sql_return_type = return_type; | 
|---|
| 87 | expr.children.pop_back(); | 
|---|
| 88 | return make_unique<StructExtractBindData>(key, key_index, GetInternalType(return_type)); | 
|---|
| 89 | } | 
|---|
| 90 |  | 
|---|
| 91 | void StructExtractFun::(BuiltinFunctions &set) { | 
|---|
| 92 | // the arguments and return types are actually set in the binder function | 
|---|
| 93 | ScalarFunction fun( "struct_extract", {SQLType::STRUCT, SQLType::VARCHAR}, SQLType::ANY, struct_extract_fun, false, | 
|---|
| 94 | struct_extract_bind); | 
|---|
| 95 | set.AddFunction(fun); | 
|---|
| 96 | } | 
|---|
| 97 |  | 
|---|
| 98 | } // namespace duckdb | 
|---|
| 99 |  | 
|---|