1 | #include "duckdb/function/aggregate/algebraic_functions.hpp" |
2 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
3 | #include "duckdb/function/function_set.hpp" |
4 | |
5 | #include <cmath> |
6 | |
7 | using namespace duckdb; |
8 | using namespace std; |
9 | |
10 | struct stddev_state_t { |
11 | uint64_t count; // n |
12 | double mean; // M1 |
13 | double dsquared; // M2 |
14 | }; |
15 | |
16 | // Streaming approximate standard deviation using Welford's |
17 | // method, DOI: 10.2307/1266577 |
18 | struct STDDevBaseOperation { |
19 | template <class STATE> static void Initialize(STATE *state) { |
20 | state->count = 0; |
21 | state->mean = 0; |
22 | state->dsquared = 0; |
23 | } |
24 | |
25 | template <class INPUT_TYPE, class STATE, class OP> |
26 | static void Operation(STATE *state, INPUT_TYPE *input_data, nullmask_t &nullmask, idx_t idx) { |
27 | // update running mean and d^2 |
28 | state->count++; |
29 | const double input = input_data[idx]; |
30 | const double mean_differential = (input - state->mean) / state->count; |
31 | const double new_mean = state->mean + mean_differential; |
32 | const double dsquared_increment = (input - new_mean) * (input - state->mean); |
33 | const double new_dsquared = state->dsquared + dsquared_increment; |
34 | |
35 | state->mean = new_mean; |
36 | state->dsquared = new_dsquared; |
37 | } |
38 | |
39 | template <class INPUT_TYPE, class STATE, class OP> |
40 | static void ConstantOperation(STATE *state, INPUT_TYPE *input_data, nullmask_t &nullmask, idx_t count) { |
41 | for (idx_t i = 0; i < count; i++) { |
42 | Operation<INPUT_TYPE, STATE, OP>(state, input_data, nullmask, 0); |
43 | } |
44 | } |
45 | |
46 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
47 | if (target->count == 0) { |
48 | *target = source; |
49 | } else if (source.count > 0) { |
50 | const auto count = target->count + source.count; |
51 | const auto mean = (source.count * source.mean + target->count * target->mean) / count; |
52 | const auto delta = source.mean - target->mean; |
53 | target->dsquared = |
54 | source.dsquared + target->dsquared + delta * delta * source.count * target->count / count; |
55 | target->mean = mean; |
56 | target->count = count; |
57 | } |
58 | } |
59 | |
60 | static bool IgnoreNull() { |
61 | return true; |
62 | } |
63 | }; |
64 | |
65 | struct VarSampOperation : public STDDevBaseOperation { |
66 | template <class T, class STATE> |
67 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
68 | if (state->count == 0) { |
69 | nullmask[idx] = true; |
70 | } else { |
71 | target[idx] = state->count > 1 ? (state->dsquared / (state->count - 1)) : 0; |
72 | if (!Value::DoubleIsValid(target[idx])) { |
73 | throw OutOfRangeException("VARSAMP is out of range!" ); |
74 | } |
75 | } |
76 | } |
77 | }; |
78 | |
79 | struct VarPopOperation : public STDDevBaseOperation { |
80 | template <class T, class STATE> |
81 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
82 | if (state->count == 0) { |
83 | nullmask[idx] = true; |
84 | } else { |
85 | target[idx] = state->count > 1 ? (state->dsquared / state->count) : 0; |
86 | if (!Value::DoubleIsValid(target[idx])) { |
87 | throw OutOfRangeException("VARPOP is out of range!" ); |
88 | } |
89 | } |
90 | } |
91 | }; |
92 | |
93 | struct STDDevSampOperation : public STDDevBaseOperation { |
94 | template <class T, class STATE> |
95 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
96 | if (state->count == 0) { |
97 | nullmask[idx] = true; |
98 | } else { |
99 | target[idx] = state->count > 1 ? sqrt(state->dsquared / (state->count - 1)) : 0; |
100 | if (!Value::DoubleIsValid(target[idx])) { |
101 | throw OutOfRangeException("STDDEV_SAMP is out of range!" ); |
102 | } |
103 | } |
104 | } |
105 | }; |
106 | |
107 | struct STDDevPopOperation : public STDDevBaseOperation { |
108 | template <class T, class STATE> |
109 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
110 | if (state->count == 0) { |
111 | nullmask[idx] = true; |
112 | } else { |
113 | target[idx] = state->count > 1 ? sqrt(state->dsquared / state->count) : 0; |
114 | if (!Value::DoubleIsValid(target[idx])) { |
115 | throw OutOfRangeException("STDDEV_POP is out of range!" ); |
116 | } |
117 | } |
118 | } |
119 | }; |
120 | |
121 | void StdDevSampFun::RegisterFunction(BuiltinFunctions &set) { |
122 | AggregateFunctionSet stddev_samp("stddev_samp" ); |
123 | stddev_samp.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, STDDevSampOperation>( |
124 | SQLType::DOUBLE, SQLType::DOUBLE)); |
125 | set.AddFunction(stddev_samp); |
126 | } |
127 | |
128 | void StdDevPopFun::RegisterFunction(BuiltinFunctions &set) { |
129 | AggregateFunctionSet stddev_pop("stddev_pop" ); |
130 | stddev_pop.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, STDDevPopOperation>( |
131 | SQLType::DOUBLE, SQLType::DOUBLE)); |
132 | set.AddFunction(stddev_pop); |
133 | } |
134 | |
135 | void VarPopFun::RegisterFunction(BuiltinFunctions &set) { |
136 | AggregateFunctionSet var_pop("var_pop" ); |
137 | var_pop.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, VarPopOperation>( |
138 | SQLType::DOUBLE, SQLType::DOUBLE)); |
139 | set.AddFunction(var_pop); |
140 | } |
141 | |
142 | void VarSampFun::RegisterFunction(BuiltinFunctions &set) { |
143 | AggregateFunctionSet var_samp("var_samp" ); |
144 | var_samp.AddFunction(AggregateFunction::UnaryAggregate<stddev_state_t, double, double, VarSampOperation>( |
145 | SQLType::DOUBLE, SQLType::DOUBLE)); |
146 | set.AddFunction(var_samp); |
147 | } |
148 | |