1 | #pragma once |
2 | |
3 | #include <AggregateFunctions/IAggregateFunction.h> |
4 | #include <Common/IFactoryWithAliases.h> |
5 | |
6 | |
7 | #include <functional> |
8 | #include <memory> |
9 | #include <string> |
10 | #include <unordered_map> |
11 | #include <vector> |
12 | |
13 | |
14 | namespace DB |
15 | { |
16 | |
17 | class Context; |
18 | class IDataType; |
19 | |
20 | using DataTypePtr = std::shared_ptr<const IDataType>; |
21 | using DataTypes = std::vector<DataTypePtr>; |
22 | |
23 | /** Creator have arguments: name of aggregate function, types of arguments, values of parameters. |
24 | * Parameters are for "parametric" aggregate functions. |
25 | * For example, in quantileWeighted(0.9)(x, weight), 0.9 is "parameter" and x, weight are "arguments". |
26 | */ |
27 | using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>; |
28 | |
29 | |
30 | /** Creates an aggregate function by name. |
31 | */ |
32 | class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases<AggregateFunctionCreator> |
33 | { |
34 | public: |
35 | |
36 | static AggregateFunctionFactory & instance(); |
37 | |
38 | /// Register a function by its name. |
39 | /// No locking, you must register all functions before usage of get. |
40 | void registerFunction( |
41 | const String & name, |
42 | Creator creator, |
43 | CaseSensitiveness case_sensitiveness = CaseSensitive); |
44 | |
45 | /// Throws an exception if not found. |
46 | AggregateFunctionPtr get( |
47 | const String & name, |
48 | const DataTypes & argument_types, |
49 | const Array & parameters = {}, |
50 | int recursion_level = 0) const; |
51 | |
52 | /// Returns nullptr if not found. |
53 | AggregateFunctionPtr tryGet( |
54 | const String & name, |
55 | const DataTypes & argument_types, |
56 | const Array & parameters = {}) const; |
57 | |
58 | bool isAggregateFunctionName(const String & name, int recursion_level = 0) const; |
59 | |
60 | private: |
61 | AggregateFunctionPtr getImpl( |
62 | const String & name, |
63 | const DataTypes & argument_types, |
64 | const Array & parameters, |
65 | int recursion_level) const; |
66 | |
67 | private: |
68 | using AggregateFunctions = std::unordered_map<String, Creator>; |
69 | |
70 | AggregateFunctions aggregate_functions; |
71 | |
72 | /// Case insensitive aggregate functions will be additionally added here with lowercased name. |
73 | AggregateFunctions case_insensitive_aggregate_functions; |
74 | |
75 | const AggregateFunctions & getCreatorMap() const override { return aggregate_functions; } |
76 | |
77 | const AggregateFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_aggregate_functions; } |
78 | |
79 | String getFactoryName() const override { return "AggregateFunctionFactory" ; } |
80 | |
81 | }; |
82 | |
83 | } |
84 | |