1/*
2 * Copyright 2018-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <folly/Portability.h>
18
19#if FOLLY_HAS_COROUTINES
20
21#include <folly/experimental/coro/SharedMutex.h>
22
23using namespace folly::coro;
24
25SharedMutexFair::~SharedMutexFair() {
26 assert(state_->lockedFlagAndReaderCount_ == kUnlocked);
27 assert(state_->waitersHead_ == nullptr);
28}
29
30bool SharedMutexFair::try_lock() noexcept {
31 auto lock = state_.contextualLock();
32 if (lock->lockedFlagAndReaderCount_ == kUnlocked) {
33 lock->lockedFlagAndReaderCount_ = kExclusiveLockFlag;
34 return true;
35 }
36 return false;
37}
38
39bool SharedMutexFair::try_lock_shared() noexcept {
40 auto lock = state_.contextualLock();
41 if (lock->lockedFlagAndReaderCount_ == kUnlocked ||
42 (lock->lockedFlagAndReaderCount_ >= kSharedLockCountIncrement &&
43 lock->waitersHead_ == nullptr)) {
44 lock->lockedFlagAndReaderCount_ += kSharedLockCountIncrement;
45 return true;
46 }
47 return false;
48}
49
50void SharedMutexFair::unlock() noexcept {
51 LockAwaiterBase* awaitersToResume = nullptr;
52 {
53 auto lockedState = state_.contextualLock();
54 assert(lockedState->lockedFlagAndReaderCount_ == kExclusiveLockFlag);
55 awaitersToResume = unlockOrGetNextWaitersToResume(*lockedState);
56 }
57
58 resumeWaiters(awaitersToResume);
59}
60
61void SharedMutexFair::unlock_shared() noexcept {
62 LockAwaiterBase* awaitersToResume = nullptr;
63 {
64 auto lockedState = state_.contextualLock();
65 assert(lockedState->lockedFlagAndReaderCount_ >= kSharedLockCountIncrement);
66 lockedState->lockedFlagAndReaderCount_ -= kSharedLockCountIncrement;
67 if (lockedState->lockedFlagAndReaderCount_ != kUnlocked) {
68 return;
69 }
70
71 awaitersToResume = unlockOrGetNextWaitersToResume(*lockedState);
72 }
73
74 resumeWaiters(awaitersToResume);
75}
76
77SharedMutexFair::LockAwaiterBase*
78SharedMutexFair::unlockOrGetNextWaitersToResume(
79 SharedMutexFair::State& state) noexcept {
80 auto* head = state.waitersHead_;
81 if (head != nullptr) {
82 if (head->lockType_ == LockType::EXCLUSIVE) {
83 state.waitersHead_ = std::exchange(head->nextAwaiter_, nullptr);
84 state.lockedFlagAndReaderCount_ = kExclusiveLockFlag;
85 } else {
86 std::size_t newState = kSharedLockCountIncrement;
87
88 // Scan for a run of SHARED lock types.
89 auto* last = head;
90 auto* next = last->nextAwaiter_;
91 while (next != nullptr && next->lockType_ == LockType::SHARED) {
92 last = next;
93 next = next->nextAwaiter_;
94 newState += kSharedLockCountIncrement;
95 }
96
97 last->nextAwaiter_ = nullptr;
98 state.lockedFlagAndReaderCount_ = newState;
99 state.waitersHead_ = next;
100 }
101
102 if (state.waitersHead_ == nullptr) {
103 state.waitersTailNext_ = &state.waitersHead_;
104 }
105 } else {
106 state.lockedFlagAndReaderCount_ = kUnlocked;
107 }
108
109 return head;
110}
111
112void SharedMutexFair::resumeWaiters(LockAwaiterBase* awaiters) noexcept {
113 while (awaiters != nullptr) {
114 std::exchange(awaiters, awaiters->nextAwaiter_)->resume();
115 }
116}
117
118#endif
119