1 | #include <Common/FieldVisitors.h> |
2 | #include <Common/typeid_cast.h> |
3 | |
4 | #include <IO/ReadHelpers.h> |
5 | |
6 | #include <Columns/ColumnAggregateFunction.h> |
7 | |
8 | #include <DataTypes/DataTypeCustomSimpleAggregateFunction.h> |
9 | #include <DataTypes/DataTypeLowCardinality.h> |
10 | #include <DataTypes/DataTypeTuple.h> |
11 | #include <DataTypes/DataTypeArray.h> |
12 | #include <DataTypes/DataTypeFactory.h> |
13 | |
14 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
15 | #include <Parsers/ASTFunction.h> |
16 | #include <Parsers/ASTLiteral.h> |
17 | #include <Parsers/ASTIdentifier.h> |
18 | |
19 | #include <boost/algorithm/string/join.hpp> |
20 | |
21 | namespace DB |
22 | { |
23 | |
24 | namespace ErrorCodes |
25 | { |
26 | extern const int SYNTAX_ERROR; |
27 | extern const int BAD_ARGUMENTS; |
28 | extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS; |
29 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
30 | extern const int LOGICAL_ERROR; |
31 | } |
32 | |
33 | static const std::vector<String> supported_functions{"any" , "anyLast" , "min" , "max" , "sum" , "groupBitAnd" , "groupBitOr" , "groupBitXor" }; |
34 | |
35 | |
36 | String DataTypeCustomSimpleAggregateFunction::getName() const |
37 | { |
38 | std::stringstream stream; |
39 | stream << "SimpleAggregateFunction(" << function->getName(); |
40 | |
41 | if (!parameters.empty()) |
42 | { |
43 | stream << "(" ; |
44 | for (size_t i = 0; i < parameters.size(); ++i) |
45 | { |
46 | if (i) |
47 | stream << ", " ; |
48 | stream << applyVisitor(DB::FieldVisitorToString(), parameters[i]); |
49 | } |
50 | stream << ")" ; |
51 | } |
52 | |
53 | for (const auto & argument_type : argument_types) |
54 | stream << ", " << argument_type->getName(); |
55 | |
56 | stream << ")" ; |
57 | return stream.str(); |
58 | } |
59 | |
60 | |
61 | static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const String & /*type_name*/, const ASTPtr & arguments) |
62 | { |
63 | String function_name; |
64 | AggregateFunctionPtr function; |
65 | DataTypes argument_types; |
66 | Array params_row; |
67 | |
68 | if (!arguments || arguments->children.empty()) |
69 | throw Exception("Data type SimpleAggregateFunction requires parameters: " |
70 | "name of aggregate function and list of data types for arguments" , ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
71 | |
72 | if (const ASTFunction * parametric = arguments->children[0]->as<ASTFunction>()) |
73 | { |
74 | if (parametric->parameters) |
75 | throw Exception("Unexpected level of parameters to aggregate function" , ErrorCodes::SYNTAX_ERROR); |
76 | function_name = parametric->name; |
77 | |
78 | const ASTs & parameters = parametric->arguments->as<ASTExpressionList &>().children; |
79 | params_row.resize(parameters.size()); |
80 | |
81 | for (size_t i = 0; i < parameters.size(); ++i) |
82 | { |
83 | const ASTLiteral * lit = parameters[i]->as<ASTLiteral>(); |
84 | if (!lit) |
85 | throw Exception("Parameters to aggregate functions must be literals" , |
86 | ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS); |
87 | |
88 | params_row[i] = lit->value; |
89 | } |
90 | } |
91 | else if (auto opt_name = tryGetIdentifierName(arguments->children[0])) |
92 | { |
93 | function_name = *opt_name; |
94 | } |
95 | else if (arguments->children[0]->as<ASTLiteral>()) |
96 | { |
97 | throw Exception("Aggregate function name for data type SimpleAggregateFunction must be passed as identifier (without quotes) or function" , |
98 | ErrorCodes::BAD_ARGUMENTS); |
99 | } |
100 | else |
101 | throw Exception("Unexpected AST element passed as aggregate function name for data type SimpleAggregateFunction. Must be identifier or function." , |
102 | ErrorCodes::BAD_ARGUMENTS); |
103 | |
104 | for (size_t i = 1; i < arguments->children.size(); ++i) |
105 | argument_types.push_back(DataTypeFactory::instance().get(arguments->children[i])); |
106 | |
107 | if (function_name.empty()) |
108 | throw Exception("Logical error: empty name of aggregate function passed" , ErrorCodes::LOGICAL_ERROR); |
109 | |
110 | function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); |
111 | |
112 | // check function |
113 | if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) |
114 | { |
115 | throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, "," ), |
116 | ErrorCodes::BAD_ARGUMENTS); |
117 | } |
118 | |
119 | DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName()); |
120 | |
121 | if (!function->getReturnType()->equals(*removeLowCardinality(storage_type))) |
122 | { |
123 | throw Exception("Incompatible data types between aggregate function '" + function->getName() + "' which returns " + function->getReturnType()->getName() + " and column storage type " + storage_type->getName(), |
124 | ErrorCodes::BAD_ARGUMENTS); |
125 | } |
126 | |
127 | DataTypeCustomNamePtr custom_name = std::make_unique<DataTypeCustomSimpleAggregateFunction>(function, argument_types, params_row); |
128 | |
129 | return std::make_pair(storage_type, std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr)); |
130 | } |
131 | |
132 | void registerDataTypeDomainSimpleAggregateFunction(DataTypeFactory & factory) |
133 | { |
134 | factory.registerDataTypeCustom("SimpleAggregateFunction" , create); |
135 | } |
136 | |
137 | } |
138 | |