1#include <Common/FieldVisitors.h>
2#include <Common/typeid_cast.h>
3
4#include <IO/ReadHelpers.h>
5
6#include <Columns/ColumnAggregateFunction.h>
7
8#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
9#include <DataTypes/DataTypeLowCardinality.h>
10#include <DataTypes/DataTypeTuple.h>
11#include <DataTypes/DataTypeArray.h>
12#include <DataTypes/DataTypeFactory.h>
13
14#include <AggregateFunctions/AggregateFunctionFactory.h>
15#include <Parsers/ASTFunction.h>
16#include <Parsers/ASTLiteral.h>
17#include <Parsers/ASTIdentifier.h>
18
19#include <boost/algorithm/string/join.hpp>
20
21namespace DB
22{
23
24namespace ErrorCodes
25{
26 extern const int SYNTAX_ERROR;
27 extern const int BAD_ARGUMENTS;
28 extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
29 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
30 extern const int LOGICAL_ERROR;
31}
32
33static const std::vector<String> supported_functions{"any", "anyLast", "min", "max", "sum", "groupBitAnd", "groupBitOr", "groupBitXor"};
34
35
36String DataTypeCustomSimpleAggregateFunction::getName() const
37{
38 std::stringstream stream;
39 stream << "SimpleAggregateFunction(" << function->getName();
40
41 if (!parameters.empty())
42 {
43 stream << "(";
44 for (size_t i = 0; i < parameters.size(); ++i)
45 {
46 if (i)
47 stream << ", ";
48 stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]);
49 }
50 stream << ")";
51 }
52
53 for (const auto & argument_type : argument_types)
54 stream << ", " << argument_type->getName();
55
56 stream << ")";
57 return stream.str();
58}
59
60
61static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const String & /*type_name*/, const ASTPtr & arguments)
62{
63 String function_name;
64 AggregateFunctionPtr function;
65 DataTypes argument_types;
66 Array params_row;
67
68 if (!arguments || arguments->children.empty())
69 throw Exception("Data type SimpleAggregateFunction requires parameters: "
70 "name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
71
72 if (const ASTFunction * parametric = arguments->children[0]->as<ASTFunction>())
73 {
74 if (parametric->parameters)
75 throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR);
76 function_name = parametric->name;
77
78 const ASTs & parameters = parametric->arguments->as<ASTExpressionList &>().children;
79 params_row.resize(parameters.size());
80
81 for (size_t i = 0; i < parameters.size(); ++i)
82 {
83 const ASTLiteral * lit = parameters[i]->as<ASTLiteral>();
84 if (!lit)
85 throw Exception("Parameters to aggregate functions must be literals",
86 ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
87
88 params_row[i] = lit->value;
89 }
90 }
91 else if (auto opt_name = tryGetIdentifierName(arguments->children[0]))
92 {
93 function_name = *opt_name;
94 }
95 else if (arguments->children[0]->as<ASTLiteral>())
96 {
97 throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function",
98 ErrorCodes::BAD_ARGUMENTS);
99 }
100 else
101 throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function.",
102 ErrorCodes::BAD_ARGUMENTS);
103
104 for (size_t i = 1; i < arguments->children.size(); ++i)
105 argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i]));
106
107 if (function_name.empty())
108 throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
109
110 function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row);
111
112 // check function
113 if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))
114 {
115 throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","),
116 ErrorCodes::BAD_ARGUMENTS);
117 }
118
119 DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());
120
121 if (!function->getReturnType()->equals(*removeLowCardinality(storage_type)))
122 {
123 throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(),
124 ErrorCodes::BAD_ARGUMENTS);
125 }
126
127 DataTypeCustomNamePtr custom_name = std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, argument_types, params_row);
128
129 return std::make_pair(storage_type, std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
130}
131
132void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory)
133{
134 factory.registerDataTypeCustom("SimpleAggregateFunction", create);
135}
136
137}
138