1 | #include <Common/FieldVisitors.h> |
2 | |
3 | #include <IO/WriteHelpers.h> |
4 | #include <IO/ReadHelpers.h> |
5 | |
6 | #include <Columns/ColumnAggregateFunction.h> |
7 | |
8 | #include <Common/typeid_cast.h> |
9 | #include <Common/assert_cast.h> |
10 | #include <Common/AlignedBuffer.h> |
11 | |
12 | #include <Formats/FormatSettings.h> |
13 | #include <Formats/ProtobufReader.h> |
14 | #include <Formats/ProtobufWriter.h> |
15 | #include <DataTypes/DataTypeAggregateFunction.h> |
16 | #include <DataTypes/DataTypeFactory.h> |
17 | |
18 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
19 | #include <Parsers/ASTFunction.h> |
20 | #include <Parsers/ASTLiteral.h> |
21 | #include <Parsers/ASTIdentifier.h> |
22 | |
23 | |
24 | namespace DB |
25 | { |
26 | |
27 | namespace ErrorCodes |
28 | { |
29 | extern const int SYNTAX_ERROR; |
30 | extern const int BAD_ARGUMENTS; |
31 | extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS; |
32 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
33 | extern const int LOGICAL_ERROR; |
34 | extern const int NOT_IMPLEMENTED; |
35 | } |
36 | |
37 | |
38 | std::string DataTypeAggregateFunction::doGetName() const |
39 | { |
40 | std::stringstream stream; |
41 | stream << "AggregateFunction(" << function->getName(); |
42 | |
43 | if (!parameters.empty()) |
44 | { |
45 | stream << "(" ; |
46 | for (size_t i = 0; i < parameters.size(); ++i) |
47 | { |
48 | if (i) |
49 | stream << ", " ; |
50 | stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]); |
51 | } |
52 | stream << ")" ; |
53 | } |
54 | |
55 | for (const auto & argument_type : argument_types) |
56 | stream << ", " << argument_type->getName(); |
57 | |
58 | stream << ")" ; |
59 | return stream.str(); |
60 | } |
61 | |
62 | void DataTypeAggregateFunction::serializeBinary(const Field & field, WriteBuffer & ostr) const |
63 | { |
64 | const String & s = get<const String &>(field); |
65 | writeVarUInt(s.size(), ostr); |
66 | writeString(s, ostr); |
67 | } |
68 | |
69 | void DataTypeAggregateFunction::deserializeBinary(Field & field, ReadBuffer & istr) const |
70 | { |
71 | UInt64 size; |
72 | readVarUInt(size, istr); |
73 | field = String(); |
74 | String & s = get<String &>(field); |
75 | s.resize(size); |
76 | istr.readStrict(s.data(), size); |
77 | } |
78 | |
79 | void DataTypeAggregateFunction::serializeBinary(const IColumn & column, size_t row_num, WriteBuffer & ostr) const |
80 | { |
81 | function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], ostr); |
82 | } |
83 | |
84 | void DataTypeAggregateFunction::deserializeBinary(IColumn & column, ReadBuffer & istr) const |
85 | { |
86 | ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column); |
87 | |
88 | Arena & arena = column_concrete.createOrGetArena(); |
89 | size_t size_of_state = function->sizeOfData(); |
90 | AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData()); |
91 | |
92 | function->create(place); |
93 | try |
94 | { |
95 | function->deserialize(place, istr, &arena); |
96 | } |
97 | catch (...) |
98 | { |
99 | function->destroy(place); |
100 | throw; |
101 | } |
102 | |
103 | column_concrete.getData().push_back(place); |
104 | } |
105 | |
106 | void DataTypeAggregateFunction::serializeBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const |
107 | { |
108 | const ColumnAggregateFunction & real_column = typeid_cast<const ColumnAggregateFunction &>(column); |
109 | const ColumnAggregateFunction::Container & vec = real_column.getData(); |
110 | |
111 | ColumnAggregateFunction::Container::const_iterator it = vec.begin() + offset; |
112 | ColumnAggregateFunction::Container::const_iterator end = limit ? it + limit : vec.end(); |
113 | |
114 | if (end > vec.end()) |
115 | end = vec.end(); |
116 | |
117 | for (; it != end; ++it) |
118 | function->serialize(*it, ostr); |
119 | } |
120 | |
121 | void DataTypeAggregateFunction::deserializeBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double /*avg_value_size_hint*/) const |
122 | { |
123 | ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(column); |
124 | ColumnAggregateFunction::Container & vec = real_column.getData(); |
125 | |
126 | Arena & arena = real_column.createOrGetArena(); |
127 | real_column.set(function); |
128 | vec.reserve(vec.size() + limit); |
129 | |
130 | size_t size_of_state = function->sizeOfData(); |
131 | size_t align_of_state = function->alignOfData(); |
132 | |
133 | for (size_t i = 0; i < limit; ++i) |
134 | { |
135 | if (istr.eof()) |
136 | break; |
137 | |
138 | AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); |
139 | |
140 | function->create(place); |
141 | |
142 | try |
143 | { |
144 | function->deserialize(place, istr, &arena); |
145 | } |
146 | catch (...) |
147 | { |
148 | function->destroy(place); |
149 | throw; |
150 | } |
151 | |
152 | vec.push_back(place); |
153 | } |
154 | } |
155 | |
156 | static String serializeToString(const AggregateFunctionPtr & function, const IColumn & column, size_t row_num) |
157 | { |
158 | WriteBufferFromOwnString buffer; |
159 | function->serialize(assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num], buffer); |
160 | return buffer.str(); |
161 | } |
162 | |
163 | static void deserializeFromString(const AggregateFunctionPtr & function, IColumn & column, const String & s) |
164 | { |
165 | ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column); |
166 | |
167 | Arena & arena = column_concrete.createOrGetArena(); |
168 | size_t size_of_state = function->sizeOfData(); |
169 | AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData()); |
170 | |
171 | function->create(place); |
172 | |
173 | try |
174 | { |
175 | ReadBufferFromString istr(s); |
176 | function->deserialize(place, istr, &arena); |
177 | } |
178 | catch (...) |
179 | { |
180 | function->destroy(place); |
181 | throw; |
182 | } |
183 | |
184 | column_concrete.getData().push_back(place); |
185 | } |
186 | |
187 | void DataTypeAggregateFunction::serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const |
188 | { |
189 | writeString(serializeToString(function, column, row_num), ostr); |
190 | } |
191 | |
192 | |
193 | void DataTypeAggregateFunction::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const |
194 | { |
195 | writeEscapedString(serializeToString(function, column, row_num), ostr); |
196 | } |
197 | |
198 | |
199 | void DataTypeAggregateFunction::deserializeTextEscaped(IColumn & column, ReadBuffer & istr, const FormatSettings &) const |
200 | { |
201 | String s; |
202 | readEscapedString(s, istr); |
203 | deserializeFromString(function, column, s); |
204 | } |
205 | |
206 | |
207 | void DataTypeAggregateFunction::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const |
208 | { |
209 | writeQuotedString(serializeToString(function, column, row_num), ostr); |
210 | } |
211 | |
212 | |
213 | void DataTypeAggregateFunction::deserializeTextQuoted(IColumn & column, ReadBuffer & istr, const FormatSettings &) const |
214 | { |
215 | String s; |
216 | readQuotedStringWithSQLStyle(s, istr); |
217 | deserializeFromString(function, column, s); |
218 | } |
219 | |
220 | |
221 | void DataTypeAggregateFunction::deserializeWholeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const |
222 | { |
223 | String s; |
224 | readStringUntilEOF(s, istr); |
225 | deserializeFromString(function, column, s); |
226 | } |
227 | |
228 | |
229 | void DataTypeAggregateFunction::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const |
230 | { |
231 | writeJSONString(serializeToString(function, column, row_num), ostr, settings); |
232 | } |
233 | |
234 | |
235 | void DataTypeAggregateFunction::deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings &) const |
236 | { |
237 | String s; |
238 | readJSONString(s, istr); |
239 | deserializeFromString(function, column, s); |
240 | } |
241 | |
242 | |
243 | void DataTypeAggregateFunction::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const |
244 | { |
245 | writeXMLString(serializeToString(function, column, row_num), ostr); |
246 | } |
247 | |
248 | |
249 | void DataTypeAggregateFunction::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const |
250 | { |
251 | writeCSV(serializeToString(function, column, row_num), ostr); |
252 | } |
253 | |
254 | |
255 | void DataTypeAggregateFunction::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const |
256 | { |
257 | String s; |
258 | readCSV(s, istr, settings.csv); |
259 | deserializeFromString(function, column, s); |
260 | } |
261 | |
262 | |
263 | void DataTypeAggregateFunction::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const |
264 | { |
265 | if (value_index) |
266 | return; |
267 | value_index = static_cast<bool>( |
268 | protobuf.writeAggregateFunction(function, assert_cast<const ColumnAggregateFunction &>(column).getData()[row_num])); |
269 | } |
270 | |
271 | void DataTypeAggregateFunction::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const |
272 | { |
273 | row_added = false; |
274 | ColumnAggregateFunction & column_concrete = assert_cast<ColumnAggregateFunction &>(column); |
275 | Arena & arena = column_concrete.createOrGetArena(); |
276 | size_t size_of_state = function->sizeOfData(); |
277 | AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData()); |
278 | function->create(place); |
279 | try |
280 | { |
281 | if (!protobuf.readAggregateFunction(function, place, arena)) |
282 | { |
283 | function->destroy(place); |
284 | return; |
285 | } |
286 | auto & container = column_concrete.getData(); |
287 | if (allow_add_row) |
288 | { |
289 | container.emplace_back(place); |
290 | row_added = true; |
291 | } |
292 | else |
293 | container.back() = place; |
294 | } |
295 | catch (...) |
296 | { |
297 | function->destroy(place); |
298 | throw; |
299 | } |
300 | } |
301 | |
302 | MutableColumnPtr DataTypeAggregateFunction::createColumn() const |
303 | { |
304 | return ColumnAggregateFunction::create(function); |
305 | } |
306 | |
307 | |
308 | /// Create empty state |
309 | Field DataTypeAggregateFunction::getDefault() const |
310 | { |
311 | Field field = AggregateFunctionStateData(); |
312 | field.get<AggregateFunctionStateData &>().name = getName(); |
313 | |
314 | AlignedBuffer place_buffer(function->sizeOfData(), function->alignOfData()); |
315 | AggregateDataPtr place = place_buffer.data(); |
316 | |
317 | function->create(place); |
318 | |
319 | try |
320 | { |
321 | WriteBufferFromString buffer_from_field(field.get<AggregateFunctionStateData &>().data); |
322 | function->serialize(place, buffer_from_field); |
323 | } |
324 | catch (...) |
325 | { |
326 | function->destroy(place); |
327 | throw; |
328 | } |
329 | |
330 | function->destroy(place); |
331 | |
332 | return field; |
333 | } |
334 | |
335 | |
336 | bool DataTypeAggregateFunction::equals(const IDataType & rhs) const |
337 | { |
338 | return typeid(rhs) == typeid(*this) && getName() == rhs.getName(); |
339 | } |
340 | |
341 | |
342 | static DataTypePtr create(const String & /*type_name*/, const ASTPtr & arguments) |
343 | { |
344 | String function_name; |
345 | AggregateFunctionPtr function; |
346 | DataTypes argument_types; |
347 | Array params_row; |
348 | |
349 | if (!arguments || arguments->children.empty()) |
350 | throw Exception("Data type AggregateFunction requires parameters: " |
351 | "name of aggregate function and list of data types for arguments" , ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
352 | |
353 | if (const auto * parametric = arguments->children[0]->as<ASTFunction>()) |
354 | { |
355 | if (parametric->parameters) |
356 | throw Exception("Unexpected level of parameters to aggregate function" , ErrorCodes::SYNTAX_ERROR); |
357 | function_name = parametric->name; |
358 | |
359 | const ASTs & parameters = parametric->arguments->children; |
360 | params_row.resize(parameters.size()); |
361 | |
362 | for (size_t i = 0; i < parameters.size(); ++i) |
363 | { |
364 | const auto * literal = parameters[i]->as<ASTLiteral>(); |
365 | if (!literal) |
366 | throw Exception("Parameters to aggregate functions must be literals" , |
367 | ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS); |
368 | |
369 | params_row[i] = literal->value; |
370 | } |
371 | } |
372 | else if (auto opt_name = tryGetIdentifierName(arguments->children[0])) |
373 | { |
374 | function_name = *opt_name; |
375 | } |
376 | else if (arguments->children[0]->as<ASTLiteral>()) |
377 | { |
378 | throw Exception("Aggregate function name for data type AggregateFunction must be passed as identifier (without quotes) or function" , |
379 | ErrorCodes::BAD_ARGUMENTS); |
380 | } |
381 | else |
382 | throw Exception("Unexpected AST element passed as aggregate function name for data type AggregateFunction. Must be identifier or function." , |
383 | ErrorCodes::BAD_ARGUMENTS); |
384 | |
385 | for (size_t i = 1; i < arguments->children.size(); ++i) |
386 | argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i])); |
387 | |
388 | if (function_name.empty()) |
389 | throw Exception("Logical error: empty name of aggregate function passed" , ErrorCodes::LOGICAL_ERROR); |
390 | |
391 | function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); |
392 | return std::make_shared<DataTypeAggregateFunction>(function, argument_types, params_row); |
393 | } |
394 | |
395 | void registerDataTypeAggregateFunction(DataTypeFactory & factory) |
396 | { |
397 | factory.registerDataType("AggregateFunction" , create); |
398 | } |
399 | |
400 | |
401 | } |
402 | |