1 | #pragma once |
2 | |
3 | #include <IO/WriteHelpers.h> |
4 | #include <IO/ReadHelpers.h> |
5 | |
6 | #include <DataTypes/DataTypeArray.h> |
7 | #include <DataTypes/DataTypesNumber.h> |
8 | #include <DataTypes/DataTypeString.h> |
9 | |
10 | #include <Columns/ColumnArray.h> |
11 | |
12 | #include <Common/SpaceSaving.h> |
13 | #include <Common/FieldVisitors.h> |
14 | #include <Common/assert_cast.h> |
15 | |
16 | #include <AggregateFunctions/IAggregateFunction.h> |
17 | |
18 | |
19 | namespace DB |
20 | { |
21 | |
22 | |
23 | template <typename T> |
24 | struct AggregateFunctionTopKData |
25 | { |
26 | using Set = SpaceSaving |
27 | < |
28 | T, |
29 | HashCRC32<T>, |
30 | HashTableGrower<4>, |
31 | HashTableAllocatorWithStackMemory<sizeof(T) * (1 << 4)> |
32 | >; |
33 | Set value; |
34 | }; |
35 | |
36 | |
37 | template <typename T, bool is_weighted> |
38 | class AggregateFunctionTopK |
39 | : public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>> |
40 | { |
41 | protected: |
42 | using State = AggregateFunctionTopKData<T>; |
43 | UInt64 threshold; |
44 | UInt64 reserved; |
45 | |
46 | public: |
47 | AggregateFunctionTopK(UInt64 threshold_, UInt64 load_factor, const DataTypes & argument_types_, const Array & params) |
48 | : IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params) |
49 | , threshold(threshold_), reserved(load_factor * threshold) {} |
50 | |
51 | String getName() const override { return is_weighted ? "topKWeighted" : "topK" ; } |
52 | |
53 | DataTypePtr getReturnType() const override |
54 | { |
55 | return std::make_shared<DataTypeArray>(std::make_shared<DataTypeNumber<T>>()); |
56 | } |
57 | |
58 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override |
59 | { |
60 | auto & set = this->data(place).value; |
61 | if (set.capacity() != reserved) |
62 | set.resize(reserved); |
63 | |
64 | if constexpr (is_weighted) |
65 | set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], columns[1]->getUInt(row_num)); |
66 | else |
67 | set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]); |
68 | } |
69 | |
70 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
71 | { |
72 | this->data(place).value.merge(this->data(rhs).value); |
73 | } |
74 | |
75 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
76 | { |
77 | this->data(place).value.write(buf); |
78 | } |
79 | |
80 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
81 | { |
82 | auto & set = this->data(place).value; |
83 | set.resize(reserved); |
84 | set.read(buf); |
85 | } |
86 | |
87 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
88 | { |
89 | ColumnArray & arr_to = assert_cast<ColumnArray &>(to); |
90 | ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); |
91 | |
92 | const typename State::Set & set = this->data(place).value; |
93 | auto result_vec = set.topK(threshold); |
94 | size_t size = result_vec.size(); |
95 | |
96 | offsets_to.push_back(offsets_to.back() + size); |
97 | |
98 | typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData(); |
99 | size_t old_size = data_to.size(); |
100 | data_to.resize(old_size + size); |
101 | |
102 | size_t i = 0; |
103 | for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i) |
104 | data_to[old_size + i] = it->key; |
105 | } |
106 | }; |
107 | |
108 | |
109 | /// Generic implementation, it uses serialized representation as object descriptor. |
110 | struct AggregateFunctionTopKGenericData |
111 | { |
112 | using Set = SpaceSaving |
113 | < |
114 | StringRef, |
115 | StringRefHash, |
116 | HashTableGrower<4>, |
117 | HashTableAllocatorWithStackMemory<sizeof(StringRef) * (1 << 4)> |
118 | >; |
119 | |
120 | Set value; |
121 | }; |
122 | |
123 | /** Template parameter with true value should be used for columns that store their elements in memory continuously. |
124 | * For such columns topK() can be implemented more efficiently (especially for small numeric arrays). |
125 | */ |
126 | template <bool is_plain_column, bool is_weighted> |
127 | class AggregateFunctionTopKGeneric : public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>> |
128 | { |
129 | private: |
130 | using State = AggregateFunctionTopKGenericData; |
131 | |
132 | UInt64 threshold; |
133 | UInt64 reserved; |
134 | DataTypePtr & input_data_type; |
135 | |
136 | static void deserializeAndInsert(StringRef str, IColumn & data_to); |
137 | |
138 | public: |
139 | AggregateFunctionTopKGeneric( |
140 | UInt64 threshold_, UInt64 load_factor, const DataTypePtr & input_data_type_, const Array & params) |
141 | : IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>({input_data_type_}, params) |
142 | , threshold(threshold_), reserved(load_factor * threshold), input_data_type(this->argument_types[0]) {} |
143 | |
144 | String getName() const override { return is_weighted ? "topKWeighted" : "topK" ; } |
145 | |
146 | DataTypePtr getReturnType() const override |
147 | { |
148 | return std::make_shared<DataTypeArray>(input_data_type); |
149 | } |
150 | |
151 | bool allocatesMemoryInArena() const override |
152 | { |
153 | return true; |
154 | } |
155 | |
156 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
157 | { |
158 | this->data(place).value.write(buf); |
159 | } |
160 | |
161 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override |
162 | { |
163 | auto & set = this->data(place).value; |
164 | set.clear(); |
165 | set.resize(reserved); |
166 | |
167 | // Specialized here because there's no deserialiser for StringRef |
168 | size_t size = 0; |
169 | readVarUInt(size, buf); |
170 | for (size_t i = 0; i < size; ++i) |
171 | { |
172 | auto ref = readStringBinaryInto(*arena, buf); |
173 | UInt64 count; |
174 | UInt64 error; |
175 | readVarUInt(count, buf); |
176 | readVarUInt(error, buf); |
177 | set.insert(ref, count, error); |
178 | arena->rollback(ref.size); |
179 | } |
180 | |
181 | set.readAlphaMap(buf); |
182 | } |
183 | |
184 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
185 | { |
186 | auto & set = this->data(place).value; |
187 | if (set.capacity() != reserved) |
188 | set.resize(reserved); |
189 | |
190 | if constexpr (is_plain_column) |
191 | { |
192 | if constexpr (is_weighted) |
193 | set.insert(columns[0]->getDataAt(row_num), columns[1]->getUInt(row_num)); |
194 | else |
195 | set.insert(columns[0]->getDataAt(row_num)); |
196 | } |
197 | else |
198 | { |
199 | const char * begin = nullptr; |
200 | StringRef str_serialized = columns[0]->serializeValueIntoArena(row_num, *arena, begin); |
201 | if constexpr (is_weighted) |
202 | set.insert(str_serialized, columns[1]->getUInt(row_num)); |
203 | else |
204 | set.insert(str_serialized); |
205 | arena->rollback(str_serialized.size); |
206 | } |
207 | } |
208 | |
209 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
210 | { |
211 | this->data(place).value.merge(this->data(rhs).value); |
212 | } |
213 | |
214 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
215 | { |
216 | ColumnArray & arr_to = assert_cast<ColumnArray &>(to); |
217 | ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); |
218 | IColumn & data_to = arr_to.getData(); |
219 | |
220 | auto result_vec = this->data(place).value.topK(threshold); |
221 | offsets_to.push_back(offsets_to.back() + result_vec.size()); |
222 | |
223 | for (auto & elem : result_vec) |
224 | { |
225 | if constexpr (is_plain_column) |
226 | data_to.insertData(elem.key.data, elem.key.size); |
227 | else |
228 | data_to.deserializeAndInsertFromArena(elem.key.data); |
229 | } |
230 | } |
231 | }; |
232 | |
233 | } |
234 | |