1#include "duckdb/function/aggregate/distributive_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/types/null_value.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include <string>
6
7using namespace std;
8
9namespace duckdb {
10
11struct string_agg_state_t {
12 char *dataptr;
13 idx_t size;
14 idx_t alloc_size;
15};
16
17struct StringAggFunction {
18 template <class STATE> static void Initialize(STATE *state) {
19 state->dataptr = nullptr;
20 state->alloc_size = 0;
21 state->size = 0;
22 }
23
24 template <class A_TYPE, class B_TYPE, class STATE, class OP>
25 static void Operation(STATE *state, A_TYPE *str_data, B_TYPE *sep_data, nullmask_t &str_nullmask,
26 nullmask_t &sep_nullmask, idx_t str_idx, idx_t sep_idx) {
27 auto str = str_data[str_idx].GetData();
28 auto sep = sep_data[sep_idx].GetData();
29 auto str_size = str_data[str_idx].GetSize() + 1;
30 auto sep_size = sep_data[sep_idx].GetSize();
31
32 if (state->dataptr == nullptr) {
33 // first iteration: allocate space for the string and copy it into the state
34 state->alloc_size = std::max((idx_t)8, (idx_t)NextPowerOfTwo(str_size));
35 state->dataptr = new char[state->alloc_size];
36 state->size = str_size - 1;
37 memcpy(state->dataptr, str, str_size);
38 } else {
39 // subsequent iteration: first check if we have space to place the string and separator
40 idx_t required_size = state->size + str_size + sep_size;
41 if (required_size > state->alloc_size) {
42 // no space! allocate extra space
43 while (state->alloc_size < required_size) {
44 state->alloc_size *= 2;
45 }
46 auto new_data = new char[state->alloc_size];
47 memcpy(new_data, state->dataptr, state->size);
48 delete[] state->dataptr;
49 state->dataptr = new_data;
50 }
51 // copy the separator
52 memcpy(state->dataptr + state->size, sep, sep_size);
53 state->size += sep_size;
54 // copy the string
55 memcpy(state->dataptr + state->size, str, str_size);
56 state->size += str_size - 1;
57 }
58 }
59
60 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
61 throw NotImplementedException("String aggregate combine!");
62 }
63
64 template <class T, class STATE>
65 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
66 if (!state->dataptr) {
67 nullmask[idx] = true;
68 } else {
69 target[idx] = StringVector::AddString(result, state->dataptr, state->size);
70 }
71 }
72
73 template <class STATE> static void Destroy(STATE *state) {
74 if (state->dataptr) {
75 delete[] state->dataptr;
76 }
77 }
78
79 static bool IgnoreNull() {
80 return true;
81 }
82};
83
84void StringAggFun::RegisterFunction(BuiltinFunctions &set) {
85 AggregateFunctionSet string_agg("string_agg");
86 string_agg.AddFunction(AggregateFunction::BinaryAggregateDestructor<string_agg_state_t, string_t, string_t,
87 string_t, StringAggFunction>(
88 SQLType::VARCHAR, SQLType::VARCHAR, SQLType::VARCHAR));
89 set.AddFunction(string_agg);
90}
91
92} // namespace duckdb
93