1#pragma once
2
3#include <Columns/ColumnArray.h>
4#include <Common/assert_cast.h>
5#include <DataTypes/DataTypeArray.h>
6#include <AggregateFunctions/IAggregateFunction.h>
7#include <IO/WriteHelpers.h>
8
9
10namespace DB
11{
12
13namespace ErrorCodes
14{
15 extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
16 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
17}
18
19
20/** Not an aggregate function, but an adapter of aggregate functions,
21 * which any aggregate function `agg(x)` makes an aggregate function of the form `aggArray(x)`.
22 * The adapted aggregate function calculates nested aggregate function for each element of the array.
23 */
24class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFunctionArray>
25{
26private:
27 AggregateFunctionPtr nested_func;
28 size_t num_arguments;
29
30public:
31 AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments)
32 : IAggregateFunctionHelper<AggregateFunctionArray>(arguments, {})
33 , nested_func(nested_), num_arguments(arguments.size())
34 {
35 for (const auto & type : arguments)
36 if (!isArray(type))
37 throw Exception("All arguments for aggregate function " + getName() + " must be arrays", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
38 }
39
40 String getName() const override
41 {
42 return nested_func->getName() + "Array";
43 }
44
45 DataTypePtr getReturnType() const override
46 {
47 return nested_func->getReturnType();
48 }
49
50 void create(AggregateDataPtr place) const override
51 {
52 nested_func->create(place);
53 }
54
55 void destroy(AggregateDataPtr place) const noexcept override
56 {
57 nested_func->destroy(place);
58 }
59
60 bool hasTrivialDestructor() const override
61 {
62 return nested_func->hasTrivialDestructor();
63 }
64
65 size_t sizeOfData() const override
66 {
67 return nested_func->sizeOfData();
68 }
69
70 size_t alignOfData() const override
71 {
72 return nested_func->alignOfData();
73 }
74
75 bool isState() const override
76 {
77 return nested_func->isState();
78 }
79
80 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
81 {
82 const IColumn * nested[num_arguments];
83
84 for (size_t i = 0; i < num_arguments; ++i)
85 nested[i] = &assert_cast<const ColumnArray &>(*columns[i]).getData();
86
87 const ColumnArray & first_array_column = assert_cast<const ColumnArray &>(*columns[0]);
88 const IColumn::Offsets & offsets = first_array_column.getOffsets();
89
90 size_t begin = offsets[row_num - 1];
91 size_t end = offsets[row_num];
92
93 /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
94 for (size_t i = 1; i < num_arguments; ++i)
95 {
96 const ColumnArray & ith_column = assert_cast<const ColumnArray &>(*columns[i]);
97 const IColumn::Offsets & ith_offsets = ith_column.getOffsets();
98
99 if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin))
100 throw Exception("Arrays passed to " + getName() + " aggregate function have different sizes", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
101 }
102
103 for (size_t i = begin; i < end; ++i)
104 nested_func->add(place, nested, i, arena);
105 }
106
107 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
108 {
109 nested_func->merge(place, rhs, arena);
110 }
111
112 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
113 {
114 nested_func->serialize(place, buf);
115 }
116
117 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
118 {
119 nested_func->deserialize(place, buf, arena);
120 }
121
122 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
123 {
124 nested_func->insertResultInto(place, to);
125 }
126
127 bool allocatesMemoryInArena() const override
128 {
129 return nested_func->allocatesMemoryInArena();
130 }
131
132 AggregateFunctionPtr getNestedFunction() const { return nested_func; }
133};
134
135}
136