1 | #pragma once |
2 | |
3 | #include <IO/WriteHelpers.h> |
4 | #include <IO/ReadHelpers.h> |
5 | |
6 | #include <DataTypes/DataTypeArray.h> |
7 | #include <DataTypes/DataTypesNumber.h> |
8 | #include <DataTypes/DataTypeString.h> |
9 | |
10 | #include <Columns/ColumnArray.h> |
11 | |
12 | #include <Common/HashTable/HashSet.h> |
13 | #include <Common/HashTable/HashTableKeyHolder.h> |
14 | #include <Common/assert_cast.h> |
15 | |
16 | #include <AggregateFunctions/IAggregateFunction.h> |
17 | |
18 | #define AGGREGATE_FUNCTION_GROUP_ARRAY_UNIQ_MAX_SIZE 0xFFFFFF |
19 | |
20 | |
21 | namespace DB |
22 | { |
23 | |
24 | |
25 | template <typename T> |
26 | struct AggregateFunctionGroupUniqArrayData |
27 | { |
28 | /// When creating, the hash table must be small. |
29 | using Set = HashSet< |
30 | T, |
31 | DefaultHash<T>, |
32 | HashTableGrower<4>, |
33 | HashTableAllocatorWithStackMemory<sizeof(T) * (1 << 4)> |
34 | >; |
35 | |
36 | Set value; |
37 | }; |
38 | |
39 | |
40 | /// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. |
41 | template <typename T, typename Tlimit_num_elem> |
42 | class AggregateFunctionGroupUniqArray |
43 | : public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, AggregateFunctionGroupUniqArray<T, Tlimit_num_elem>> |
44 | { |
45 | static constexpr bool limit_num_elems = Tlimit_num_elem::value; |
46 | UInt64 max_elems; |
47 | |
48 | private: |
49 | using State = AggregateFunctionGroupUniqArrayData<T>; |
50 | |
51 | public: |
52 | AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) |
53 | : IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, |
54 | AggregateFunctionGroupUniqArray<T, Tlimit_num_elem>>({argument_type}, {}), |
55 | max_elems(max_elems_) {} |
56 | |
57 | String getName() const override { return "groupUniqArray" ; } |
58 | |
59 | DataTypePtr getReturnType() const override |
60 | { |
61 | return std::make_shared<DataTypeArray>(std::make_shared<DataTypeNumber<T>>()); |
62 | } |
63 | |
64 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override |
65 | { |
66 | if (limit_num_elems && this->data(place).value.size() >= max_elems) |
67 | return; |
68 | this->data(place).value.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]); |
69 | } |
70 | |
71 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
72 | { |
73 | if (!limit_num_elems) |
74 | this->data(place).value.merge(this->data(rhs).value); |
75 | else |
76 | { |
77 | auto & cur_set = this->data(place).value; |
78 | auto & rhs_set = this->data(rhs).value; |
79 | |
80 | for (auto & rhs_elem : rhs_set) |
81 | { |
82 | if (cur_set.size() >= max_elems) |
83 | return; |
84 | cur_set.insert(rhs_elem.getValue()); |
85 | } |
86 | } |
87 | } |
88 | |
89 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
90 | { |
91 | auto & set = this->data(place).value; |
92 | size_t size = set.size(); |
93 | writeVarUInt(size, buf); |
94 | for (const auto & elem : set) |
95 | writeIntBinary(elem, buf); |
96 | } |
97 | |
98 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
99 | { |
100 | this->data(place).value.read(buf); |
101 | } |
102 | |
103 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
104 | { |
105 | ColumnArray & arr_to = assert_cast<ColumnArray &>(to); |
106 | ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); |
107 | |
108 | const typename State::Set & set = this->data(place).value; |
109 | size_t size = set.size(); |
110 | |
111 | offsets_to.push_back(offsets_to.back() + size); |
112 | |
113 | typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData(); |
114 | size_t old_size = data_to.size(); |
115 | data_to.resize(old_size + size); |
116 | |
117 | size_t i = 0; |
118 | for (auto it = set.begin(); it != set.end(); ++it, ++i) |
119 | data_to[old_size + i] = it->getValue(); |
120 | } |
121 | }; |
122 | |
123 | |
124 | /// Generic implementation, it uses serialized representation as object descriptor. |
125 | struct AggregateFunctionGroupUniqArrayGenericData |
126 | { |
127 | static constexpr size_t INIT_ELEMS = 2; /// adjustable |
128 | static constexpr size_t ELEM_SIZE = sizeof(HashSetCellWithSavedHash<StringRef, StringRefHash>); |
129 | using Set = HashSetWithSavedHash<StringRef, StringRefHash, HashTableGrower<INIT_ELEMS>, HashTableAllocatorWithStackMemory<INIT_ELEMS * ELEM_SIZE>>; |
130 | |
131 | Set value; |
132 | }; |
133 | |
134 | template <bool is_plain_column> |
135 | static void deserializeAndInsertImpl(StringRef str, IColumn & data_to); |
136 | |
137 | /** Template parameter with true value should be used for columns that store their elements in memory continuously. |
138 | * For such columns groupUniqArray() can be implemented more efficiently (especially for small numeric arrays). |
139 | */ |
140 | template <bool is_plain_column = false, typename Tlimit_num_elem = std::false_type> |
141 | class AggregateFunctionGroupUniqArrayGeneric |
142 | : public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, Tlimit_num_elem>> |
143 | { |
144 | DataTypePtr & input_data_type; |
145 | |
146 | static constexpr bool limit_num_elems = Tlimit_num_elem::value; |
147 | UInt64 max_elems; |
148 | |
149 | using State = AggregateFunctionGroupUniqArrayGenericData; |
150 | |
151 | static auto getKeyHolder(const IColumn & column, size_t row_num, Arena & arena) |
152 | { |
153 | if constexpr (is_plain_column) |
154 | { |
155 | return ArenaKeyHolder{column.getDataAt(row_num), arena}; |
156 | } |
157 | else |
158 | { |
159 | const char * begin = nullptr; |
160 | StringRef serialized = column.serializeValueIntoArena(row_num, arena, begin); |
161 | return SerializedKeyHolder{serialized, arena}; |
162 | } |
163 | } |
164 | |
165 | static void deserializeAndInsert(StringRef str, IColumn & data_to) |
166 | { |
167 | return deserializeAndInsertImpl<is_plain_column>(str, data_to); |
168 | } |
169 | |
170 | public: |
171 | AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max()) |
172 | : IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, Tlimit_num_elem>>({input_data_type_}, {}) |
173 | , input_data_type(this->argument_types[0]) |
174 | , max_elems(max_elems_) {} |
175 | |
176 | String getName() const override { return "groupUniqArray" ; } |
177 | |
178 | DataTypePtr getReturnType() const override |
179 | { |
180 | return std::make_shared<DataTypeArray>(input_data_type); |
181 | } |
182 | |
183 | bool allocatesMemoryInArena() const override |
184 | { |
185 | return true; |
186 | } |
187 | |
188 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
189 | { |
190 | auto & set = this->data(place).value; |
191 | writeVarUInt(set.size(), buf); |
192 | |
193 | for (const auto & elem : set) |
194 | { |
195 | writeStringBinary(elem.getValue(), buf); |
196 | } |
197 | } |
198 | |
199 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override |
200 | { |
201 | auto & set = this->data(place).value; |
202 | size_t size; |
203 | readVarUInt(size, buf); |
204 | //TODO: set.reserve(size); |
205 | |
206 | for (size_t i = 0; i < size; ++i) |
207 | { |
208 | set.insert(readStringBinaryInto(*arena, buf)); |
209 | } |
210 | } |
211 | |
212 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override |
213 | { |
214 | auto & set = this->data(place).value; |
215 | if (limit_num_elems && set.size() >= max_elems) |
216 | return; |
217 | |
218 | bool inserted; |
219 | State::Set::LookupResult it; |
220 | auto key_holder = getKeyHolder(*columns[0], row_num, *arena); |
221 | set.emplace(key_holder, it, inserted); |
222 | } |
223 | |
224 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override |
225 | { |
226 | auto & cur_set = this->data(place).value; |
227 | auto & rhs_set = this->data(rhs).value; |
228 | |
229 | bool inserted; |
230 | State::Set::LookupResult it; |
231 | for (auto & rhs_elem : rhs_set) |
232 | { |
233 | if (limit_num_elems && cur_set.size() >= max_elems) |
234 | return; |
235 | |
236 | // We have to copy the keys to our arena. |
237 | assert(arena != nullptr); |
238 | cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), *arena}, it, inserted); |
239 | } |
240 | } |
241 | |
242 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
243 | { |
244 | ColumnArray & arr_to = assert_cast<ColumnArray &>(to); |
245 | ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); |
246 | IColumn & data_to = arr_to.getData(); |
247 | |
248 | auto & set = this->data(place).value; |
249 | offsets_to.push_back(offsets_to.back() + set.size()); |
250 | |
251 | for (auto & elem : set) |
252 | { |
253 | deserializeAndInsert(elem.getValue(), data_to); |
254 | } |
255 | } |
256 | }; |
257 | |
258 | template <> |
259 | inline void deserializeAndInsertImpl<false>(StringRef str, IColumn & data_to) |
260 | { |
261 | data_to.deserializeAndInsertFromArena(str.data); |
262 | } |
263 | |
264 | template <> |
265 | inline void deserializeAndInsertImpl<true>(StringRef str, IColumn & data_to) |
266 | { |
267 | data_to.insertData(str.data, str.size); |
268 | } |
269 | |
270 | #undef AGGREGATE_FUNCTION_GROUP_ARRAY_UNIQ_MAX_SIZE |
271 | |
272 | } |
273 | |