1#pragma once
2
3#include <array>
4#include <AggregateFunctions/IAggregateFunction.h>
5#include <Columns/ColumnNullable.h>
6#include <Common/assert_cast.h>
7#include <DataTypes/DataTypeNullable.h>
8#include <IO/ReadHelpers.h>
9#include <IO/WriteHelpers.h>
10
11
12namespace DB
13{
14
15namespace ErrorCodes
16{
17 extern const int LOGICAL_ERROR;
18 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
19}
20
21
22/// This class implements a wrapper around an aggregate function. Despite its name,
23/// this is an adapter. It is used to handle aggregate functions that are called with
24/// at least one nullable argument. It implements the logic according to which any
25/// row that contains at least one NULL is skipped.
26
27/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
28/// true - return NULL; false - return value from empty aggregation state of nested function.
29
30template <bool result_is_nullable, typename Derived>
31class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived>
32{
33protected:
34 AggregateFunctionPtr nested_function;
35 size_t prefix_size;
36
37 /** In addition to data for nested aggregate function, we keep a flag
38 * indicating - was there at least one non-NULL value accumulated.
39 * In case of no not-NULL values, the function will return NULL.
40 *
41 * We use prefix_size bytes for flag to satisfy the alignment requirement of nested state.
42 */
43
44 AggregateDataPtr nestedPlace(AggregateDataPtr place) const noexcept
45 {
46 return place + prefix_size;
47 }
48
49 ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) const noexcept
50 {
51 return place + prefix_size;
52 }
53
54 static void initFlag(AggregateDataPtr place) noexcept
55 {
56 if (result_is_nullable)
57 place[0] = 0;
58 }
59
60 static void setFlag(AggregateDataPtr place) noexcept
61 {
62 if (result_is_nullable)
63 place[0] = 1;
64 }
65
66 static bool getFlag(ConstAggregateDataPtr place) noexcept
67 {
68 return result_is_nullable ? place[0] : 1;
69 }
70
71public:
72 AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
73 : IAggregateFunctionHelper<Derived>(arguments, params), nested_function{nested_function_}
74 {
75 if (result_is_nullable)
76 prefix_size = nested_function->alignOfData();
77 else
78 prefix_size = 0;
79 }
80
81 String getName() const override
82 {
83 /// This is just a wrapper. The function for Nullable arguments is named the same as the nested function itself.
84 return nested_function->getName();
85 }
86
87 DataTypePtr getReturnType() const override
88 {
89 return result_is_nullable
90 ? makeNullable(nested_function->getReturnType())
91 : nested_function->getReturnType();
92 }
93
94 void create(AggregateDataPtr place) const override
95 {
96 initFlag(place);
97 nested_function->create(nestedPlace(place));
98 }
99
100 void destroy(AggregateDataPtr place) const noexcept override
101 {
102 nested_function->destroy(nestedPlace(place));
103 }
104
105 bool hasTrivialDestructor() const override
106 {
107 return nested_function->hasTrivialDestructor();
108 }
109
110 size_t sizeOfData() const override
111 {
112 return prefix_size + nested_function->sizeOfData();
113 }
114
115 size_t alignOfData() const override
116 {
117 return nested_function->alignOfData();
118 }
119
120 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
121 {
122 if (result_is_nullable && getFlag(rhs))
123 setFlag(place);
124
125 nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena);
126 }
127
128 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
129 {
130 bool flag = getFlag(place);
131 if (result_is_nullable)
132 writeBinary(flag, buf);
133 if (flag)
134 nested_function->serialize(nestedPlace(place), buf);
135 }
136
137 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
138 {
139 bool flag = 1;
140 if (result_is_nullable)
141 readBinary(flag, buf);
142 if (flag)
143 {
144 setFlag(place);
145 nested_function->deserialize(nestedPlace(place), buf, arena);
146 }
147 }
148
149 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
150 {
151 if (result_is_nullable)
152 {
153 ColumnNullable & to_concrete = assert_cast<ColumnNullable &>(to);
154 if (getFlag(place))
155 {
156 nested_function->insertResultInto(nestedPlace(place), to_concrete.getNestedColumn());
157 to_concrete.getNullMapData().push_back(0);
158 }
159 else
160 {
161 to_concrete.insertDefault();
162 }
163 }
164 else
165 {
166 nested_function->insertResultInto(nestedPlace(place), to);
167 }
168 }
169
170 bool allocatesMemoryInArena() const override
171 {
172 return nested_function->allocatesMemoryInArena();
173 }
174
175 bool isState() const override
176 {
177 return nested_function->isState();
178 }
179};
180
181
182/** There are two cases: for single argument and variadic.
183 * Code for single argument is much more efficient.
184 */
185template <bool result_is_nullable>
186class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>
187{
188public:
189 AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
190 : AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(std::move(nested_function_), arguments, params)
191 {
192 }
193
194 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
195 {
196 const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
197 if (!column->isNullAt(row_num))
198 {
199 this->setFlag(place);
200 const IColumn * nested_column = &column->getNestedColumn();
201 this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
202 }
203 }
204};
205
206
207template <bool result_is_nullable>
208class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>
209{
210public:
211 AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
212 : AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable>>(std::move(nested_function_), arguments, params),
213 number_of_arguments(arguments.size())
214 {
215 if (number_of_arguments == 1)
216 throw Exception("Logical error: single argument is passed to AggregateFunctionNullVariadic", ErrorCodes::LOGICAL_ERROR);
217
218 if (number_of_arguments > MAX_ARGS)
219 throw Exception("Maximum number of arguments for aggregate function with Nullable types is " + toString(size_t(MAX_ARGS)),
220 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
221
222 for (size_t i = 0; i < number_of_arguments; ++i)
223 is_nullable[i] = arguments[i]->isNullable();
224 }
225
226 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
227 {
228 /// This container stores the columns we really pass to the nested function.
229 const IColumn * nested_columns[number_of_arguments];
230
231 for (size_t i = 0; i < number_of_arguments; ++i)
232 {
233 if (is_nullable[i])
234 {
235 const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);
236 if (nullable_col.isNullAt(row_num))
237 {
238 /// If at least one column has a null value in the current row,
239 /// we don't process this row.
240 return;
241 }
242 nested_columns[i] = &nullable_col.getNestedColumn();
243 }
244 else
245 nested_columns[i] = columns[i];
246 }
247
248 this->setFlag(place);
249 this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
250 }
251
252 bool allocatesMemoryInArena() const override
253 {
254 return this->nested_function->allocatesMemoryInArena();
255 }
256
257private:
258 enum { MAX_ARGS = 8 };
259 size_t number_of_arguments = 0;
260 std::array<char, MAX_ARGS> is_nullable; /// Plain array is better than std::vector due to one indirection less.
261};
262
263}
264