1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
2 | #include <AggregateFunctions/AggregateFunctionGroupArray.h> |
3 | #include <AggregateFunctions/Helpers.h> |
4 | #include <AggregateFunctions/FactoryHelpers.h> |
5 | #include <DataTypes/DataTypeDate.h> |
6 | #include <DataTypes/DataTypeDateTime.h> |
7 | #include "registerAggregateFunctions.h" |
8 | |
9 | |
10 | namespace DB |
11 | { |
12 | |
13 | namespace ErrorCodes |
14 | { |
15 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
16 | extern const int BAD_ARGUMENTS; |
17 | } |
18 | |
19 | namespace |
20 | { |
21 | |
22 | template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename ... TArgs> |
23 | static IAggregateFunction * createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args) |
24 | { |
25 | WhichDataType which(argument_type); |
26 | if (which.idx == TypeIndex::Date) return new AggregateFunctionTemplate<UInt16, Data>(std::forward<TArgs>(args)...); |
27 | if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTemplate<UInt32, Data>(std::forward<TArgs>(args)...); |
28 | return createWithNumericType<AggregateFunctionTemplate, Data, TArgs...>(argument_type, std::forward<TArgs>(args)...); |
29 | } |
30 | |
31 | |
32 | template <typename has_limit, typename ... TArgs> |
33 | inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, TArgs ... args) |
34 | { |
35 | if (auto res = createWithNumericOrTimeType<GroupArrayNumericImpl, has_limit>(*argument_type, argument_type, std::forward<TArgs>(args)...)) |
36 | return AggregateFunctionPtr(res); |
37 | |
38 | WhichDataType which(argument_type); |
39 | if (which.idx == TypeIndex::String) |
40 | return std::make_shared<GroupArrayGeneralListImpl<GroupArrayListNodeString, has_limit::value>>(argument_type, std::forward<TArgs>(args)...); |
41 | |
42 | return std::make_shared<GroupArrayGeneralListImpl<GroupArrayListNodeGeneral, has_limit::value>>(argument_type, std::forward<TArgs>(args)...); |
43 | } |
44 | |
45 | |
46 | static AggregateFunctionPtr createAggregateFunctionGroupArray(const std::string & name, const DataTypes & argument_types, const Array & parameters) |
47 | { |
48 | assertUnary(name, argument_types); |
49 | |
50 | bool limit_size = false; |
51 | UInt64 max_elems = std::numeric_limits<UInt64>::max(); |
52 | |
53 | if (parameters.empty()) |
54 | { |
55 | // no limit |
56 | } |
57 | else if (parameters.size() == 1) |
58 | { |
59 | auto type = parameters[0].getType(); |
60 | if (type != Field::Types::Int64 && type != Field::Types::UInt64) |
61 | throw Exception("Parameter for aggregate function " + name + " should be positive number" , ErrorCodes::BAD_ARGUMENTS); |
62 | |
63 | if ((type == Field::Types::Int64 && parameters[0].get<Int64>() < 0) || |
64 | (type == Field::Types::UInt64 && parameters[0].get<UInt64>() == 0)) |
65 | throw Exception("Parameter for aggregate function " + name + " should be positive number" , ErrorCodes::BAD_ARGUMENTS); |
66 | |
67 | limit_size = true; |
68 | max_elems = parameters[0].get<UInt64>(); |
69 | } |
70 | else |
71 | throw Exception("Incorrect number of parameters for aggregate function " + name + ", should be 0 or 1" , |
72 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
73 | |
74 | if (!limit_size) |
75 | return createAggregateFunctionGroupArrayImpl<std::false_type>(argument_types[0]); |
76 | else |
77 | return createAggregateFunctionGroupArrayImpl<std::true_type>(argument_types[0], max_elems); |
78 | } |
79 | |
80 | } |
81 | |
82 | |
83 | void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory) |
84 | { |
85 | factory.registerFunction("groupArray" , createAggregateFunctionGroupArray); |
86 | } |
87 | |
88 | } |
89 | |