1 | #include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h> |
2 | |
3 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
4 | #include <AggregateFunctions/FactoryHelpers.h> |
5 | |
6 | #include <Core/TypeListNumber.h> |
7 | #include "registerAggregateFunctions.h" |
8 | |
9 | namespace DB |
10 | { |
11 | |
12 | namespace |
13 | { |
14 | |
15 | AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression( |
16 | const String & name, |
17 | const DataTypes & arguments, |
18 | const Array & params |
19 | ) |
20 | { |
21 | assertNoParameters(name, params); |
22 | assertBinary(name, arguments); |
23 | |
24 | const IDataType * x_arg = arguments.front().get(); |
25 | WhichDataType which_x = x_arg; |
26 | |
27 | const IDataType * y_arg = arguments.back().get(); |
28 | WhichDataType which_y = y_arg; |
29 | |
30 | |
31 | #define FOR_LEASTSQR_TYPES_2(M, T) \ |
32 | M(T, UInt8) \ |
33 | M(T, UInt16) \ |
34 | M(T, UInt32) \ |
35 | M(T, UInt64) \ |
36 | M(T, Int8) \ |
37 | M(T, Int16) \ |
38 | M(T, Int32) \ |
39 | M(T, Int64) \ |
40 | M(T, Float32) \ |
41 | M(T, Float64) |
42 | #define FOR_LEASTSQR_TYPES(M) \ |
43 | FOR_LEASTSQR_TYPES_2(M, UInt8) \ |
44 | FOR_LEASTSQR_TYPES_2(M, UInt16) \ |
45 | FOR_LEASTSQR_TYPES_2(M, UInt32) \ |
46 | FOR_LEASTSQR_TYPES_2(M, UInt64) \ |
47 | FOR_LEASTSQR_TYPES_2(M, Int8) \ |
48 | FOR_LEASTSQR_TYPES_2(M, Int16) \ |
49 | FOR_LEASTSQR_TYPES_2(M, Int32) \ |
50 | FOR_LEASTSQR_TYPES_2(M, Int64) \ |
51 | FOR_LEASTSQR_TYPES_2(M, Float32) \ |
52 | FOR_LEASTSQR_TYPES_2(M, Float64) |
53 | #define DISPATCH(T1, T2) \ |
54 | if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \ |
55 | return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>( \ |
56 | arguments, \ |
57 | params \ |
58 | ); |
59 | |
60 | FOR_LEASTSQR_TYPES(DISPATCH) |
61 | |
62 | #undef FOR_LEASTSQR_TYPES_2 |
63 | #undef FOR_LEASTSQR_TYPES |
64 | #undef DISPATCH |
65 | |
66 | throw Exception( |
67 | "Illegal types (" |
68 | + x_arg->getName() + ", " + y_arg->getName() |
69 | + ") of arguments of aggregate function " + name |
70 | + ", must be Native Ints, Native UInts or Floats" , |
71 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT |
72 | ); |
73 | } |
74 | |
75 | } |
76 | |
77 | void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory) |
78 | { |
79 | factory.registerFunction("simpleLinearRegression" , createAggregateFunctionSimpleLinearRegression); |
80 | } |
81 | |
82 | } |
83 | |