1#pragma once
2
3#include <IO/WriteHelpers.h>
4#include <IO/ReadHelpers.h>
5#include <DataTypes/DataTypesNumber.h>
6#include <AggregateFunctions/IAggregateFunction.h>
7#include <Columns/ColumnsNumber.h>
8#include <Common/assert_cast.h>
9
10#include <cmath>
11
12
13namespace DB
14{
15
16namespace
17{
18
19/// This function returns true if both values are large and comparable.
20/// It is used to calculate the mean value by merging two sources.
21/// It means that if the sizes of both sources are large and comparable, then we must apply a special
22/// formula guaranteeing more stability.
23bool areComparable(UInt64 a, UInt64 b)
24{
25 const Float64 sensitivity = 0.001;
26 const UInt64 threshold = 10000;
27
28 if ((a == 0) || (b == 0))
29 return false;
30
31 auto res = std::minmax(a, b);
32 return (((1 - static_cast<Float64>(res.first) / res.second) < sensitivity) && (res.first > threshold));
33}
34
35}
36
37/** Statistical aggregate functions
38 * varSamp - sample variance
39 * stddevSamp - mean sample quadratic deviation
40 * varPop - variance
41 * stddevPop - standard deviation
42 * covarSamp - selective covariance
43 * covarPop - covariance
44 * corr - correlation
45 */
46
47/** Parallel and incremental algorithm for calculating variance.
48 * Source: "Updating formulae and a pairwise algorithm for computing sample variances"
49 * (Chan et al., Stanford University, 12.1979)
50 */
51template <typename T, typename Op>
52class AggregateFunctionVarianceData
53{
54public:
55 void update(const IColumn & column, size_t row_num)
56 {
57 T received = assert_cast<const ColumnVector<T> &>(column).getData()[row_num];
58 Float64 val = static_cast<Float64>(received);
59 Float64 delta = val - mean;
60
61 ++count;
62 mean += delta / count;
63 m2 += delta * (val - mean);
64 }
65
66 void mergeWith(const AggregateFunctionVarianceData & source)
67 {
68 UInt64 total_count = count + source.count;
69 if (total_count == 0)
70 return;
71
72 Float64 factor = static_cast<Float64>(count * source.count) / total_count;
73 Float64 delta = mean - source.mean;
74
75 if (areComparable(count, source.count))
76 mean = (source.count * source.mean + count * mean) / total_count;
77 else
78 mean = source.mean + delta * (static_cast<Float64>(count) / total_count);
79
80 m2 += source.m2 + delta * delta * factor;
81 count = total_count;
82 }
83
84 void serialize(WriteBuffer & buf) const
85 {
86 writeVarUInt(count, buf);
87 writeBinary(mean, buf);
88 writeBinary(m2, buf);
89 }
90
91 void deserialize(ReadBuffer & buf)
92 {
93 readVarUInt(count, buf);
94 readBinary(mean, buf);
95 readBinary(m2, buf);
96 }
97
98 void publish(IColumn & to) const
99 {
100 assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(m2, count));
101 }
102
103private:
104 UInt64 count = 0;
105 Float64 mean = 0.0;
106 Float64 m2 = 0.0;
107};
108
109/** The main code for the implementation of varSamp, stddevSamp, varPop, stddevPop.
110 */
111template <typename T, typename Op>
112class AggregateFunctionVariance final
113 : public IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>
114{
115public:
116 AggregateFunctionVariance(const DataTypePtr & arg)
117 : IAggregateFunctionDataHelper<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op>>({arg}, {}) {}
118
119 String getName() const override { return Op::name; }
120
121 DataTypePtr getReturnType() const override
122 {
123 return std::make_shared<DataTypeFloat64>();
124 }
125
126 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
127 {
128 this->data(place).update(*columns[0], row_num);
129 }
130
131 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
132 {
133 this->data(place).mergeWith(this->data(rhs));
134 }
135
136 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
137 {
138 this->data(place).serialize(buf);
139 }
140
141 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
142 {
143 this->data(place).deserialize(buf);
144 }
145
146 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
147 {
148 this->data(place).publish(to);
149 }
150};
151
152/** Implementing the varSamp function.
153 */
154struct AggregateFunctionVarSampImpl
155{
156 static constexpr auto name = "varSamp";
157
158 static inline Float64 apply(Float64 m2, UInt64 count)
159 {
160 if (count < 2)
161 return std::numeric_limits<Float64>::infinity();
162 else
163 return m2 / (count - 1);
164 }
165};
166
167/** Implementing the stddevSamp function.
168 */
169struct AggregateFunctionStdDevSampImpl
170{
171 static constexpr auto name = "stddevSamp";
172
173 static inline Float64 apply(Float64 m2, UInt64 count)
174 {
175 return sqrt(AggregateFunctionVarSampImpl::apply(m2, count));
176 }
177};
178
179/** Implementing the varPop function.
180 */
181struct AggregateFunctionVarPopImpl
182{
183 static constexpr auto name = "varPop";
184
185 static inline Float64 apply(Float64 m2, UInt64 count)
186 {
187 if (count == 0)
188 return std::numeric_limits<Float64>::infinity();
189 else if (count == 1)
190 return 0.0;
191 else
192 return m2 / count;
193 }
194};
195
196/** Implementing the stddevPop function.
197 */
198struct AggregateFunctionStdDevPopImpl
199{
200 static constexpr auto name = "stddevPop";
201
202 static inline Float64 apply(Float64 m2, UInt64 count)
203 {
204 return sqrt(AggregateFunctionVarPopImpl::apply(m2, count));
205 }
206};
207
208/** If `compute_marginal_moments` flag is set this class provides the successor
209 * CovarianceData support of marginal moments for calculating the correlation.
210 */
211template <bool compute_marginal_moments>
212class BaseCovarianceData
213{
214protected:
215 void incrementMarginalMoments(Float64, Float64) {}
216 void mergeWith(const BaseCovarianceData &) {}
217 void serialize(WriteBuffer &) const {}
218 void deserialize(const ReadBuffer &) {}
219};
220
221template <>
222class BaseCovarianceData<true>
223{
224protected:
225 void incrementMarginalMoments(Float64 left_incr, Float64 right_incr)
226 {
227 left_m2 += left_incr;
228 right_m2 += right_incr;
229 }
230
231 void mergeWith(const BaseCovarianceData & source)
232 {
233 left_m2 += source.left_m2;
234 right_m2 += source.right_m2;
235 }
236
237 void serialize(WriteBuffer & buf) const
238 {
239 writeBinary(left_m2, buf);
240 writeBinary(right_m2, buf);
241 }
242
243 void deserialize(ReadBuffer & buf)
244 {
245 readBinary(left_m2, buf);
246 readBinary(right_m2, buf);
247 }
248
249protected:
250 Float64 left_m2 = 0.0;
251 Float64 right_m2 = 0.0;
252};
253
254/** Parallel and incremental algorithm for calculating covariance.
255 * Source: "Numerically Stable, Single-Pass, Parallel Statistics Algorithms"
256 * (J. Bennett et al., Sandia National Laboratories,
257 * 2009 IEEE International Conference on Cluster Computing)
258 */
259template <typename T, typename U, typename Op, bool compute_marginal_moments>
260class CovarianceData : public BaseCovarianceData<compute_marginal_moments>
261{
262private:
263 using Base = BaseCovarianceData<compute_marginal_moments>;
264
265public:
266 void update(const IColumn & column_left, const IColumn & column_right, size_t row_num)
267 {
268 T left_received = assert_cast<const ColumnVector<T> &>(column_left).getData()[row_num];
269 Float64 left_val = static_cast<Float64>(left_received);
270 Float64 left_delta = left_val - left_mean;
271
272 U right_received = assert_cast<const ColumnVector<U> &>(column_right).getData()[row_num];
273 Float64 right_val = static_cast<Float64>(right_received);
274 Float64 right_delta = right_val - right_mean;
275
276 Float64 old_right_mean = right_mean;
277
278 ++count;
279
280 left_mean += left_delta / count;
281 right_mean += right_delta / count;
282 co_moment += (left_val - left_mean) * (right_val - old_right_mean);
283
284 /// Update the marginal moments, if any.
285 if (compute_marginal_moments)
286 {
287 Float64 left_incr = left_delta * (left_val - left_mean);
288 Float64 right_incr = right_delta * (right_val - right_mean);
289 Base::incrementMarginalMoments(left_incr, right_incr);
290 }
291 }
292
293 void mergeWith(const CovarianceData & source)
294 {
295 UInt64 total_count = count + source.count;
296 if (total_count == 0)
297 return;
298
299 Float64 factor = static_cast<Float64>(count * source.count) / total_count;
300 Float64 left_delta = left_mean - source.left_mean;
301 Float64 right_delta = right_mean - source.right_mean;
302
303 if (areComparable(count, source.count))
304 {
305 left_mean = (source.count * source.left_mean + count * left_mean) / total_count;
306 right_mean = (source.count * source.right_mean + count * right_mean) / total_count;
307 }
308 else
309 {
310 left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
311 right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
312 }
313
314 co_moment += source.co_moment + left_delta * right_delta * factor;
315 count = total_count;
316
317 /// Update the marginal moments, if any.
318 if (compute_marginal_moments)
319 {
320 Float64 left_incr = left_delta * left_delta * factor;
321 Float64 right_incr = right_delta * right_delta * factor;
322 Base::mergeWith(source);
323 Base::incrementMarginalMoments(left_incr, right_incr);
324 }
325 }
326
327 void serialize(WriteBuffer & buf) const
328 {
329 writeVarUInt(count, buf);
330 writeBinary(left_mean, buf);
331 writeBinary(right_mean, buf);
332 writeBinary(co_moment, buf);
333 Base::serialize(buf);
334 }
335
336 void deserialize(ReadBuffer & buf)
337 {
338 readVarUInt(count, buf);
339 readBinary(left_mean, buf);
340 readBinary(right_mean, buf);
341 readBinary(co_moment, buf);
342 Base::deserialize(buf);
343 }
344
345 void publish(IColumn & to) const
346 {
347 if constexpr (compute_marginal_moments)
348 assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, Base::left_m2, Base::right_m2, count));
349 else
350 assert_cast<ColumnFloat64 &>(to).getData().push_back(Op::apply(co_moment, count));
351 }
352
353private:
354 UInt64 count = 0;
355 Float64 left_mean = 0.0;
356 Float64 right_mean = 0.0;
357 Float64 co_moment = 0.0;
358};
359
360template <typename T, typename U, typename Op, bool compute_marginal_moments = false>
361class AggregateFunctionCovariance final
362 : public IAggregateFunctionDataHelper<
363 CovarianceData<T, U, Op, compute_marginal_moments>,
364 AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>
365{
366public:
367 AggregateFunctionCovariance(const DataTypes & args) : IAggregateFunctionDataHelper<
368 CovarianceData<T, U, Op, compute_marginal_moments>,
369 AggregateFunctionCovariance<T, U, Op, compute_marginal_moments>>(args, {}) {}
370
371 String getName() const override { return Op::name; }
372
373 DataTypePtr getReturnType() const override
374 {
375 return std::make_shared<DataTypeFloat64>();
376 }
377
378 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
379 {
380 this->data(place).update(*columns[0], *columns[1], row_num);
381 }
382
383 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
384 {
385 this->data(place).mergeWith(this->data(rhs));
386 }
387
388 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
389 {
390 this->data(place).serialize(buf);
391 }
392
393 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
394 {
395 this->data(place).deserialize(buf);
396 }
397
398 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
399 {
400 this->data(place).publish(to);
401 }
402};
403
404/** Implementing the covarSamp function.
405 */
406struct AggregateFunctionCovarSampImpl
407{
408 static constexpr auto name = "covarSamp";
409
410 static inline Float64 apply(Float64 co_moment, UInt64 count)
411 {
412 if (count < 2)
413 return std::numeric_limits<Float64>::infinity();
414 else
415 return co_moment / (count - 1);
416 }
417};
418
419/** Implementing the covarPop function.
420 */
421struct AggregateFunctionCovarPopImpl
422{
423 static constexpr auto name = "covarPop";
424
425 static inline Float64 apply(Float64 co_moment, UInt64 count)
426 {
427 if (count == 0)
428 return std::numeric_limits<Float64>::infinity();
429 else if (count == 1)
430 return 0.0;
431 else
432 return co_moment / count;
433 }
434};
435
436/** `corr` function implementation.
437 */
438struct AggregateFunctionCorrImpl
439{
440 static constexpr auto name = "corr";
441
442 static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count)
443 {
444 if (count < 2)
445 return std::numeric_limits<Float64>::infinity();
446 else
447 return co_moment / sqrt(left_m2 * right_m2);
448 }
449};
450
451template <typename T>
452using AggregateFunctionVarSampStable = AggregateFunctionVariance<T, AggregateFunctionVarSampImpl>;
453
454template <typename T>
455using AggregateFunctionStddevSampStable = AggregateFunctionVariance<T, AggregateFunctionStdDevSampImpl>;
456
457template <typename T>
458using AggregateFunctionVarPopStable = AggregateFunctionVariance<T, AggregateFunctionVarPopImpl>;
459
460template <typename T>
461using AggregateFunctionStddevPopStable = AggregateFunctionVariance<T, AggregateFunctionStdDevPopImpl>;
462
463template <typename T, typename U>
464using AggregateFunctionCovarSampStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarSampImpl>;
465
466template <typename T, typename U>
467using AggregateFunctionCovarPopStable = AggregateFunctionCovariance<T, U, AggregateFunctionCovarPopImpl>;
468
469template <typename T, typename U>
470using AggregateFunctionCorrStable = AggregateFunctionCovariance<T, U, AggregateFunctionCorrImpl, true>;
471
472}
473