1#include "duckdb/function/aggregate/distributive_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/types/null_value.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include "duckdb/common/vector_operations/aggregate_executor.hpp"
6#include "duckdb/common/operator/numeric_binary_operators.hpp"
7
8using namespace std;
9
10namespace duckdb {
11
12struct sum_state_t {
13 double value;
14 bool isset;
15};
16
17struct SumOperation {
18 template <class STATE> static void Initialize(STATE *state) {
19 state->value = 0;
20 state->isset = false;
21 }
22
23 template <class INPUT_TYPE, class STATE, class OP>
24 static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) {
25 state->isset = true;
26 state->value += input[idx];
27 }
28
29 template <class INPUT_TYPE, class STATE, class OP>
30 static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) {
31 state->isset = true;
32 state->value += (double)input[0] * (double)count;
33 }
34
35 template <class T, class STATE>
36 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
37 if (!state->isset) {
38 nullmask[idx] = true;
39 } else {
40 if (!Value::DoubleIsValid(state->value)) {
41 throw OutOfRangeException("SUM is out of range!");
42 }
43 target[idx] = state->value;
44 }
45 }
46
47 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
48 if (!source.isset) {
49 // source is NULL, nothing to do
50 return;
51 }
52 if (!target->isset) {
53 // target is NULL, use source value directly
54 *target = source;
55 } else {
56 // else perform the operation
57 target->value += source.value;
58 }
59 }
60
61 static bool IgnoreNull() {
62 return true;
63 }
64};
65
66void SumFun::RegisterFunction(BuiltinFunctions &set) {
67 AggregateFunctionSet sum("sum");
68 // integer sums to bigint
69 sum.AddFunction(AggregateFunction::UnaryAggregate<sum_state_t, int32_t, double, SumOperation>(SQLType::INTEGER,
70 SQLType::DOUBLE));
71 sum.AddFunction(AggregateFunction::UnaryAggregate<sum_state_t, int64_t, double, SumOperation>(SQLType::BIGINT,
72 SQLType::DOUBLE));
73 // float sums to float
74 sum.AddFunction(
75 AggregateFunction::UnaryAggregate<sum_state_t, double, double, SumOperation>(SQLType::DOUBLE, SQLType::DOUBLE));
76
77 set.AddFunction(sum);
78}
79
80} // namespace duckdb
81