| 1 | #pragma once |
| 2 | |
| 3 | #include <IO/WriteHelpers.h> |
| 4 | #include <IO/ReadHelpers.h> |
| 5 | |
| 6 | #include <DataTypes/DataTypeArray.h> |
| 7 | #include <DataTypes/DataTypeTuple.h> |
| 8 | |
| 9 | #include <Columns/ColumnArray.h> |
| 10 | #include <Columns/ColumnTuple.h> |
| 11 | #include <Columns/ColumnVector.h> |
| 12 | #include <Columns/ColumnDecimal.h> |
| 13 | |
| 14 | #include <Common/FieldVisitors.h> |
| 15 | #include <Common/assert_cast.h> |
| 16 | #include <AggregateFunctions/IAggregateFunction.h> |
| 17 | #include <map> |
| 18 | |
| 19 | |
| 20 | namespace DB |
| 21 | { |
| 22 | |
| 23 | namespace ErrorCodes |
| 24 | { |
| 25 | extern const int LOGICAL_ERROR; |
| 26 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
| 27 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| 28 | } |
| 29 | |
| 30 | template <typename T> |
| 31 | struct AggregateFunctionSumMapData |
| 32 | { |
| 33 | // Map needs to be ordered to maintain function properties |
| 34 | std::map<T, Array> merged_maps; |
| 35 | }; |
| 36 | |
| 37 | /** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of of at least 2 arrays - |
| 38 | * ordered keys and variable number of argument values summed up by corresponding keys. |
| 39 | * |
| 40 | * This function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map". |
| 41 | * |
| 42 | * Example: sumMap(k, v...) of: |
| 43 | * k v |
| 44 | * [1,2,3] [10,10,10] |
| 45 | * [3,4,5] [10,10,10] |
| 46 | * [4,5,6] [10,10,10] |
| 47 | * [6,7,8] [10,10,10] |
| 48 | * [7,5,3] [5,15,25] |
| 49 | * [8,9,10] [20,20,20] |
| 50 | * will return: |
| 51 | * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20]) |
| 52 | */ |
| 53 | |
| 54 | template <typename T, typename Derived, typename OverflowPolicy> |
| 55 | class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper< |
| 56 | AggregateFunctionSumMapData<NearestFieldType<T>>, Derived> |
| 57 | { |
| 58 | private: |
| 59 | using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; |
| 60 | |
| 61 | DataTypePtr keys_type; |
| 62 | DataTypes values_types; |
| 63 | |
| 64 | public: |
| 65 | AggregateFunctionSumMapBase( |
| 66 | const DataTypePtr & keys_type_, const DataTypes & values_types_, |
| 67 | const DataTypes & argument_types_, const Array & params_) |
| 68 | : IAggregateFunctionDataHelper<AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>(argument_types_, params_) |
| 69 | , keys_type(keys_type_), values_types(values_types_) {} |
| 70 | |
| 71 | String getName() const override { return "sumMap" ; } |
| 72 | |
| 73 | DataTypePtr getReturnType() const override |
| 74 | { |
| 75 | DataTypes types; |
| 76 | types.emplace_back(std::make_shared<DataTypeArray>(keys_type)); |
| 77 | |
| 78 | for (const auto & value_type : values_types) |
| 79 | types.emplace_back(std::make_shared<DataTypeArray>(OverflowPolicy::promoteType(value_type))); |
| 80 | |
| 81 | return std::make_shared<DataTypeTuple>(types); |
| 82 | } |
| 83 | |
| 84 | void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override |
| 85 | { |
| 86 | // Column 0 contains array of keys of known type |
| 87 | const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]); |
| 88 | const IColumn::Offsets & offsets0 = array_column0.getOffsets(); |
| 89 | const auto & keys_vec = static_cast<const ColVecType &>(array_column0.getData()); |
| 90 | const size_t keys_vec_offset = offsets0[row_num - 1]; |
| 91 | const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset); |
| 92 | |
| 93 | // Columns 1..n contain arrays of numeric values to sum |
| 94 | auto & merged_maps = this->data(place).merged_maps; |
| 95 | for (size_t col = 0, size = values_types.size(); col < size; ++col) |
| 96 | { |
| 97 | Field value; |
| 98 | const ColumnArray & array_column = assert_cast<const ColumnArray &>(*columns[col + 1]); |
| 99 | const IColumn::Offsets & offsets = array_column.getOffsets(); |
| 100 | const size_t values_vec_offset = offsets[row_num - 1]; |
| 101 | const size_t values_vec_size = (offsets[row_num] - values_vec_offset); |
| 102 | |
| 103 | // Expect key and value arrays to be of same length |
| 104 | if (keys_vec_size != values_vec_size) |
| 105 | throw Exception("Sizes of keys and values arrays do not match" , ErrorCodes::LOGICAL_ERROR); |
| 106 | |
| 107 | // Insert column values for all keys |
| 108 | for (size_t i = 0; i < keys_vec_size; ++i) |
| 109 | { |
| 110 | using MapType = std::decay_t<decltype(merged_maps)>; |
| 111 | using IteratorType = typename MapType::iterator; |
| 112 | |
| 113 | array_column.getData().get(values_vec_offset + i, value); |
| 114 | const auto & key = keys_vec.getElement(keys_vec_offset + i); |
| 115 | |
| 116 | if (!keepKey(key)) |
| 117 | { |
| 118 | continue; |
| 119 | } |
| 120 | |
| 121 | IteratorType it; |
| 122 | if constexpr (IsDecimalNumber<T>) |
| 123 | { |
| 124 | UInt32 scale = keys_vec.getData().getScale(); |
| 125 | it = merged_maps.find(DecimalField<T>(key, scale)); |
| 126 | } |
| 127 | else |
| 128 | it = merged_maps.find(key); |
| 129 | |
| 130 | if (it != merged_maps.end()) |
| 131 | applyVisitor(FieldVisitorSum(value), it->second[col]); |
| 132 | else |
| 133 | { |
| 134 | // Create a value array for this key |
| 135 | Array new_values; |
| 136 | new_values.resize(values_types.size()); |
| 137 | for (size_t k = 0; k < new_values.size(); ++k) |
| 138 | new_values[k] = (k == col) ? value : values_types[k]->getDefault(); |
| 139 | |
| 140 | if constexpr (IsDecimalNumber<T>) |
| 141 | { |
| 142 | UInt32 scale = keys_vec.getData().getScale(); |
| 143 | merged_maps.emplace(DecimalField<T>(key, scale), std::move(new_values)); |
| 144 | } |
| 145 | else |
| 146 | merged_maps.emplace(key, std::move(new_values)); |
| 147 | } |
| 148 | } |
| 149 | } |
| 150 | } |
| 151 | |
| 152 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
| 153 | { |
| 154 | auto & merged_maps = this->data(place).merged_maps; |
| 155 | const auto & rhs_maps = this->data(rhs).merged_maps; |
| 156 | |
| 157 | for (const auto & elem : rhs_maps) |
| 158 | { |
| 159 | const auto & it = merged_maps.find(elem.first); |
| 160 | if (it != merged_maps.end()) |
| 161 | { |
| 162 | for (size_t col = 0; col < values_types.size(); ++col) |
| 163 | applyVisitor(FieldVisitorSum(elem.second[col]), it->second[col]); |
| 164 | } |
| 165 | else |
| 166 | merged_maps[elem.first] = elem.second; |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
| 171 | { |
| 172 | const auto & merged_maps = this->data(place).merged_maps; |
| 173 | size_t size = merged_maps.size(); |
| 174 | writeVarUInt(size, buf); |
| 175 | |
| 176 | for (const auto & elem : merged_maps) |
| 177 | { |
| 178 | keys_type->serializeBinary(elem.first, buf); |
| 179 | for (size_t col = 0; col < values_types.size(); ++col) |
| 180 | values_types[col]->serializeBinary(elem.second[col], buf); |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
| 185 | { |
| 186 | auto & merged_maps = this->data(place).merged_maps; |
| 187 | size_t size = 0; |
| 188 | readVarUInt(size, buf); |
| 189 | |
| 190 | for (size_t i = 0; i < size; ++i) |
| 191 | { |
| 192 | Field key; |
| 193 | keys_type->deserializeBinary(key, buf); |
| 194 | |
| 195 | Array values; |
| 196 | values.resize(values_types.size()); |
| 197 | for (size_t col = 0; col < values_types.size(); ++col) |
| 198 | values_types[col]->deserializeBinary(values[col], buf); |
| 199 | |
| 200 | if constexpr (IsDecimalNumber<T>) |
| 201 | merged_maps[key.get<DecimalField<T>>()] = values; |
| 202 | else |
| 203 | merged_maps[key.get<T>()] = values; |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
| 208 | { |
| 209 | // Final step does compaction of keys that have zero values, this mutates the state |
| 210 | auto & merged_maps = this->data(const_cast<AggregateDataPtr>(place)).merged_maps; |
| 211 | for (auto it = merged_maps.cbegin(); it != merged_maps.cend();) |
| 212 | { |
| 213 | // Key is not compacted if it has at least one non-zero value |
| 214 | bool erase = true; |
| 215 | for (size_t col = 0; col < values_types.size(); ++col) |
| 216 | { |
| 217 | if (it->second[col] != values_types[col]->getDefault()) |
| 218 | { |
| 219 | erase = false; |
| 220 | break; |
| 221 | } |
| 222 | } |
| 223 | |
| 224 | if (erase) |
| 225 | it = merged_maps.erase(it); |
| 226 | else |
| 227 | ++it; |
| 228 | } |
| 229 | |
| 230 | size_t size = merged_maps.size(); |
| 231 | |
| 232 | auto & to_tuple = assert_cast<ColumnTuple &>(to); |
| 233 | auto & to_keys_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(0)); |
| 234 | auto & to_keys_col = to_keys_arr.getData(); |
| 235 | |
| 236 | // Advance column offsets |
| 237 | auto & to_keys_offsets = to_keys_arr.getOffsets(); |
| 238 | to_keys_offsets.push_back(to_keys_offsets.back() + size); |
| 239 | to_keys_col.reserve(size); |
| 240 | |
| 241 | for (size_t col = 0; col < values_types.size(); ++col) |
| 242 | { |
| 243 | auto & to_values_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)); |
| 244 | auto & to_values_offsets = to_values_arr.getOffsets(); |
| 245 | to_values_offsets.push_back(to_values_offsets.back() + size); |
| 246 | to_values_arr.getData().reserve(size); |
| 247 | } |
| 248 | |
| 249 | // Write arrays of keys and values |
| 250 | for (const auto & elem : merged_maps) |
| 251 | { |
| 252 | // Write array of keys into column |
| 253 | to_keys_col.insert(elem.first); |
| 254 | |
| 255 | // Write 0..n arrays of values |
| 256 | for (size_t col = 0; col < values_types.size(); ++col) |
| 257 | { |
| 258 | auto & to_values_col = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)).getData(); |
| 259 | to_values_col.insert(elem.second[col]); |
| 260 | } |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); } |
| 265 | }; |
| 266 | |
| 267 | template <typename T, typename OverflowPolicy> |
| 268 | class AggregateFunctionSumMap final : |
| 269 | public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, OverflowPolicy>, OverflowPolicy> |
| 270 | { |
| 271 | private: |
| 272 | using Self = AggregateFunctionSumMap<T, OverflowPolicy>; |
| 273 | using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>; |
| 274 | |
| 275 | public: |
| 276 | AggregateFunctionSumMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_) |
| 277 | : Base{keys_type_, values_types_, argument_types_, {}} |
| 278 | {} |
| 279 | |
| 280 | String getName() const override { return "sumMap" ; } |
| 281 | |
| 282 | bool keepKey(const T &) const { return true; } |
| 283 | }; |
| 284 | |
| 285 | template <typename T, typename OverflowPolicy> |
| 286 | class AggregateFunctionSumMapFiltered final : |
| 287 | public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T, OverflowPolicy>, OverflowPolicy> |
| 288 | { |
| 289 | private: |
| 290 | using Self = AggregateFunctionSumMapFiltered<T, OverflowPolicy>; |
| 291 | using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>; |
| 292 | |
| 293 | std::unordered_set<T> keys_to_keep; |
| 294 | |
| 295 | public: |
| 296 | AggregateFunctionSumMapFiltered( |
| 297 | const DataTypePtr & keys_type_, const DataTypes & values_types_, const Array & keys_to_keep_, |
| 298 | const DataTypes & argument_types_, const Array & params_) |
| 299 | : Base{keys_type_, values_types_, argument_types_, params_} |
| 300 | { |
| 301 | keys_to_keep.reserve(keys_to_keep_.size()); |
| 302 | for (const Field & f : keys_to_keep_) |
| 303 | { |
| 304 | keys_to_keep.emplace(f.safeGet<NearestFieldType<T>>()); |
| 305 | } |
| 306 | } |
| 307 | |
| 308 | String getName() const override { return "sumMapFiltered" ; } |
| 309 | |
| 310 | bool keepKey(const T & key) const { return keys_to_keep.count(key); } |
| 311 | }; |
| 312 | |
| 313 | } |
| 314 | |