1#pragma once
2
3#include <cmath>
4
5#include <common/arithmeticOverflow.h>
6
7#include <IO/WriteHelpers.h>
8#include <IO/ReadHelpers.h>
9
10#include <AggregateFunctions/IAggregateFunction.h>
11
12#include <DataTypes/DataTypesNumber.h>
13#include <DataTypes/DataTypesDecimal.h>
14#include <Columns/ColumnVector.h>
15#include <Columns/ColumnDecimal.h>
16
17
18/** This is simple, not numerically stable
19 * implementations of variance/covariance/correlation functions.
20 *
21 * It is about two times faster than stable variants.
22 * Numerical errors may occur during summation.
23 *
24 * This implementation is selected as default,
25 * because "you don't pay for what you don't need" principle.
26 *
27 * For more sophisticated implementation, look at AggregateFunctionStatistics.h
28 */
29
30namespace DB
31{
32
33namespace ErrorCodes
34{
35 extern const int DECIMAL_OVERFLOW;
36}
37
38
39/**
40 Calculating univariate central moments
41 Levels:
42 level 2 (pop & samp): var, stddev
43 level 3: skewness
44 level 4: kurtosis
45 References:
46 https://en.wikipedia.org/wiki/Moment_(mathematics)
47 https://en.wikipedia.org/wiki/Skewness
48 https://en.wikipedia.org/wiki/Kurtosis
49*/
50template <typename T, size_t _level>
51struct VarMoments
52{
53 T m[_level + 1]{};
54
55 void add(T x)
56 {
57 ++m[0];
58 m[1] += x;
59 m[2] += x * x;
60 if constexpr (_level >= 3) m[3] += x * x * x;
61 if constexpr (_level >= 4) m[4] += x * x * x * x;
62 }
63
64 void merge(const VarMoments & rhs)
65 {
66 m[0] += rhs.m[0];
67 m[1] += rhs.m[1];
68 m[2] += rhs.m[2];
69 if constexpr (_level >= 3) m[3] += rhs.m[3];
70 if constexpr (_level >= 4) m[4] += rhs.m[4];
71 }
72
73 void write(WriteBuffer & buf) const
74 {
75 writePODBinary(*this, buf);
76 }
77
78 void read(ReadBuffer & buf)
79 {
80 readPODBinary(*this, buf);
81 }
82
83 T NO_SANITIZE_UNDEFINED getPopulation() const
84 {
85 return (m[2] - m[1] * m[1] / m[0]) / m[0];
86 }
87
88 T NO_SANITIZE_UNDEFINED getSample() const
89 {
90 if (m[0] == 0)
91 return std::numeric_limits<T>::quiet_NaN();
92 return (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1);
93 }
94
95 T NO_SANITIZE_UNDEFINED getMoment3() const
96 {
97 // to avoid accuracy problem
98 if (m[0] == 1)
99 return 0;
100 return (m[3]
101 - (3 * m[2]
102 - 2 * m[1] * m[1] / m[0]
103 ) * m[1] / m[0]
104 ) / m[0];
105 }
106
107 T NO_SANITIZE_UNDEFINED getMoment4() const
108 {
109 // to avoid accuracy problem
110 if (m[0] == 1)
111 return 0;
112 return (m[4]
113 - (4 * m[3]
114 - (6 * m[2]
115 - 3 * m[1] * m[1] / m[0]
116 ) * m[1] / m[0]
117 ) * m[1] / m[0]
118 ) / m[0];
119 }
120};
121
122template <typename T, size_t _level>
123struct VarMomentsDecimal
124{
125 using NativeType = typename T::NativeType;
126
127 UInt64 m0{};
128 NativeType m[_level]{};
129
130 NativeType & getM(size_t i)
131 {
132 return m[i - 1];
133 }
134
135 const NativeType & getM(size_t i) const
136 {
137 return m[i - 1];
138 }
139
140 void add(NativeType x)
141 {
142 ++m0;
143 getM(1) += x;
144
145 NativeType tmp;
146 if (common::mulOverflow(x, x, tmp) || common::addOverflow(getM(2), tmp, getM(2)))
147 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
148 if constexpr (_level >= 3)
149 if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(3), tmp, getM(3)))
150 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
151 if constexpr (_level >= 4)
152 if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(4), tmp, getM(4)))
153 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
154 }
155
156 void merge(const VarMomentsDecimal & rhs)
157 {
158 m0 += rhs.m0;
159 getM(1) += rhs.getM(1);
160
161 if (common::addOverflow(getM(2), rhs.getM(2), getM(2)))
162 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
163 if constexpr (_level >= 3)
164 if (common::addOverflow(getM(3), rhs.getM(3), getM(3)))
165 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
166 if constexpr (_level >= 4)
167 if (common::addOverflow(getM(4), rhs.getM(4), getM(4)))
168 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
169 }
170
171 void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
172 void read(ReadBuffer & buf) { readPODBinary(*this, buf); }
173
174 Float64 getPopulation(UInt32 scale) const
175 {
176 if (m0 == 0)
177 return std::numeric_limits<Float64>::infinity();
178
179 NativeType tmp;
180 if (common::mulOverflow(getM(1), getM(1), tmp) ||
181 common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
182 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
183 return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
184 }
185
186 Float64 getSample(UInt32 scale) const
187 {
188 if (m0 == 0)
189 return std::numeric_limits<Float64>::quiet_NaN();
190 if (m0 == 1)
191 return std::numeric_limits<Float64>::infinity();
192
193 NativeType tmp;
194 if (common::mulOverflow(getM(1), getM(1), tmp) ||
195 common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
196 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
197 return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
198 }
199
200 Float64 getMoment3(UInt32 scale) const
201 {
202 if (m0 == 0)
203 return std::numeric_limits<Float64>::infinity();
204
205 NativeType tmp;
206 if (common::mulOverflow(2 * getM(1), getM(1), tmp) ||
207 common::subOverflow(3 * getM(2), NativeType(tmp / m0), tmp) ||
208 common::mulOverflow(tmp, getM(1), tmp) ||
209 common::subOverflow(getM(3), NativeType(tmp / m0), tmp))
210 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
211 return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
212 }
213
214 Float64 getMoment4(UInt32 scale) const
215 {
216 if (m0 == 0)
217 return std::numeric_limits<Float64>::infinity();
218
219 NativeType tmp;
220 if (common::mulOverflow(3 * getM(1), getM(1), tmp) ||
221 common::subOverflow(6 * getM(2), NativeType(tmp / m0), tmp) ||
222 common::mulOverflow(tmp, getM(1), tmp) ||
223 common::subOverflow(4 * getM(3), NativeType(tmp / m0), tmp) ||
224 common::mulOverflow(tmp, getM(1), tmp) ||
225 common::subOverflow(getM(4), NativeType(tmp / m0), tmp))
226 throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
227 return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
228 }
229};
230
231/**
232 Calculating multivariate central moments
233 Levels:
234 level 2 (pop & samp): covar
235 References:
236 https://en.wikipedia.org/wiki/Moment_(mathematics)
237*/
238template <typename T>
239struct CovarMoments
240{
241 T m0{};
242 T x1{};
243 T y1{};
244 T xy{};
245
246 void add(T x, T y)
247 {
248 ++m0;
249 x1 += x;
250 y1 += y;
251 xy += x * y;
252 }
253
254 void merge(const CovarMoments & rhs)
255 {
256 m0 += rhs.m0;
257 x1 += rhs.x1;
258 y1 += rhs.y1;
259 xy += rhs.xy;
260 }
261
262 void write(WriteBuffer & buf) const
263 {
264 writePODBinary(*this, buf);
265 }
266
267 void read(ReadBuffer & buf)
268 {
269 readPODBinary(*this, buf);
270 }
271
272 T NO_SANITIZE_UNDEFINED getPopulation() const
273 {
274 return (xy - x1 * y1 / m0) / m0;
275 }
276
277 T NO_SANITIZE_UNDEFINED getSample() const
278 {
279 if (m0 == 0)
280 return std::numeric_limits<T>::quiet_NaN();
281 return (xy - x1 * y1 / m0) / (m0 - 1);
282 }
283};
284
285template <typename T>
286struct CorrMoments
287{
288 T m0{};
289 T x1{};
290 T y1{};
291 T xy{};
292 T x2{};
293 T y2{};
294
295 void add(T x, T y)
296 {
297 ++m0;
298 x1 += x;
299 y1 += y;
300 xy += x * y;
301 x2 += x * x;
302 y2 += y * y;
303 }
304
305 void merge(const CorrMoments & rhs)
306 {
307 m0 += rhs.m0;
308 x1 += rhs.x1;
309 y1 += rhs.y1;
310 xy += rhs.xy;
311 x2 += rhs.x2;
312 y2 += rhs.y2;
313 }
314
315 void write(WriteBuffer & buf) const
316 {
317 writePODBinary(*this, buf);
318 }
319
320 void read(ReadBuffer & buf)
321 {
322 readPODBinary(*this, buf);
323 }
324
325 T NO_SANITIZE_UNDEFINED get() const
326 {
327 return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
328 }
329};
330
331
332enum class StatisticsFunctionKind
333{
334 varPop, varSamp,
335 stddevPop, stddevSamp,
336 skewPop, skewSamp,
337 kurtPop, kurtSamp,
338 covarPop, covarSamp,
339 corr
340};
341
342
343template <typename T, StatisticsFunctionKind _kind, size_t _level>
344struct StatFuncOneArg
345{
346 using Type1 = T;
347 using Type2 = T;
348 using ResultType = std::conditional_t<std::is_same_v<T, Float32>, Float32, Float64>;
349 using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128, _level>, VarMoments<ResultType, _level>>;
350
351 static constexpr StatisticsFunctionKind kind = _kind;
352 static constexpr UInt32 num_args = 1;
353};
354
355template <typename T1, typename T2, StatisticsFunctionKind _kind>
356struct StatFuncTwoArg
357{
358 using Type1 = T1;
359 using Type2 = T2;
360 using ResultType = std::conditional_t<std::is_same_v<T1, T2> && std::is_same_v<T1, Float32>, Float32, Float64>;
361 using Data = std::conditional_t<_kind == StatisticsFunctionKind::corr, CorrMoments<ResultType>, CovarMoments<ResultType>>;
362
363 static constexpr StatisticsFunctionKind kind = _kind;
364 static constexpr UInt32 num_args = 2;
365};
366
367
368template <typename StatFunc>
369class AggregateFunctionVarianceSimple final
370 : public IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>
371{
372public:
373 using T1 = typename StatFunc::Type1;
374 using T2 = typename StatFunc::Type2;
375 using ColVecT1 = std::conditional_t<IsDecimalNumber<T1>, ColumnDecimal<T1>, ColumnVector<T1>>;
376 using ColVecT2 = std::conditional_t<IsDecimalNumber<T2>, ColumnDecimal<T2>, ColumnVector<T2>>;
377 using ResultType = typename StatFunc::ResultType;
378 using ColVecResult = ColumnVector<ResultType>;
379
380 AggregateFunctionVarianceSimple(const DataTypes & argument_types_)
381 : IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {})
382 , src_scale(0)
383 {}
384
385 AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types_)
386 : IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>(argument_types_, {})
387 , src_scale(getDecimalScale(data_type))
388 {}
389
390 String getName() const override
391 {
392 if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
393 return "varPop";
394 if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
395 return "varSamp";
396 if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
397 return "stddevPop";
398 if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
399 return "stddevSamp";
400 if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
401 return "skewPop";
402 if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
403 return "skewSamp";
404 if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
405 return "kurtPop";
406 if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
407 return "kurtSamp";
408 if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
409 return "covarPop";
410 if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
411 return "covarSamp";
412 if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
413 return "corr";
414 __builtin_unreachable();
415 }
416
417 DataTypePtr getReturnType() const override
418 {
419 return std::make_shared<DataTypeNumber<ResultType>>();
420 }
421
422 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
423 {
424 if constexpr (StatFunc::num_args == 2)
425 this->data(place).add(
426 static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num],
427 static_cast<const ColVecT2 &>(*columns[1]).getData()[row_num]);
428 else
429 this->data(place).add(
430 static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num]);
431 }
432
433 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
434 {
435 this->data(place).merge(this->data(rhs));
436 }
437
438 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
439 {
440 this->data(place).write(buf);
441 }
442
443 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
444 {
445 this->data(place).read(buf);
446 }
447
448 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
449 {
450 const auto & data = this->data(place);
451 auto & dst = static_cast<ColVecResult &>(to).getData();
452
453 if constexpr (IsDecimalNumber<T1>)
454 {
455 if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
456 dst.push_back(data.getPopulation(src_scale * 2));
457 if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
458 dst.push_back(data.getSample(src_scale * 2));
459 if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
460 dst.push_back(sqrt(data.getPopulation(src_scale * 2)));
461 if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
462 dst.push_back(sqrt(data.getSample(src_scale * 2)));
463 if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
464 {
465 Float64 var_value = data.getPopulation(src_scale * 2);
466
467 if (var_value > 0)
468 dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
469 else
470 dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
471 }
472 if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
473 {
474 Float64 var_value = data.getSample(src_scale * 2);
475
476 if (var_value > 0)
477 dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
478 else
479 dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
480 }
481 if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
482 {
483 Float64 var_value = data.getPopulation(src_scale * 2);
484
485 if (var_value > 0)
486 dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
487 else
488 dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
489 }
490 if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
491 {
492 Float64 var_value = data.getSample(src_scale * 2);
493
494 if (var_value > 0)
495 dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
496 else
497 dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
498 }
499 }
500 else
501 {
502 if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
503 dst.push_back(data.getPopulation());
504 if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
505 dst.push_back(data.getSample());
506 if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
507 dst.push_back(sqrt(data.getPopulation()));
508 if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
509 dst.push_back(sqrt(data.getSample()));
510 if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
511 {
512 ResultType var_value = data.getPopulation();
513
514 if (var_value > 0)
515 dst.push_back(data.getMoment3() / pow(var_value, 1.5));
516 else
517 dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
518 }
519 if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
520 {
521 ResultType var_value = data.getSample();
522
523 if (var_value > 0)
524 dst.push_back(data.getMoment3() / pow(var_value, 1.5));
525 else
526 dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
527 }
528 if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
529 {
530 ResultType var_value = data.getPopulation();
531
532 if (var_value > 0)
533 dst.push_back(data.getMoment4() / pow(var_value, 2));
534 else
535 dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
536 }
537 if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
538 {
539 ResultType var_value = data.getSample();
540
541 if (var_value > 0)
542 dst.push_back(data.getMoment4() / pow(var_value, 2));
543 else
544 dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
545 }
546 if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
547 dst.push_back(data.getPopulation());
548 if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
549 dst.push_back(data.getSample());
550 if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
551 dst.push_back(data.get());
552 }
553 }
554
555private:
556 UInt32 src_scale;
557};
558
559
560template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop, 2>>;
561template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp, 2>>;
562template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop, 2>>;
563template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp, 2>>;
564template <typename T> using AggregateFunctionSkewPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewPop, 3>>;
565template <typename T> using AggregateFunctionSkewSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewSamp, 3>>;
566template <typename T> using AggregateFunctionKurtPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtPop, 4>>;
567template <typename T> using AggregateFunctionKurtSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtSamp, 4>>;
568template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarPop>>;
569template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarSamp>>;
570template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::corr>>;
571
572}
573