1 | #pragma once |
2 | |
3 | #include <IO/WriteHelpers.h> |
4 | #include <IO/ReadHelpers.h> |
5 | |
6 | #include <DataTypes/DataTypeArray.h> |
7 | #include <DataTypes/DataTypeTuple.h> |
8 | |
9 | #include <Columns/ColumnArray.h> |
10 | #include <Columns/ColumnTuple.h> |
11 | #include <Columns/ColumnVector.h> |
12 | #include <Columns/ColumnDecimal.h> |
13 | |
14 | #include <Common/FieldVisitors.h> |
15 | #include <Common/assert_cast.h> |
16 | #include <AggregateFunctions/IAggregateFunction.h> |
17 | #include <map> |
18 | |
19 | |
20 | namespace DB |
21 | { |
22 | |
23 | namespace ErrorCodes |
24 | { |
25 | extern const int LOGICAL_ERROR; |
26 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
27 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
28 | } |
29 | |
30 | template <typename T> |
31 | struct AggregateFunctionSumMapData |
32 | { |
33 | // Map needs to be ordered to maintain function properties |
34 | std::map<T, Array> merged_maps; |
35 | }; |
36 | |
37 | /** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of of at least 2 arrays - |
38 | * ordered keys and variable number of argument values summed up by corresponding keys. |
39 | * |
40 | * This function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map". |
41 | * |
42 | * Example: sumMap(k, v...) of: |
43 | * k v |
44 | * [1,2,3] [10,10,10] |
45 | * [3,4,5] [10,10,10] |
46 | * [4,5,6] [10,10,10] |
47 | * [6,7,8] [10,10,10] |
48 | * [7,5,3] [5,15,25] |
49 | * [8,9,10] [20,20,20] |
50 | * will return: |
51 | * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20]) |
52 | */ |
53 | |
54 | template <typename T, typename Derived, typename OverflowPolicy> |
55 | class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper< |
56 | AggregateFunctionSumMapData<NearestFieldType<T>>, Derived> |
57 | { |
58 | private: |
59 | using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; |
60 | |
61 | DataTypePtr keys_type; |
62 | DataTypes values_types; |
63 | |
64 | public: |
65 | AggregateFunctionSumMapBase( |
66 | const DataTypePtr & keys_type_, const DataTypes & values_types_, |
67 | const DataTypes & argument_types_, const Array & params_) |
68 | : IAggregateFunctionDataHelper<AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>(argument_types_, params_) |
69 | , keys_type(keys_type_), values_types(values_types_) {} |
70 | |
71 | String getName() const override { return "sumMap" ; } |
72 | |
73 | DataTypePtr getReturnType() const override |
74 | { |
75 | DataTypes types; |
76 | types.emplace_back(std::make_shared<DataTypeArray>(keys_type)); |
77 | |
78 | for (const auto & value_type : values_types) |
79 | types.emplace_back(std::make_shared<DataTypeArray>(OverflowPolicy::promoteType(value_type))); |
80 | |
81 | return std::make_shared<DataTypeTuple>(types); |
82 | } |
83 | |
84 | void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override |
85 | { |
86 | // Column 0 contains array of keys of known type |
87 | const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]); |
88 | const IColumn::Offsets & offsets0 = array_column0.getOffsets(); |
89 | const auto & keys_vec = static_cast<const ColVecType &>(array_column0.getData()); |
90 | const size_t keys_vec_offset = offsets0[row_num - 1]; |
91 | const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset); |
92 | |
93 | // Columns 1..n contain arrays of numeric values to sum |
94 | auto & merged_maps = this->data(place).merged_maps; |
95 | for (size_t col = 0, size = values_types.size(); col < size; ++col) |
96 | { |
97 | Field value; |
98 | const ColumnArray & array_column = assert_cast<const ColumnArray &>(*columns[col + 1]); |
99 | const IColumn::Offsets & offsets = array_column.getOffsets(); |
100 | const size_t values_vec_offset = offsets[row_num - 1]; |
101 | const size_t values_vec_size = (offsets[row_num] - values_vec_offset); |
102 | |
103 | // Expect key and value arrays to be of same length |
104 | if (keys_vec_size != values_vec_size) |
105 | throw Exception("Sizes of keys and values arrays do not match" , ErrorCodes::LOGICAL_ERROR); |
106 | |
107 | // Insert column values for all keys |
108 | for (size_t i = 0; i < keys_vec_size; ++i) |
109 | { |
110 | using MapType = std::decay_t<decltype(merged_maps)>; |
111 | using IteratorType = typename MapType::iterator; |
112 | |
113 | array_column.getData().get(values_vec_offset + i, value); |
114 | const auto & key = keys_vec.getElement(keys_vec_offset + i); |
115 | |
116 | if (!keepKey(key)) |
117 | { |
118 | continue; |
119 | } |
120 | |
121 | IteratorType it; |
122 | if constexpr (IsDecimalNumber<T>) |
123 | { |
124 | UInt32 scale = keys_vec.getData().getScale(); |
125 | it = merged_maps.find(DecimalField<T>(key, scale)); |
126 | } |
127 | else |
128 | it = merged_maps.find(key); |
129 | |
130 | if (it != merged_maps.end()) |
131 | applyVisitor(FieldVisitorSum(value), it->second[col]); |
132 | else |
133 | { |
134 | // Create a value array for this key |
135 | Array new_values; |
136 | new_values.resize(values_types.size()); |
137 | for (size_t k = 0; k < new_values.size(); ++k) |
138 | new_values[k] = (k == col) ? value : values_types[k]->getDefault(); |
139 | |
140 | if constexpr (IsDecimalNumber<T>) |
141 | { |
142 | UInt32 scale = keys_vec.getData().getScale(); |
143 | merged_maps.emplace(DecimalField<T>(key, scale), std::move(new_values)); |
144 | } |
145 | else |
146 | merged_maps.emplace(key, std::move(new_values)); |
147 | } |
148 | } |
149 | } |
150 | } |
151 | |
152 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
153 | { |
154 | auto & merged_maps = this->data(place).merged_maps; |
155 | const auto & rhs_maps = this->data(rhs).merged_maps; |
156 | |
157 | for (const auto & elem : rhs_maps) |
158 | { |
159 | const auto & it = merged_maps.find(elem.first); |
160 | if (it != merged_maps.end()) |
161 | { |
162 | for (size_t col = 0; col < values_types.size(); ++col) |
163 | applyVisitor(FieldVisitorSum(elem.second[col]), it->second[col]); |
164 | } |
165 | else |
166 | merged_maps[elem.first] = elem.second; |
167 | } |
168 | } |
169 | |
170 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
171 | { |
172 | const auto & merged_maps = this->data(place).merged_maps; |
173 | size_t size = merged_maps.size(); |
174 | writeVarUInt(size, buf); |
175 | |
176 | for (const auto & elem : merged_maps) |
177 | { |
178 | keys_type->serializeBinary(elem.first, buf); |
179 | for (size_t col = 0; col < values_types.size(); ++col) |
180 | values_types[col]->serializeBinary(elem.second[col], buf); |
181 | } |
182 | } |
183 | |
184 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
185 | { |
186 | auto & merged_maps = this->data(place).merged_maps; |
187 | size_t size = 0; |
188 | readVarUInt(size, buf); |
189 | |
190 | for (size_t i = 0; i < size; ++i) |
191 | { |
192 | Field key; |
193 | keys_type->deserializeBinary(key, buf); |
194 | |
195 | Array values; |
196 | values.resize(values_types.size()); |
197 | for (size_t col = 0; col < values_types.size(); ++col) |
198 | values_types[col]->deserializeBinary(values[col], buf); |
199 | |
200 | if constexpr (IsDecimalNumber<T>) |
201 | merged_maps[key.get<DecimalField<T>>()] = values; |
202 | else |
203 | merged_maps[key.get<T>()] = values; |
204 | } |
205 | } |
206 | |
207 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
208 | { |
209 | // Final step does compaction of keys that have zero values, this mutates the state |
210 | auto & merged_maps = this->data(const_cast<AggregateDataPtr>(place)).merged_maps; |
211 | for (auto it = merged_maps.cbegin(); it != merged_maps.cend();) |
212 | { |
213 | // Key is not compacted if it has at least one non-zero value |
214 | bool erase = true; |
215 | for (size_t col = 0; col < values_types.size(); ++col) |
216 | { |
217 | if (it->second[col] != values_types[col]->getDefault()) |
218 | { |
219 | erase = false; |
220 | break; |
221 | } |
222 | } |
223 | |
224 | if (erase) |
225 | it = merged_maps.erase(it); |
226 | else |
227 | ++it; |
228 | } |
229 | |
230 | size_t size = merged_maps.size(); |
231 | |
232 | auto & to_tuple = assert_cast<ColumnTuple &>(to); |
233 | auto & to_keys_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(0)); |
234 | auto & to_keys_col = to_keys_arr.getData(); |
235 | |
236 | // Advance column offsets |
237 | auto & to_keys_offsets = to_keys_arr.getOffsets(); |
238 | to_keys_offsets.push_back(to_keys_offsets.back() + size); |
239 | to_keys_col.reserve(size); |
240 | |
241 | for (size_t col = 0; col < values_types.size(); ++col) |
242 | { |
243 | auto & to_values_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)); |
244 | auto & to_values_offsets = to_values_arr.getOffsets(); |
245 | to_values_offsets.push_back(to_values_offsets.back() + size); |
246 | to_values_arr.getData().reserve(size); |
247 | } |
248 | |
249 | // Write arrays of keys and values |
250 | for (const auto & elem : merged_maps) |
251 | { |
252 | // Write array of keys into column |
253 | to_keys_col.insert(elem.first); |
254 | |
255 | // Write 0..n arrays of values |
256 | for (size_t col = 0; col < values_types.size(); ++col) |
257 | { |
258 | auto & to_values_col = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)).getData(); |
259 | to_values_col.insert(elem.second[col]); |
260 | } |
261 | } |
262 | } |
263 | |
264 | bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); } |
265 | }; |
266 | |
267 | template <typename T, typename OverflowPolicy> |
268 | class AggregateFunctionSumMap final : |
269 | public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, OverflowPolicy>, OverflowPolicy> |
270 | { |
271 | private: |
272 | using Self = AggregateFunctionSumMap<T, OverflowPolicy>; |
273 | using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>; |
274 | |
275 | public: |
276 | AggregateFunctionSumMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_) |
277 | : Base{keys_type_, values_types_, argument_types_, {}} |
278 | {} |
279 | |
280 | String getName() const override { return "sumMap" ; } |
281 | |
282 | bool keepKey(const T &) const { return true; } |
283 | }; |
284 | |
285 | template <typename T, typename OverflowPolicy> |
286 | class AggregateFunctionSumMapFiltered final : |
287 | public AggregateFunctionSumMapBase<T, AggregateFunctionSumMapFiltered<T, OverflowPolicy>, OverflowPolicy> |
288 | { |
289 | private: |
290 | using Self = AggregateFunctionSumMapFiltered<T, OverflowPolicy>; |
291 | using Base = AggregateFunctionSumMapBase<T, Self, OverflowPolicy>; |
292 | |
293 | std::unordered_set<T> keys_to_keep; |
294 | |
295 | public: |
296 | AggregateFunctionSumMapFiltered( |
297 | const DataTypePtr & keys_type_, const DataTypes & values_types_, const Array & keys_to_keep_, |
298 | const DataTypes & argument_types_, const Array & params_) |
299 | : Base{keys_type_, values_types_, argument_types_, params_} |
300 | { |
301 | keys_to_keep.reserve(keys_to_keep_.size()); |
302 | for (const Field & f : keys_to_keep_) |
303 | { |
304 | keys_to_keep.emplace(f.safeGet<NearestFieldType<T>>()); |
305 | } |
306 | } |
307 | |
308 | String getName() const override { return "sumMapFiltered" ; } |
309 | |
310 | bool keepKey(const T & key) const { return keys_to_keep.count(key); } |
311 | }; |
312 | |
313 | } |
314 | |