1 | #pragma once |
2 | |
3 | #include <AggregateFunctions/IAggregateFunction.h> |
4 | #include <Columns/ColumnVector.h> |
5 | #include <Columns/ColumnTuple.h> |
6 | #include <Common/assert_cast.h> |
7 | #include <DataTypes/DataTypeNullable.h> |
8 | #include <DataTypes/DataTypesNumber.h> |
9 | #include <DataTypes/DataTypeTuple.h> |
10 | #include <IO/ReadHelpers.h> |
11 | #include <IO/WriteHelpers.h> |
12 | #include <limits> |
13 | |
14 | namespace DB |
15 | { |
16 | |
17 | namespace ErrorCodes |
18 | { |
19 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
20 | } |
21 | |
22 | template <typename X, typename Y, typename Ret> |
23 | struct AggregateFunctionSimpleLinearRegressionData final |
24 | { |
25 | size_t count = 0; |
26 | Ret sum_x = 0; |
27 | Ret sum_y = 0; |
28 | Ret sum_xx = 0; |
29 | Ret sum_xy = 0; |
30 | |
31 | void add(X x, Y y) |
32 | { |
33 | count += 1; |
34 | sum_x += x; |
35 | sum_y += y; |
36 | sum_xx += x * x; |
37 | sum_xy += x * y; |
38 | } |
39 | |
40 | void merge(const AggregateFunctionSimpleLinearRegressionData & other) |
41 | { |
42 | count += other.count; |
43 | sum_x += other.sum_x; |
44 | sum_y += other.sum_y; |
45 | sum_xx += other.sum_xx; |
46 | sum_xy += other.sum_xy; |
47 | } |
48 | |
49 | void serialize(WriteBuffer & buf) const |
50 | { |
51 | writeBinary(count, buf); |
52 | writeBinary(sum_x, buf); |
53 | writeBinary(sum_y, buf); |
54 | writeBinary(sum_xx, buf); |
55 | writeBinary(sum_xy, buf); |
56 | } |
57 | |
58 | void deserialize(ReadBuffer & buf) |
59 | { |
60 | readBinary(count, buf); |
61 | readBinary(sum_x, buf); |
62 | readBinary(sum_y, buf); |
63 | readBinary(sum_xx, buf); |
64 | readBinary(sum_xy, buf); |
65 | } |
66 | |
67 | Ret getK() const |
68 | { |
69 | Ret divisor = sum_xx * count - sum_x * sum_x; |
70 | |
71 | if (divisor == 0) |
72 | return std::numeric_limits<Ret>::quiet_NaN(); |
73 | |
74 | return (sum_xy * count - sum_x * sum_y) / divisor; |
75 | } |
76 | |
77 | Ret getB(Ret k) const |
78 | { |
79 | if (count == 0) |
80 | return std::numeric_limits<Ret>::quiet_NaN(); |
81 | |
82 | return (sum_y - k * sum_x) / count; |
83 | } |
84 | }; |
85 | |
86 | /// Calculates simple linear regression parameters. |
87 | /// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation. |
88 | template <typename X, typename Y, typename Ret = Float64> |
89 | class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper< |
90 | AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>, |
91 | AggregateFunctionSimpleLinearRegression<X, Y, Ret> |
92 | > |
93 | { |
94 | public: |
95 | AggregateFunctionSimpleLinearRegression( |
96 | const DataTypes & arguments, |
97 | const Array & params |
98 | ): |
99 | IAggregateFunctionDataHelper< |
100 | AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>, |
101 | AggregateFunctionSimpleLinearRegression<X, Y, Ret> |
102 | > {arguments, params} |
103 | { |
104 | // notice: arguments has been checked before |
105 | } |
106 | |
107 | String getName() const override |
108 | { |
109 | return "simpleLinearRegression" ; |
110 | } |
111 | |
112 | void add( |
113 | AggregateDataPtr place, |
114 | const IColumn ** columns, |
115 | size_t row_num, |
116 | Arena * |
117 | ) const override |
118 | { |
119 | auto col_x = assert_cast<const ColumnVector<X> *>(columns[0]); |
120 | auto col_y = assert_cast<const ColumnVector<Y> *>(columns[1]); |
121 | |
122 | X x = col_x->getData()[row_num]; |
123 | Y y = col_y->getData()[row_num]; |
124 | |
125 | this->data(place).add(x, y); |
126 | } |
127 | |
128 | void merge( |
129 | AggregateDataPtr place, |
130 | ConstAggregateDataPtr rhs, Arena * |
131 | ) const override |
132 | { |
133 | this->data(place).merge(this->data(rhs)); |
134 | } |
135 | |
136 | void serialize( |
137 | ConstAggregateDataPtr place, |
138 | WriteBuffer & buf |
139 | ) const override |
140 | { |
141 | this->data(place).serialize(buf); |
142 | } |
143 | |
144 | void deserialize( |
145 | AggregateDataPtr place, |
146 | ReadBuffer & buf, Arena * |
147 | ) const override |
148 | { |
149 | this->data(place).deserialize(buf); |
150 | } |
151 | |
152 | DataTypePtr getReturnType() const override |
153 | { |
154 | DataTypes types |
155 | { |
156 | std::make_shared<DataTypeNumber<Ret>>(), |
157 | std::make_shared<DataTypeNumber<Ret>>(), |
158 | }; |
159 | |
160 | Strings names |
161 | { |
162 | "k" , |
163 | "b" , |
164 | }; |
165 | |
166 | return std::make_shared<DataTypeTuple>( |
167 | std::move(types), |
168 | std::move(names) |
169 | ); |
170 | } |
171 | |
172 | void insertResultInto( |
173 | ConstAggregateDataPtr place, |
174 | IColumn & to |
175 | ) const override |
176 | { |
177 | Ret k = this->data(place).getK(); |
178 | Ret b = this->data(place).getB(k); |
179 | |
180 | auto & col_tuple = assert_cast<ColumnTuple &>(to); |
181 | auto & col_k = assert_cast<ColumnVector<Ret> &>(col_tuple.getColumn(0)); |
182 | auto & col_b = assert_cast<ColumnVector<Ret> &>(col_tuple.getColumn(1)); |
183 | |
184 | col_k.getData().push_back(k); |
185 | col_b.getData().push_back(b); |
186 | } |
187 | }; |
188 | |
189 | } |
190 | |