1 | #pragma once |
2 | |
3 | #include <Columns/ColumnArray.h> |
4 | #include <Common/assert_cast.h> |
5 | #include <DataTypes/DataTypeArray.h> |
6 | #include <AggregateFunctions/IAggregateFunction.h> |
7 | #include <IO/WriteHelpers.h> |
8 | |
9 | |
10 | namespace DB |
11 | { |
12 | |
13 | namespace ErrorCodes |
14 | { |
15 | extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; |
16 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
17 | } |
18 | |
19 | |
20 | /** Not an aggregate function, but an adapter of aggregate functions, |
21 | * which any aggregate function `agg(x)` makes an aggregate function of the form `aggArray(x)`. |
22 | * The adapted aggregate function calculates nested aggregate function for each element of the array. |
23 | */ |
24 | class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFunctionArray> |
25 | { |
26 | private: |
27 | AggregateFunctionPtr nested_func; |
28 | size_t num_arguments; |
29 | |
30 | public: |
31 | AggregateFunctionArray(AggregateFunctionPtr nested_, const DataTypes & arguments) |
32 | : IAggregateFunctionHelper<AggregateFunctionArray>(arguments, {}) |
33 | , nested_func(nested_), num_arguments(arguments.size()) |
34 | { |
35 | for (const auto & type : arguments) |
36 | if (!isArray(type)) |
37 | throw Exception("All arguments for aggregate function " + getName() + " must be arrays" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
38 | } |
39 | |
40 | String getName() const override |
41 | { |
42 | return nested_func->getName() + "Array" ; |
43 | } |
44 | |
45 | DataTypePtr getReturnType() const override |
46 | { |
47 | return nested_func->getReturnType(); |
48 | } |
49 | |
50 | void create(AggregateDataPtr place) const override |
51 | { |
52 | nested_func->create(place); |
53 | } |
54 | |
55 | void destroy(AggregateDataPtr place) const noexcept override |
56 | { |
57 | nested_func->destroy(place); |
58 | } |
59 | |
60 | bool hasTrivialDestructor() const override |
61 | { |
62 | return nested_func->hasTrivialDestructor(); |
63 | } |
64 | |
65 | size_t sizeOfData() const override |
66 | { |
67 | return nested_func->sizeOfData(); |
68 | } |
69 | |
70 | size_t alignOfData() const override |
71 | { |
72 | return nested_func->alignOfData(); |
73 | } |
74 | |
75 | bool isState() const override |
76 | { |
77 | return nested_func->isState(); |
78 | } |
79 | |
80 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
81 | { |
82 | const IColumn * nested[num_arguments]; |
83 | |
84 | for (size_t i = 0; i < num_arguments; ++i) |
85 | nested[i] = &assert_cast<const ColumnArray &>(*columns[i]).getData(); |
86 | |
87 | const ColumnArray & first_array_column = assert_cast<const ColumnArray &>(*columns[0]); |
88 | const IColumn::Offsets & offsets = first_array_column.getOffsets(); |
89 | |
90 | size_t begin = offsets[row_num - 1]; |
91 | size_t end = offsets[row_num]; |
92 | |
93 | /// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance. |
94 | for (size_t i = 1; i < num_arguments; ++i) |
95 | { |
96 | const ColumnArray & ith_column = assert_cast<const ColumnArray &>(*columns[i]); |
97 | const IColumn::Offsets & ith_offsets = ith_column.getOffsets(); |
98 | |
99 | if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin)) |
100 | throw Exception("Arrays passed to " + getName() + " aggregate function have different sizes" , ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); |
101 | } |
102 | |
103 | for (size_t i = begin; i < end; ++i) |
104 | nested_func->add(place, nested, i, arena); |
105 | } |
106 | |
107 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override |
108 | { |
109 | nested_func->merge(place, rhs, arena); |
110 | } |
111 | |
112 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
113 | { |
114 | nested_func->serialize(place, buf); |
115 | } |
116 | |
117 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override |
118 | { |
119 | nested_func->deserialize(place, buf, arena); |
120 | } |
121 | |
122 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
123 | { |
124 | nested_func->insertResultInto(place, to); |
125 | } |
126 | |
127 | bool allocatesMemoryInArena() const override |
128 | { |
129 | return nested_func->allocatesMemoryInArena(); |
130 | } |
131 | |
132 | AggregateFunctionPtr getNestedFunction() const { return nested_func; } |
133 | }; |
134 | |
135 | } |
136 | |