| 1 | #include "duckdb/function/aggregate/nested_functions.hpp" |
| 2 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 3 | #include "duckdb/common/types/chunk_collection.hpp" |
| 4 | |
| 5 | using namespace std; |
| 6 | |
| 7 | namespace duckdb { |
| 8 | |
| 9 | struct list_agg_state_t { |
| 10 | ChunkCollection *cc; |
| 11 | }; |
| 12 | |
| 13 | struct ListFunction { |
| 14 | template <class STATE> static void Initialize(STATE *state) { |
| 15 | state->cc = nullptr; |
| 16 | } |
| 17 | |
| 18 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
| 19 | throw NotImplementedException("COMBINE not implemented for LIST" ); |
| 20 | } |
| 21 | |
| 22 | template <class STATE> static void Destroy(STATE *state) { |
| 23 | if (state->cc) { |
| 24 | delete state->cc; |
| 25 | } |
| 26 | } |
| 27 | static bool IgnoreNull() { |
| 28 | return true; |
| 29 | } |
| 30 | }; |
| 31 | |
| 32 | static void list_update(Vector inputs[], idx_t input_count, Vector &state_vector, idx_t count) { |
| 33 | assert(input_count == 1); |
| 34 | |
| 35 | auto &input = inputs[0]; |
| 36 | VectorData sdata; |
| 37 | state_vector.Orrify(count, sdata); |
| 38 | |
| 39 | DataChunk insert_chunk; |
| 40 | |
| 41 | vector<TypeId> chunk_types; |
| 42 | chunk_types.push_back(input.type); |
| 43 | insert_chunk.Initialize(chunk_types); |
| 44 | insert_chunk.SetCardinality(1); |
| 45 | |
| 46 | auto states = (list_agg_state_t **)sdata.data; |
| 47 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
| 48 | for (idx_t i = 0; i < count; i++) { |
| 49 | auto state = states[sdata.sel->get_index(i)]; |
| 50 | if (!state->cc) { |
| 51 | state->cc = new ChunkCollection(); |
| 52 | } |
| 53 | sel.set_index(0, i); |
| 54 | insert_chunk.data[0].Slice(input, sel, 1); |
| 55 | state->cc->Append(insert_chunk); |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | static void list_finalize(Vector &state_vector, Vector &result, idx_t count) { |
| 60 | VectorData sdata; |
| 61 | state_vector.Orrify(count, sdata); |
| 62 | auto states = (list_agg_state_t **)sdata.data; |
| 63 | |
| 64 | result.Initialize(TypeId::LIST); |
| 65 | auto list_struct_data = FlatVector::GetData<list_entry_t>(result); |
| 66 | |
| 67 | size_t total_len = 0; |
| 68 | for (idx_t i = 0; i < count; i++) { |
| 69 | auto state = states[sdata.sel->get_index(i)]; |
| 70 | assert(state->cc); |
| 71 | auto &state_cc = *state->cc; |
| 72 | assert(state_cc.types.size() == 1); |
| 73 | list_struct_data[i].length = state_cc.count; |
| 74 | list_struct_data[i].offset = total_len; |
| 75 | total_len += state_cc.count; |
| 76 | } |
| 77 | |
| 78 | auto list_child = make_unique<ChunkCollection>(); |
| 79 | for (idx_t i = 0; i < count; i++) { |
| 80 | auto state = states[sdata.sel->get_index(i)]; |
| 81 | auto &state_cc = *state->cc; |
| 82 | assert(state_cc.chunks[0]->column_count() == 1); |
| 83 | list_child->Append(state_cc); |
| 84 | } |
| 85 | assert(list_child->count == total_len); |
| 86 | ListVector::SetEntry(result, move(list_child)); |
| 87 | } |
| 88 | |
| 89 | unique_ptr<FunctionData> list_bind(BoundAggregateExpression &expr, ClientContext &context, SQLType &return_type) { |
| 90 | assert(expr.children.size() == 1); |
| 91 | return_type = SQLType::LIST; |
| 92 | return_type.child_type.push_back(make_pair("" , expr.arguments[0])); |
| 93 | return make_unique<ListBindData>(); // TODO atm this is not used anywhere but it might not be required after all |
| 94 | // except for sanity checking |
| 95 | } |
| 96 | |
| 97 | void ListFun::RegisterFunction(BuiltinFunctions &set) { |
| 98 | auto agg = AggregateFunction("list" , {SQLType::ANY}, SQLType::LIST, AggregateFunction::StateSize<list_agg_state_t>, |
| 99 | AggregateFunction::StateInitialize<list_agg_state_t, ListFunction>, list_update, |
| 100 | AggregateFunction::StateCombine<list_agg_state_t, ListFunction>, list_finalize, |
| 101 | nullptr, list_bind, AggregateFunction::StateDestroy<list_agg_state_t, ListFunction>); |
| 102 | set.AddFunction(agg); |
| 103 | } |
| 104 | |
| 105 | } // namespace duckdb |
| 106 | |