1#include "duckdb/function/aggregate/nested_functions.hpp"
2#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
3#include "duckdb/common/types/chunk_collection.hpp"
4
5using namespace std;
6
7namespace duckdb {
8
9struct list_agg_state_t {
10 ChunkCollection *cc;
11};
12
13struct ListFunction {
14 template <class STATE> static void Initialize(STATE *state) {
15 state->cc = nullptr;
16 }
17
18 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
19 throw NotImplementedException("COMBINE not implemented for LIST");
20 }
21
22 template <class STATE> static void Destroy(STATE *state) {
23 if (state->cc) {
24 delete state->cc;
25 }
26 }
27 static bool IgnoreNull() {
28 return true;
29 }
30};
31
32static void list_update(Vector inputs[], idx_t input_count, Vector &state_vector, idx_t count) {
33 assert(input_count == 1);
34
35 auto &input = inputs[0];
36 VectorData sdata;
37 state_vector.Orrify(count, sdata);
38
39 DataChunk insert_chunk;
40
41 vector<TypeId> chunk_types;
42 chunk_types.push_back(input.type);
43 insert_chunk.Initialize(chunk_types);
44 insert_chunk.SetCardinality(1);
45
46 auto states = (list_agg_state_t **)sdata.data;
47 SelectionVector sel(STANDARD_VECTOR_SIZE);
48 for (idx_t i = 0; i < count; i++) {
49 auto state = states[sdata.sel->get_index(i)];
50 if (!state->cc) {
51 state->cc = new ChunkCollection();
52 }
53 sel.set_index(0, i);
54 insert_chunk.data[0].Slice(input, sel, 1);
55 state->cc->Append(insert_chunk);
56 }
57}
58
59static void list_finalize(Vector &state_vector, Vector &result, idx_t count) {
60 VectorData sdata;
61 state_vector.Orrify(count, sdata);
62 auto states = (list_agg_state_t **)sdata.data;
63
64 result.Initialize(TypeId::LIST);
65 auto list_struct_data = FlatVector::GetData<list_entry_t>(result);
66
67 size_t total_len = 0;
68 for (idx_t i = 0; i < count; i++) {
69 auto state = states[sdata.sel->get_index(i)];
70 assert(state->cc);
71 auto &state_cc = *state->cc;
72 assert(state_cc.types.size() == 1);
73 list_struct_data[i].length = state_cc.count;
74 list_struct_data[i].offset = total_len;
75 total_len += state_cc.count;
76 }
77
78 auto list_child = make_unique<ChunkCollection>();
79 for (idx_t i = 0; i < count; i++) {
80 auto state = states[sdata.sel->get_index(i)];
81 auto &state_cc = *state->cc;
82 assert(state_cc.chunks[0]->column_count() == 1);
83 list_child->Append(state_cc);
84 }
85 assert(list_child->count == total_len);
86 ListVector::SetEntry(result, move(list_child));
87}
88
89unique_ptr<FunctionData> list_bind(BoundAggregateExpression &expr, ClientContext &context, SQLType &return_type) {
90 assert(expr.children.size() == 1);
91 return_type = SQLType::LIST;
92 return_type.child_type.push_back(make_pair("", expr.arguments[0]));
93 return make_unique<ListBindData>(); // TODO atm this is not used anywhere but it might not be required after all
94 // except for sanity checking
95}
96
97void ListFun::RegisterFunction(BuiltinFunctions &set) {
98 auto agg = AggregateFunction("list", {SQLType::ANY}, SQLType::LIST, AggregateFunction::StateSize<list_agg_state_t>,
99 AggregateFunction::StateInitialize<list_agg_state_t, ListFunction>, list_update,
100 AggregateFunction::StateCombine<list_agg_state_t, ListFunction>, list_finalize,
101 nullptr, list_bind, AggregateFunction::StateDestroy<list_agg_state_t, ListFunction>);
102 set.AddFunction(agg);
103}
104
105} // namespace duckdb
106