1 | #pragma once |
2 | |
3 | #include <IO/WriteHelpers.h> |
4 | #include <IO/ReadHelpers.h> |
5 | #include <DataTypes/DataTypesNumber.h> |
6 | #include <AggregateFunctions/IAggregateFunction.h> |
7 | #include <Columns/ColumnsNumber.h> |
8 | #include <Common/assert_cast.h> |
9 | |
10 | #include <cmath> |
11 | |
12 | |
13 | namespace DB |
14 | { |
15 | |
16 | namespace |
17 | { |
18 | |
19 | /// This function returns true if both values are large and comparable. |
20 | /// It is used to calculate the mean value by merging two sources. |
21 | /// It means that if the sizes of both sources are large and comparable, then we must apply a special |
22 | /// formula guaranteeing more stability. |
23 | bool areComparable(UInt64 a, UInt64 b) |
24 | { |
25 | const Float64 sensitivity = 0.001; |
26 | const UInt64 threshold = 10000; |
27 | |
28 | if ((a == 0) || (b == 0)) |
29 | return false; |
30 | |
31 | auto res = std::minmax(a, b); |
32 | return (((1 - static_cast<Float64>(res.first) / res.second) < sensitivity) && (res.first > threshold)); |
33 | } |
34 | |
35 | } |
36 | |
37 | /** Statistical aggregate functions |
38 | * varSamp - sample variance |
39 | * stddevSamp - mean sample quadratic deviation |
40 | * varPop - variance |
41 | * stddevPop - standard deviation |
42 | * covarSamp - selective covariance |
43 | * covarPop - covariance |
44 | * corr - correlation |
45 | */ |
46 | |
47 | /** Parallel and incremental algorithm for calculating variance. |
48 | * Source: "Updating formulae and a pairwise algorithm for computing sample variances" |
49 | * (Chan et al., Stanford University, 12.1979) |
50 | */ |
51 | template <typename T, typename Op> |
52 | class AggregateFunctionVarianceData |
53 | { |
54 | public: |
55 | void update(const IColumn & column, size_t row_num) |
56 | { |
57 | T received = assert_cast<const ColumnVector<T> &>(column).getData()[row_num]; |
58 | Float64 val = static_cast<Float64>(received); |
59 | Float64 delta = val - mean; |
60 | |
61 | ++count; |
62 | mean += delta / count; |
63 | m2 += delta * (val - mean); |
64 | } |
65 | |
66 | void mergeWith(const AggregateFunctionVarianceData & source) |
67 | { |
68 | UInt64 total_count = count + source.count; |
69 | if (total_count == 0) |
70 | return; |
71 | |
72 | Float64 factor = static_cast<Float64>(count * source.count) / total_count; |
73 | Float64 delta = mean - source.mean; |
74 | |
75 | if (areComparable(count, source.count)) |
76 | mean = (source.count * source.mean + count * mean) / total_count; |
77 | else |
78 | mean = source.mean + delta * (static_cast<Float64>(count) / total_count); |
79 | |
80 | m2 += source.m2 + delta * delta * factor; |
81 | count = total_count; |
82 | } |
83 | |
84 | void serialize(WriteBuffer & buf) const |
85 | { |
86 | writeVarUInt(count, buf); |
87 | writeBinary(mean, buf); |
88 | writeBinary(m2, buf); |
89 | } |
90 | |
91 | void deserialize(ReadBuffer & buf) |
92 | { |
93 | readVarUInt(count, buf); |
94 | readBinary(mean, buf); |
95 | readBinary(m2, buf); |
96 | } |
97 | |
98 | void publish(IColumn & to) const |
99 | { |
100 | assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(m2, count)); |
101 | } |
102 | |
103 | private: |
104 | UInt64 count = 0; |
105 | Float64 mean = 0.0; |
106 | Float64 m2 = 0.0; |
107 | }; |
108 | |
109 | /** The main code for the implementation of varSamp, stddevSamp, varPop, stddevPop. |
110 | */ |
111 | template <typename T, typename Op> |
112 | class AggregateFunctionVariance final |
113 | : public IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>> |
114 | { |
115 | public: |
116 | AggregateFunctionVariance(const DataTypePtr & arg) |
117 | : IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}) {} |
118 | |
119 | String getName() const override { return Op::name; } |
120 | |
121 | DataTypePtr getReturnType() const override |
122 | { |
123 | return std::make_shared<DataTypeFloat64>(); |
124 | } |
125 | |
126 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override |
127 | { |
128 | this->data(place).update(*columns[0], row_num); |
129 | } |
130 | |
131 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
132 | { |
133 | this->data(place).mergeWith(this->data(rhs)); |
134 | } |
135 | |
136 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
137 | { |
138 | this->data(place).serialize(buf); |
139 | } |
140 | |
141 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
142 | { |
143 | this->data(place).deserialize(buf); |
144 | } |
145 | |
146 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
147 | { |
148 | this->data(place).publish(to); |
149 | } |
150 | }; |
151 | |
152 | /** Implementing the varSamp function. |
153 | */ |
154 | struct AggregateFunctionVarSampImpl |
155 | { |
156 | static constexpr auto name = "varSamp" ; |
157 | |
158 | static inline Float64 apply(Float64 m2, UInt64 count) |
159 | { |
160 | if (count < 2) |
161 | return std::numeric_limits<Float64>::infinity(); |
162 | else |
163 | return m2 / (count - 1); |
164 | } |
165 | }; |
166 | |
167 | /** Implementing the stddevSamp function. |
168 | */ |
169 | struct AggregateFunctionStdDevSampImpl |
170 | { |
171 | static constexpr auto name = "stddevSamp" ; |
172 | |
173 | static inline Float64 apply(Float64 m2, UInt64 count) |
174 | { |
175 | return sqrt(AggregateFunctionVarSampImpl::apply(m2, count)); |
176 | } |
177 | }; |
178 | |
179 | /** Implementing the varPop function. |
180 | */ |
181 | struct AggregateFunctionVarPopImpl |
182 | { |
183 | static constexpr auto name = "varPop" ; |
184 | |
185 | static inline Float64 apply(Float64 m2, UInt64 count) |
186 | { |
187 | if (count == 0) |
188 | return std::numeric_limits<Float64>::infinity(); |
189 | else if (count == 1) |
190 | return 0.0; |
191 | else |
192 | return m2 / count; |
193 | } |
194 | }; |
195 | |
196 | /** Implementing the stddevPop function. |
197 | */ |
198 | struct AggregateFunctionStdDevPopImpl |
199 | { |
200 | static constexpr auto name = "stddevPop" ; |
201 | |
202 | static inline Float64 apply(Float64 m2, UInt64 count) |
203 | { |
204 | return sqrt(AggregateFunctionVarPopImpl::apply(m2, count)); |
205 | } |
206 | }; |
207 | |
208 | /** If `compute_marginal_moments` flag is set this class provides the successor |
209 | * CovarianceData support of marginal moments for calculating the correlation. |
210 | */ |
211 | template <bool compute_marginal_moments> |
212 | class BaseCovarianceData |
213 | { |
214 | protected: |
215 | void incrementMarginalMoments(Float64, Float64) {} |
216 | void mergeWith(const BaseCovarianceData &) {} |
217 | void serialize(WriteBuffer &) const {} |
218 | void deserialize(const ReadBuffer &) {} |
219 | }; |
220 | |
221 | template <> |
222 | class BaseCovarianceData<true> |
223 | { |
224 | protected: |
225 | void incrementMarginalMoments(Float64 left_incr, Float64 right_incr) |
226 | { |
227 | left_m2 += left_incr; |
228 | right_m2 += right_incr; |
229 | } |
230 | |
231 | void mergeWith(const BaseCovarianceData & source) |
232 | { |
233 | left_m2 += source.left_m2; |
234 | right_m2 += source.right_m2; |
235 | } |
236 | |
237 | void serialize(WriteBuffer & buf) const |
238 | { |
239 | writeBinary(left_m2, buf); |
240 | writeBinary(right_m2, buf); |
241 | } |
242 | |
243 | void deserialize(ReadBuffer & buf) |
244 | { |
245 | readBinary(left_m2, buf); |
246 | readBinary(right_m2, buf); |
247 | } |
248 | |
249 | protected: |
250 | Float64 left_m2 = 0.0; |
251 | Float64 right_m2 = 0.0; |
252 | }; |
253 | |
254 | /** Parallel and incremental algorithm for calculating covariance. |
255 | * Source: "Numerically Stable, Single-Pass, Parallel Statistics Algorithms" |
256 | * (J. Bennett et al., Sandia National Laboratories, |
257 | * 2009 IEEE International Conference on Cluster Computing) |
258 | */ |
259 | template <typename T, typename U, typename Op, bool compute_marginal_moments> |
260 | class CovarianceData : public BaseCovarianceData<compute_marginal_moments> |
261 | { |
262 | private: |
263 | using Base = BaseCovarianceData<compute_marginal_moments>; |
264 | |
265 | public: |
266 | void update(const IColumn & column_left, const IColumn & column_right, size_t row_num) |
267 | { |
268 | T left_received = assert_cast<const ColumnVector<T> &>(column_left).getData()[row_num]; |
269 | Float64 left_val = static_cast<Float64>(left_received); |
270 | Float64 left_delta = left_val - left_mean; |
271 | |
272 | U right_received = assert_cast<const ColumnVector<U> &>(column_right).getData()[row_num]; |
273 | Float64 right_val = static_cast<Float64>(right_received); |
274 | Float64 right_delta = right_val - right_mean; |
275 | |
276 | Float64 old_right_mean = right_mean; |
277 | |
278 | ++count; |
279 | |
280 | left_mean += left_delta / count; |
281 | right_mean += right_delta / count; |
282 | co_moment += (left_val - left_mean) * (right_val - old_right_mean); |
283 | |
284 | /// Update the marginal moments, if any. |
285 | if (compute_marginal_moments) |
286 | { |
287 | Float64 left_incr = left_delta * (left_val - left_mean); |
288 | Float64 right_incr = right_delta * (right_val - right_mean); |
289 | Base::incrementMarginalMoments(left_incr, right_incr); |
290 | } |
291 | } |
292 | |
293 | void mergeWith(const CovarianceData & source) |
294 | { |
295 | UInt64 total_count = count + source.count; |
296 | if (total_count == 0) |
297 | return; |
298 | |
299 | Float64 factor = static_cast<Float64>(count * source.count) / total_count; |
300 | Float64 left_delta = left_mean - source.left_mean; |
301 | Float64 right_delta = right_mean - source.right_mean; |
302 | |
303 | if (areComparable(count, source.count)) |
304 | { |
305 | left_mean = (source.count * source.left_mean + count * left_mean) / total_count; |
306 | right_mean = (source.count * source.right_mean + count * right_mean) / total_count; |
307 | } |
308 | else |
309 | { |
310 | left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count); |
311 | right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count); |
312 | } |
313 | |
314 | co_moment += source.co_moment + left_delta * right_delta * factor; |
315 | count = total_count; |
316 | |
317 | /// Update the marginal moments, if any. |
318 | if (compute_marginal_moments) |
319 | { |
320 | Float64 left_incr = left_delta * left_delta * factor; |
321 | Float64 right_incr = right_delta * right_delta * factor; |
322 | Base::mergeWith(source); |
323 | Base::incrementMarginalMoments(left_incr, right_incr); |
324 | } |
325 | } |
326 | |
327 | void serialize(WriteBuffer & buf) const |
328 | { |
329 | writeVarUInt(count, buf); |
330 | writeBinary(left_mean, buf); |
331 | writeBinary(right_mean, buf); |
332 | writeBinary(co_moment, buf); |
333 | Base::serialize(buf); |
334 | } |
335 | |
336 | void deserialize(ReadBuffer & buf) |
337 | { |
338 | readVarUInt(count, buf); |
339 | readBinary(left_mean, buf); |
340 | readBinary(right_mean, buf); |
341 | readBinary(co_moment, buf); |
342 | Base::deserialize(buf); |
343 | } |
344 | |
345 | void publish(IColumn & to) const |
346 | { |
347 | if constexpr (compute_marginal_moments) |
348 | assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, Base::left_m2, Base::right_m2, count)); |
349 | else |
350 | assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, count)); |
351 | } |
352 | |
353 | private: |
354 | UInt64 count = 0; |
355 | Float64 left_mean = 0.0; |
356 | Float64 right_mean = 0.0; |
357 | Float64 co_moment = 0.0; |
358 | }; |
359 | |
360 | template <typename T, typename U, typename Op, bool compute_marginal_moments = false> |
361 | class AggregateFunctionCovariance final |
362 | : public IAggregateFunctionDataHelper< |
363 | CovarianceData<T, U, Op, compute_marginal_moments>, |
364 | AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>> |
365 | { |
366 | public: |
367 | AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper< |
368 | CovarianceData<T, U, Op, compute_marginal_moments>, |
369 | AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}) {} |
370 | |
371 | String getName() const override { return Op::name; } |
372 | |
373 | DataTypePtr getReturnType() const override |
374 | { |
375 | return std::make_shared<DataTypeFloat64>(); |
376 | } |
377 | |
378 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override |
379 | { |
380 | this->data(place).update(*columns[0], *columns[1], row_num); |
381 | } |
382 | |
383 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
384 | { |
385 | this->data(place).mergeWith(this->data(rhs)); |
386 | } |
387 | |
388 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
389 | { |
390 | this->data(place).serialize(buf); |
391 | } |
392 | |
393 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
394 | { |
395 | this->data(place).deserialize(buf); |
396 | } |
397 | |
398 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
399 | { |
400 | this->data(place).publish(to); |
401 | } |
402 | }; |
403 | |
404 | /** Implementing the covarSamp function. |
405 | */ |
406 | struct AggregateFunctionCovarSampImpl |
407 | { |
408 | static constexpr auto name = "covarSamp" ; |
409 | |
410 | static inline Float64 apply(Float64 co_moment, UInt64 count) |
411 | { |
412 | if (count < 2) |
413 | return std::numeric_limits<Float64>::infinity(); |
414 | else |
415 | return co_moment / (count - 1); |
416 | } |
417 | }; |
418 | |
419 | /** Implementing the covarPop function. |
420 | */ |
421 | struct AggregateFunctionCovarPopImpl |
422 | { |
423 | static constexpr auto name = "covarPop" ; |
424 | |
425 | static inline Float64 apply(Float64 co_moment, UInt64 count) |
426 | { |
427 | if (count == 0) |
428 | return std::numeric_limits<Float64>::infinity(); |
429 | else if (count == 1) |
430 | return 0.0; |
431 | else |
432 | return co_moment / count; |
433 | } |
434 | }; |
435 | |
436 | /** `corr` function implementation. |
437 | */ |
438 | struct AggregateFunctionCorrImpl |
439 | { |
440 | static constexpr auto name = "corr" ; |
441 | |
442 | static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count) |
443 | { |
444 | if (count < 2) |
445 | return std::numeric_limits<Float64>::infinity(); |
446 | else |
447 | return co_moment / sqrt(left_m2 * right_m2); |
448 | } |
449 | }; |
450 | |
451 | template <typename T> |
452 | using AggregateFunctionVarSampStable = AggregateFunctionVariance<T, AggregateFunctionVarSampImpl>; |
453 | |
454 | template <typename T> |
455 | using AggregateFunctionStddevSampStable = AggregateFunctionVariance<T, AggregateFunctionStdDevSampImpl>; |
456 | |
457 | template <typename T> |
458 | using AggregateFunctionVarPopStable = AggregateFunctionVariance<T, AggregateFunctionVarPopImpl>; |
459 | |
460 | template <typename T> |
461 | using AggregateFunctionStddevPopStable = AggregateFunctionVariance<T, AggregateFunctionStdDevPopImpl>; |
462 | |
463 | template <typename T, typename U> |
464 | using AggregateFunctionCovarSampStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarSampImpl>; |
465 | |
466 | template <typename T, typename U> |
467 | using AggregateFunctionCovarPopStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarPopImpl>; |
468 | |
469 | template <typename T, typename U> |
470 | using AggregateFunctionCorrStable = AggregateFunctionCovariance<T, U, AggregateFunctionCorrImpl, true>; |
471 | |
472 | } |
473 | |