1#pragma once
2
3#include <IO/ReadHelpers.h>
4#include <IO/WriteHelpers.h>
5
6#include <Columns/ColumnsNumber.h>
7#include <DataTypes/DataTypesDecimal.h>
8#include <DataTypes/DataTypesNumber.h>
9
10#include <AggregateFunctions/IAggregateFunction.h>
11
12
13namespace DB
14{
15namespace ErrorCodes
16{
17 extern const int LOGICAL_ERROR;
18}
19
20template <typename T, typename Denominator>
21struct AggregateFunctionAvgData
22{
23 T numerator = 0;
24 Denominator denominator = 0;
25
26 template <typename ResultT>
27 ResultT NO_SANITIZE_UNDEFINED result() const
28 {
29 if constexpr (std::is_floating_point_v<ResultT>)
30 if constexpr (std::numeric_limits<ResultT>::is_iec559)
31 return static_cast<ResultT>(numerator) / denominator; /// allow division by zero
32
33 if (denominator == 0)
34 return static_cast<ResultT>(0);
35 return static_cast<ResultT>(numerator / denominator);
36 }
37};
38
39/// Calculates arithmetic mean of numbers.
40template <typename T, typename Data, typename Derived>
41class AggregateFunctionAvgBase : public IAggregateFunctionDataHelper<Data, Derived>
42{
43public:
44 using ResultType = std::conditional_t<IsDecimalNumber<T>, T, Float64>;
45 using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<T>, DataTypeNumber<Float64>>;
46 using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
47 using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<Float64>>;
48
49 /// ctor for native types
50 AggregateFunctionAvgBase(const DataTypes & argument_types_) : IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(0) {}
51
52 /// ctor for Decimals
53 AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_)
54 : IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(getDecimalScale(data_type))
55 {
56 }
57
58 DataTypePtr getReturnType() const override
59 {
60 if constexpr (IsDecimalNumber<T>)
61 return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale);
62 else
63 return std::make_shared<ResultDataType>();
64 }
65
66 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
67 {
68 this->data(place).numerator += this->data(rhs).numerator;
69 this->data(place).denominator += this->data(rhs).denominator;
70 }
71
72 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
73 {
74 writeBinary(this->data(place).numerator, buf);
75 writeVarUInt(this->data(place).denominator, buf);
76 }
77
78 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
79 {
80 readBinary(this->data(place).numerator, buf);
81 readVarUInt(this->data(place).denominator, buf);
82 }
83
84 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
85 {
86 auto & column = static_cast<ColVecResult &>(to);
87 column.getData().push_back(this->data(place).template result<ResultType>());
88 }
89
90protected:
91 UInt32 scale;
92};
93
94template <typename T, typename Data>
95class AggregateFunctionAvg final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>
96{
97public:
98 using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>::AggregateFunctionAvgBase;
99
100 using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
101 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
102 {
103 const auto & column = static_cast<const ColVecType &>(*columns[0]);
104 this->data(place).numerator += column.getData()[row_num];
105 this->data(place).denominator += 1;
106 }
107
108 String getName() const override { return "avg"; }
109};
110
111}
112