| 1 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
| 2 | #include "duckdb/common/exception.hpp" |
| 3 | #include "duckdb/common/types/null_value.hpp" |
| 4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 5 | #include <string> |
| 6 | |
| 7 | using namespace std; |
| 8 | |
| 9 | namespace duckdb { |
| 10 | |
| 11 | struct string_agg_state_t { |
| 12 | char *dataptr; |
| 13 | idx_t size; |
| 14 | idx_t alloc_size; |
| 15 | }; |
| 16 | |
| 17 | struct StringAggFunction { |
| 18 | template <class STATE> static void Initialize(STATE *state) { |
| 19 | state->dataptr = nullptr; |
| 20 | state->alloc_size = 0; |
| 21 | state->size = 0; |
| 22 | } |
| 23 | |
| 24 | template <class A_TYPE, class B_TYPE, class STATE, class OP> |
| 25 | static void Operation(STATE *state, A_TYPE *str_data, B_TYPE *sep_data, nullmask_t &str_nullmask, |
| 26 | nullmask_t &sep_nullmask, idx_t str_idx, idx_t sep_idx) { |
| 27 | auto str = str_data[str_idx].GetData(); |
| 28 | auto sep = sep_data[sep_idx].GetData(); |
| 29 | auto str_size = str_data[str_idx].GetSize() + 1; |
| 30 | auto sep_size = sep_data[sep_idx].GetSize(); |
| 31 | |
| 32 | if (state->dataptr == nullptr) { |
| 33 | // first iteration: allocate space for the string and copy it into the state |
| 34 | state->alloc_size = std::max((idx_t)8, (idx_t)NextPowerOfTwo(str_size)); |
| 35 | state->dataptr = new char[state->alloc_size]; |
| 36 | state->size = str_size - 1; |
| 37 | memcpy(state->dataptr, str, str_size); |
| 38 | } else { |
| 39 | // subsequent iteration: first check if we have space to place the string and separator |
| 40 | idx_t required_size = state->size + str_size + sep_size; |
| 41 | if (required_size > state->alloc_size) { |
| 42 | // no space! allocate extra space |
| 43 | while (state->alloc_size < required_size) { |
| 44 | state->alloc_size *= 2; |
| 45 | } |
| 46 | auto new_data = new char[state->alloc_size]; |
| 47 | memcpy(new_data, state->dataptr, state->size); |
| 48 | delete[] state->dataptr; |
| 49 | state->dataptr = new_data; |
| 50 | } |
| 51 | // copy the separator |
| 52 | memcpy(state->dataptr + state->size, sep, sep_size); |
| 53 | state->size += sep_size; |
| 54 | // copy the string |
| 55 | memcpy(state->dataptr + state->size, str, str_size); |
| 56 | state->size += str_size - 1; |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
| 61 | throw NotImplementedException("String aggregate combine!" ); |
| 62 | } |
| 63 | |
| 64 | template <class T, class STATE> |
| 65 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
| 66 | if (!state->dataptr) { |
| 67 | nullmask[idx] = true; |
| 68 | } else { |
| 69 | target[idx] = StringVector::AddString(result, state->dataptr, state->size); |
| 70 | } |
| 71 | } |
| 72 | |
| 73 | template <class STATE> static void Destroy(STATE *state) { |
| 74 | if (state->dataptr) { |
| 75 | delete[] state->dataptr; |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | static bool IgnoreNull() { |
| 80 | return true; |
| 81 | } |
| 82 | }; |
| 83 | |
| 84 | void StringAggFun::RegisterFunction(BuiltinFunctions &set) { |
| 85 | AggregateFunctionSet string_agg("string_agg" ); |
| 86 | string_agg.AddFunction(AggregateFunction::BinaryAggregateDestructor<string_agg_state_t, string_t, string_t, |
| 87 | string_t, StringAggFunction>( |
| 88 | SQLType::VARCHAR, SQLType::VARCHAR, SQLType::VARCHAR)); |
| 89 | set.AddFunction(string_agg); |
| 90 | } |
| 91 | |
| 92 | } // namespace duckdb |
| 93 | |