1 | /* |
2 | * Copyright 2017-present Facebook, Inc. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #pragma once |
17 | |
18 | #include <folly/Optional.h> |
19 | #include <folly/concurrency/detail/ConcurrentHashMap-detail.h> |
20 | #include <folly/synchronization/Hazptr.h> |
21 | #include <atomic> |
22 | #include <mutex> |
23 | |
24 | namespace folly { |
25 | |
26 | /** |
27 | * Based on Java's ConcurrentHashMap |
28 | * |
29 | * Readers are always wait-free. |
30 | * Writers are sharded, but take a lock. |
31 | * |
32 | * The interface is as close to std::unordered_map as possible, but there |
33 | * are a handful of changes: |
34 | * |
35 | * * Iterators hold hazard pointers to the returned elements. Elements can only |
36 | * be accessed while Iterators are still valid! |
37 | * |
38 | * * Therefore operator[] and at() return copies, since they do not |
39 | * return an iterator. The returned value is const, to remind you |
40 | * that changes do not affect the value in the map. |
41 | * |
42 | * * erase() calls the hash function, and may fail if the hash |
43 | * function throws an exception. |
44 | * |
45 | * * clear() initializes new segments, and is not noexcept. |
46 | * |
47 | * * The interface adds assign_if_equal, since find() doesn't take a lock. |
48 | * |
49 | * * Only const version of find() is supported, and const iterators. |
50 | * Mutation must use functions provided, like assign(). |
51 | * |
52 | * * iteration iterates over all the buckets in the table, unlike |
53 | * std::unordered_map which iterates over a linked list of elements. |
54 | * If the table is sparse, this may be more expensive. |
55 | * |
56 | * * rehash policy is a power of two, using supplied factor. |
57 | * |
58 | * * Allocator must be stateless. |
59 | * |
60 | * * ValueTypes without copy constructors will work, but pessimize the |
61 | * implementation. |
62 | * |
63 | * Comparisons: |
64 | * Single-threaded performance is extremely similar to std::unordered_map. |
65 | * |
66 | * Multithreaded performance beats anything except the lock-free |
67 | * atomic maps (AtomicUnorderedMap, AtomicHashMap), BUT only |
68 | * if you can perfectly size the atomic maps, and you don't |
69 | * need erase(). If you don't know the size in advance or |
70 | * your workload frequently calls erase(), this is the |
71 | * better choice. |
72 | */ |
73 | |
74 | template < |
75 | typename KeyType, |
76 | typename ValueType, |
77 | typename HashFn = std::hash<KeyType>, |
78 | typename KeyEqual = std::equal_to<KeyType>, |
79 | typename Allocator = std::allocator<uint8_t>, |
80 | uint8_t ShardBits = 8, |
81 | template <typename> class Atom = std::atomic, |
82 | class Mutex = std::mutex> |
83 | class ConcurrentHashMap { |
84 | using SegmentT = detail::ConcurrentHashMapSegment< |
85 | KeyType, |
86 | ValueType, |
87 | ShardBits, |
88 | HashFn, |
89 | KeyEqual, |
90 | Allocator, |
91 | Atom, |
92 | Mutex>; |
93 | static constexpr uint64_t NumShards = (1 << ShardBits); |
94 | // Slightly higher than 1.0, in case hashing to shards isn't |
95 | // perfectly balanced, reserve(size) will still work without |
96 | // rehashing. |
97 | float load_factor_ = 1.05; |
98 | |
99 | public: |
100 | class ConstIterator; |
101 | |
102 | typedef KeyType key_type; |
103 | typedef ValueType mapped_type; |
104 | typedef std::pair<const KeyType, ValueType> value_type; |
105 | typedef std::size_t size_type; |
106 | typedef HashFn hasher; |
107 | typedef KeyEqual key_equal; |
108 | typedef ConstIterator const_iterator; |
109 | |
110 | /* |
111 | * Construct a ConcurrentHashMap with 1 << ShardBits shards, size |
112 | * and max_size given. Both size and max_size will be rounded up to |
113 | * the next power of two, if they are not already a power of two, so |
114 | * that we can index in to Shards efficiently. |
115 | * |
116 | * Insertion functions will throw bad_alloc if max_size is exceeded. |
117 | */ |
118 | explicit ConcurrentHashMap(size_t size = 8, size_t max_size = 0) { |
119 | size_ = folly::nextPowTwo(size); |
120 | if (max_size != 0) { |
121 | max_size_ = folly::nextPowTwo(max_size); |
122 | } |
123 | CHECK(max_size_ == 0 || max_size_ >= size_); |
124 | for (uint64_t i = 0; i < NumShards; i++) { |
125 | segments_[i].store(nullptr, std::memory_order_relaxed); |
126 | } |
127 | } |
128 | |
129 | ConcurrentHashMap(ConcurrentHashMap&& o) noexcept |
130 | : size_(o.size_), max_size_(o.max_size_) { |
131 | for (uint64_t i = 0; i < NumShards; i++) { |
132 | segments_[i].store( |
133 | o.segments_[i].load(std::memory_order_relaxed), |
134 | std::memory_order_relaxed); |
135 | o.segments_[i].store(nullptr, std::memory_order_relaxed); |
136 | } |
137 | batch_.store(o.batch(), std::memory_order_relaxed); |
138 | o.batch_.store(nullptr, std::memory_order_relaxed); |
139 | } |
140 | |
141 | ConcurrentHashMap& operator=(ConcurrentHashMap&& o) { |
142 | for (uint64_t i = 0; i < NumShards; i++) { |
143 | auto seg = segments_[i].load(std::memory_order_relaxed); |
144 | if (seg) { |
145 | seg->~SegmentT(); |
146 | Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT)); |
147 | } |
148 | segments_[i].store( |
149 | o.segments_[i].load(std::memory_order_relaxed), |
150 | std::memory_order_relaxed); |
151 | o.segments_[i].store(nullptr, std::memory_order_relaxed); |
152 | } |
153 | size_ = o.size_; |
154 | max_size_ = o.max_size_; |
155 | batch_shutdown_cleanup(); |
156 | batch_.store(o.batch(), std::memory_order_relaxed); |
157 | o.batch_.store(nullptr, std::memory_order_relaxed); |
158 | return *this; |
159 | } |
160 | |
161 | ~ConcurrentHashMap() { |
162 | for (uint64_t i = 0; i < NumShards; i++) { |
163 | auto seg = segments_[i].load(std::memory_order_relaxed); |
164 | if (seg) { |
165 | seg->~SegmentT(); |
166 | Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT)); |
167 | } |
168 | } |
169 | batch_shutdown_cleanup(); |
170 | } |
171 | |
172 | bool empty() const noexcept { |
173 | for (uint64_t i = 0; i < NumShards; i++) { |
174 | auto seg = segments_[i].load(std::memory_order_acquire); |
175 | if (seg) { |
176 | if (!seg->empty()) { |
177 | return false; |
178 | } |
179 | } |
180 | } |
181 | return true; |
182 | } |
183 | |
184 | ConstIterator find(const KeyType& k) const { |
185 | auto segment = pickSegment(k); |
186 | ConstIterator res(this, segment); |
187 | auto seg = segments_[segment].load(std::memory_order_acquire); |
188 | if (!seg || !seg->find(res.it_, k)) { |
189 | res.segment_ = NumShards; |
190 | } |
191 | return res; |
192 | } |
193 | |
194 | ConstIterator cend() const noexcept { |
195 | return ConstIterator(NumShards); |
196 | } |
197 | |
198 | ConstIterator cbegin() const noexcept { |
199 | return ConstIterator(this); |
200 | } |
201 | |
202 | ConstIterator end() const noexcept { |
203 | return cend(); |
204 | } |
205 | |
206 | ConstIterator begin() const noexcept { |
207 | return cbegin(); |
208 | } |
209 | |
210 | std::pair<ConstIterator, bool> insert( |
211 | std::pair<key_type, mapped_type>&& foo) { |
212 | auto segment = pickSegment(foo.first); |
213 | std::pair<ConstIterator, bool> res( |
214 | std::piecewise_construct, |
215 | std::forward_as_tuple(this, segment), |
216 | std::forward_as_tuple(false)); |
217 | res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo)); |
218 | return res; |
219 | } |
220 | |
221 | template <typename Key, typename Value> |
222 | std::pair<ConstIterator, bool> insert(Key&& k, Value&& v) { |
223 | auto segment = pickSegment(k); |
224 | std::pair<ConstIterator, bool> res( |
225 | std::piecewise_construct, |
226 | std::forward_as_tuple(this, segment), |
227 | std::forward_as_tuple(false)); |
228 | res.second = ensureSegment(segment)->insert( |
229 | res.first.it_, std::forward<Key>(k), std::forward<Value>(v)); |
230 | return res; |
231 | } |
232 | |
233 | template <typename Key, typename... Args> |
234 | std::pair<ConstIterator, bool> try_emplace(Key&& k, Args&&... args) { |
235 | auto segment = pickSegment(k); |
236 | std::pair<ConstIterator, bool> res( |
237 | std::piecewise_construct, |
238 | std::forward_as_tuple(this, segment), |
239 | std::forward_as_tuple(false)); |
240 | res.second = ensureSegment(segment)->try_emplace( |
241 | res.first.it_, std::forward<Key>(k), std::forward<Args>(args)...); |
242 | return res; |
243 | } |
244 | |
245 | template <typename... Args> |
246 | std::pair<ConstIterator, bool> emplace(Args&&... args) { |
247 | using Node = typename SegmentT::Node; |
248 | auto node = (Node*)Allocator().allocate(sizeof(Node)); |
249 | new (node) Node(ensureBatch(), std::forward<Args>(args)...); |
250 | auto segment = pickSegment(node->getItem().first); |
251 | std::pair<ConstIterator, bool> res( |
252 | std::piecewise_construct, |
253 | std::forward_as_tuple(this, segment), |
254 | std::forward_as_tuple(false)); |
255 | res.second = ensureSegment(segment)->emplace( |
256 | res.first.it_, node->getItem().first, node); |
257 | if (!res.second) { |
258 | node->~Node(); |
259 | Allocator().deallocate((uint8_t*)node, sizeof(Node)); |
260 | } |
261 | return res; |
262 | } |
263 | |
264 | template <typename Key, typename Value> |
265 | std::pair<ConstIterator, bool> insert_or_assign(Key&& k, Value&& v) { |
266 | auto segment = pickSegment(k); |
267 | std::pair<ConstIterator, bool> res( |
268 | std::piecewise_construct, |
269 | std::forward_as_tuple(this, segment), |
270 | std::forward_as_tuple(false)); |
271 | res.second = ensureSegment(segment)->insert_or_assign( |
272 | res.first.it_, std::forward<Key>(k), std::forward<Value>(v)); |
273 | return res; |
274 | } |
275 | |
276 | template <typename Key, typename Value> |
277 | folly::Optional<ConstIterator> assign(Key&& k, Value&& v) { |
278 | auto segment = pickSegment(k); |
279 | ConstIterator res(this, segment); |
280 | auto seg = segments_[segment].load(std::memory_order_acquire); |
281 | if (!seg) { |
282 | return none; |
283 | } else { |
284 | auto r = |
285 | seg->assign(res.it_, std::forward<Key>(k), std::forward<Value>(v)); |
286 | if (!r) { |
287 | return none; |
288 | } |
289 | } |
290 | return std::move(res); |
291 | } |
292 | |
293 | // Assign to desired if and only if key k is equal to expected |
294 | template <typename Key, typename Value> |
295 | folly::Optional<ConstIterator> |
296 | assign_if_equal(Key&& k, const ValueType& expected, Value&& desired) { |
297 | auto segment = pickSegment(k); |
298 | ConstIterator res(this, segment); |
299 | auto seg = segments_[segment].load(std::memory_order_acquire); |
300 | if (!seg) { |
301 | return none; |
302 | } else { |
303 | auto r = seg->assign_if_equal( |
304 | res.it_, |
305 | std::forward<Key>(k), |
306 | expected, |
307 | std::forward<Value>(desired)); |
308 | if (!r) { |
309 | return none; |
310 | } |
311 | } |
312 | return std::move(res); |
313 | } |
314 | |
315 | // Copying wrappers around insert and find. |
316 | // Only available for copyable types. |
317 | const ValueType operator[](const KeyType& key) { |
318 | auto item = insert(key, ValueType()); |
319 | return item.first->second; |
320 | } |
321 | |
322 | const ValueType at(const KeyType& key) const { |
323 | auto item = find(key); |
324 | if (item == cend()) { |
325 | throw std::out_of_range("at(): value out of range" ); |
326 | } |
327 | return item->second; |
328 | } |
329 | |
330 | // TODO update assign interface, operator[], at |
331 | |
332 | size_type erase(const key_type& k) { |
333 | auto segment = pickSegment(k); |
334 | auto seg = segments_[segment].load(std::memory_order_acquire); |
335 | if (!seg) { |
336 | return 0; |
337 | } else { |
338 | return seg->erase(k); |
339 | } |
340 | } |
341 | |
342 | // Calls the hash function, and therefore may throw. |
343 | ConstIterator erase(ConstIterator& pos) { |
344 | auto segment = pickSegment(pos->first); |
345 | ConstIterator res(this, segment); |
346 | ensureSegment(segment)->erase(res.it_, pos.it_); |
347 | res.next(); // May point to segment end, and need to advance. |
348 | return res; |
349 | } |
350 | |
351 | // Erase if and only if key k is equal to expected |
352 | size_type erase_if_equal(const key_type& k, const ValueType& expected) { |
353 | auto segment = pickSegment(k); |
354 | auto seg = segments_[segment].load(std::memory_order_acquire); |
355 | if (!seg) { |
356 | return 0; |
357 | } |
358 | return seg->erase_if_equal(k, expected); |
359 | } |
360 | |
361 | // NOT noexcept, initializes new shard segments vs. |
362 | void clear() { |
363 | for (uint64_t i = 0; i < NumShards; i++) { |
364 | auto seg = segments_[i].load(std::memory_order_acquire); |
365 | if (seg) { |
366 | seg->clear(); |
367 | } |
368 | } |
369 | } |
370 | |
371 | void reserve(size_t count) { |
372 | count = count >> ShardBits; |
373 | for (uint64_t i = 0; i < NumShards; i++) { |
374 | auto seg = segments_[i].load(std::memory_order_acquire); |
375 | if (seg) { |
376 | seg->rehash(count); |
377 | } |
378 | } |
379 | } |
380 | |
381 | // This is a rolling size, and is not exact at any moment in time. |
382 | size_t size() const noexcept { |
383 | size_t res = 0; |
384 | for (uint64_t i = 0; i < NumShards; i++) { |
385 | auto seg = segments_[i].load(std::memory_order_acquire); |
386 | if (seg) { |
387 | res += seg->size(); |
388 | } |
389 | } |
390 | return res; |
391 | } |
392 | |
393 | float max_load_factor() const { |
394 | return load_factor_; |
395 | } |
396 | |
397 | void max_load_factor(float factor) { |
398 | for (uint64_t i = 0; i < NumShards; i++) { |
399 | auto seg = segments_[i].load(std::memory_order_acquire); |
400 | if (seg) { |
401 | seg->max_load_factor(factor); |
402 | } |
403 | } |
404 | } |
405 | |
406 | class ConstIterator { |
407 | public: |
408 | friend class ConcurrentHashMap; |
409 | |
410 | const value_type& operator*() const { |
411 | return *it_; |
412 | } |
413 | |
414 | const value_type* operator->() const { |
415 | return &*it_; |
416 | } |
417 | |
418 | ConstIterator& operator++() { |
419 | ++it_; |
420 | next(); |
421 | return *this; |
422 | } |
423 | |
424 | bool operator==(const ConstIterator& o) const { |
425 | return it_ == o.it_ && segment_ == o.segment_; |
426 | } |
427 | |
428 | bool operator!=(const ConstIterator& o) const { |
429 | return !(*this == o); |
430 | } |
431 | |
432 | ConstIterator& operator=(const ConstIterator& o) = delete; |
433 | |
434 | ConstIterator& operator=(ConstIterator&& o) noexcept { |
435 | if (this != &o) { |
436 | it_ = std::move(o.it_); |
437 | segment_ = std::exchange(o.segment_, uint64_t(NumShards)); |
438 | parent_ = std::exchange(o.parent_, nullptr); |
439 | } |
440 | return *this; |
441 | } |
442 | |
443 | ConstIterator(const ConstIterator& o) = delete; |
444 | |
445 | ConstIterator(ConstIterator&& o) noexcept |
446 | : it_(std::move(o.it_)), |
447 | segment_(std::exchange(o.segment_, uint64_t(NumShards))), |
448 | parent_(std::exchange(o.parent_, nullptr)) {} |
449 | |
450 | ConstIterator(const ConcurrentHashMap* parent, uint64_t segment) |
451 | : segment_(segment), parent_(parent) {} |
452 | |
453 | private: |
454 | // cbegin iterator |
455 | explicit ConstIterator(const ConcurrentHashMap* parent) |
456 | : it_(parent->ensureSegment(0)->cbegin()), |
457 | segment_(0), |
458 | parent_(parent) { |
459 | // Always iterate to the first element, could be in any shard. |
460 | next(); |
461 | } |
462 | |
463 | // cend iterator |
464 | explicit ConstIterator(uint64_t shards) : it_(nullptr), segment_(shards) {} |
465 | |
466 | void next() { |
467 | while (segment_ < parent_->NumShards && |
468 | it_ == parent_->ensureSegment(segment_)->cend()) { |
469 | SegmentT* seg{nullptr}; |
470 | while (!seg) { |
471 | segment_++; |
472 | if (segment_ < parent_->NumShards) { |
473 | seg = parent_->segments_[segment_].load(std::memory_order_acquire); |
474 | if (!seg) { |
475 | continue; |
476 | } |
477 | it_ = seg->cbegin(); |
478 | } |
479 | break; |
480 | } |
481 | } |
482 | } |
483 | |
484 | typename SegmentT::Iterator it_; |
485 | uint64_t segment_; |
486 | const ConcurrentHashMap* parent_; |
487 | }; |
488 | |
489 | private: |
490 | uint64_t pickSegment(const KeyType& k) const { |
491 | auto h = HashFn()(k); |
492 | // Use the lowest bits for our shard bits. |
493 | // |
494 | // This works well even if the hash function is biased towards the |
495 | // low bits: The sharding will happen in the segments_ instead of |
496 | // in the segment buckets, so we'll still get write sharding as |
497 | // well. |
498 | // |
499 | // Low-bit bias happens often for std::hash using small numbers, |
500 | // since the integer hash function is the identity function. |
501 | return h & (NumShards - 1); |
502 | } |
503 | |
504 | SegmentT* ensureSegment(uint64_t i) const { |
505 | SegmentT* seg = segments_[i].load(std::memory_order_acquire); |
506 | if (!seg) { |
507 | auto b = ensureBatch(); |
508 | SegmentT* newseg = (SegmentT*)Allocator().allocate(sizeof(SegmentT)); |
509 | newseg = new (newseg) |
510 | SegmentT(size_ >> ShardBits, load_factor_, max_size_ >> ShardBits, b); |
511 | if (!segments_[i].compare_exchange_strong(seg, newseg)) { |
512 | // seg is updated with new value, delete ours. |
513 | newseg->~SegmentT(); |
514 | Allocator().deallocate((uint8_t*)newseg, sizeof(SegmentT)); |
515 | } else { |
516 | seg = newseg; |
517 | } |
518 | } |
519 | return seg; |
520 | } |
521 | |
522 | hazptr_obj_batch<Atom>* batch() const noexcept { |
523 | return batch_.load(std::memory_order_acquire); |
524 | } |
525 | |
526 | hazptr_obj_batch<Atom>* ensureBatch() const { |
527 | auto b = batch(); |
528 | if (!b) { |
529 | auto storage = Allocator().allocate(sizeof(hazptr_obj_batch<Atom>)); |
530 | auto newbatch = new (storage) hazptr_obj_batch<Atom>(); |
531 | if (batch_.compare_exchange_strong(b, newbatch)) { |
532 | b = newbatch; |
533 | } else { |
534 | newbatch->shutdown_and_reclaim(); |
535 | newbatch->~hazptr_obj_batch<Atom>(); |
536 | Allocator().deallocate(storage, sizeof(hazptr_obj_batch<Atom>)); |
537 | } |
538 | } |
539 | return b; |
540 | } |
541 | |
542 | void batch_shutdown_cleanup() { |
543 | auto b = batch(); |
544 | if (b) { |
545 | b->shutdown_and_reclaim(); |
546 | hazptr_cleanup_batch_tag(b); |
547 | b->~hazptr_obj_batch<Atom>(); |
548 | Allocator().deallocate((uint8_t*)b, sizeof(hazptr_obj_batch<Atom>)); |
549 | } |
550 | } |
551 | |
552 | mutable Atom<SegmentT*> segments_[NumShards]; |
553 | size_t size_{0}; |
554 | size_t max_size_{0}; |
555 | mutable Atom<hazptr_obj_batch<Atom>*> batch_{nullptr}; |
556 | }; |
557 | |
558 | } // namespace folly |
559 | |