1 | #include "duckdb/function/aggregate/nested_functions.hpp" |
2 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
3 | #include "duckdb/common/types/chunk_collection.hpp" |
4 | |
5 | using namespace std; |
6 | |
7 | namespace duckdb { |
8 | |
9 | struct list_agg_state_t { |
10 | ChunkCollection *cc; |
11 | }; |
12 | |
13 | struct ListFunction { |
14 | template <class STATE> static void Initialize(STATE *state) { |
15 | state->cc = nullptr; |
16 | } |
17 | |
18 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
19 | throw NotImplementedException("COMBINE not implemented for LIST" ); |
20 | } |
21 | |
22 | template <class STATE> static void Destroy(STATE *state) { |
23 | if (state->cc) { |
24 | delete state->cc; |
25 | } |
26 | } |
27 | static bool IgnoreNull() { |
28 | return true; |
29 | } |
30 | }; |
31 | |
32 | static void list_update(Vector inputs[], idx_t input_count, Vector &state_vector, idx_t count) { |
33 | assert(input_count == 1); |
34 | |
35 | auto &input = inputs[0]; |
36 | VectorData sdata; |
37 | state_vector.Orrify(count, sdata); |
38 | |
39 | DataChunk insert_chunk; |
40 | |
41 | vector<TypeId> chunk_types; |
42 | chunk_types.push_back(input.type); |
43 | insert_chunk.Initialize(chunk_types); |
44 | insert_chunk.SetCardinality(1); |
45 | |
46 | auto states = (list_agg_state_t **)sdata.data; |
47 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
48 | for (idx_t i = 0; i < count; i++) { |
49 | auto state = states[sdata.sel->get_index(i)]; |
50 | if (!state->cc) { |
51 | state->cc = new ChunkCollection(); |
52 | } |
53 | sel.set_index(0, i); |
54 | insert_chunk.data[0].Slice(input, sel, 1); |
55 | state->cc->Append(insert_chunk); |
56 | } |
57 | } |
58 | |
59 | static void list_finalize(Vector &state_vector, Vector &result, idx_t count) { |
60 | VectorData sdata; |
61 | state_vector.Orrify(count, sdata); |
62 | auto states = (list_agg_state_t **)sdata.data; |
63 | |
64 | result.Initialize(TypeId::LIST); |
65 | auto list_struct_data = FlatVector::GetData<list_entry_t>(result); |
66 | |
67 | size_t total_len = 0; |
68 | for (idx_t i = 0; i < count; i++) { |
69 | auto state = states[sdata.sel->get_index(i)]; |
70 | assert(state->cc); |
71 | auto &state_cc = *state->cc; |
72 | assert(state_cc.types.size() == 1); |
73 | list_struct_data[i].length = state_cc.count; |
74 | list_struct_data[i].offset = total_len; |
75 | total_len += state_cc.count; |
76 | } |
77 | |
78 | auto list_child = make_unique<ChunkCollection>(); |
79 | for (idx_t i = 0; i < count; i++) { |
80 | auto state = states[sdata.sel->get_index(i)]; |
81 | auto &state_cc = *state->cc; |
82 | assert(state_cc.chunks[0]->column_count() == 1); |
83 | list_child->Append(state_cc); |
84 | } |
85 | assert(list_child->count == total_len); |
86 | ListVector::SetEntry(result, move(list_child)); |
87 | } |
88 | |
89 | unique_ptr<FunctionData> list_bind(BoundAggregateExpression &expr, ClientContext &context, SQLType &return_type) { |
90 | assert(expr.children.size() == 1); |
91 | return_type = SQLType::LIST; |
92 | return_type.child_type.push_back(make_pair("" , expr.arguments[0])); |
93 | return make_unique<ListBindData>(); // TODO atm this is not used anywhere but it might not be required after all |
94 | // except for sanity checking |
95 | } |
96 | |
97 | void ListFun::RegisterFunction(BuiltinFunctions &set) { |
98 | auto agg = AggregateFunction("list" , {SQLType::ANY}, SQLType::LIST, AggregateFunction::StateSize<list_agg_state_t>, |
99 | AggregateFunction::StateInitialize<list_agg_state_t, ListFunction>, list_update, |
100 | AggregateFunction::StateCombine<list_agg_state_t, ListFunction>, list_finalize, |
101 | nullptr, list_bind, AggregateFunction::StateDestroy<list_agg_state_t, ListFunction>); |
102 | set.AddFunction(agg); |
103 | } |
104 | |
105 | } // namespace duckdb |
106 | |