1#include <Common/FieldVisitors.h>
2
3#include <IO/WriteHelpers.h>
4#include <IO/ReadHelpers.h>
5
6#include <Columns/ColumnAggregateFunction.h>
7
8#include <Common/typeid_cast.h>
9#include <Common/assert_cast.h>
10#include <Common/AlignedBuffer.h>
11
12#include <Formats/FormatSettings.h>
13#include <Formats/ProtobufReader.h>
14#include <Formats/ProtobufWriter.h>
15#include <DataTypes/DataTypeAggregateFunction.h>
16#include <DataTypes/DataTypeFactory.h>
17
18#include <AggregateFunctions/AggregateFunctionFactory.h>
19#include <Parsers/ASTFunction.h>
20#include <Parsers/ASTLiteral.h>
21#include <Parsers/ASTIdentifier.h>
22
23
24namespace DB
25{
26
27namespace ErrorCodes
28{
29 extern const int SYNTAX_ERROR;
30 extern const int BAD_ARGUMENTS;
31 extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
32 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
33 extern const int LOGICAL_ERROR;
34 extern const int NOT_IMPLEMENTED;
35}
36
37
38std::string DataTypeAggregateFunction::doGetName() const
39{
40 std::stringstream stream;
41 stream << "AggregateFunction(" << function->getName();
42
43 if (!parameters.empty())
44 {
45 stream << "(";
46 for (size_t i = 0; i < parameters.size(); ++i)
47 {
48 if (i)
49 stream << ", ";
50 stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]);
51 }
52 stream << ")";
53 }
54
55 for (const auto & argument_type : argument_types)
56 stream << ", " << argument_type->getName();
57
58 stream << ")";
59 return stream.str();
60}
61
62void DataTypeAggregateFunction::serializeBinary(const Field & field, WriteBuffer & ostr) const
63{
64 const String & s = get<const String &>(field);
65 writeVarUInt(s.size(), ostr);
66 writeString(s, ostr);
67}
68
69void DataTypeAggregateFunction::deserializeBinary(Field & field, ReadBuffer & istr) const
70{
71 UInt64 size;
72 readVarUInt(size, istr);
73 field = String();
74 String & s = get<String &>(field);
75 s.resize(size);
76 istr.readStrict(s.data(), size);
77}
78
79void DataTypeAggregateFunction::serializeBinary(const IColumn & column, size_t row_num, WriteBuffer & ostr) const
80{
81 function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], ostr);
82}
83
84void DataTypeAggregateFunction::deserializeBinary(IColumn & column, ReadBuffer & istr) const
85{
86 ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column);
87
88 Arena & arena = column_concrete.createOrGetArena();
89 size_t size_of_state = function->sizeOfData();
90 AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData());
91
92 function->create(place);
93 try
94 {
95 function->deserialize(place, istr, &arena);
96 }
97 catch (...)
98 {
99 function->destroy(place);
100 throw;
101 }
102
103 column_concrete.getData().push_back(place);
104}
105
106void DataTypeAggregateFunction::serializeBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const
107{
108 const ColumnAggregateFunction & real_column = typeid_cast<const ColumnAggregateFunction &>(column);
109 const ColumnAggregateFunction::Container & vec = real_column.getData();
110
111 ColumnAggregateFunction::Container::const_iterator it = vec.begin() + offset;
112 ColumnAggregateFunction::Container::const_iterator end = limit ? it + limit : vec.end();
113
114 if (end > vec.end())
115 end = vec.end();
116
117 for (; it != end; ++it)
118 function->serialize(*it, ostr);
119}
120
121void DataTypeAggregateFunction::deserializeBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double /*avg_value_size_hint*/) const
122{
123 ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(column);
124 ColumnAggregateFunction::Container & vec = real_column.getData();
125
126 Arena & arena = real_column.createOrGetArena();
127 real_column.set(function);
128 vec.reserve(vec.size() + limit);
129
130 size_t size_of_state = function->sizeOfData();
131 size_t align_of_state = function->alignOfData();
132
133 for (size_t i = 0; i < limit; ++i)
134 {
135 if (istr.eof())
136 break;
137
138 AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state);
139
140 function->create(place);
141
142 try
143 {
144 function->deserialize(place, istr, &arena);
145 }
146 catch (...)
147 {
148 function->destroy(place);
149 throw;
150 }
151
152 vec.push_back(place);
153 }
154}
155
156static String serializeToString(const AggregateFunctionPtr & function, const IColumn & column, size_t row_num)
157{
158 WriteBufferFromOwnString buffer;
159 function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], buffer);
160 return buffer.str();
161}
162
163static void deserializeFromString(const AggregateFunctionPtr & function, IColumn & column, const String & s)
164{
165 ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column);
166
167 Arena & arena = column_concrete.createOrGetArena();
168 size_t size_of_state = function->sizeOfData();
169 AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData());
170
171 function->create(place);
172
173 try
174 {
175 ReadBufferFromString istr(s);
176 function->deserialize(place, istr, &arena);
177 }
178 catch (...)
179 {
180 function->destroy(place);
181 throw;
182 }
183
184 column_concrete.getData().push_back(place);
185}
186
187void DataTypeAggregateFunction::serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const
188{
189 writeString(serializeToString(function, column, row_num), ostr);
190}
191
192
193void DataTypeAggregateFunction::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const
194{
195 writeEscapedString(serializeToString(function, column, row_num), ostr);
196}
197
198
199void DataTypeAggregateFunction::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings &) const
200{
201 String s;
202 readEscapedString(s, istr);
203 deserializeFromString(function, column, s);
204}
205
206
207void DataTypeAggregateFunction::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const
208{
209 writeQuotedString(serializeToString(function, column, row_num), ostr);
210}
211
212
213void DataTypeAggregateFunction::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings &) const
214{
215 String s;
216 readQuotedStringWithSQLStyle(s, istr);
217 deserializeFromString(function, column, s);
218}
219
220
221void DataTypeAggregateFunction::deserializeWholeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const
222{
223 String s;
224 readStringUntilEOF(s, istr);
225 deserializeFromString(function, column, s);
226}
227
228
229void DataTypeAggregateFunction::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const
230{
231 writeJSONString(serializeToString(function, column, row_num), ostr, settings);
232}
233
234
235void DataTypeAggregateFunction::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings &) const
236{
237 String s;
238 readJSONString(s, istr);
239 deserializeFromString(function, column, s);
240}
241
242
243void DataTypeAggregateFunction::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const
244{
245 writeXMLString(serializeToString(function, column, row_num), ostr);
246}
247
248
249void DataTypeAggregateFunction::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const
250{
251 writeCSV(serializeToString(function, column, row_num), ostr);
252}
253
254
255void DataTypeAggregateFunction::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const
256{
257 String s;
258 readCSV(s, istr, settings.csv);
259 deserializeFromString(function, column, s);
260}
261
262
263void DataTypeAggregateFunction::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const
264{
265 if (value_index)
266 return;
267 value_index = static_cast<bool>(
268 protobuf.writeAggregateFunction(function, assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num]));
269}
270
271void DataTypeAggregateFunction::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const
272{
273 row_added = false;
274 ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column);
275 Arena & arena = column_concrete.createOrGetArena();
276 size_t size_of_state = function->sizeOfData();
277 AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData());
278 function->create(place);
279 try
280 {
281 if (!protobuf.readAggregateFunction(function, place, arena))
282 {
283 function->destroy(place);
284 return;
285 }
286 auto & container = column_concrete.getData();
287 if (allow_add_row)
288 {
289 container.emplace_back(place);
290 row_added = true;
291 }
292 else
293 container.back() = place;
294 }
295 catch (...)
296 {
297 function->destroy(place);
298 throw;
299 }
300}
301
302MutableColumnPtr DataTypeAggregateFunction::createColumn() const
303{
304 return ColumnAggregateFunction::create(function);
305}
306
307
308/// Create empty state
309Field DataTypeAggregateFunction::getDefault() const
310{
311 Field field = AggregateFunctionStateData();
312 field.get<AggregateFunctionStateData &>().name = getName();
313
314 AlignedBuffer place_buffer(function->sizeOfData(), function->alignOfData());
315 AggregateDataPtr place = place_buffer.data();
316
317 function->create(place);
318
319 try
320 {
321 WriteBufferFromString buffer_from_field(field.get<AggregateFunctionStateData &>().data);
322 function->serialize(place, buffer_from_field);
323 }
324 catch (...)
325 {
326 function->destroy(place);
327 throw;
328 }
329
330 function->destroy(place);
331
332 return field;
333}
334
335
336bool DataTypeAggregateFunction::equals(const IDataType & rhs) const
337{
338 return typeid(rhs) == typeid(*this) && getName() == rhs.getName();
339}
340
341
342static DataTypePtr create(const String & /*type_name*/, const ASTPtr & arguments)
343{
344 String function_name;
345 AggregateFunctionPtr function;
346 DataTypes argument_types;
347 Array params_row;
348
349 if (!arguments || arguments->children.empty())
350 throw Exception("Data type AggregateFunction requires parameters: "
351 "name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
352
353 if (const auto * parametric = arguments->children[0]->as<ASTFunction>())
354 {
355 if (parametric->parameters)
356 throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR);
357 function_name = parametric->name;
358
359 const ASTs & parameters = parametric->arguments->children;
360 params_row.resize(parameters.size());
361
362 for (size_t i = 0; i < parameters.size(); ++i)
363 {
364 const auto * literal = parameters[i]->as<ASTLiteral>();
365 if (!literal)
366 throw Exception("Parameters to aggregate functions must be literals",
367 ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
368
369 params_row[i] = literal->value;
370 }
371 }
372 else if (auto opt_name = tryGetIdentifierName(arguments->children[0]))
373 {
374 function_name = *opt_name;
375 }
376 else if (arguments->children[0]->as<ASTLiteral>())
377 {
378 throw Exception("Aggregate function name for data type AggregateFunction must be passed as identifier (without quotes) or function",
379 ErrorCodes::BAD_ARGUMENTS);
380 }
381 else
382 throw Exception("Unexpected AST element passed as aggregate function name for data type AggregateFunction. Must be identifier or function.",
383 ErrorCodes::BAD_ARGUMENTS);
384
385 for (size_t i = 1; i < arguments->children.size(); ++i)
386 argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i]));
387
388 if (function_name.empty())
389 throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
390
391 function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
392 return std::make_shared<DataTypeAggregateFunction>(function, argument_types, params_row);
393}
394
395void registerDataTypeAggregateFunction(DataTypeFactory & factory)
396{
397 factory.registerDataType("AggregateFunction", create);
398}
399
400
401}
402