1#pragma once
2
3#include <IO/WriteHelpers.h>
4#include <IO/ReadHelpers.h>
5
6#include <DataTypes/DataTypeArray.h>
7#include <DataTypes/DataTypesNumber.h>
8#include <DataTypes/DataTypeString.h>
9
10#include <Columns/ColumnArray.h>
11
12#include <Common/HashTable/HashSet.h>
13#include <Common/HashTable/HashTableKeyHolder.h>
14#include <Common/assert_cast.h>
15
16#include <AggregateFunctions/IAggregateFunction.h>
17
18#define AGGREGATE_FUNCTION_GROUP_ARRAY_UNIQ_MAX_SIZE 0xFFFFFF
19
20
21namespace DB
22{
23
24
25template <typename T>
26struct AggregateFunctionGroupUniqArrayData
27{
28 /// When creating, the hash table must be small.
29 using Set = HashSet<
30 T,
31 DefaultHash<T>,
32 HashTableGrower<4>,
33 HashTableAllocatorWithStackMemory<sizeof(T) * (1 << 4)>
34 >;
35
36 Set value;
37};
38
39
40/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types.
41template <typename T, typename Tlimit_num_elem>
42class AggregateFunctionGroupUniqArray
43 : public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>, AggregateFunctionGroupUniqArray<T, Tlimit_num_elem>>
44{
45 static constexpr bool limit_num_elems = Tlimit_num_elem::value;
46 UInt64 max_elems;
47
48private:
49 using State = AggregateFunctionGroupUniqArrayData<T>;
50
51public:
52 AggregateFunctionGroupUniqArray(const DataTypePtr & argument_type, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
53 : IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayData<T>,
54 AggregateFunctionGroupUniqArray<T, Tlimit_num_elem>>({argument_type}, {}),
55 max_elems(max_elems_) {}
56
57 String getName() const override { return "groupUniqArray"; }
58
59 DataTypePtr getReturnType() const override
60 {
61 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeNumber<T>>());
62 }
63
64 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
65 {
66 if (limit_num_elems && this->data(place).value.size() >= max_elems)
67 return;
68 this->data(place).value.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
69 }
70
71 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
72 {
73 if (!limit_num_elems)
74 this->data(place).value.merge(this->data(rhs).value);
75 else
76 {
77 auto & cur_set = this->data(place).value;
78 auto & rhs_set = this->data(rhs).value;
79
80 for (auto & rhs_elem : rhs_set)
81 {
82 if (cur_set.size() >= max_elems)
83 return;
84 cur_set.insert(rhs_elem.getValue());
85 }
86 }
87 }
88
89 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
90 {
91 auto & set = this->data(place).value;
92 size_t size = set.size();
93 writeVarUInt(size, buf);
94 for (const auto & elem : set)
95 writeIntBinary(elem, buf);
96 }
97
98 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
99 {
100 this->data(place).value.read(buf);
101 }
102
103 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
104 {
105 ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
106 ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
107
108 const typename State::Set & set = this->data(place).value;
109 size_t size = set.size();
110
111 offsets_to.push_back(offsets_to.back() + size);
112
113 typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
114 size_t old_size = data_to.size();
115 data_to.resize(old_size + size);
116
117 size_t i = 0;
118 for (auto it = set.begin(); it != set.end(); ++it, ++i)
119 data_to[old_size + i] = it->getValue();
120 }
121};
122
123
124/// Generic implementation, it uses serialized representation as object descriptor.
125struct AggregateFunctionGroupUniqArrayGenericData
126{
127 static constexpr size_t INIT_ELEMS = 2; /// adjustable
128 static constexpr size_t ELEM_SIZE = sizeof(HashSetCellWithSavedHash<StringRef, StringRefHash>);
129 using Set = HashSetWithSavedHash<StringRef, StringRefHash, HashTableGrower<INIT_ELEMS>, HashTableAllocatorWithStackMemory<INIT_ELEMS * ELEM_SIZE>>;
130
131 Set value;
132};
133
134template <bool is_plain_column>
135static void deserializeAndInsertImpl(StringRef str, IColumn & data_to);
136
137/** Template parameter with true value should be used for columns that store their elements in memory continuously.
138 * For such columns groupUniqArray() can be implemented more efficiently (especially for small numeric arrays).
139 */
140template <bool is_plain_column = false, typename Tlimit_num_elem = std::false_type>
141class AggregateFunctionGroupUniqArrayGeneric
142 : public IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, Tlimit_num_elem>>
143{
144 DataTypePtr & input_data_type;
145
146 static constexpr bool limit_num_elems = Tlimit_num_elem::value;
147 UInt64 max_elems;
148
149 using State = AggregateFunctionGroupUniqArrayGenericData;
150
151 static auto getKeyHolder(const IColumn & column, size_t row_num, Arena & arena)
152 {
153 if constexpr (is_plain_column)
154 {
155 return ArenaKeyHolder{column.getDataAt(row_num), arena};
156 }
157 else
158 {
159 const char * begin = nullptr;
160 StringRef serialized = column.serializeValueIntoArena(row_num, arena, begin);
161 return SerializedKeyHolder{serialized, arena};
162 }
163 }
164
165 static void deserializeAndInsert(StringRef str, IColumn & data_to)
166 {
167 return deserializeAndInsertImpl<is_plain_column>(str, data_to);
168 }
169
170public:
171 AggregateFunctionGroupUniqArrayGeneric(const DataTypePtr & input_data_type_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
172 : IAggregateFunctionDataHelper<AggregateFunctionGroupUniqArrayGenericData, AggregateFunctionGroupUniqArrayGeneric<is_plain_column, Tlimit_num_elem>>({input_data_type_}, {})
173 , input_data_type(this->argument_types[0])
174 , max_elems(max_elems_) {}
175
176 String getName() const override { return "groupUniqArray"; }
177
178 DataTypePtr getReturnType() const override
179 {
180 return std::make_shared<DataTypeArray>(input_data_type);
181 }
182
183 bool allocatesMemoryInArena() const override
184 {
185 return true;
186 }
187
188 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
189 {
190 auto & set = this->data(place).value;
191 writeVarUInt(set.size(), buf);
192
193 for (const auto & elem : set)
194 {
195 writeStringBinary(elem.getValue(), buf);
196 }
197 }
198
199 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
200 {
201 auto & set = this->data(place).value;
202 size_t size;
203 readVarUInt(size, buf);
204 //TODO: set.reserve(size);
205
206 for (size_t i = 0; i < size; ++i)
207 {
208 set.insert(readStringBinaryInto(*arena, buf));
209 }
210 }
211
212 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
213 {
214 auto & set = this->data(place).value;
215 if (limit_num_elems && set.size() >= max_elems)
216 return;
217
218 bool inserted;
219 State::Set::LookupResult it;
220 auto key_holder = getKeyHolder(*columns[0], row_num, *arena);
221 set.emplace(key_holder, it, inserted);
222 }
223
224 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
225 {
226 auto & cur_set = this->data(place).value;
227 auto & rhs_set = this->data(rhs).value;
228
229 bool inserted;
230 State::Set::LookupResult it;
231 for (auto & rhs_elem : rhs_set)
232 {
233 if (limit_num_elems && cur_set.size() >= max_elems)
234 return;
235
236 // We have to copy the keys to our arena.
237 assert(arena != nullptr);
238 cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), *arena}, it, inserted);
239 }
240 }
241
242 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
243 {
244 ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
245 ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
246 IColumn & data_to = arr_to.getData();
247
248 auto & set = this->data(place).value;
249 offsets_to.push_back(offsets_to.back() + set.size());
250
251 for (auto & elem : set)
252 {
253 deserializeAndInsert(elem.getValue(), data_to);
254 }
255 }
256};
257
258template <>
259inline void deserializeAndInsertImpl<false>(StringRef str, IColumn & data_to)
260{
261 data_to.deserializeAndInsertFromArena(str.data);
262}
263
264template <>
265inline void deserializeAndInsertImpl<true>(StringRef str, IColumn & data_to)
266{
267 data_to.insertData(str.data, str.size);
268}
269
270#undef AGGREGATE_FUNCTION_GROUP_ARRAY_UNIQ_MAX_SIZE
271
272}
273