1/*
2 * Copyright 2017-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#pragma once
17
18#include <folly/Optional.h>
19#include <folly/concurrency/detail/ConcurrentHashMap-detail.h>
20#include <folly/synchronization/Hazptr.h>
21#include <atomic>
22#include <mutex>
23
24namespace folly {
25
26/**
27 * Based on Java's ConcurrentHashMap
28 *
29 * Readers are always wait-free.
30 * Writers are sharded, but take a lock.
31 *
32 * The interface is as close to std::unordered_map as possible, but there
33 * are a handful of changes:
34 *
35 * * Iterators hold hazard pointers to the returned elements. Elements can only
36 * be accessed while Iterators are still valid!
37 *
38 * * Therefore operator[] and at() return copies, since they do not
39 * return an iterator. The returned value is const, to remind you
40 * that changes do not affect the value in the map.
41 *
42 * * erase() calls the hash function, and may fail if the hash
43 * function throws an exception.
44 *
45 * * clear() initializes new segments, and is not noexcept.
46 *
47 * * The interface adds assign_if_equal, since find() doesn't take a lock.
48 *
49 * * Only const version of find() is supported, and const iterators.
50 * Mutation must use functions provided, like assign().
51 *
52 * * iteration iterates over all the buckets in the table, unlike
53 * std::unordered_map which iterates over a linked list of elements.
54 * If the table is sparse, this may be more expensive.
55 *
56 * * rehash policy is a power of two, using supplied factor.
57 *
58 * * Allocator must be stateless.
59 *
60 * * ValueTypes without copy constructors will work, but pessimize the
61 * implementation.
62 *
63 * Comparisons:
64 * Single-threaded performance is extremely similar to std::unordered_map.
65 *
66 * Multithreaded performance beats anything except the lock-free
67 * atomic maps (AtomicUnorderedMap, AtomicHashMap), BUT only
68 * if you can perfectly size the atomic maps, and you don't
69 * need erase(). If you don't know the size in advance or
70 * your workload frequently calls erase(), this is the
71 * better choice.
72 */
73
74template <
75 typename KeyType,
76 typename ValueType,
77 typename HashFn = std::hash<KeyType>,
78 typename KeyEqual = std::equal_to<KeyType>,
79 typename Allocator = std::allocator<uint8_t>,
80 uint8_t ShardBits = 8,
81 template <typename> class Atom = std::atomic,
82 class Mutex = std::mutex>
83class ConcurrentHashMap {
84 using SegmentT = detail::ConcurrentHashMapSegment<
85 KeyType,
86 ValueType,
87 ShardBits,
88 HashFn,
89 KeyEqual,
90 Allocator,
91 Atom,
92 Mutex>;
93 static constexpr uint64_t NumShards = (1 << ShardBits);
94 // Slightly higher than 1.0, in case hashing to shards isn't
95 // perfectly balanced, reserve(size) will still work without
96 // rehashing.
97 float load_factor_ = 1.05;
98
99 public:
100 class ConstIterator;
101
102 typedef KeyType key_type;
103 typedef ValueType mapped_type;
104 typedef std::pair<const KeyType, ValueType> value_type;
105 typedef std::size_t size_type;
106 typedef HashFn hasher;
107 typedef KeyEqual key_equal;
108 typedef ConstIterator const_iterator;
109
110 /*
111 * Construct a ConcurrentHashMap with 1 << ShardBits shards, size
112 * and max_size given. Both size and max_size will be rounded up to
113 * the next power of two, if they are not already a power of two, so
114 * that we can index in to Shards efficiently.
115 *
116 * Insertion functions will throw bad_alloc if max_size is exceeded.
117 */
118 explicit ConcurrentHashMap(size_t size = 8, size_t max_size = 0) {
119 size_ = folly::nextPowTwo(size);
120 if (max_size != 0) {
121 max_size_ = folly::nextPowTwo(max_size);
122 }
123 CHECK(max_size_ == 0 || max_size_ >= size_);
124 for (uint64_t i = 0; i < NumShards; i++) {
125 segments_[i].store(nullptr, std::memory_order_relaxed);
126 }
127 }
128
129 ConcurrentHashMap(ConcurrentHashMap&& o) noexcept
130 : size_(o.size_), max_size_(o.max_size_) {
131 for (uint64_t i = 0; i < NumShards; i++) {
132 segments_[i].store(
133 o.segments_[i].load(std::memory_order_relaxed),
134 std::memory_order_relaxed);
135 o.segments_[i].store(nullptr, std::memory_order_relaxed);
136 }
137 batch_.store(o.batch(), std::memory_order_relaxed);
138 o.batch_.store(nullptr, std::memory_order_relaxed);
139 }
140
141 ConcurrentHashMap& operator=(ConcurrentHashMap&& o) {
142 for (uint64_t i = 0; i < NumShards; i++) {
143 auto seg = segments_[i].load(std::memory_order_relaxed);
144 if (seg) {
145 seg->~SegmentT();
146 Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
147 }
148 segments_[i].store(
149 o.segments_[i].load(std::memory_order_relaxed),
150 std::memory_order_relaxed);
151 o.segments_[i].store(nullptr, std::memory_order_relaxed);
152 }
153 size_ = o.size_;
154 max_size_ = o.max_size_;
155 batch_shutdown_cleanup();
156 batch_.store(o.batch(), std::memory_order_relaxed);
157 o.batch_.store(nullptr, std::memory_order_relaxed);
158 return *this;
159 }
160
161 ~ConcurrentHashMap() {
162 for (uint64_t i = 0; i < NumShards; i++) {
163 auto seg = segments_[i].load(std::memory_order_relaxed);
164 if (seg) {
165 seg->~SegmentT();
166 Allocator().deallocate((uint8_t*)seg, sizeof(SegmentT));
167 }
168 }
169 batch_shutdown_cleanup();
170 }
171
172 bool empty() const noexcept {
173 for (uint64_t i = 0; i < NumShards; i++) {
174 auto seg = segments_[i].load(std::memory_order_acquire);
175 if (seg) {
176 if (!seg->empty()) {
177 return false;
178 }
179 }
180 }
181 return true;
182 }
183
184 ConstIterator find(const KeyType& k) const {
185 auto segment = pickSegment(k);
186 ConstIterator res(this, segment);
187 auto seg = segments_[segment].load(std::memory_order_acquire);
188 if (!seg || !seg->find(res.it_, k)) {
189 res.segment_ = NumShards;
190 }
191 return res;
192 }
193
194 ConstIterator cend() const noexcept {
195 return ConstIterator(NumShards);
196 }
197
198 ConstIterator cbegin() const noexcept {
199 return ConstIterator(this);
200 }
201
202 ConstIterator end() const noexcept {
203 return cend();
204 }
205
206 ConstIterator begin() const noexcept {
207 return cbegin();
208 }
209
210 std::pair<ConstIterator, bool> insert(
211 std::pair<key_type, mapped_type>&& foo) {
212 auto segment = pickSegment(foo.first);
213 std::pair<ConstIterator, bool> res(
214 std::piecewise_construct,
215 std::forward_as_tuple(this, segment),
216 std::forward_as_tuple(false));
217 res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo));
218 return res;
219 }
220
221 template <typename Key, typename Value>
222 std::pair<ConstIterator, bool> insert(Key&& k, Value&& v) {
223 auto segment = pickSegment(k);
224 std::pair<ConstIterator, bool> res(
225 std::piecewise_construct,
226 std::forward_as_tuple(this, segment),
227 std::forward_as_tuple(false));
228 res.second = ensureSegment(segment)->insert(
229 res.first.it_, std::forward<Key>(k), std::forward<Value>(v));
230 return res;
231 }
232
233 template <typename Key, typename... Args>
234 std::pair<ConstIterator, bool> try_emplace(Key&& k, Args&&... args) {
235 auto segment = pickSegment(k);
236 std::pair<ConstIterator, bool> res(
237 std::piecewise_construct,
238 std::forward_as_tuple(this, segment),
239 std::forward_as_tuple(false));
240 res.second = ensureSegment(segment)->try_emplace(
241 res.first.it_, std::forward<Key>(k), std::forward<Args>(args)...);
242 return res;
243 }
244
245 template <typename... Args>
246 std::pair<ConstIterator, bool> emplace(Args&&... args) {
247 using Node = typename SegmentT::Node;
248 auto node = (Node*)Allocator().allocate(sizeof(Node));
249 new (node) Node(ensureBatch(), std::forward<Args>(args)...);
250 auto segment = pickSegment(node->getItem().first);
251 std::pair<ConstIterator, bool> res(
252 std::piecewise_construct,
253 std::forward_as_tuple(this, segment),
254 std::forward_as_tuple(false));
255 res.second = ensureSegment(segment)->emplace(
256 res.first.it_, node->getItem().first, node);
257 if (!res.second) {
258 node->~Node();
259 Allocator().deallocate((uint8_t*)node, sizeof(Node));
260 }
261 return res;
262 }
263
264 template <typename Key, typename Value>
265 std::pair<ConstIterator, bool> insert_or_assign(Key&& k, Value&& v) {
266 auto segment = pickSegment(k);
267 std::pair<ConstIterator, bool> res(
268 std::piecewise_construct,
269 std::forward_as_tuple(this, segment),
270 std::forward_as_tuple(false));
271 res.second = ensureSegment(segment)->insert_or_assign(
272 res.first.it_, std::forward<Key>(k), std::forward<Value>(v));
273 return res;
274 }
275
276 template <typename Key, typename Value>
277 folly::Optional<ConstIterator> assign(Key&& k, Value&& v) {
278 auto segment = pickSegment(k);
279 ConstIterator res(this, segment);
280 auto seg = segments_[segment].load(std::memory_order_acquire);
281 if (!seg) {
282 return none;
283 } else {
284 auto r =
285 seg->assign(res.it_, std::forward<Key>(k), std::forward<Value>(v));
286 if (!r) {
287 return none;
288 }
289 }
290 return std::move(res);
291 }
292
293 // Assign to desired if and only if key k is equal to expected
294 template <typename Key, typename Value>
295 folly::Optional<ConstIterator>
296 assign_if_equal(Key&& k, const ValueType& expected, Value&& desired) {
297 auto segment = pickSegment(k);
298 ConstIterator res(this, segment);
299 auto seg = segments_[segment].load(std::memory_order_acquire);
300 if (!seg) {
301 return none;
302 } else {
303 auto r = seg->assign_if_equal(
304 res.it_,
305 std::forward<Key>(k),
306 expected,
307 std::forward<Value>(desired));
308 if (!r) {
309 return none;
310 }
311 }
312 return std::move(res);
313 }
314
315 // Copying wrappers around insert and find.
316 // Only available for copyable types.
317 const ValueType operator[](const KeyType& key) {
318 auto item = insert(key, ValueType());
319 return item.first->second;
320 }
321
322 const ValueType at(const KeyType& key) const {
323 auto item = find(key);
324 if (item == cend()) {
325 throw std::out_of_range("at(): value out of range");
326 }
327 return item->second;
328 }
329
330 // TODO update assign interface, operator[], at
331
332 size_type erase(const key_type& k) {
333 auto segment = pickSegment(k);
334 auto seg = segments_[segment].load(std::memory_order_acquire);
335 if (!seg) {
336 return 0;
337 } else {
338 return seg->erase(k);
339 }
340 }
341
342 // Calls the hash function, and therefore may throw.
343 ConstIterator erase(ConstIterator& pos) {
344 auto segment = pickSegment(pos->first);
345 ConstIterator res(this, segment);
346 ensureSegment(segment)->erase(res.it_, pos.it_);
347 res.next(); // May point to segment end, and need to advance.
348 return res;
349 }
350
351 // Erase if and only if key k is equal to expected
352 size_type erase_if_equal(const key_type& k, const ValueType& expected) {
353 auto segment = pickSegment(k);
354 auto seg = segments_[segment].load(std::memory_order_acquire);
355 if (!seg) {
356 return 0;
357 }
358 return seg->erase_if_equal(k, expected);
359 }
360
361 // NOT noexcept, initializes new shard segments vs.
362 void clear() {
363 for (uint64_t i = 0; i < NumShards; i++) {
364 auto seg = segments_[i].load(std::memory_order_acquire);
365 if (seg) {
366 seg->clear();
367 }
368 }
369 }
370
371 void reserve(size_t count) {
372 count = count >> ShardBits;
373 for (uint64_t i = 0; i < NumShards; i++) {
374 auto seg = segments_[i].load(std::memory_order_acquire);
375 if (seg) {
376 seg->rehash(count);
377 }
378 }
379 }
380
381 // This is a rolling size, and is not exact at any moment in time.
382 size_t size() const noexcept {
383 size_t res = 0;
384 for (uint64_t i = 0; i < NumShards; i++) {
385 auto seg = segments_[i].load(std::memory_order_acquire);
386 if (seg) {
387 res += seg->size();
388 }
389 }
390 return res;
391 }
392
393 float max_load_factor() const {
394 return load_factor_;
395 }
396
397 void max_load_factor(float factor) {
398 for (uint64_t i = 0; i < NumShards; i++) {
399 auto seg = segments_[i].load(std::memory_order_acquire);
400 if (seg) {
401 seg->max_load_factor(factor);
402 }
403 }
404 }
405
406 class ConstIterator {
407 public:
408 friend class ConcurrentHashMap;
409
410 const value_type& operator*() const {
411 return *it_;
412 }
413
414 const value_type* operator->() const {
415 return &*it_;
416 }
417
418 ConstIterator& operator++() {
419 ++it_;
420 next();
421 return *this;
422 }
423
424 bool operator==(const ConstIterator& o) const {
425 return it_ == o.it_ && segment_ == o.segment_;
426 }
427
428 bool operator!=(const ConstIterator& o) const {
429 return !(*this == o);
430 }
431
432 ConstIterator& operator=(const ConstIterator& o) = delete;
433
434 ConstIterator& operator=(ConstIterator&& o) noexcept {
435 if (this != &o) {
436 it_ = std::move(o.it_);
437 segment_ = std::exchange(o.segment_, uint64_t(NumShards));
438 parent_ = std::exchange(o.parent_, nullptr);
439 }
440 return *this;
441 }
442
443 ConstIterator(const ConstIterator& o) = delete;
444
445 ConstIterator(ConstIterator&& o) noexcept
446 : it_(std::move(o.it_)),
447 segment_(std::exchange(o.segment_, uint64_t(NumShards))),
448 parent_(std::exchange(o.parent_, nullptr)) {}
449
450 ConstIterator(const ConcurrentHashMap* parent, uint64_t segment)
451 : segment_(segment), parent_(parent) {}
452
453 private:
454 // cbegin iterator
455 explicit ConstIterator(const ConcurrentHashMap* parent)
456 : it_(parent->ensureSegment(0)->cbegin()),
457 segment_(0),
458 parent_(parent) {
459 // Always iterate to the first element, could be in any shard.
460 next();
461 }
462
463 // cend iterator
464 explicit ConstIterator(uint64_t shards) : it_(nullptr), segment_(shards) {}
465
466 void next() {
467 while (segment_ < parent_->NumShards &&
468 it_ == parent_->ensureSegment(segment_)->cend()) {
469 SegmentT* seg{nullptr};
470 while (!seg) {
471 segment_++;
472 if (segment_ < parent_->NumShards) {
473 seg = parent_->segments_[segment_].load(std::memory_order_acquire);
474 if (!seg) {
475 continue;
476 }
477 it_ = seg->cbegin();
478 }
479 break;
480 }
481 }
482 }
483
484 typename SegmentT::Iterator it_;
485 uint64_t segment_;
486 const ConcurrentHashMap* parent_;
487 };
488
489 private:
490 uint64_t pickSegment(const KeyType& k) const {
491 auto h = HashFn()(k);
492 // Use the lowest bits for our shard bits.
493 //
494 // This works well even if the hash function is biased towards the
495 // low bits: The sharding will happen in the segments_ instead of
496 // in the segment buckets, so we'll still get write sharding as
497 // well.
498 //
499 // Low-bit bias happens often for std::hash using small numbers,
500 // since the integer hash function is the identity function.
501 return h & (NumShards - 1);
502 }
503
504 SegmentT* ensureSegment(uint64_t i) const {
505 SegmentT* seg = segments_[i].load(std::memory_order_acquire);
506 if (!seg) {
507 auto b = ensureBatch();
508 SegmentT* newseg = (SegmentT*)Allocator().allocate(sizeof(SegmentT));
509 newseg = new (newseg)
510 SegmentT(size_ >> ShardBits, load_factor_, max_size_ >> ShardBits, b);
511 if (!segments_[i].compare_exchange_strong(seg, newseg)) {
512 // seg is updated with new value, delete ours.
513 newseg->~SegmentT();
514 Allocator().deallocate((uint8_t*)newseg, sizeof(SegmentT));
515 } else {
516 seg = newseg;
517 }
518 }
519 return seg;
520 }
521
522 hazptr_obj_batch<Atom>* batch() const noexcept {
523 return batch_.load(std::memory_order_acquire);
524 }
525
526 hazptr_obj_batch<Atom>* ensureBatch() const {
527 auto b = batch();
528 if (!b) {
529 auto storage = Allocator().allocate(sizeof(hazptr_obj_batch<Atom>));
530 auto newbatch = new (storage) hazptr_obj_batch<Atom>();
531 if (batch_.compare_exchange_strong(b, newbatch)) {
532 b = newbatch;
533 } else {
534 newbatch->shutdown_and_reclaim();
535 newbatch->~hazptr_obj_batch<Atom>();
536 Allocator().deallocate(storage, sizeof(hazptr_obj_batch<Atom>));
537 }
538 }
539 return b;
540 }
541
542 void batch_shutdown_cleanup() {
543 auto b = batch();
544 if (b) {
545 b->shutdown_and_reclaim();
546 hazptr_cleanup_batch_tag(b);
547 b->~hazptr_obj_batch<Atom>();
548 Allocator().deallocate((uint8_t*)b, sizeof(hazptr_obj_batch<Atom>));
549 }
550 }
551
552 mutable Atom<SegmentT*> segments_[NumShards];
553 size_t size_{0};
554 size_t max_size_{0};
555 mutable Atom<hazptr_obj_batch<Atom>*> batch_{nullptr};
556};
557
558} // namespace folly
559