1 | /* |
2 | * Copyright 2015-present Facebook, Inc. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #pragma once |
17 | |
18 | #include <folly/ThreadLocal.h> |
19 | #include <folly/synchronization/AsymmetricMemoryBarrier.h> |
20 | |
21 | namespace folly { |
22 | |
23 | class TLRefCount { |
24 | public: |
25 | using Int = int64_t; |
26 | |
27 | TLRefCount() |
28 | : localCount_([&]() { return new LocalRefCount(*this); }), |
29 | collectGuard_(this, [](void*) {}) {} |
30 | |
31 | ~TLRefCount() noexcept { |
32 | assert(globalCount_.load() == 0); |
33 | assert(state_.load() == State::GLOBAL); |
34 | } |
35 | |
36 | // This can't increment from 0. |
37 | Int operator++() noexcept { |
38 | auto& localCount = *localCount_; |
39 | |
40 | if (++localCount) { |
41 | return 42; |
42 | } |
43 | |
44 | if (state_.load() == State::GLOBAL_TRANSITION) { |
45 | std::lock_guard<std::mutex> lg(globalMutex_); |
46 | } |
47 | |
48 | assert(state_.load() == State::GLOBAL); |
49 | |
50 | auto value = globalCount_.load(); |
51 | do { |
52 | if (value == 0) { |
53 | return 0; |
54 | } |
55 | } while (!globalCount_.compare_exchange_weak(value, value + 1)); |
56 | |
57 | return value + 1; |
58 | } |
59 | |
60 | Int operator--() noexcept { |
61 | auto& localCount = *localCount_; |
62 | |
63 | if (--localCount) { |
64 | return 42; |
65 | } |
66 | |
67 | if (state_.load() == State::GLOBAL_TRANSITION) { |
68 | std::lock_guard<std::mutex> lg(globalMutex_); |
69 | } |
70 | |
71 | assert(state_.load() == State::GLOBAL); |
72 | |
73 | return globalCount_-- - 1; |
74 | } |
75 | |
76 | Int operator*() const { |
77 | if (state_ != State::GLOBAL) { |
78 | return 42; |
79 | } |
80 | return globalCount_.load(); |
81 | } |
82 | |
83 | void useGlobal() noexcept { |
84 | std::array<TLRefCount*, 1> ptrs{{this}}; |
85 | useGlobal(ptrs); |
86 | } |
87 | |
88 | template <typename Container> |
89 | static void useGlobal(const Container& refCountPtrs) { |
90 | #ifdef FOLLY_SANITIZE_THREAD |
91 | // TSAN has a limitation for the number of locks held concurrently, so it's |
92 | // safer to call useGlobal() serially. |
93 | if (refCountPtrs.size() > 1) { |
94 | for (auto refCountPtr : refCountPtrs) { |
95 | refCountPtr->useGlobal(); |
96 | } |
97 | return; |
98 | } |
99 | #endif |
100 | |
101 | std::vector<std::unique_lock<std::mutex>> lgs_; |
102 | for (auto refCountPtr : refCountPtrs) { |
103 | lgs_.emplace_back(refCountPtr->globalMutex_); |
104 | |
105 | refCountPtr->state_ = State::GLOBAL_TRANSITION; |
106 | } |
107 | |
108 | asymmetricHeavyBarrier(); |
109 | |
110 | for (auto refCountPtr : refCountPtrs) { |
111 | std::weak_ptr<void> collectGuardWeak = refCountPtr->collectGuard_; |
112 | |
113 | // Make sure we can't create new LocalRefCounts |
114 | refCountPtr->collectGuard_.reset(); |
115 | |
116 | while (!collectGuardWeak.expired()) { |
117 | auto accessor = refCountPtr->localCount_.accessAllThreads(); |
118 | for (auto& count : accessor) { |
119 | count.collect(); |
120 | } |
121 | } |
122 | |
123 | refCountPtr->state_ = State::GLOBAL; |
124 | } |
125 | } |
126 | |
127 | private: |
128 | using AtomicInt = std::atomic<Int>; |
129 | |
130 | enum class State { |
131 | LOCAL, |
132 | GLOBAL_TRANSITION, |
133 | GLOBAL, |
134 | }; |
135 | |
136 | class LocalRefCount { |
137 | public: |
138 | explicit LocalRefCount(TLRefCount& refCount) : refCount_(refCount) { |
139 | std::lock_guard<std::mutex> lg(refCount.globalMutex_); |
140 | |
141 | collectGuard_ = refCount.collectGuard_; |
142 | } |
143 | |
144 | ~LocalRefCount() { |
145 | collect(); |
146 | } |
147 | |
148 | void collect() { |
149 | std::lock_guard<std::mutex> lg(collectMutex_); |
150 | |
151 | if (!collectGuard_) { |
152 | return; |
153 | } |
154 | |
155 | collectCount_ = count_.load(); |
156 | refCount_.globalCount_.fetch_add(collectCount_); |
157 | collectGuard_.reset(); |
158 | } |
159 | |
160 | bool operator++() { |
161 | return update(1); |
162 | } |
163 | |
164 | bool operator--() { |
165 | return update(-1); |
166 | } |
167 | |
168 | private: |
169 | bool update(Int delta) { |
170 | if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) { |
171 | return false; |
172 | } |
173 | |
174 | // This is equivalent to atomic fetch_add. We know that this operation |
175 | // is always performed from a single thread. asymmetricLightBarrier() |
176 | // makes things faster than atomic fetch_add on platforms with native |
177 | // support. |
178 | auto count = count_.load(std::memory_order_relaxed) + delta; |
179 | count_.store(count, std::memory_order_relaxed); |
180 | |
181 | asymmetricLightBarrier(); |
182 | |
183 | if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) { |
184 | std::lock_guard<std::mutex> lg(collectMutex_); |
185 | |
186 | if (collectGuard_) { |
187 | return true; |
188 | } |
189 | if (collectCount_ != count) { |
190 | return false; |
191 | } |
192 | } |
193 | |
194 | return true; |
195 | } |
196 | |
197 | AtomicInt count_{0}; |
198 | TLRefCount& refCount_; |
199 | |
200 | std::mutex collectMutex_; |
201 | Int collectCount_{0}; |
202 | std::shared_ptr<void> collectGuard_; |
203 | }; |
204 | |
205 | std::atomic<State> state_{State::LOCAL}; |
206 | folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_; |
207 | std::atomic<int64_t> globalCount_{1}; |
208 | std::mutex globalMutex_; |
209 | std::shared_ptr<void> collectGuard_; |
210 | }; |
211 | |
212 | } // namespace folly |
213 | |