1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/AggregateFunctionTopK.h>
3#include <AggregateFunctions/Helpers.h>
4#include <AggregateFunctions/FactoryHelpers.h>
5#include <DataTypes/DataTypeDate.h>
6#include <DataTypes/DataTypeDateTime.h>
7#include "registerAggregateFunctions.h"
8
9#define TOP_K_MAX_SIZE 0xFFFFFF
10
11
12namespace DB
13{
14
15namespace ErrorCodes
16{
17 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
18 extern const int ARGUMENT_OUT_OF_BOUND;
19 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
20}
21
22
23namespace
24{
25
26/// Substitute return type for Date and DateTime
27template <bool is_weighted>
28class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
29{
30 using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK;
31 DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
32};
33
34template <bool is_weighted>
35class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>
36{
37 using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK;
38 DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
39};
40
41
42template <bool is_weighted>
43static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold, UInt64 load_factor, const Array & params)
44{
45 WhichDataType which(argument_type);
46 if (which.idx == TypeIndex::Date)
47 return new AggregateFunctionTopKDate<is_weighted>(threshold, load_factor, {argument_type}, params);
48 if (which.idx == TypeIndex::DateTime)
49 return new AggregateFunctionTopKDateTime<is_weighted>(threshold, load_factor, {argument_type}, params);
50
51 /// Check that we can use plain version of AggregateFunctionTopKGeneric
52 if (argument_type->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
53 return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, load_factor, argument_type, params);
54 else
55 return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, load_factor, argument_type, params);
56}
57
58
59template <bool is_weighted>
60AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types, const Array & params)
61{
62 if (!is_weighted)
63 {
64 assertUnary(name, argument_types);
65 }
66 else
67 {
68 assertBinary(name, argument_types);
69 if (!isInteger(argument_types[1]))
70 throw Exception("The second argument for aggregate function 'topKWeighted' must have integer type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
71 }
72
73 UInt64 threshold = 10; /// default values
74 UInt64 load_factor = 3;
75
76 if (!params.empty())
77 {
78 if (params.size() > 2)
79 throw Exception("Aggregate function " + name + " requires two parameters or less.",
80 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
81
82 UInt64 k = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
83 if (params.size() == 2)
84 {
85 load_factor = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
86
87 if (load_factor < 1)
88 throw Exception("Too small parameter for aggregate function " + name + ". Minimum: 1",
89 ErrorCodes::ARGUMENT_OUT_OF_BOUND);
90 }
91
92 if (k > TOP_K_MAX_SIZE)
93 throw Exception("Too large parameter for aggregate function " + name + ". Maximum: " + toString(TOP_K_MAX_SIZE),
94 ErrorCodes::ARGUMENT_OUT_OF_BOUND);
95
96 if (k == 0)
97 throw Exception("Parameter 0 is illegal for aggregate function " + name,
98 ErrorCodes::ARGUMENT_OUT_OF_BOUND);
99
100 threshold = k;
101 }
102
103 AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(*argument_types[0], threshold, load_factor, argument_types, params));
104
105 if (!res)
106 res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types[0], threshold, load_factor, params));
107
108 if (!res)
109 throw Exception("Illegal type " + argument_types[0]->getName() +
110 " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
111
112 return res;
113}
114
115}
116
117void registerAggregateFunctionTopK(AggregateFunctionFactory & factory)
118{
119 factory.registerFunction("topK", createAggregateFunctionTopK<false>);
120 factory.registerFunction("topKWeighted", createAggregateFunctionTopK<true>);
121}
122
123}
124