1#pragma once
2
3#include <IO/WriteHelpers.h>
4#include <IO/ReadHelpers.h>
5
6#include <DataTypes/DataTypeArray.h>
7#include <DataTypes/DataTypeTuple.h>
8
9#include <Columns/ColumnArray.h>
10#include <Columns/ColumnTuple.h>
11#include <Columns/ColumnVector.h>
12#include <Columns/ColumnDecimal.h>
13
14#include <Common/FieldVisitors.h>
15#include <Common/assert_cast.h>
16#include <AggregateFunctions/IAggregateFunction.h>
17#include <map>
18
19
20namespace DB
21{
22
23namespace ErrorCodes
24{
25 extern const int LOGICAL_ERROR;
26 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
27 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
28}
29
30template <typename T>
31struct AggregateFunctionSumMapData
32{
33 // Map needs to be ordered to maintain function properties
34 std::map<T, Array> merged_maps;
35};
36
37/** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of of at least 2 arrays -
38 * ordered keys and variable number of argument values summed up by corresponding keys.
39 *
40 * This function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map".
41 *
42 * Example: sumMap(k, v...) of:
43 * k v
44 * [1,2,3] [10,10,10]
45 * [3,4,5] [10,10,10]
46 * [4,5,6] [10,10,10]
47 * [6,7,8] [10,10,10]
48 * [7,5,3] [5,15,25]
49 * [8,9,10] [20,20,20]
50 * will return:
51 * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
52 */
53
54template <typename T, typename Derived, typename OverflowPolicy>
55class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper<
56 AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>
57{
58private:
59 using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
60
61 DataTypePtr keys_type;
62 DataTypes values_types;
63
64public:
65 AggregateFunctionSumMapBase(
66 const DataTypePtr & keys_type_, const DataTypes & values_types_,
67 const DataTypes & argument_types_, const Array & params_)
68 : IAggregateFunctionDataHelper<AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>(argument_types_, params_)
69 , keys_type(keys_type_), values_types(values_types_) {}
70
71 String getName() const override { return "sumMap"; }
72
73 DataTypePtr getReturnType() const override
74 {
75 DataTypes types;
76 types.emplace_back(std::make_shared<DataTypeArray>(keys_type));
77
78 for (const auto & value_type : values_types)
79 types.emplace_back(std::make_shared<DataTypeArray>(OverflowPolicy::promoteType(value_type)));
80
81 return std::make_shared<DataTypeTuple>(types);
82 }
83
84 void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
85 {
86 // Column 0 contains array of keys of known type
87 const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]);
88 const IColumn::Offsets & offsets0 = array_column0.getOffsets();
89 const auto & keys_vec = static_cast<const ColVecType &>(array_column0.getData());
90 const size_t keys_vec_offset = offsets0[row_num - 1];
91 const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset);
92
93 // Columns 1..n contain arrays of numeric values to sum
94 auto & merged_maps = this->data(place).merged_maps;
95 for (size_t col = 0, size = values_types.size(); col < size; ++col)
96 {
97 Field value;
98 const ColumnArray & array_column = assert_cast<const ColumnArray &>(*columns[col + 1]);
99 const IColumn::Offsets & offsets = array_column.getOffsets();
100 const size_t values_vec_offset = offsets[row_num - 1];
101 const size_t values_vec_size = (offsets[row_num] - values_vec_offset);
102
103 // Expect key and value arrays to be of same length
104 if (keys_vec_size != values_vec_size)
105 throw Exception("Sizes of keys and values arrays do not match", ErrorCodes::LOGICAL_ERROR);
106
107 // Insert column values for all keys
108 for (size_t i = 0; i < keys_vec_size; ++i)
109 {
110 using MapType = std::decay_t<decltype(merged_maps)>;
111 using IteratorType = typename MapType::iterator;
112
113 array_column.getData().get(values_vec_offset + i, value);
114 const auto & key = keys_vec.getElement(keys_vec_offset + i);
115
116 if (!keepKey(key))
117 {
118 continue;
119 }
120
121 IteratorType it;
122 if constexpr (IsDecimalNumber<T>)
123 {
124 UInt32 scale = keys_vec.getData().getScale();
125 it = merged_maps.find(DecimalField<T>(key, scale));
126 }
127 else
128 it = merged_maps.find(key);
129
130 if (it != merged_maps.end())
131 applyVisitor(FieldVisitorSum(value), it->second[col]);
132 else
133 {
134 // Create a value array for this key
135 Array new_values;
136 new_values.resize(values_types.size());
137 for (size_t k = 0; k < new_values.size(); ++k)
138 new_values[k] = (k == col) ? value : values_types[k]->getDefault();
139
140 if constexpr (IsDecimalNumber<T>)
141 {
142 UInt32 scale = keys_vec.getData().getScale();
143 merged_maps.emplace(DecimalField<T>(key, scale), std::move(new_values));
144 }
145 else
146 merged_maps.emplace(key, std::move(new_values));
147 }
148 }
149 }
150 }
151
152 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
153 {
154 auto & merged_maps = this->data(place).merged_maps;
155 const auto & rhs_maps = this->data(rhs).merged_maps;
156
157 for (const auto & elem : rhs_maps)
158 {
159 const auto & it = merged_maps.find(elem.first);
160 if (it != merged_maps.end())
161 {
162 for (size_t col = 0; col < values_types.size(); ++col)
163 applyVisitor(FieldVisitorSum(elem.second[col]), it->second[col]);
164 }
165 else
166 merged_maps[elem.first] = elem.second;
167 }
168 }
169
170 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
171 {
172 const auto & merged_maps = this->data(place).merged_maps;
173 size_t size = merged_maps.size();
174 writeVarUInt(size, buf);
175
176 for (const auto & elem : merged_maps)
177 {
178 keys_type->serializeBinary(elem.first, buf);
179 for (size_t col = 0; col < values_types.size(); ++col)
180 values_types[col]->serializeBinary(elem.second[col], buf);
181 }
182 }
183
184 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
185 {
186 auto & merged_maps = this->data(place).merged_maps;
187 size_t size = 0;
188 readVarUInt(size, buf);
189
190 for (size_t i = 0; i < size; ++i)
191 {
192 Field key;
193 keys_type->deserializeBinary(key, buf);
194
195 Array values;
196 values.resize(values_types.size());
197 for (size_t col = 0; col < values_types.size(); ++col)
198 values_types[col]->deserializeBinary(values[col], buf);
199
200 if constexpr (IsDecimalNumber<T>)
201 merged_maps[key.get<DecimalField<T>>()] = values;
202 else
203 merged_maps[key.get<T>()] = values;
204 }
205 }
206
207 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
208 {
209 // Final step does compaction of keys that have zero values, this mutates the state
210 auto & merged_maps = this->data(const_cast<AggregateDataPtr>(place)).merged_maps;
211 for (auto it = merged_maps.cbegin(); it != merged_maps.cend();)
212 {
213 // Key is not compacted if it has at least one non-zero value
214 bool erase = true;
215 for (size_t col = 0; col < values_types.size(); ++col)
216 {
217 if (it->second[col] != values_types[col]->getDefault())
218 {
219 erase = false;
220 break;
221 }
222 }
223
224 if (erase)
225 it = merged_maps.erase(it);
226 else
227 ++it;
228 }
229
230 size_t size = merged_maps.size();
231
232 auto & to_tuple = assert_cast<ColumnTuple &>(to);
233 auto & to_keys_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(0));
234 auto & to_keys_col = to_keys_arr.getData();
235
236 // Advance column offsets
237 auto & to_keys_offsets = to_keys_arr.getOffsets();
238 to_keys_offsets.push_back(to_keys_offsets.back() + size);
239 to_keys_col.reserve(size);
240
241 for (size_t col = 0; col < values_types.size(); ++col)
242 {
243 auto & to_values_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1));
244 auto & to_values_offsets = to_values_arr.getOffsets();
245 to_values_offsets.push_back(to_values_offsets.back() + size);
246 to_values_arr.getData().reserve(size);
247 }
248
249 // Write arrays of keys and values
250 for (const auto & elem : merged_maps)
251 {
252 // Write array of keys into column
253 to_keys_col.insert(elem.first);
254
255 // Write 0..n arrays of values
256 for (size_t col = 0; col < values_types.size(); ++col)
257 {
258 auto & to_values_col = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)).getData();
259 to_values_col.insert(elem.second[col]);
260 }
261 }
262 }
263
264 bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
265};
266
267template <typename T, typename OverflowPolicy>
268class AggregateFunctionSumMap final :
269 public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, OverflowPolicy>, OverflowPolicy>
270{
271private:
272 using Self = AggregateFunctionSumMap<T, OverflowPolicy>;
273 using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
274
275public:
276 AggregateFunctionSumMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_)
277 : Base{keys_type_, values_types_, argument_types_, {}}
278 {}
279
280 String getName() const override { return "sumMap"; }
281
282 bool keepKey(const T &) const { return true; }
283};
284
285template <typename T, typename OverflowPolicy>
286class AggregateFunctionSumMapFiltered final :
287 public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T, OverflowPolicy>, OverflowPolicy>
288{
289private:
290 using Self = AggregateFunctionSumMapFiltered<T, OverflowPolicy>;
291 using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>;
292
293 std::unordered_set<T> keys_to_keep;
294
295public:
296 AggregateFunctionSumMapFiltered(
297 const DataTypePtr & keys_type_, const DataTypes & values_types_, const Array & keys_to_keep_,
298 const DataTypes & argument_types_, const Array & params_)
299 : Base{keys_type_, values_types_, argument_types_, params_}
300 {
301 keys_to_keep.reserve(keys_to_keep_.size());
302 for (const Field & f : keys_to_keep_)
303 {
304 keys_to_keep.emplace(f.safeGet<NearestFieldType<T>>());
305 }
306 }
307
308 String getName() const override { return "sumMapFiltered"; }
309
310 bool keepKey(const T & key) const { return keys_to_keep.count(key); }
311};
312
313}
314