1#include "duckdb/function/aggregate/algebraic_functions.hpp"
2#include "duckdb/common/vector_operations/vector_operations.hpp"
3#include "duckdb/function/function_set.hpp"
4
5#include <cmath>
6
7using namespace duckdb;
8using namespace std;
9
10struct stddev_state_t {
11 uint64_t count; // n
12 double mean; // M1
13 double dsquared; // M2
14};
15
16// Streaming approximate standard deviation using Welford's
17// method, DOI: 10.2307/1266577
18struct STDDevBaseOperation {
19 template <class STATE> static void Initialize(STATE *state) {
20 state->count = 0;
21 state->mean = 0;
22 state->dsquared = 0;
23 }
24
25 template <class INPUT_TYPE, class STATE, class OP>
26 static void Operation(STATE *state, INPUT_TYPE *input_data, nullmask_t &nullmask, idx_t idx) {
27 // update running mean and d^2
28 state->count++;
29 const double input = input_data[idx];
30 const double mean_differential = (input - state->mean) / state->count;
31 const double new_mean = state->mean + mean_differential;
32 const double dsquared_increment = (input - new_mean) * (input - state->mean);
33 const double new_dsquared = state->dsquared + dsquared_increment;
34
35 state->mean = new_mean;
36 state->dsquared = new_dsquared;
37 }
38
39 template <class INPUT_TYPE, class STATE, class OP>
40 static void ConstantOperation(STATE *state, INPUT_TYPE *input_data, nullmask_t &nullmask, idx_t count) {
41 for (idx_t i = 0; i < count; i++) {
42 Operation<INPUT_TYPE, STATE, OP>(state, input_data, nullmask, 0);
43 }
44 }
45
46 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
47 if (target->count == 0) {
48 *target = source;
49 } else if (source.count > 0) {
50 const auto count = target->count + source.count;
51 const auto mean = (source.count * source.mean + target->count * target->mean) / count;
52 const auto delta = source.mean - target->mean;
53 target->dsquared =
54 source.dsquared + target->dsquared + delta * delta * source.count * target->count / count;
55 target->mean = mean;
56 target->count = count;
57 }
58 }
59
60 static bool IgnoreNull() {
61 return true;
62 }
63};
64
65struct VarSampOperation : public STDDevBaseOperation {
66 template <class T, class STATE>
67 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
68 if (state->count == 0) {
69 nullmask[idx] = true;
70 } else {
71 target[idx] = state->count > 1 ? (state->dsquared / (state->count - 1)) : 0;
72 if (!Value::DoubleIsValid(target[idx])) {
73 throw OutOfRangeException("VARSAMP is out of range!");
74 }
75 }
76 }
77};
78
79struct VarPopOperation : public STDDevBaseOperation {
80 template <class T, class STATE>
81 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
82 if (state->count == 0) {
83 nullmask[idx] = true;
84 } else {
85 target[idx] = state->count > 1 ? (state->dsquared / state->count) : 0;
86 if (!Value::DoubleIsValid(target[idx])) {
87 throw OutOfRangeException("VARPOP is out of range!");
88 }
89 }
90 }
91};
92
93struct STDDevSampOperation : public STDDevBaseOperation {
94 template <class T, class STATE>
95 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
96 if (state->count == 0) {
97 nullmask[idx] = true;
98 } else {
99 target[idx] = state->count > 1 ? sqrt(state->dsquared / (state->count - 1)) : 0;
100 if (!Value::DoubleIsValid(target[idx])) {
101 throw OutOfRangeException("STDDEV_SAMP is out of range!");
102 }
103 }
104 }
105};
106
107struct STDDevPopOperation : public STDDevBaseOperation {
108 template <class T, class STATE>
109 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
110 if (state->count == 0) {
111 nullmask[idx] = true;
112 } else {
113 target[idx] = state->count > 1 ? sqrt(state->dsquared / state->count) : 0;
114 if (!Value::DoubleIsValid(target[idx])) {
115 throw OutOfRangeException("STDDEV_POP is out of range!");
116 }
117 }
118 }
119};
120
121void StdDevSampFun::RegisterFunction(BuiltinFunctions &set) {
122 AggregateFunctionSet stddev_samp("stddev_samp");
123 stddev_samp.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, STDDevSampOperation>(
124 SQLType::DOUBLE, SQLType::DOUBLE));
125 set.AddFunction(stddev_samp);
126}
127
128void StdDevPopFun::RegisterFunction(BuiltinFunctions &set) {
129 AggregateFunctionSet stddev_pop("stddev_pop");
130 stddev_pop.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, STDDevPopOperation>(
131 SQLType::DOUBLE, SQLType::DOUBLE));
132 set.AddFunction(stddev_pop);
133}
134
135void VarPopFun::RegisterFunction(BuiltinFunctions &set) {
136 AggregateFunctionSet var_pop("var_pop");
137 var_pop.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, VarPopOperation>(
138 SQLType::DOUBLE, SQLType::DOUBLE));
139 set.AddFunction(var_pop);
140}
141
142void VarSampFun::RegisterFunction(BuiltinFunctions &set) {
143 AggregateFunctionSet var_samp("var_samp");
144 var_samp.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, VarSampOperation>(
145 SQLType::DOUBLE, SQLType::DOUBLE));
146 set.AddFunction(var_samp);
147}
148