1 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
2 | #include <AggregateFunctions/AggregateFunctionTopK.h> |
3 | #include <AggregateFunctions/Helpers.h> |
4 | #include <AggregateFunctions/FactoryHelpers.h> |
5 | #include <DataTypes/DataTypeDate.h> |
6 | #include <DataTypes/DataTypeDateTime.h> |
7 | #include "registerAggregateFunctions.h" |
8 | |
9 | #define TOP_K_MAX_SIZE 0xFFFFFF |
10 | |
11 | |
12 | namespace DB |
13 | { |
14 | |
15 | namespace ErrorCodes |
16 | { |
17 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
18 | extern const int ARGUMENT_OUT_OF_BOUND; |
19 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
20 | } |
21 | |
22 | |
23 | namespace |
24 | { |
25 | |
26 | /// Substitute return type for Date and DateTime |
27 | template <bool is_weighted> |
28 | class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted> |
29 | { |
30 | using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK; |
31 | DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); } |
32 | }; |
33 | |
34 | template <bool is_weighted> |
35 | class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted> |
36 | { |
37 | using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK; |
38 | DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); } |
39 | }; |
40 | |
41 | |
42 | template <bool is_weighted> |
43 | static IAggregateFunction * (const DataTypePtr & argument_type, UInt64 threshold, UInt64 load_factor, const Array & params) |
44 | { |
45 | WhichDataType which(argument_type); |
46 | if (which.idx == TypeIndex::Date) |
47 | return new AggregateFunctionTopKDate<is_weighted>(threshold, load_factor, {argument_type}, params); |
48 | if (which.idx == TypeIndex::DateTime) |
49 | return new AggregateFunctionTopKDateTime<is_weighted>(threshold, load_factor, {argument_type}, params); |
50 | |
51 | /// Check that we can use plain version of AggregateFunctionTopKGeneric |
52 | if (argument_type->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) |
53 | return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, load_factor, argument_type, params); |
54 | else |
55 | return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, load_factor, argument_type, params); |
56 | } |
57 | |
58 | |
59 | template <bool is_weighted> |
60 | AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types, const Array & params) |
61 | { |
62 | if (!is_weighted) |
63 | { |
64 | assertUnary(name, argument_types); |
65 | } |
66 | else |
67 | { |
68 | assertBinary(name, argument_types); |
69 | if (!isInteger(argument_types[1])) |
70 | throw Exception("The second argument for aggregate function 'topKWeighted' must have integer type" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
71 | } |
72 | |
73 | UInt64 threshold = 10; /// default values |
74 | UInt64 load_factor = 3; |
75 | |
76 | if (!params.empty()) |
77 | { |
78 | if (params.size() > 2) |
79 | throw Exception("Aggregate function " + name + " requires two parameters or less." , |
80 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
81 | |
82 | UInt64 k = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]); |
83 | if (params.size() == 2) |
84 | { |
85 | load_factor = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]); |
86 | |
87 | if (load_factor < 1) |
88 | throw Exception("Too small parameter for aggregate function " + name + ". Minimum: 1" , |
89 | ErrorCodes::ARGUMENT_OUT_OF_BOUND); |
90 | } |
91 | |
92 | if (k > TOP_K_MAX_SIZE) |
93 | throw Exception("Too large parameter for aggregate function " + name + ". Maximum: " + toString(TOP_K_MAX_SIZE), |
94 | ErrorCodes::ARGUMENT_OUT_OF_BOUND); |
95 | |
96 | if (k == 0) |
97 | throw Exception("Parameter 0 is illegal for aggregate function " + name, |
98 | ErrorCodes::ARGUMENT_OUT_OF_BOUND); |
99 | |
100 | threshold = k; |
101 | } |
102 | |
103 | AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(*argument_types[0], threshold, load_factor, argument_types, params)); |
104 | |
105 | if (!res) |
106 | res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types[0], threshold, load_factor, params)); |
107 | |
108 | if (!res) |
109 | throw Exception("Illegal type " + argument_types[0]->getName() + |
110 | " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
111 | |
112 | return res; |
113 | } |
114 | |
115 | } |
116 | |
117 | void registerAggregateFunctionTopK(AggregateFunctionFactory & factory) |
118 | { |
119 | factory.registerFunction("topK" , createAggregateFunctionTopK<false>); |
120 | factory.registerFunction("topKWeighted" , createAggregateFunctionTopK<true>); |
121 | } |
122 | |
123 | } |
124 | |