1#pragma once
2
3#include <DataTypes/DataTypesNumber.h>
4#include <Columns/ColumnsNumber.h>
5#include <Common/assert_cast.h>
6#include <AggregateFunctions/IAggregateFunction.h>
7
8
9namespace DB
10{
11
12namespace ErrorCodes
13{
14 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
15 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
16}
17
18/** Not an aggregate function, but an adapter of aggregate functions,
19 * which any aggregate function `agg(x)` makes an aggregate function of the form `aggIf(x, cond)`.
20 * The adapted aggregate function takes two arguments - a value and a condition,
21 * and calculates the nested aggregate function for the values when the condition is satisfied.
22 * For example, avgIf(x, cond) calculates the average x if `cond`.
23 */
24class AggregateFunctionIf final : public IAggregateFunctionHelper<AggregateFunctionIf>
25{
26private:
27 AggregateFunctionPtr nested_func;
28 size_t num_arguments;
29
30public:
31 AggregateFunctionIf(AggregateFunctionPtr nested, const DataTypes & types)
32 : IAggregateFunctionHelper<AggregateFunctionIf>(types, nested->getParameters())
33 , nested_func(nested), num_arguments(types.size())
34 {
35 if (num_arguments == 0)
36 throw Exception("Aggregate function " + getName() + " require at least one argument", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
37
38 if (!isUInt8(types.back()))
39 throw Exception("Last argument for aggregate function " + getName() + " must be UInt8", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
40 }
41
42 String getName() const override
43 {
44 return nested_func->getName() + "If";
45 }
46
47 DataTypePtr getReturnType() const override
48 {
49 return nested_func->getReturnType();
50 }
51
52 void create(AggregateDataPtr place) const override
53 {
54 nested_func->create(place);
55 }
56
57 void destroy(AggregateDataPtr place) const noexcept override
58 {
59 nested_func->destroy(place);
60 }
61
62 bool hasTrivialDestructor() const override
63 {
64 return nested_func->hasTrivialDestructor();
65 }
66
67 size_t sizeOfData() const override
68 {
69 return nested_func->sizeOfData();
70 }
71
72 size_t alignOfData() const override
73 {
74 return nested_func->alignOfData();
75 }
76
77 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
78 {
79 if (assert_cast<const ColumnUInt8 &>(*columns[num_arguments - 1]).getData()[row_num])
80 nested_func->add(place, columns, row_num, arena);
81 }
82
83 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
84 {
85 nested_func->merge(place, rhs, arena);
86 }
87
88 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
89 {
90 nested_func->serialize(place, buf);
91 }
92
93 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
94 {
95 nested_func->deserialize(place, buf, arena);
96 }
97
98 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
99 {
100 nested_func->insertResultInto(place, to);
101 }
102
103 bool allocatesMemoryInArena() const override
104 {
105 return nested_func->allocatesMemoryInArena();
106 }
107
108 bool isState() const override
109 {
110 return nested_func->isState();
111 }
112};
113
114}
115