1#pragma once
2
3#include <type_traits>
4
5#include <IO/WriteHelpers.h>
6#include <IO/ReadHelpers.h>
7
8#include <DataTypes/DataTypesNumber.h>
9#include <DataTypes/DataTypesDecimal.h>
10#include <Columns/ColumnVector.h>
11
12#include <AggregateFunctions/IAggregateFunction.h>
13
14
15namespace DB
16{
17
18template <typename T>
19struct AggregateFunctionSumData
20{
21 T sum{};
22
23 void add(T value)
24 {
25 sum += value;
26 }
27
28 void merge(const AggregateFunctionSumData & rhs)
29 {
30 sum += rhs.sum;
31 }
32
33 void write(WriteBuffer & buf) const
34 {
35 writeBinary(sum, buf);
36 }
37
38 void read(ReadBuffer & buf)
39 {
40 readBinary(sum, buf);
41 }
42
43 T get() const
44 {
45 return sum;
46 }
47};
48
49template <typename T>
50struct AggregateFunctionSumKahanData
51{
52 static_assert(std::is_floating_point_v<T>,
53 "It doesn't make sense to use Kahan Summation algorithm for non floating point types");
54
55 T sum{};
56 T compensation{};
57
58 void add(T value)
59 {
60 auto compensated_value = value - compensation;
61 auto new_sum = sum + compensated_value;
62 compensation = (new_sum - sum) - compensated_value;
63 sum = new_sum;
64 }
65
66 void merge(const AggregateFunctionSumKahanData & rhs)
67 {
68 auto raw_sum = sum + rhs.sum;
69 auto rhs_compensated = raw_sum - sum;
70 auto compensations = ((rhs.sum - rhs_compensated) + (sum - (raw_sum - rhs_compensated))) + compensation + rhs.compensation;
71 sum = raw_sum + compensations;
72 compensation = compensations - (sum - raw_sum);
73 }
74
75 void write(WriteBuffer & buf) const
76 {
77 writeBinary(sum, buf);
78 writeBinary(compensation, buf);
79 }
80
81 void read(ReadBuffer & buf)
82 {
83 readBinary(sum, buf);
84 readBinary(compensation, buf);
85 }
86
87 T get() const
88 {
89 return sum;
90 }
91};
92
93
94/// Counts the sum of the numbers.
95template <typename T, typename TResult, typename Data>
96class AggregateFunctionSum final : public IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>
97{
98public:
99 using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<TResult>, DataTypeNumber<TResult>>;
100 using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
101 using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<TResult>, ColumnVector<TResult>>;
102
103 String getName() const override { return "sum"; }
104
105 AggregateFunctionSum(const DataTypes & argument_types_)
106 : IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types_, {})
107 , scale(0)
108 {}
109
110 AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types_)
111 : IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types_, {})
112 , scale(getDecimalScale(data_type))
113 {}
114
115 DataTypePtr getReturnType() const override
116 {
117 if constexpr (IsDecimalNumber<T>)
118 return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale);
119 else
120 return std::make_shared<ResultDataType>();
121 }
122
123 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
124 {
125 const auto & column = static_cast<const ColVecType &>(*columns[0]);
126 this->data(place).add(column.getData()[row_num]);
127 }
128
129 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
130 {
131 this->data(place).merge(this->data(rhs));
132 }
133
134 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
135 {
136 this->data(place).write(buf);
137 }
138
139 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
140 {
141 this->data(place).read(buf);
142 }
143
144 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
145 {
146 auto & column = static_cast<ColVecResult &>(to);
147 column.getData().push_back(this->data(place).get());
148 }
149
150private:
151 UInt32 scale;
152};
153
154}
155