1#pragma once
2
3#include <Common/HashTable/SmallTable.h>
4#include <Common/HashTable/HashSet.h>
5#include <Common/HyperLogLogCounter.h>
6#include <Core/Defines.h>
7
8
9namespace DB
10{
11
12namespace details
13{
14
15enum class ContainerType : UInt8 { SMALL = 1, MEDIUM = 2, LARGE = 3 };
16
17static inline ContainerType max(const ContainerType & lhs, const ContainerType & rhs)
18{
19 UInt8 res = std::max(static_cast<UInt8>(lhs), static_cast<UInt8>(rhs));
20 return static_cast<ContainerType>(res);
21}
22
23}
24
25/** For a small number of keys - an array of fixed size "on the stack".
26 * For the average, HashSet is allocated.
27 * For large, HyperLogLog is allocated.
28 */
29template
30<
31 typename Key,
32 typename HashContainer,
33 UInt8 small_set_size_max,
34 UInt8 medium_set_power2_max,
35 UInt8 K,
36 typename Hash = IntHash32<Key>,
37 typename HashValueType = UInt32,
38 typename BiasEstimator = TrivialBiasEstimator,
39 HyperLogLogMode mode = HyperLogLogMode::FullFeatured,
40 typename DenominatorType = double
41>
42class CombinedCardinalityEstimator
43{
44public:
45 using Self = CombinedCardinalityEstimator
46 <
47 Key,
48 HashContainer,
49 small_set_size_max,
50 medium_set_power2_max,
51 K,
52 Hash,
53 HashValueType,
54 BiasEstimator,
55 mode,
56 DenominatorType
57 >;
58
59 using value_type = Key;
60
61private:
62 using Small = SmallSet<Key, small_set_size_max>;
63 using Medium = HashContainer;
64 using Large = HyperLogLogCounter<K, Hash, HashValueType, DenominatorType, BiasEstimator, mode>;
65
66public:
67 CombinedCardinalityEstimator()
68 {
69 setContainerType(details::ContainerType::SMALL);
70 }
71
72 ~CombinedCardinalityEstimator()
73 {
74 destroy();
75 }
76
77 void insert(Key value)
78 {
79 auto container_type = getContainerType();
80
81 if (container_type == details::ContainerType::SMALL)
82 {
83 if (small.find(value) == small.end())
84 {
85 if (!small.full())
86 small.insert(value);
87 else
88 {
89 toMedium();
90 getContainer<Medium>().insert(value);
91 }
92 }
93 }
94 else if (container_type == details::ContainerType::MEDIUM)
95 {
96 auto & container = getContainer<Medium>();
97 if (container.size() < medium_set_size_max)
98 container.insert(value);
99 else
100 {
101 toLarge();
102 getContainer<Large>().insert(value);
103 }
104 }
105 else if (container_type == details::ContainerType::LARGE)
106 getContainer<Large>().insert(value);
107 }
108
109 UInt64 size() const
110 {
111 auto container_type = getContainerType();
112
113 if (container_type == details::ContainerType::SMALL)
114 return small.size();
115 else if (container_type == details::ContainerType::MEDIUM)
116 return getContainer<Medium>().size();
117 else if (container_type == details::ContainerType::LARGE)
118 return getContainer<Large>().size();
119 else
120 throw Poco::Exception("Internal error", ErrorCodes::LOGICAL_ERROR);
121 }
122
123 void merge(const Self & rhs)
124 {
125 auto container_type = getContainerType();
126 auto max_container_type = details::max(container_type, rhs.getContainerType());
127
128 if (container_type != max_container_type)
129 {
130 if (max_container_type == details::ContainerType::MEDIUM)
131 toMedium();
132 else if (max_container_type == details::ContainerType::LARGE)
133 toLarge();
134 }
135
136 if (rhs.getContainerType() == details::ContainerType::SMALL)
137 {
138 for (const auto & x : rhs.small)
139 insert(x.getValue());
140 }
141 else if (rhs.getContainerType() == details::ContainerType::MEDIUM)
142 {
143 for (const auto & x : rhs.getContainer<Medium>())
144 insert(x.getValue());
145 }
146 else if (rhs.getContainerType() == details::ContainerType::LARGE)
147 getContainer<Large>().merge(rhs.getContainer<Large>());
148 }
149
150 /// You can only call for an empty object.
151 void read(DB::ReadBuffer & in)
152 {
153 UInt8 v;
154 readBinary(v, in);
155 auto container_type = static_cast<details::ContainerType>(v);
156
157 if (container_type == details::ContainerType::SMALL)
158 small.read(in);
159 else if (container_type == details::ContainerType::MEDIUM)
160 {
161 toMedium();
162 getContainer<Medium>().read(in);
163 }
164 else if (container_type == details::ContainerType::LARGE)
165 {
166 toLarge();
167 getContainer<Large>().read(in);
168 }
169 }
170
171 void readAndMerge(DB::ReadBuffer & in)
172 {
173 auto container_type = getContainerType();
174
175 /// If readAndMerge is called with an empty state, just deserialize
176 /// the state is specified as a parameter.
177 if ((container_type == details::ContainerType::SMALL) && small.empty())
178 {
179 read(in);
180 return;
181 }
182
183 UInt8 v;
184 readBinary(v, in);
185 auto rhs_container_type = static_cast<details::ContainerType>(v);
186
187 auto max_container_type = details::max(container_type, rhs_container_type);
188
189 if (container_type != max_container_type)
190 {
191 if (max_container_type == details::ContainerType::MEDIUM)
192 toMedium();
193 else if (max_container_type == details::ContainerType::LARGE)
194 toLarge();
195 }
196
197 if (rhs_container_type == details::ContainerType::SMALL)
198 {
199 typename Small::Reader reader(in);
200 while (reader.next())
201 insert(reader.get());
202 }
203 else if (rhs_container_type == details::ContainerType::MEDIUM)
204 {
205 typename Medium::Reader reader(in);
206 while (reader.next())
207 insert(reader.get());
208 }
209 else if (rhs_container_type == details::ContainerType::LARGE)
210 getContainer<Large>().readAndMerge(in);
211 }
212
213 void write(DB::WriteBuffer & out) const
214 {
215 auto container_type = getContainerType();
216 writeBinary(static_cast<UInt8>(container_type), out);
217
218 if (container_type == details::ContainerType::SMALL)
219 small.write(out);
220 else if (container_type == details::ContainerType::MEDIUM)
221 getContainer<Medium>().write(out);
222 else if (container_type == details::ContainerType::LARGE)
223 getContainer<Large>().write(out);
224 }
225
226private:
227 void toMedium()
228 {
229 if (getContainerType() != details::ContainerType::SMALL)
230 throw Poco::Exception("Internal error", ErrorCodes::LOGICAL_ERROR);
231
232 auto tmp_medium = std::make_unique<Medium>();
233
234 for (const auto & x : small)
235 tmp_medium->insert(x.getValue());
236
237 medium = tmp_medium.release();
238 setContainerType(details::ContainerType::MEDIUM);
239 }
240
241 void toLarge()
242 {
243 auto container_type = getContainerType();
244
245 if ((container_type != details::ContainerType::SMALL) && (container_type != details::ContainerType::MEDIUM))
246 throw Poco::Exception("Internal error", ErrorCodes::LOGICAL_ERROR);
247
248 auto tmp_large = std::make_unique<Large>();
249
250 if (container_type == details::ContainerType::SMALL)
251 {
252 for (const auto & x : small)
253 tmp_large->insert(x.getValue());
254 }
255 else if (container_type == details::ContainerType::MEDIUM)
256 {
257 for (const auto & x : getContainer<Medium>())
258 tmp_large->insert(x.getValue());
259
260 destroy();
261 }
262
263 large = tmp_large.release();
264 setContainerType(details::ContainerType::LARGE);
265 }
266
267 void NO_INLINE destroy()
268 {
269 auto container_type = getContainerType();
270
271 clearContainerType();
272
273 if (container_type == details::ContainerType::MEDIUM)
274 {
275 delete medium;
276 medium = nullptr;
277 }
278 else if (container_type == details::ContainerType::LARGE)
279 {
280 delete large;
281 large = nullptr;
282 }
283 }
284
285 template <typename T>
286 inline T & getContainer()
287 {
288 return *reinterpret_cast<T *>(address & mask);
289 }
290
291 template <typename T>
292 inline const T & getContainer() const
293 {
294 return *reinterpret_cast<T *>(address & mask);
295 }
296
297 void setContainerType(details::ContainerType t)
298 {
299 address &= mask;
300 address |= static_cast<UInt8>(t);
301 }
302
303 inline details::ContainerType getContainerType() const
304 {
305 return static_cast<details::ContainerType>(address & ~mask);
306 }
307
308 void clearContainerType()
309 {
310 address &= mask;
311 }
312
313private:
314 Small small;
315 union
316 {
317 Medium * medium;
318 Large * large;
319 UInt64 address = 0;
320 };
321 static const UInt64 mask = 0xFFFFFFFFFFFFFFFC;
322 static const UInt32 medium_set_size_max = 1UL << medium_set_power2_max;
323};
324
325}
326