1 | #pragma once |
2 | |
3 | #include <type_traits> |
4 | |
5 | #include <IO/WriteHelpers.h> |
6 | #include <IO/ReadHelpers.h> |
7 | |
8 | #include <DataTypes/DataTypesNumber.h> |
9 | #include <DataTypes/DataTypesDecimal.h> |
10 | #include <Columns/ColumnVector.h> |
11 | |
12 | #include <AggregateFunctions/IAggregateFunction.h> |
13 | |
14 | |
15 | namespace DB |
16 | { |
17 | |
18 | template <typename T> |
19 | struct AggregateFunctionSumData |
20 | { |
21 | T sum{}; |
22 | |
23 | void add(T value) |
24 | { |
25 | sum += value; |
26 | } |
27 | |
28 | void merge(const AggregateFunctionSumData & rhs) |
29 | { |
30 | sum += rhs.sum; |
31 | } |
32 | |
33 | void write(WriteBuffer & buf) const |
34 | { |
35 | writeBinary(sum, buf); |
36 | } |
37 | |
38 | void read(ReadBuffer & buf) |
39 | { |
40 | readBinary(sum, buf); |
41 | } |
42 | |
43 | T get() const |
44 | { |
45 | return sum; |
46 | } |
47 | }; |
48 | |
49 | template <typename T> |
50 | struct AggregateFunctionSumKahanData |
51 | { |
52 | static_assert(std::is_floating_point_v<T>, |
53 | "It doesn't make sense to use Kahan Summation algorithm for non floating point types" ); |
54 | |
55 | T sum{}; |
56 | T compensation{}; |
57 | |
58 | void add(T value) |
59 | { |
60 | auto compensated_value = value - compensation; |
61 | auto new_sum = sum + compensated_value; |
62 | compensation = (new_sum - sum) - compensated_value; |
63 | sum = new_sum; |
64 | } |
65 | |
66 | void merge(const AggregateFunctionSumKahanData & rhs) |
67 | { |
68 | auto raw_sum = sum + rhs.sum; |
69 | auto rhs_compensated = raw_sum - sum; |
70 | auto compensations = ((rhs.sum - rhs_compensated) + (sum - (raw_sum - rhs_compensated))) + compensation + rhs.compensation; |
71 | sum = raw_sum + compensations; |
72 | compensation = compensations - (sum - raw_sum); |
73 | } |
74 | |
75 | void write(WriteBuffer & buf) const |
76 | { |
77 | writeBinary(sum, buf); |
78 | writeBinary(compensation, buf); |
79 | } |
80 | |
81 | void read(ReadBuffer & buf) |
82 | { |
83 | readBinary(sum, buf); |
84 | readBinary(compensation, buf); |
85 | } |
86 | |
87 | T get() const |
88 | { |
89 | return sum; |
90 | } |
91 | }; |
92 | |
93 | |
94 | /// Counts the sum of the numbers. |
95 | template <typename T, typename TResult, typename Data> |
96 | class AggregateFunctionSum final : public IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>> |
97 | { |
98 | public: |
99 | using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<TResult>, DataTypeNumber<TResult>>; |
100 | using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; |
101 | using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<TResult>, ColumnVector<TResult>>; |
102 | |
103 | String getName() const override { return "sum" ; } |
104 | |
105 | AggregateFunctionSum(const DataTypes & argument_types_) |
106 | : IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types_, {}) |
107 | , scale(0) |
108 | {} |
109 | |
110 | AggregateFunctionSum(const IDataType & data_type, const DataTypes & argument_types_) |
111 | : IAggregateFunctionDataHelper<Data, AggregateFunctionSum<T, TResult, Data>>(argument_types_, {}) |
112 | , scale(getDecimalScale(data_type)) |
113 | {} |
114 | |
115 | DataTypePtr getReturnType() const override |
116 | { |
117 | if constexpr (IsDecimalNumber<T>) |
118 | return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale); |
119 | else |
120 | return std::make_shared<ResultDataType>(); |
121 | } |
122 | |
123 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override |
124 | { |
125 | const auto & column = static_cast<const ColVecType &>(*columns[0]); |
126 | this->data(place).add(column.getData()[row_num]); |
127 | } |
128 | |
129 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
130 | { |
131 | this->data(place).merge(this->data(rhs)); |
132 | } |
133 | |
134 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
135 | { |
136 | this->data(place).write(buf); |
137 | } |
138 | |
139 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
140 | { |
141 | this->data(place).read(buf); |
142 | } |
143 | |
144 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
145 | { |
146 | auto & column = static_cast<ColVecResult &>(to); |
147 | column.getData().push_back(this->data(place).get()); |
148 | } |
149 | |
150 | private: |
151 | UInt32 scale; |
152 | }; |
153 | |
154 | } |
155 | |