1 | #pragma once |
2 | |
3 | #include <AggregateFunctions/AggregateFunctionAvg.h> |
4 | |
5 | namespace DB |
6 | { |
7 | template <typename T, typename Data> |
8 | class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>> |
9 | { |
10 | public: |
11 | using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>::AggregateFunctionAvgBase; |
12 | |
13 | using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; |
14 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override |
15 | { |
16 | const auto & values = static_cast<const ColVecType &>(*columns[0]); |
17 | const auto & weights = static_cast<const ColVecType &>(*columns[1]); |
18 | |
19 | this->data(place).numerator += values.getData()[row_num] * weights.getData()[row_num]; |
20 | this->data(place).denominator += weights.getData()[row_num]; |
21 | } |
22 | |
23 | String getName() const override { return "avgWeighted" ; } |
24 | }; |
25 | |
26 | } |
27 | |