| 1 | #include "duckdb/function/scalar/nested_functions.hpp" |
| 2 | #include "duckdb/execution/expression_executor.hpp" |
| 3 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
| 4 | #include "duckdb/common/string_util.hpp" |
| 5 | |
| 6 | using namespace std; |
| 7 | |
| 8 | namespace duckdb { |
| 9 | |
| 10 | static void (DataChunk &args, ExpressionState &state, Vector &result) { |
| 11 | auto &func_expr = (BoundFunctionExpression &)state.expr; |
| 12 | auto &info = (StructExtractBindData &)*func_expr.bind_info; |
| 13 | |
| 14 | // this should be guaranteed by the binder |
| 15 | assert(args.column_count() == 1); |
| 16 | auto &vec = args.data[0]; |
| 17 | |
| 18 | vec.Verify(args.size()); |
| 19 | if (vec.vector_type == VectorType::DICTIONARY_VECTOR) { |
| 20 | auto &child = DictionaryVector::Child(vec); |
| 21 | auto &dict_sel = DictionaryVector::SelVector(vec); |
| 22 | auto &children = StructVector::GetEntries(child); |
| 23 | if (info.index >= children.size()) { |
| 24 | throw Exception("Not enough struct entries for struct_extract" ); |
| 25 | } |
| 26 | auto &struct_child = children[info.index]; |
| 27 | if (struct_child.first != info.key || struct_child.second->type != info.type) { |
| 28 | throw Exception("Struct key or type mismatch" ); |
| 29 | } |
| 30 | result.Slice(*struct_child.second, dict_sel, args.size()); |
| 31 | } else { |
| 32 | auto &children = StructVector::GetEntries(vec); |
| 33 | if (info.index >= children.size()) { |
| 34 | throw Exception("Not enough struct entries for struct_extract" ); |
| 35 | } |
| 36 | auto &struct_child = children[info.index]; |
| 37 | if (struct_child.first != info.key || struct_child.second->type != info.type) { |
| 38 | throw Exception("Struct key or type mismatch" ); |
| 39 | } |
| 40 | result.Reference(*struct_child.second); |
| 41 | } |
| 42 | result.Verify(args.size()); |
| 43 | } |
| 44 | |
| 45 | static unique_ptr<FunctionData> (BoundFunctionExpression &expr, ClientContext &context) { |
| 46 | // the binder should fix this for us. |
| 47 | assert(expr.children.size() == 2); |
| 48 | assert(expr.arguments.size() == expr.children.size()); |
| 49 | assert(expr.arguments[0].id == SQLTypeId::STRUCT); |
| 50 | assert(expr.children[0]->return_type == TypeId::STRUCT); |
| 51 | if (expr.arguments[0].child_type.size() < 1) { |
| 52 | throw Exception("Can't extract something from an empty struct" ); |
| 53 | } |
| 54 | |
| 55 | auto &key_child = expr.children[1]; |
| 56 | |
| 57 | if (expr.arguments[1].id != SQLTypeId::VARCHAR || key_child->return_type != TypeId::VARCHAR || |
| 58 | !key_child->IsScalar()) { |
| 59 | throw Exception("Key name for struct_extract needs to be a constant string" ); |
| 60 | } |
| 61 | Value key_val = ExpressionExecutor::EvaluateScalar(*key_child.get()); |
| 62 | assert(key_val.type == TypeId::VARCHAR); |
| 63 | if (key_val.is_null || key_val.str_value.length() < 1) { |
| 64 | throw Exception("Key name for struct_extract needs to be neither NULL nor empty" ); |
| 65 | } |
| 66 | string key = StringUtil::Lower(key_val.str_value); |
| 67 | |
| 68 | SQLType return_type; |
| 69 | idx_t key_index = 0; |
| 70 | bool found_key = false; |
| 71 | |
| 72 | for (size_t i = 0; i < expr.arguments[0].child_type.size(); i++) { |
| 73 | auto &child = expr.arguments[0].child_type[i]; |
| 74 | if (child.first == key) { |
| 75 | found_key = true; |
| 76 | key_index = i; |
| 77 | return_type = child.second; |
| 78 | break; |
| 79 | } |
| 80 | } |
| 81 | if (!found_key) { |
| 82 | throw Exception("Could not find key in struct" ); |
| 83 | } |
| 84 | |
| 85 | expr.return_type = GetInternalType(return_type); |
| 86 | expr.sql_return_type = return_type; |
| 87 | expr.children.pop_back(); |
| 88 | return make_unique<StructExtractBindData>(key, key_index, GetInternalType(return_type)); |
| 89 | } |
| 90 | |
| 91 | void StructExtractFun::(BuiltinFunctions &set) { |
| 92 | // the arguments and return types are actually set in the binder function |
| 93 | ScalarFunction fun("struct_extract" , {SQLType::STRUCT, SQLType::VARCHAR}, SQLType::ANY, struct_extract_fun, false, |
| 94 | struct_extract_bind); |
| 95 | set.AddFunction(fun); |
| 96 | } |
| 97 | |
| 98 | } // namespace duckdb |
| 99 | |