1#pragma once
2
3#include <AggregateFunctions/IAggregateFunction.h>
4#include <Columns/ColumnVector.h>
5#include <Columns/ColumnTuple.h>
6#include <Common/assert_cast.h>
7#include <DataTypes/DataTypeNullable.h>
8#include <DataTypes/DataTypesNumber.h>
9#include <DataTypes/DataTypeTuple.h>
10#include <IO/ReadHelpers.h>
11#include <IO/WriteHelpers.h>
12#include <limits>
13
14namespace DB
15{
16
17namespace ErrorCodes
18{
19 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
20}
21
22template <typename X, typename Y, typename Ret>
23struct AggregateFunctionSimpleLinearRegressionData final
24{
25 size_t count = 0;
26 Ret sum_x = 0;
27 Ret sum_y = 0;
28 Ret sum_xx = 0;
29 Ret sum_xy = 0;
30
31 void add(X x, Y y)
32 {
33 count += 1;
34 sum_x += x;
35 sum_y += y;
36 sum_xx += x * x;
37 sum_xy += x * y;
38 }
39
40 void merge(const AggregateFunctionSimpleLinearRegressionData & other)
41 {
42 count += other.count;
43 sum_x += other.sum_x;
44 sum_y += other.sum_y;
45 sum_xx += other.sum_xx;
46 sum_xy += other.sum_xy;
47 }
48
49 void serialize(WriteBuffer & buf) const
50 {
51 writeBinary(count, buf);
52 writeBinary(sum_x, buf);
53 writeBinary(sum_y, buf);
54 writeBinary(sum_xx, buf);
55 writeBinary(sum_xy, buf);
56 }
57
58 void deserialize(ReadBuffer & buf)
59 {
60 readBinary(count, buf);
61 readBinary(sum_x, buf);
62 readBinary(sum_y, buf);
63 readBinary(sum_xx, buf);
64 readBinary(sum_xy, buf);
65 }
66
67 Ret getK() const
68 {
69 Ret divisor = sum_xx * count - sum_x * sum_x;
70
71 if (divisor == 0)
72 return std::numeric_limits<Ret>::quiet_NaN();
73
74 return (sum_xy * count - sum_x * sum_y) / divisor;
75 }
76
77 Ret getB(Ret k) const
78 {
79 if (count == 0)
80 return std::numeric_limits<Ret>::quiet_NaN();
81
82 return (sum_y - k * sum_x) / count;
83 }
84};
85
86/// Calculates simple linear regression parameters.
87/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
88template <typename X, typename Y, typename Ret = Float64>
89class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
90 AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
91 AggregateFunctionSimpleLinearRegression<X, Y, Ret>
92>
93{
94public:
95 AggregateFunctionSimpleLinearRegression(
96 const DataTypes & arguments,
97 const Array & params
98 ):
99 IAggregateFunctionDataHelper<
100 AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
101 AggregateFunctionSimpleLinearRegression<X, Y, Ret>
102 > {arguments, params}
103 {
104 // notice: arguments has been checked before
105 }
106
107 String getName() const override
108 {
109 return "simpleLinearRegression";
110 }
111
112 void add(
113 AggregateDataPtr place,
114 const IColumn ** columns,
115 size_t row_num,
116 Arena *
117 ) const override
118 {
119 auto col_x = assert_cast<const ColumnVector<X> *>(columns[0]);
120 auto col_y = assert_cast<const ColumnVector<Y> *>(columns[1]);
121
122 X x = col_x->getData()[row_num];
123 Y y = col_y->getData()[row_num];
124
125 this->data(place).add(x, y);
126 }
127
128 void merge(
129 AggregateDataPtr place,
130 ConstAggregateDataPtr rhs, Arena *
131 ) const override
132 {
133 this->data(place).merge(this->data(rhs));
134 }
135
136 void serialize(
137 ConstAggregateDataPtr place,
138 WriteBuffer & buf
139 ) const override
140 {
141 this->data(place).serialize(buf);
142 }
143
144 void deserialize(
145 AggregateDataPtr place,
146 ReadBuffer & buf, Arena *
147 ) const override
148 {
149 this->data(place).deserialize(buf);
150 }
151
152 DataTypePtr getReturnType() const override
153 {
154 DataTypes types
155 {
156 std::make_shared<DataTypeNumber<Ret>>(),
157 std::make_shared<DataTypeNumber<Ret>>(),
158 };
159
160 Strings names
161 {
162 "k",
163 "b",
164 };
165
166 return std::make_shared<DataTypeTuple>(
167 std::move(types),
168 std::move(names)
169 );
170 }
171
172 void insertResultInto(
173 ConstAggregateDataPtr place,
174 IColumn & to
175 ) const override
176 {
177 Ret k = this->data(place).getK();
178 Ret b = this->data(place).getB(k);
179
180 auto & col_tuple = assert_cast<ColumnTuple &>(to);
181 auto & col_k = assert_cast<ColumnVector<Ret> &>(col_tuple.getColumn(0));
182 auto & col_b = assert_cast<ColumnVector<Ret> &>(col_tuple.getColumn(1));
183
184 col_k.getData().push_back(k);
185 col_b.getData().push_back(b);
186 }
187};
188
189}
190