1#include "duckdb/function/aggregate/algebraic_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/types/null_value.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include "duckdb/function/function_set.hpp"
6
7using namespace duckdb;
8using namespace std;
9
10template <class T> struct avg_state_t {
11 uint64_t count;
12 T sum;
13};
14
15struct AverageFunction {
16 template <class STATE> static void Initialize(STATE *state) {
17 state->count = 0;
18 state->sum = 0;
19 }
20
21 template <class INPUT_TYPE, class STATE, class OP>
22 static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) {
23 state->sum += input[idx];
24 state->count++;
25 }
26
27 template <class INPUT_TYPE, class STATE, class OP>
28 static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) {
29 state->count += count;
30 state->sum += input[0] * count;
31 }
32
33 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
34 target->count += source.count;
35 target->sum += source.sum;
36 }
37
38 template <class T, class STATE>
39 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
40 if (!Value::DoubleIsValid(state->sum)) {
41 throw OutOfRangeException("AVG is out of range!");
42 } else if (state->count == 0) {
43 nullmask[idx] = true;
44 } else {
45 target[idx] = state->sum / state->count;
46 }
47 }
48
49 static bool IgnoreNull() {
50 return true;
51 }
52};
53
54void AvgFun::RegisterFunction(BuiltinFunctions &set) {
55 AggregateFunctionSet avg("avg");
56 avg.AddFunction(AggregateFunction::UnaryAggregate<avg_state_t<double>, double, double, AverageFunction>(
57 SQLType::DOUBLE, SQLType::DOUBLE));
58 set.AddFunction(avg);
59}
60