1#include "duckdb/function/aggregate/algebraic_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/types/null_value.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include "duckdb/function/function_set.hpp"
6#include <cmath>
7
8using namespace duckdb;
9using namespace std;
10
11struct covar_state_t {
12 uint64_t count;
13 double meanx;
14 double meany;
15 double co_moment;
16};
17
18struct CovarOperation {
19 template <class STATE> static void Initialize(STATE *state) {
20 state->count = 0;
21 state->meanx = 0;
22 state->meany = 0;
23 state->co_moment = 0;
24 }
25
26 template <class A_TYPE, class B_TYPE, class STATE, class OP>
27 static void Operation(STATE *state, A_TYPE *x_data, B_TYPE *y_data, nullmask_t &anullmask, nullmask_t &bnullmask,
28 idx_t xidx, idx_t yidx) {
29 // update running mean and d^2
30 const uint64_t n = ++(state->count);
31
32 const auto x = x_data[xidx];
33 const double dx = (x - state->meanx);
34 const double meanx = state->meanx + dx / n;
35
36 const auto y = y_data[yidx];
37 const double dy = (y - state->meany);
38 const double meany = state->meany + dy / n;
39
40 const double C = state->co_moment + dx * (y - meany);
41
42 state->meanx = meanx;
43 state->meany = meany;
44 state->co_moment = C;
45 }
46
47 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
48 if (target->count == 0) {
49 *target = source;
50 } else if (source.count > 0) {
51 const auto count = target->count + source.count;
52 const auto meanx = (source.count * source.meanx + target->count * target->meanx) / count;
53 const auto meany = (source.count * source.meany + target->count * target->meany) / count;
54
55 // Schubert and Gertz SSDBM 2018, equation 21
56 const auto deltax = target->meanx - source.meanx;
57 const auto deltay = target->meany - source.meany;
58 target->co_moment =
59 source.co_moment + target->co_moment + deltax * deltay * source.count * target->count / count;
60 target->meanx = meanx;
61 target->meany = meany;
62 target->count = count;
63 }
64 }
65
66 static bool IgnoreNull() {
67 return true;
68 }
69};
70
71struct CovarPopOperation : public CovarOperation {
72 template <class T, class STATE>
73 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
74 if (state->count == 0) {
75 nullmask[idx] = true;
76 } else {
77 target[idx] = state->co_moment / state->count;
78 }
79 }
80};
81
82struct CovarSampOperation : public CovarOperation {
83 template <class T, class STATE>
84 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
85 if ((state->count) < 2) {
86 nullmask[idx] = true;
87 } else {
88 target[idx] = state->co_moment / (state->count - 1);
89 }
90 }
91};
92
93void CovarPopFun::RegisterFunction(BuiltinFunctions &set) {
94 AggregateFunctionSet covar_pop("covar_pop");
95 covar_pop.AddFunction(AggregateFunction::BinaryAggregate<covar_state_t, double, double, double, CovarPopOperation>(
96 SQLType::DOUBLE, SQLType::DOUBLE, SQLType::DOUBLE));
97 set.AddFunction(covar_pop);
98}
99
100void CovarSampFun::RegisterFunction(BuiltinFunctions &set) {
101 AggregateFunctionSet covar_samp("covar_samp");
102 covar_samp.AddFunction(
103 AggregateFunction::BinaryAggregate<covar_state_t, double, double, double, CovarSampOperation>(
104 SQLType::DOUBLE, SQLType::DOUBLE, SQLType::DOUBLE));
105 set.AddFunction(covar_samp);
106}
107