1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | // Private header, not to be exported |
19 | |
20 | #ifndef ARROW_UTIL_HASHING_H |
21 | #define ARROW_UTIL_HASHING_H |
22 | |
23 | #include <algorithm> |
24 | #include <cassert> |
25 | #include <cmath> |
26 | #include <cstdint> |
27 | #include <cstring> |
28 | #include <limits> |
29 | #include <memory> |
30 | #include <string> |
31 | #include <type_traits> |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | #include "arrow/array.h" |
36 | #include "arrow/buffer.h" |
37 | #include "arrow/builder.h" |
38 | #include "arrow/type.h" |
39 | #include "arrow/type_traits.h" |
40 | #include "arrow/util/bit-util.h" |
41 | #include "arrow/util/checked_cast.h" |
42 | #include "arrow/util/hash-util.h" |
43 | #include "arrow/util/macros.h" |
44 | #include "arrow/util/string_view.h" |
45 | |
46 | namespace arrow { |
47 | namespace internal { |
48 | |
49 | // XXX would it help to have a 32-bit hash value on large datasets? |
50 | typedef uint64_t hash_t; |
51 | |
52 | // Notes about the choice of a hash function. |
53 | // - xxHash64 is extremely fast on large enough data |
54 | // - for small- to medium-sized data, there are better choices |
55 | // (see comprehensive benchmarks results at |
56 | // https://aras-p.info/blog/2016/08/09/More-Hash-Function-Tests/) |
57 | // - for very small fixed-size data (<= 16 bytes, e.g. Decimal128), it is |
58 | // beneficial to define specialized hash functions |
59 | // - while xxHash and others have good statistical properties, we can relax those |
60 | // a bit if it helps performance (especially if the hash table implementation |
61 | // has a good collision resolution strategy) |
62 | |
63 | template <uint64_t AlgNum> |
64 | inline hash_t ComputeStringHash(const void* data, int64_t length); |
65 | |
66 | template <typename Scalar, uint64_t AlgNum> |
67 | struct ScalarHelperBase { |
68 | static bool CompareScalars(Scalar u, Scalar v) { return u == v; } |
69 | |
70 | static hash_t ComputeHash(const Scalar& value) { |
71 | // Generic hash computation for scalars. Simply apply the string hash |
72 | // to the bit representation of the value. |
73 | |
74 | // XXX in the case of FP values, we'd like equal values to have the same hash, |
75 | // even if they have different bit representations... |
76 | return ComputeStringHash<AlgNum>(&value, sizeof(value)); |
77 | } |
78 | }; |
79 | |
80 | template <typename Scalar, uint64_t AlgNum = 0, typename Enable = void> |
81 | struct ScalarHelper : public ScalarHelperBase<Scalar, AlgNum> {}; |
82 | |
83 | template <typename Scalar, uint64_t AlgNum> |
84 | struct ScalarHelper<Scalar, AlgNum, |
85 | typename std::enable_if<std::is_integral<Scalar>::value>::type> |
86 | : public ScalarHelperBase<Scalar, AlgNum> { |
87 | // ScalarHelper specialization for integers |
88 | |
89 | static hash_t ComputeHash(const Scalar& value) { |
90 | // Faster hash computation for integers. |
91 | |
92 | // Two of xxhash's prime multipliers (which are chosen for their |
93 | // bit dispersion properties) |
94 | static constexpr uint64_t multipliers[] = {11400714785074694791ULL, |
95 | 14029467366897019727ULL}; |
96 | |
97 | // Multiplying by the prime number mixes the low bits into the high bits, |
98 | // then byte-swapping (which is a single CPU instruction) allows the |
99 | // combined high and low bits to participate in the initial hash table index. |
100 | auto h = static_cast<hash_t>(value); |
101 | return BitUtil::ByteSwap(multipliers[AlgNum] * h); |
102 | } |
103 | }; |
104 | |
105 | template <typename Scalar, uint64_t AlgNum> |
106 | struct ScalarHelper< |
107 | Scalar, AlgNum, |
108 | typename std::enable_if<std::is_same<util::string_view, Scalar>::value>::type> |
109 | : public ScalarHelperBase<Scalar, AlgNum> { |
110 | // ScalarHelper specialization for util::string_view |
111 | |
112 | static hash_t ComputeHash(const util::string_view& value) { |
113 | return ComputeStringHash<AlgNum>(value.data(), static_cast<int64_t>(value.size())); |
114 | } |
115 | }; |
116 | |
117 | template <typename Scalar, uint64_t AlgNum> |
118 | struct ScalarHelper<Scalar, AlgNum, |
119 | typename std::enable_if<std::is_floating_point<Scalar>::value>::type> |
120 | : public ScalarHelperBase<Scalar, AlgNum> { |
121 | // ScalarHelper specialization for reals |
122 | |
123 | static bool CompareScalars(Scalar u, Scalar v) { |
124 | if (std::isnan(u)) { |
125 | // XXX should we do a bit-precise comparison? |
126 | return std::isnan(v); |
127 | } |
128 | return u == v; |
129 | } |
130 | }; |
131 | |
132 | template <uint64_t AlgNum = 0> |
133 | hash_t ComputeStringHash(const void* data, int64_t length) { |
134 | if (ARROW_PREDICT_TRUE(length <= 16)) { |
135 | // Specialize for small hash strings, as they are quite common as |
136 | // hash table keys. |
137 | auto p = reinterpret_cast<const uint8_t*>(data); |
138 | auto n = static_cast<uint32_t>(length); |
139 | if (n <= 8) { |
140 | if (n <= 3) { |
141 | if (n == 0) { |
142 | return 1U; |
143 | } |
144 | uint32_t x = (n << 24) ^ (p[0] << 16) ^ (p[n / 2] << 8) ^ p[n - 1]; |
145 | return ScalarHelper<uint32_t, AlgNum>::ComputeHash(x); |
146 | } |
147 | // 4 <= length <= 8 |
148 | // We can read the string as two overlapping 32-bit ints, apply |
149 | // different hash functions to each of them in parallel, then XOR |
150 | // the results |
151 | uint32_t x, y; |
152 | hash_t hx, hy; |
153 | // XXX those are unaligned accesses. Should we have a facility for that? |
154 | x = *reinterpret_cast<const uint32_t*>(p + n - 4); |
155 | y = *reinterpret_cast<const uint32_t*>(p); |
156 | hx = ScalarHelper<uint32_t, AlgNum>::ComputeHash(x); |
157 | hy = ScalarHelper<uint32_t, AlgNum ^ 1>::ComputeHash(y); |
158 | return n ^ hx ^ hy; |
159 | } |
160 | // 8 <= length <= 16 |
161 | // Apply the same principle as above |
162 | uint64_t x, y; |
163 | hash_t hx, hy; |
164 | x = *reinterpret_cast<const uint64_t*>(p + n - 8); |
165 | y = *reinterpret_cast<const uint64_t*>(p); |
166 | hx = ScalarHelper<uint64_t, AlgNum>::ComputeHash(x); |
167 | hy = ScalarHelper<uint64_t, AlgNum ^ 1>::ComputeHash(y); |
168 | return n ^ hx ^ hy; |
169 | } |
170 | |
171 | if (HashUtil::have_hardware_crc32) { |
172 | // DoubleCrcHash is faster that Murmur2. |
173 | auto h = HashUtil::DoubleCrcHash(data, static_cast<int32_t>(length), AlgNum); |
174 | return ScalarHelper<uint64_t, AlgNum>::ComputeHash(h); |
175 | } else { |
176 | // Fall back on 64-bit Murmur2 for longer strings. |
177 | // It has decent speed for medium-sized strings. There may be faster |
178 | // hashes on long strings such as xxHash, but that may not matter much |
179 | // for the typical length distribution of hash keys. |
180 | return HashUtil::MurmurHash2_64(data, static_cast<int>(length), AlgNum); |
181 | } |
182 | } |
183 | |
184 | // XXX add a HashEq<ArrowType> struct with both hash and compare functions? |
185 | |
186 | // ---------------------------------------------------------------------- |
187 | // An open-addressing insert-only hash table (no deletes) |
188 | |
189 | template <typename Payload> |
190 | class HashTable { |
191 | public: |
192 | struct Entry { |
193 | hash_t h; |
194 | Payload payload; |
195 | }; |
196 | |
197 | explicit HashTable(uint64_t capacity) { |
198 | // Presize for at least 8 elements |
199 | capacity = std::max(capacity, static_cast<uint64_t>(8U)); |
200 | size_ = BitUtil::NextPower2(capacity * 4U); |
201 | size_mask_ = size_ - 1; |
202 | n_filled_ = 0; |
203 | // This will zero out hash entries, marking them empty |
204 | entries_.resize(size_); |
205 | } |
206 | |
207 | // Lookup with non-linear probing |
208 | // cmp_func should have signature bool(const Payload*). |
209 | // Return a (Entry*, found) pair. |
210 | template <typename CmpFunc> |
211 | std::pair<Entry*, bool> Lookup(hash_t h, CmpFunc&& cmp_func) { |
212 | auto p = Lookup<DoCompare, CmpFunc>(h, entries_.data(), size_mask_, |
213 | std::forward<CmpFunc>(cmp_func)); |
214 | return {&entries_[p.first], p.second}; |
215 | } |
216 | |
217 | template <typename CmpFunc> |
218 | std::pair<const Entry*, bool> Lookup(hash_t h, CmpFunc&& cmp_func) const { |
219 | auto p = Lookup<DoCompare, CmpFunc>(h, entries_.data(), size_mask_, |
220 | std::forward<CmpFunc>(cmp_func)); |
221 | return {&entries_[p.first], p.second}; |
222 | } |
223 | |
224 | void Insert(Entry* entry, hash_t h, const Payload& payload) { |
225 | assert(entry->h == 0); |
226 | entry->h = FixHash(h); |
227 | entry->payload = payload; |
228 | ++n_filled_; |
229 | if (NeedUpsizing()) { |
230 | // Resizing is expensive, avoid doing it too often |
231 | Upsize(size_ * 4); |
232 | } |
233 | } |
234 | |
235 | uint64_t size() const { return n_filled_; } |
236 | |
237 | // Visit all non-empty entries in the table |
238 | // The visit_func should have signature void(const Entry*) |
239 | template <typename VisitFunc> |
240 | void VisitEntries(VisitFunc&& visit_func) const { |
241 | for (const auto& entry : entries_) { |
242 | if (entry.h != 0U) { |
243 | visit_func(&entry); |
244 | } |
245 | } |
246 | } |
247 | |
248 | protected: |
249 | // NoCompare is for when the value is known not to exist in the table |
250 | enum CompareKind { DoCompare, NoCompare }; |
251 | |
252 | // The workhorse lookup function |
253 | template <CompareKind CKind, typename CmpFunc> |
254 | std::pair<uint64_t, bool> Lookup(hash_t h, const Entry* entries, uint64_t size_mask, |
255 | CmpFunc&& cmp_func) const { |
256 | static constexpr uint8_t perturb_shift = 5; |
257 | |
258 | uint64_t index, perturb; |
259 | const Entry* entry; |
260 | |
261 | h = FixHash(h); |
262 | index = h & size_mask; |
263 | perturb = (h >> perturb_shift) + 1U; |
264 | |
265 | while (true) { |
266 | entry = &entries[index]; |
267 | if (CompareEntry<CKind, CmpFunc>(h, entry, std::forward<CmpFunc>(cmp_func))) { |
268 | // Found |
269 | return {index, true}; |
270 | } |
271 | if (entry->h == 0U) { |
272 | // Empty slot |
273 | return {index, false}; |
274 | } |
275 | |
276 | // Perturbation logic inspired from CPython's set / dict object. |
277 | // The goal is that all 64 bits of the unmasked hash value eventually |
278 | // participate in the probing sequence, to minimize clustering. |
279 | index = (index + perturb) & size_mask; |
280 | perturb = (perturb >> perturb_shift) + 1U; |
281 | } |
282 | } |
283 | |
284 | template <CompareKind CKind, typename CmpFunc> |
285 | bool CompareEntry(hash_t h, const Entry* entry, CmpFunc&& cmp_func) const { |
286 | if (CKind == NoCompare) { |
287 | return false; |
288 | } else { |
289 | return entry->h == h && cmp_func(&entry->payload); |
290 | } |
291 | } |
292 | |
293 | bool NeedUpsizing() const { |
294 | // Keep the load factor <= 1/2 |
295 | return n_filled_ * 2U >= size_; |
296 | } |
297 | |
298 | void Upsize(uint64_t new_size) { |
299 | assert(new_size > size_); |
300 | uint64_t new_mask = new_size - 1; |
301 | assert((new_size & new_mask) == 0); // it's a power of two |
302 | |
303 | std::vector<Entry> new_entries(new_size); |
304 | for (auto& entry : entries_) { |
305 | hash_t h = entry.h; |
306 | if (h != 0) { |
307 | // Dummy compare function (will not be called) |
308 | auto cmp_func = [](const Payload*) { return false; }; |
309 | // Non-empty slot, move into new |
310 | auto p = Lookup<NoCompare>(h, new_entries.data(), new_mask, cmp_func); |
311 | assert(!p.second); // shouldn't have found a matching entry |
312 | Entry* new_entry = &new_entries[p.first]; |
313 | new_entry->h = h; |
314 | new_entry->payload = entry.payload; |
315 | } |
316 | } |
317 | std::swap(entries_, new_entries); |
318 | size_ = new_size; |
319 | size_mask_ = new_mask; |
320 | } |
321 | |
322 | hash_t FixHash(hash_t h) const { |
323 | // 0 is used to indicate empty entries |
324 | return (h == 0U) ? 42U : h; |
325 | } |
326 | |
327 | uint64_t size_; |
328 | uint64_t size_mask_; |
329 | uint64_t n_filled_; |
330 | std::vector<Entry> entries_; |
331 | }; |
332 | |
333 | // XXX typedef memo_index_t int32_t ? |
334 | |
335 | // ---------------------------------------------------------------------- |
336 | // A memoization table for memory-cheap scalar values. |
337 | |
338 | // The memoization table remembers and allows to look up the insertion |
339 | // index for each key. |
340 | |
341 | template <typename Scalar, template <class> class HashTableTemplateType = HashTable> |
342 | class ScalarMemoTable { |
343 | public: |
344 | explicit ScalarMemoTable(int64_t entries = 0) |
345 | : hash_table_(static_cast<uint64_t>(entries)) {} |
346 | |
347 | int32_t Get(const Scalar& value) const { |
348 | auto cmp_func = [value](const Payload* payload) -> bool { |
349 | return ScalarHelper<Scalar, 0>::CompareScalars(payload->value, value); |
350 | }; |
351 | hash_t h = ComputeHash(value); |
352 | auto p = hash_table_.Lookup(h, cmp_func); |
353 | if (p.second) { |
354 | return p.first->payload.memo_index; |
355 | } else { |
356 | return -1; |
357 | } |
358 | } |
359 | |
360 | template <typename Func1, typename Func2> |
361 | int32_t GetOrInsert(const Scalar& value, Func1&& on_found, Func2&& on_not_found) { |
362 | auto cmp_func = [value](const Payload* payload) -> bool { |
363 | return ScalarHelper<Scalar, 0>::CompareScalars(value, payload->value); |
364 | }; |
365 | hash_t h = ComputeHash(value); |
366 | auto p = hash_table_.Lookup(h, cmp_func); |
367 | int32_t memo_index; |
368 | if (p.second) { |
369 | memo_index = p.first->payload.memo_index; |
370 | on_found(memo_index); |
371 | } else { |
372 | memo_index = size(); |
373 | hash_table_.Insert(p.first, h, {value, memo_index}); |
374 | on_not_found(memo_index); |
375 | } |
376 | return memo_index; |
377 | } |
378 | |
379 | int32_t GetOrInsert(const Scalar& value) { |
380 | return GetOrInsert(value, [](int32_t i) {}, [](int32_t i) {}); |
381 | } |
382 | |
383 | // The number of entries in the memo table |
384 | // (which is also 1 + the largest memo index) |
385 | int32_t size() const { return static_cast<int32_t>(hash_table_.size()); } |
386 | |
387 | // Copy values starting from index `start` into `out_data` |
388 | void CopyValues(int32_t start, Scalar* out_data) const { |
389 | hash_table_.VisitEntries([=](const HashTableEntry* entry) { |
390 | int32_t index = entry->payload.memo_index - start; |
391 | if (index >= 0) { |
392 | out_data[index] = entry->payload.value; |
393 | } |
394 | }); |
395 | } |
396 | |
397 | void CopyValues(Scalar* out_data) const { CopyValues(0, out_data); } |
398 | |
399 | protected: |
400 | struct Payload { |
401 | Scalar value; |
402 | int32_t memo_index; |
403 | }; |
404 | |
405 | using HashTableType = HashTableTemplateType<Payload>; |
406 | using HashTableEntry = typename HashTableType::Entry; |
407 | HashTableType hash_table_; |
408 | |
409 | hash_t ComputeHash(const Scalar& value) const { |
410 | return ScalarHelper<Scalar, 0>::ComputeHash(value); |
411 | } |
412 | }; |
413 | |
414 | // ---------------------------------------------------------------------- |
415 | // A memoization table for small scalar values, using direct indexing |
416 | |
417 | template <typename Scalar, typename Enable = void> |
418 | struct SmallScalarTraits {}; |
419 | |
420 | template <> |
421 | struct SmallScalarTraits<bool> { |
422 | static constexpr int32_t cardinality = 2; |
423 | |
424 | static uint32_t AsIndex(bool value) { return value ? 1 : 0; } |
425 | }; |
426 | |
427 | template <typename Scalar> |
428 | struct SmallScalarTraits<Scalar, |
429 | typename std::enable_if<std::is_integral<Scalar>::value>::type> { |
430 | using Unsigned = typename std::make_unsigned<Scalar>::type; |
431 | |
432 | static constexpr int32_t cardinality = 1U + std::numeric_limits<Unsigned>::max(); |
433 | |
434 | static uint32_t AsIndex(Scalar value) { return static_cast<Unsigned>(value); } |
435 | }; |
436 | |
437 | template <typename Scalar, template <class> class HashTableTemplateType = HashTable> |
438 | class SmallScalarMemoTable { |
439 | public: |
440 | explicit SmallScalarMemoTable(int64_t entries = 0) { |
441 | std::fill(value_to_index_, value_to_index_ + cardinality, -1); |
442 | index_to_value_.reserve(cardinality); |
443 | } |
444 | |
445 | int32_t Get(const Scalar value) const { |
446 | auto value_index = AsIndex(value); |
447 | return value_to_index_[value_index]; |
448 | } |
449 | |
450 | template <typename Func1, typename Func2> |
451 | int32_t GetOrInsert(const Scalar value, Func1&& on_found, Func2&& on_not_found) { |
452 | auto value_index = AsIndex(value); |
453 | auto memo_index = value_to_index_[value_index]; |
454 | if (memo_index < 0) { |
455 | memo_index = static_cast<int32_t>(index_to_value_.size()); |
456 | index_to_value_.push_back(value); |
457 | value_to_index_[value_index] = memo_index; |
458 | assert(memo_index < cardinality); |
459 | on_not_found(memo_index); |
460 | } else { |
461 | on_found(memo_index); |
462 | } |
463 | return memo_index; |
464 | } |
465 | |
466 | int32_t GetOrInsert(const Scalar value) { |
467 | return GetOrInsert(value, [](int32_t i) {}, [](int32_t i) {}); |
468 | } |
469 | |
470 | // The number of entries in the memo table |
471 | // (which is also 1 + the largest memo index) |
472 | int32_t size() const { return static_cast<int32_t>(index_to_value_.size()); } |
473 | |
474 | // Copy values starting from index `start` into `out_data` |
475 | void CopyValues(int32_t start, Scalar* out_data) const { |
476 | memcpy(out_data, &index_to_value_[start], size() - start); |
477 | } |
478 | |
479 | void CopyValues(Scalar* out_data) const { CopyValues(0, out_data); } |
480 | |
481 | const std::vector<Scalar>& values() const { return index_to_value_; } |
482 | |
483 | protected: |
484 | static constexpr auto cardinality = SmallScalarTraits<Scalar>::cardinality; |
485 | static_assert(cardinality <= 256, "cardinality too large for direct-addressed table" ); |
486 | |
487 | uint32_t AsIndex(Scalar value) const { |
488 | return SmallScalarTraits<Scalar>::AsIndex(value); |
489 | } |
490 | |
491 | int32_t value_to_index_[cardinality]; |
492 | std::vector<Scalar> index_to_value_; |
493 | }; |
494 | |
495 | // ---------------------------------------------------------------------- |
496 | // A memoization table for variable-sized binary data. |
497 | |
498 | class BinaryMemoTable { |
499 | public: |
500 | explicit BinaryMemoTable(int64_t entries = 0, int64_t values_size = -1) |
501 | : hash_table_(static_cast<uint64_t>(entries)) { |
502 | offsets_.reserve(entries + 1); |
503 | offsets_.push_back(0); |
504 | if (values_size == -1) { |
505 | values_.reserve(entries * 4); // A conservative heuristic |
506 | } else { |
507 | values_.reserve(values_size); |
508 | } |
509 | } |
510 | |
511 | int32_t Get(const void* data, int32_t length) const { |
512 | hash_t h = ComputeStringHash<0>(data, length); |
513 | auto p = Lookup(h, data, length); |
514 | if (p.second) { |
515 | return p.first->payload.memo_index; |
516 | } else { |
517 | return -1; |
518 | } |
519 | } |
520 | |
521 | int32_t Get(const std::string& value) const { |
522 | return Get(value.data(), static_cast<int32_t>(value.length())); |
523 | } |
524 | |
525 | int32_t Get(const util::string_view& value) const { |
526 | return Get(value.data(), static_cast<int32_t>(value.length())); |
527 | } |
528 | |
529 | template <typename Func1, typename Func2> |
530 | int32_t GetOrInsert(const void* data, int32_t length, Func1&& on_found, |
531 | Func2&& on_not_found) { |
532 | hash_t h = ComputeStringHash<0>(data, length); |
533 | auto p = Lookup(h, data, length); |
534 | int32_t memo_index; |
535 | if (p.second) { |
536 | memo_index = p.first->payload.memo_index; |
537 | on_found(memo_index); |
538 | } else { |
539 | memo_index = size(); |
540 | // Insert offset |
541 | auto offset = static_cast<int32_t>(values_.size()); |
542 | assert(offsets_.size() == static_cast<uint32_t>(memo_index + 1)); |
543 | assert(offsets_[memo_index] == offset); |
544 | offsets_.push_back(offset + length); |
545 | // Insert string value |
546 | values_.append(static_cast<const char*>(data), length); |
547 | // Insert hash entry |
548 | hash_table_.Insert(const_cast<HashTableEntry*>(p.first), h, {memo_index}); |
549 | |
550 | on_not_found(memo_index); |
551 | } |
552 | return memo_index; |
553 | } |
554 | |
555 | template <typename Func1, typename Func2> |
556 | int32_t GetOrInsert(const util::string_view& value, Func1&& on_found, |
557 | Func2&& on_not_found) { |
558 | return GetOrInsert(value.data(), static_cast<int32_t>(value.length()), |
559 | std::forward<Func1>(on_found), std::forward<Func2>(on_not_found)); |
560 | } |
561 | |
562 | int32_t GetOrInsert(const void* data, int32_t length) { |
563 | return GetOrInsert(data, length, [](int32_t i) {}, [](int32_t i) {}); |
564 | } |
565 | |
566 | int32_t GetOrInsert(const util::string_view& value) { |
567 | return GetOrInsert(value.data(), static_cast<int32_t>(value.length())); |
568 | } |
569 | |
570 | int32_t GetOrInsert(const std::string& value) { |
571 | return GetOrInsert(value.data(), static_cast<int32_t>(value.length())); |
572 | } |
573 | |
574 | // The number of entries in the memo table |
575 | // (which is also 1 + the largest memo index) |
576 | int32_t size() const { return static_cast<int32_t>(hash_table_.size()); } |
577 | |
578 | int32_t values_size() const { return static_cast<int32_t>(values_.size()); } |
579 | |
580 | const uint8_t* values_data() const { |
581 | return reinterpret_cast<const uint8_t*>(values_.data()); |
582 | } |
583 | |
584 | // Copy (n + 1) offsets starting from index `start` into `out_data` |
585 | template <class Offset> |
586 | void CopyOffsets(int32_t start, Offset* out_data) const { |
587 | auto delta = offsets_[start]; |
588 | for (uint32_t i = start; i < offsets_.size(); ++i) { |
589 | auto adjusted_offset = offsets_[i] - delta; |
590 | auto cast_offset = static_cast<Offset>(adjusted_offset); |
591 | assert(static_cast<int32_t>(cast_offset) == adjusted_offset); // avoid truncation |
592 | *out_data++ = cast_offset; |
593 | } |
594 | } |
595 | |
596 | template <class Offset> |
597 | void CopyOffsets(Offset* out_data) const { |
598 | CopyOffsets(0, out_data); |
599 | } |
600 | |
601 | // Copy values starting from index `start` into `out_data` |
602 | void CopyValues(int32_t start, uint8_t* out_data) const { |
603 | CopyValues(start, -1, out_data); |
604 | } |
605 | |
606 | // Same as above, but check output size in debug mode |
607 | void CopyValues(int32_t start, int64_t out_size, uint8_t* out_data) const { |
608 | int32_t offset = offsets_[start]; |
609 | auto length = values_.size() - static_cast<size_t>(offset); |
610 | if (out_size != -1) { |
611 | assert(static_cast<int64_t>(length) == out_size); |
612 | } |
613 | memcpy(out_data, values_.data() + offset, length); |
614 | } |
615 | |
616 | void CopyValues(uint8_t* out_data) const { CopyValues(0, -1, out_data); } |
617 | |
618 | void CopyValues(int64_t out_size, uint8_t* out_data) const { |
619 | CopyValues(0, out_size, out_data); |
620 | } |
621 | |
622 | // Visit the stored values in insertion order. |
623 | // The visitor function should have the signature `void(util::string_view)` |
624 | // or `void(const util::string_view&)`. |
625 | template <typename VisitFunc> |
626 | void VisitValues(int32_t start, VisitFunc&& visit) const { |
627 | for (uint32_t i = start; i < offsets_.size() - 1; ++i) { |
628 | visit( |
629 | util::string_view(values_.data() + offsets_[i], offsets_[i + 1] - offsets_[i])); |
630 | } |
631 | } |
632 | |
633 | protected: |
634 | struct Payload { |
635 | int32_t memo_index; |
636 | }; |
637 | |
638 | using HashTableType = HashTable<Payload>; |
639 | using HashTableEntry = typename HashTable<Payload>::Entry; |
640 | HashTableType hash_table_; |
641 | |
642 | std::vector<int32_t> offsets_; |
643 | std::string values_; |
644 | |
645 | std::pair<const HashTableEntry*, bool> Lookup(hash_t h, const void* data, |
646 | int32_t length) const { |
647 | auto cmp_func = [=](const Payload* payload) { |
648 | int32_t start, stop; |
649 | start = offsets_[payload->memo_index]; |
650 | stop = offsets_[payload->memo_index + 1]; |
651 | return length == stop - start && memcmp(data, values_.data() + start, length) == 0; |
652 | }; |
653 | return hash_table_.Lookup(h, cmp_func); |
654 | } |
655 | }; |
656 | |
657 | template <typename T, typename Enable = void> |
658 | struct HashTraits {}; |
659 | |
660 | template <> |
661 | struct HashTraits<BooleanType> { |
662 | using MemoTableType = SmallScalarMemoTable<bool>; |
663 | }; |
664 | |
665 | template <typename T> |
666 | struct HashTraits<T, enable_if_8bit_int<T>> { |
667 | using c_type = typename T::c_type; |
668 | using MemoTableType = SmallScalarMemoTable<typename T::c_type>; |
669 | }; |
670 | |
671 | template <typename T> |
672 | struct HashTraits< |
673 | T, typename std::enable_if<has_c_type<T>::value && !is_8bit_int<T>::value>::type> { |
674 | using c_type = typename T::c_type; |
675 | using MemoTableType = ScalarMemoTable<c_type, HashTable>; |
676 | }; |
677 | |
678 | template <typename T> |
679 | struct HashTraits<T, enable_if_binary<T>> { |
680 | using MemoTableType = BinaryMemoTable; |
681 | }; |
682 | |
683 | template <typename T> |
684 | struct HashTraits<T, enable_if_fixed_size_binary<T>> { |
685 | using MemoTableType = BinaryMemoTable; |
686 | }; |
687 | |
688 | template <typename T, typename Enable = void> |
689 | struct DictionaryTraits {}; |
690 | |
691 | template <> |
692 | struct DictionaryTraits<BooleanType> { |
693 | using T = BooleanType; |
694 | using MemoTableType = typename HashTraits<T>::MemoTableType; |
695 | |
696 | static Status GetDictionaryArrayData(MemoryPool* pool, |
697 | const std::shared_ptr<DataType>& type, |
698 | const MemoTableType& memo_table, |
699 | int64_t start_offset, |
700 | std::shared_ptr<ArrayData>* out) { |
701 | BooleanBuilder builder(pool); |
702 | const auto& bool_values = memo_table.values(); |
703 | auto it = bool_values.begin() + start_offset; |
704 | for (; it != bool_values.end(); ++it) { |
705 | RETURN_NOT_OK(builder.Append(*it)); |
706 | } |
707 | return builder.FinishInternal(out); |
708 | } |
709 | }; |
710 | |
711 | template <typename T> |
712 | struct DictionaryTraits<T, enable_if_has_c_type<T>> { |
713 | using c_type = typename T::c_type; |
714 | using MemoTableType = typename HashTraits<T>::MemoTableType; |
715 | |
716 | static Status GetDictionaryArrayData(MemoryPool* pool, |
717 | const std::shared_ptr<DataType>& type, |
718 | const MemoTableType& memo_table, |
719 | int64_t start_offset, |
720 | std::shared_ptr<ArrayData>* out) { |
721 | std::shared_ptr<Buffer> dict_buffer; |
722 | auto dict_length = static_cast<int64_t>(memo_table.size()) - start_offset; |
723 | // This makes a copy, but we assume a dictionary array is usually small |
724 | // compared to the size of the dictionary-using array. |
725 | // (also, copying the dictionary values is cheap compared to the cost |
726 | // of building the memo table) |
727 | RETURN_NOT_OK( |
728 | AllocateBuffer(pool, TypeTraits<T>::bytes_required(dict_length), &dict_buffer)); |
729 | memo_table.CopyValues(static_cast<int32_t>(start_offset), |
730 | reinterpret_cast<c_type*>(dict_buffer->mutable_data())); |
731 | *out = ArrayData::Make(type, dict_length, {nullptr, dict_buffer}, 0 /* null_count */); |
732 | return Status::OK(); |
733 | } |
734 | }; |
735 | |
736 | template <typename T> |
737 | struct DictionaryTraits<T, enable_if_binary<T>> { |
738 | using MemoTableType = typename HashTraits<T>::MemoTableType; |
739 | |
740 | static Status GetDictionaryArrayData(MemoryPool* pool, |
741 | const std::shared_ptr<DataType>& type, |
742 | const MemoTableType& memo_table, |
743 | int64_t start_offset, |
744 | std::shared_ptr<ArrayData>* out) { |
745 | std::shared_ptr<Buffer> dict_offsets; |
746 | std::shared_ptr<Buffer> dict_data; |
747 | |
748 | // Create the offsets buffer |
749 | auto dict_length = static_cast<int64_t>(memo_table.size() - start_offset); |
750 | RETURN_NOT_OK(AllocateBuffer( |
751 | pool, TypeTraits<Int32Type>::bytes_required(dict_length + 1), &dict_offsets)); |
752 | auto raw_offsets = reinterpret_cast<int32_t*>(dict_offsets->mutable_data()); |
753 | memo_table.CopyOffsets(static_cast<int32_t>(start_offset), raw_offsets); |
754 | |
755 | // Create the data buffer |
756 | DCHECK_EQ(raw_offsets[0], 0); |
757 | RETURN_NOT_OK(AllocateBuffer(pool, raw_offsets[dict_length], &dict_data)); |
758 | memo_table.CopyValues(static_cast<int32_t>(start_offset), dict_data->size(), |
759 | dict_data->mutable_data()); |
760 | |
761 | *out = ArrayData::Make(type, dict_length, {nullptr, dict_offsets, dict_data}, |
762 | 0 /* null_count */); |
763 | return Status::OK(); |
764 | } |
765 | }; |
766 | |
767 | template <typename T> |
768 | struct DictionaryTraits<T, enable_if_fixed_size_binary<T>> { |
769 | using MemoTableType = typename HashTraits<T>::MemoTableType; |
770 | |
771 | static Status GetDictionaryArrayData(MemoryPool* pool, |
772 | const std::shared_ptr<DataType>& type, |
773 | const MemoTableType& memo_table, |
774 | int64_t start_offset, |
775 | std::shared_ptr<ArrayData>* out) { |
776 | const T& concrete_type = internal::checked_cast<const T&>(*type); |
777 | std::shared_ptr<Buffer> dict_data; |
778 | |
779 | // Create the data buffer |
780 | auto dict_length = static_cast<int64_t>(memo_table.size() - start_offset); |
781 | auto data_length = dict_length * concrete_type.byte_width(); |
782 | RETURN_NOT_OK(AllocateBuffer(pool, data_length, &dict_data)); |
783 | memo_table.CopyValues(static_cast<int32_t>(start_offset), data_length, |
784 | dict_data->mutable_data()); |
785 | |
786 | *out = ArrayData::Make(type, dict_length, {nullptr, dict_data}, 0 /* null_count */); |
787 | return Status::OK(); |
788 | } |
789 | }; |
790 | |
791 | } // namespace internal |
792 | } // namespace arrow |
793 | |
794 | #endif // ARROW_UTIL_HASHING_H |
795 | |