1#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
2
3#include <AggregateFunctions/AggregateFunctionFactory.h>
4#include <AggregateFunctions/FactoryHelpers.h>
5
6#include <Core/TypeListNumber.h>
7#include "registerAggregateFunctions.h"
8
9namespace DB
10{
11
12namespace
13{
14
15AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
16 const String & name,
17 const DataTypes & arguments,
18 const Array & params
19)
20{
21 assertNoParameters(name, params);
22 assertBinary(name, arguments);
23
24 const IDataType * x_arg = arguments.front().get();
25 WhichDataType which_x = x_arg;
26
27 const IDataType * y_arg = arguments.back().get();
28 WhichDataType which_y = y_arg;
29
30
31 #define FOR_LEASTSQR_TYPES_2(M, T) \
32 M(T, UInt8) \
33 M(T, UInt16) \
34 M(T, UInt32) \
35 M(T, UInt64) \
36 M(T, Int8) \
37 M(T, Int16) \
38 M(T, Int32) \
39 M(T, Int64) \
40 M(T, Float32) \
41 M(T, Float64)
42 #define FOR_LEASTSQR_TYPES(M) \
43 FOR_LEASTSQR_TYPES_2(M, UInt8) \
44 FOR_LEASTSQR_TYPES_2(M, UInt16) \
45 FOR_LEASTSQR_TYPES_2(M, UInt32) \
46 FOR_LEASTSQR_TYPES_2(M, UInt64) \
47 FOR_LEASTSQR_TYPES_2(M, Int8) \
48 FOR_LEASTSQR_TYPES_2(M, Int16) \
49 FOR_LEASTSQR_TYPES_2(M, Int32) \
50 FOR_LEASTSQR_TYPES_2(M, Int64) \
51 FOR_LEASTSQR_TYPES_2(M, Float32) \
52 FOR_LEASTSQR_TYPES_2(M, Float64)
53 #define DISPATCH(T1, T2) \
54 if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
55 return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>( \
56 arguments, \
57 params \
58 );
59
60 FOR_LEASTSQR_TYPES(DISPATCH)
61
62 #undef FOR_LEASTSQR_TYPES_2
63 #undef FOR_LEASTSQR_TYPES
64 #undef DISPATCH
65
66 throw Exception(
67 "Illegal types ("
68 + x_arg->getName() + ", " + y_arg->getName()
69 + ") of arguments of aggregate function " + name
70 + ", must be Native Ints, Native UInts or Floats",
71 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
72 );
73}
74
75}
76
77void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory)
78{
79 factory.registerFunction("simpleLinearRegression", createAggregateFunctionSimpleLinearRegression);
80}
81
82}
83