| 1 | #pragma once |
| 2 | |
| 3 | #include <Columns/ColumnArray.h> |
| 4 | #include <Common/assert_cast.h> |
| 5 | #include <DataTypes/DataTypeArray.h> |
| 6 | #include <AggregateFunctions/IAggregateFunction.h> |
| 7 | |
| 8 | #include <IO/WriteBuffer.h> |
| 9 | #include <IO/ReadBuffer.h> |
| 10 | #include <IO/WriteHelpers.h> |
| 11 | #include <IO/ReadHelpers.h> |
| 12 | |
| 13 | |
| 14 | namespace DB |
| 15 | { |
| 16 | |
| 17 | namespace ErrorCodes |
| 18 | { |
| 19 | extern const int PARAMETER_OUT_OF_BOUND; |
| 20 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| 21 | extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; |
| 22 | } |
| 23 | |
| 24 | |
| 25 | struct AggregateFunctionForEachData |
| 26 | { |
| 27 | size_t dynamic_array_size = 0; |
| 28 | char * array_of_aggregate_datas = nullptr; |
| 29 | }; |
| 30 | |
| 31 | /** Adaptor for aggregate functions. |
| 32 | * Adding -ForEach suffix to aggregate function |
| 33 | * will convert that aggregate function to a function, accepting arrays, |
| 34 | * and applies aggregation for each corresponding elements of arrays independently, |
| 35 | * returning arrays of aggregated values on corresponding positions. |
| 36 | * |
| 37 | * Example: sumForEach of: |
| 38 | * [1, 2], |
| 39 | * [3, 4, 5], |
| 40 | * [6, 7] |
| 41 | * will return: |
| 42 | * [10, 13, 5] |
| 43 | * |
| 44 | * TODO Allow variable number of arguments. |
| 45 | */ |
| 46 | class AggregateFunctionForEach final : public IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach> |
| 47 | { |
| 48 | private: |
| 49 | AggregateFunctionPtr nested_func; |
| 50 | size_t nested_size_of_data = 0; |
| 51 | size_t num_arguments; |
| 52 | |
| 53 | AggregateFunctionForEachData & ensureAggregateData(AggregateDataPtr place, size_t new_size, Arena & arena) const |
| 54 | { |
| 55 | AggregateFunctionForEachData & state = data(place); |
| 56 | |
| 57 | /// Ensure we have aggreate states for new_size elements, allocate |
| 58 | /// from arena if needed. When reallocating, we can't copy the |
| 59 | /// states to new buffer with memcpy, because they may contain pointers |
| 60 | /// to themselves. In particular, this happens when a state contains |
| 61 | /// a PODArrayWithStackMemory, which stores small number of elements |
| 62 | /// inline. This is why we create new empty states in the new buffer, |
| 63 | /// and merge the old states to them. |
| 64 | size_t old_size = state.dynamic_array_size; |
| 65 | if (old_size < new_size) |
| 66 | { |
| 67 | char * old_state = state.array_of_aggregate_datas; |
| 68 | char * new_state = arena.alignedAlloc( |
| 69 | new_size * nested_size_of_data, |
| 70 | nested_func->alignOfData()); |
| 71 | |
| 72 | size_t i; |
| 73 | try |
| 74 | { |
| 75 | for (i = 0; i < new_size; ++i) |
| 76 | { |
| 77 | nested_func->create(&new_state[i * nested_size_of_data]); |
| 78 | } |
| 79 | } |
| 80 | catch (...) |
| 81 | { |
| 82 | size_t cleanup_size = i; |
| 83 | |
| 84 | for (i = 0; i < cleanup_size; ++i) |
| 85 | { |
| 86 | nested_func->destroy(&new_state[i * nested_size_of_data]); |
| 87 | } |
| 88 | |
| 89 | throw; |
| 90 | } |
| 91 | |
| 92 | for (i = 0; i < old_size; i++) |
| 93 | { |
| 94 | nested_func->merge(&new_state[i * nested_size_of_data], |
| 95 | &old_state[i * nested_size_of_data], |
| 96 | &arena); |
| 97 | } |
| 98 | |
| 99 | state.array_of_aggregate_datas = new_state; |
| 100 | state.dynamic_array_size = new_size; |
| 101 | } |
| 102 | |
| 103 | return state; |
| 104 | } |
| 105 | |
| 106 | public: |
| 107 | AggregateFunctionForEach(AggregateFunctionPtr nested_, const DataTypes & arguments) |
| 108 | : IAggregateFunctionDataHelper<AggregateFunctionForEachData, AggregateFunctionForEach>(arguments, {}) |
| 109 | , nested_func(nested_), num_arguments(arguments.size()) |
| 110 | { |
| 111 | nested_size_of_data = nested_func->sizeOfData(); |
| 112 | |
| 113 | if (arguments.empty()) |
| 114 | throw Exception("Aggregate function " + getName() + " require at least one argument" , ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 115 | |
| 116 | for (const auto & type : arguments) |
| 117 | if (!isArray(type)) |
| 118 | throw Exception("All arguments for aggregate function " + getName() + " must be arrays" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 119 | } |
| 120 | |
| 121 | String getName() const override |
| 122 | { |
| 123 | return nested_func->getName() + "ForEach" ; |
| 124 | } |
| 125 | |
| 126 | DataTypePtr getReturnType() const override |
| 127 | { |
| 128 | return std::make_shared<DataTypeArray>(nested_func->getReturnType()); |
| 129 | } |
| 130 | |
| 131 | void destroy(AggregateDataPtr place) const noexcept override |
| 132 | { |
| 133 | AggregateFunctionForEachData & state = data(place); |
| 134 | |
| 135 | char * nested_state = state.array_of_aggregate_datas; |
| 136 | for (size_t i = 0; i < state.dynamic_array_size; ++i) |
| 137 | { |
| 138 | nested_func->destroy(nested_state); |
| 139 | nested_state += nested_size_of_data; |
| 140 | } |
| 141 | } |
| 142 | |
| 143 | bool hasTrivialDestructor() const override |
| 144 | { |
| 145 | return nested_func->hasTrivialDestructor(); |
| 146 | } |
| 147 | |
| 148 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
| 149 | { |
| 150 | const IColumn * nested[num_arguments]; |
| 151 | |
| 152 | for (size_t i = 0; i < num_arguments; ++i) |
| 153 | nested[i] = &assert_cast<const ColumnArray &>(*columns[i]).getData(); |
| 154 | |
| 155 | const ColumnArray & first_array_column = assert_cast<const ColumnArray &>(*columns[0]); |
| 156 | const IColumn::Offsets & offsets = first_array_column.getOffsets(); |
| 157 | |
| 158 | size_t begin = offsets[row_num - 1]; |
| 159 | size_t end = offsets[row_num]; |
| 160 | |
| 161 | /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. |
| 162 | for (size_t i = 1; i < num_arguments; ++i) |
| 163 | { |
| 164 | const ColumnArray & ith_column = assert_cast<const ColumnArray &>(*columns[i]); |
| 165 | const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); |
| 166 | |
| 167 | if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) |
| 168 | throw Exception("Arrays passed to " + getName() + " aggregate function have different sizes" , ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); |
| 169 | } |
| 170 | |
| 171 | AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena); |
| 172 | |
| 173 | char * nested_state = state.array_of_aggregate_datas; |
| 174 | for (size_t i = begin; i < end; ++i) |
| 175 | { |
| 176 | nested_func->add(nested_state, nested, i, arena); |
| 177 | nested_state += nested_size_of_data; |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override |
| 182 | { |
| 183 | const AggregateFunctionForEachData & rhs_state = data(rhs); |
| 184 | AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena); |
| 185 | |
| 186 | const char * rhs_nested_state = rhs_state.array_of_aggregate_datas; |
| 187 | char * nested_state = state.array_of_aggregate_datas; |
| 188 | |
| 189 | for (size_t i = 0; i < state.dynamic_array_size && i < rhs_state.dynamic_array_size; ++i) |
| 190 | { |
| 191 | nested_func->merge(nested_state, rhs_nested_state, arena); |
| 192 | |
| 193 | rhs_nested_state += nested_size_of_data; |
| 194 | nested_state += nested_size_of_data; |
| 195 | } |
| 196 | } |
| 197 | |
| 198 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
| 199 | { |
| 200 | const AggregateFunctionForEachData & state = data(place); |
| 201 | writeBinary(state.dynamic_array_size, buf); |
| 202 | |
| 203 | const char * nested_state = state.array_of_aggregate_datas; |
| 204 | for (size_t i = 0; i < state.dynamic_array_size; ++i) |
| 205 | { |
| 206 | nested_func->serialize(nested_state, buf); |
| 207 | nested_state += nested_size_of_data; |
| 208 | } |
| 209 | } |
| 210 | |
| 211 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override |
| 212 | { |
| 213 | AggregateFunctionForEachData & state = data(place); |
| 214 | |
| 215 | size_t new_size = 0; |
| 216 | readBinary(new_size, buf); |
| 217 | |
| 218 | ensureAggregateData(place, new_size, *arena); |
| 219 | |
| 220 | char * nested_state = state.array_of_aggregate_datas; |
| 221 | for (size_t i = 0; i < new_size; ++i) |
| 222 | { |
| 223 | nested_func->deserialize(nested_state, buf, arena); |
| 224 | nested_state += nested_size_of_data; |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
| 229 | { |
| 230 | const AggregateFunctionForEachData & state = data(place); |
| 231 | |
| 232 | ColumnArray & arr_to = assert_cast<ColumnArray &>(to); |
| 233 | ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); |
| 234 | IColumn & elems_to = arr_to.getData(); |
| 235 | |
| 236 | const char * nested_state = state.array_of_aggregate_datas; |
| 237 | for (size_t i = 0; i < state.dynamic_array_size; ++i) |
| 238 | { |
| 239 | nested_func->insertResultInto(nested_state, elems_to); |
| 240 | nested_state += nested_size_of_data; |
| 241 | } |
| 242 | |
| 243 | offsets_to.push_back(offsets_to.back() + state.dynamic_array_size); |
| 244 | } |
| 245 | |
| 246 | bool allocatesMemoryInArena() const override |
| 247 | { |
| 248 | return true; |
| 249 | } |
| 250 | }; |
| 251 | |
| 252 | |
| 253 | } |
| 254 | |