1 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
2 | #include "duckdb/common/exception.hpp" |
3 | #include "duckdb/common/types/null_value.hpp" |
4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
5 | #include <string> |
6 | |
7 | using namespace std; |
8 | |
9 | namespace duckdb { |
10 | |
11 | struct string_agg_state_t { |
12 | char *dataptr; |
13 | idx_t size; |
14 | idx_t alloc_size; |
15 | }; |
16 | |
17 | struct StringAggFunction { |
18 | template <class STATE> static void Initialize(STATE *state) { |
19 | state->dataptr = nullptr; |
20 | state->alloc_size = 0; |
21 | state->size = 0; |
22 | } |
23 | |
24 | template <class A_TYPE, class B_TYPE, class STATE, class OP> |
25 | static void Operation(STATE *state, A_TYPE *str_data, B_TYPE *sep_data, nullmask_t &str_nullmask, |
26 | nullmask_t &sep_nullmask, idx_t str_idx, idx_t sep_idx) { |
27 | auto str = str_data[str_idx].GetData(); |
28 | auto sep = sep_data[sep_idx].GetData(); |
29 | auto str_size = str_data[str_idx].GetSize() + 1; |
30 | auto sep_size = sep_data[sep_idx].GetSize(); |
31 | |
32 | if (state->dataptr == nullptr) { |
33 | // first iteration: allocate space for the string and copy it into the state |
34 | state->alloc_size = std::max((idx_t)8, (idx_t)NextPowerOfTwo(str_size)); |
35 | state->dataptr = new char[state->alloc_size]; |
36 | state->size = str_size - 1; |
37 | memcpy(state->dataptr, str, str_size); |
38 | } else { |
39 | // subsequent iteration: first check if we have space to place the string and separator |
40 | idx_t required_size = state->size + str_size + sep_size; |
41 | if (required_size > state->alloc_size) { |
42 | // no space! allocate extra space |
43 | while (state->alloc_size < required_size) { |
44 | state->alloc_size *= 2; |
45 | } |
46 | auto new_data = new char[state->alloc_size]; |
47 | memcpy(new_data, state->dataptr, state->size); |
48 | delete[] state->dataptr; |
49 | state->dataptr = new_data; |
50 | } |
51 | // copy the separator |
52 | memcpy(state->dataptr + state->size, sep, sep_size); |
53 | state->size += sep_size; |
54 | // copy the string |
55 | memcpy(state->dataptr + state->size, str, str_size); |
56 | state->size += str_size - 1; |
57 | } |
58 | } |
59 | |
60 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
61 | throw NotImplementedException("String aggregate combine!" ); |
62 | } |
63 | |
64 | template <class T, class STATE> |
65 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
66 | if (!state->dataptr) { |
67 | nullmask[idx] = true; |
68 | } else { |
69 | target[idx] = StringVector::AddString(result, state->dataptr, state->size); |
70 | } |
71 | } |
72 | |
73 | template <class STATE> static void Destroy(STATE *state) { |
74 | if (state->dataptr) { |
75 | delete[] state->dataptr; |
76 | } |
77 | } |
78 | |
79 | static bool IgnoreNull() { |
80 | return true; |
81 | } |
82 | }; |
83 | |
84 | void StringAggFun::RegisterFunction(BuiltinFunctions &set) { |
85 | AggregateFunctionSet string_agg("string_agg" ); |
86 | string_agg.AddFunction(AggregateFunction::BinaryAggregateDestructor<string_agg_state_t, string_t, string_t, |
87 | string_t, StringAggFunction>( |
88 | SQLType::VARCHAR, SQLType::VARCHAR, SQLType::VARCHAR)); |
89 | set.AddFunction(string_agg); |
90 | } |
91 | |
92 | } // namespace duckdb |
93 | |