1/*
2 * Copyright 2018-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <algorithm>
18#include <cstdint>
19#include <limits>
20#include <utility>
21
22#include <glog/logging.h>
23
24namespace folly {
25
26namespace detail {
27
28// Internal cancellation state object.
29class CancellationState {
30 public:
31 FOLLY_NODISCARD static CancellationStateSourcePtr create();
32
33 private:
34 // Constructed initially with a CancellationSource reference count of 1.
35 CancellationState() noexcept;
36
37 ~CancellationState();
38
39 friend struct CancellationStateTokenDeleter;
40 friend struct CancellationStateSourceDeleter;
41
42 void removeTokenReference() noexcept;
43 void removeSourceReference() noexcept;
44
45 public:
46 FOLLY_NODISCARD CancellationStateTokenPtr addTokenReference() noexcept;
47
48 FOLLY_NODISCARD CancellationStateSourcePtr addSourceReference() noexcept;
49
50 bool tryAddCallback(
51 CancellationCallback* callback,
52 bool incrementRefCountIfSuccessful) noexcept;
53
54 void removeCallback(CancellationCallback* callback) noexcept;
55
56 bool isCancellationRequested() const noexcept;
57 bool canBeCancelled() const noexcept;
58
59 // Request cancellation.
60 // Return 'true' if cancellation had already been requested.
61 // Return 'false' if this was the first thread to request
62 // cancellation.
63 bool requestCancellation() noexcept;
64
65 private:
66 void lock() noexcept;
67 void unlock() noexcept;
68 void unlockAndIncrementTokenCount() noexcept;
69 void unlockAndDecrementTokenCount() noexcept;
70 bool tryLockAndCancelUnlessCancelled() noexcept;
71
72 template <typename Predicate>
73 bool tryLock(Predicate predicate) noexcept;
74
75 static bool canBeCancelled(std::uint64_t state) noexcept;
76 static bool isCancellationRequested(std::uint64_t state) noexcept;
77 static bool isLocked(std::uint64_t state) noexcept;
78
79 static constexpr std::uint64_t kCancellationRequestedFlag = 1;
80 static constexpr std::uint64_t kLockedFlag = 2;
81 static constexpr std::uint64_t kTokenReferenceCountIncrement = 4;
82 static constexpr std::uint64_t kSourceReferenceCountIncrement =
83 std::uint64_t(1) << 33u;
84 static constexpr std::uint64_t kTokenReferenceCountMask =
85 (kSourceReferenceCountIncrement - 1u) -
86 (kTokenReferenceCountIncrement - 1u);
87 static constexpr std::uint64_t kSourceReferenceCountMask =
88 std::numeric_limits<std::uint64_t>::max() -
89 (kSourceReferenceCountIncrement - 1u);
90
91 // Bit 0 - Cancellation Requested
92 // Bit 1 - Locked Flag
93 // Bits 2-32 - Token reference count (max ~2 billion)
94 // Bits 33-63 - Source reference count (max ~2 billion)
95 std::atomic<std::uint64_t> state_;
96 CancellationCallback* head_;
97 std::thread::id signallingThreadId_;
98};
99
100inline void CancellationStateTokenDeleter::operator()(
101 CancellationState* state) noexcept {
102 state->removeTokenReference();
103}
104
105inline void CancellationStateSourceDeleter::operator()(
106 CancellationState* state) noexcept {
107 state->removeSourceReference();
108}
109
110} // namespace detail
111
112inline CancellationToken::CancellationToken(
113 const CancellationToken& other) noexcept
114 : state_() {
115 if (other.state_) {
116 state_ = other.state_->addTokenReference();
117 }
118}
119
120inline CancellationToken::CancellationToken(CancellationToken&& other) noexcept
121 : state_(std::move(other.state_)) {}
122
123inline CancellationToken& CancellationToken::operator=(
124 const CancellationToken& other) noexcept {
125 if (state_ != other.state_) {
126 CancellationToken temp{other};
127 swap(temp);
128 }
129 return *this;
130}
131
132inline CancellationToken& CancellationToken::operator=(
133 CancellationToken&& other) noexcept {
134 state_ = std::move(other.state_);
135 return *this;
136}
137
138inline bool CancellationToken::isCancellationRequested() const noexcept {
139 return state_ != nullptr && state_->isCancellationRequested();
140}
141
142inline bool CancellationToken::canBeCancelled() const noexcept {
143 return state_ != nullptr && state_->canBeCancelled();
144}
145
146inline void CancellationToken::swap(CancellationToken& other) noexcept {
147 std::swap(state_, other.state_);
148}
149
150inline CancellationToken::CancellationToken(
151 detail::CancellationStateTokenPtr state) noexcept
152 : state_(std::move(state)) {}
153
154inline bool operator==(
155 const CancellationToken& a,
156 const CancellationToken& b) noexcept {
157 return a.state_ == b.state_;
158}
159
160inline bool operator!=(
161 const CancellationToken& a,
162 const CancellationToken& b) noexcept {
163 return !(a == b);
164}
165
166inline CancellationSource::CancellationSource()
167 : state_(detail::CancellationState::create()) {}
168
169inline CancellationSource::CancellationSource(
170 const CancellationSource& other) noexcept
171 : state_() {
172 if (other.state_) {
173 state_ = other.state_->addSourceReference();
174 }
175}
176
177inline CancellationSource::CancellationSource(
178 CancellationSource&& other) noexcept
179 : state_(std::move(other.state_)) {}
180
181inline CancellationSource& CancellationSource::operator=(
182 const CancellationSource& other) noexcept {
183 if (state_ != other.state_) {
184 CancellationSource temp{other};
185 swap(temp);
186 }
187 return *this;
188}
189
190inline CancellationSource& CancellationSource::operator=(
191 CancellationSource&& other) noexcept {
192 state_ = std::move(other.state_);
193 return *this;
194}
195
196inline bool CancellationSource::isCancellationRequested() const noexcept {
197 return state_ != nullptr && state_->isCancellationRequested();
198}
199
200inline bool CancellationSource::canBeCancelled() const noexcept {
201 return state_ != nullptr;
202}
203
204inline CancellationToken CancellationSource::getToken() const noexcept {
205 if (state_ != nullptr) {
206 return CancellationToken{state_->addTokenReference()};
207 }
208 return CancellationToken{};
209}
210
211inline bool CancellationSource::requestCancellation() const noexcept {
212 if (state_ != nullptr) {
213 return state_->requestCancellation();
214 }
215 return false;
216}
217
218inline void CancellationSource::swap(CancellationSource& other) noexcept {
219 std::swap(state_, other.state_);
220}
221
222template <
223 typename Callable,
224 std::enable_if_t<
225 std::is_constructible<CancellationCallback::VoidFunction, Callable>::
226 value,
227 int>>
228inline CancellationCallback::CancellationCallback(
229 CancellationToken&& ct,
230 Callable&& callable)
231 : next_(nullptr),
232 prevNext_(nullptr),
233 state_(nullptr),
234 callback_(static_cast<Callable&&>(callable)),
235 destructorHasRunInsideCallback_(nullptr),
236 callbackCompleted_(false) {
237 if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, false)) {
238 state_ = ct.state_.release();
239 }
240}
241
242template <
243 typename Callable,
244 std::enable_if_t<
245 std::is_constructible<CancellationCallback::VoidFunction, Callable>::
246 value,
247 int>>
248inline CancellationCallback::CancellationCallback(
249 const CancellationToken& ct,
250 Callable&& callable)
251 : next_(nullptr),
252 prevNext_(nullptr),
253 state_(nullptr),
254 callback_(static_cast<Callable&&>(callable)),
255 destructorHasRunInsideCallback_(nullptr),
256 callbackCompleted_(false) {
257 if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, true)) {
258 state_ = ct.state_.get();
259 }
260}
261
262inline CancellationCallback::~CancellationCallback() {
263 if (state_ != nullptr) {
264 state_->removeCallback(this);
265 }
266}
267
268inline void CancellationCallback::invokeCallback() noexcept {
269 // Invoke within a noexcept context so that we std::terminate() if it throws.
270 callback_();
271}
272
273namespace detail {
274
275inline CancellationStateSourcePtr CancellationState::create() {
276 return CancellationStateSourcePtr{new CancellationState()};
277}
278
279inline CancellationState::CancellationState() noexcept
280 : state_(kSourceReferenceCountIncrement),
281 head_(nullptr),
282 signallingThreadId_() {}
283
284inline CancellationStateTokenPtr
285CancellationState::addTokenReference() noexcept {
286 state_.fetch_add(kTokenReferenceCountIncrement, std::memory_order_relaxed);
287 return CancellationStateTokenPtr{this};
288}
289
290inline void CancellationState::removeTokenReference() noexcept {
291 const auto oldState = state_.fetch_sub(
292 kTokenReferenceCountIncrement, std::memory_order_acq_rel);
293 DCHECK(
294 (oldState & kTokenReferenceCountMask) >= kTokenReferenceCountIncrement);
295 if (oldState < (2 * kTokenReferenceCountIncrement)) {
296 delete this;
297 }
298}
299
300inline CancellationStateSourcePtr
301CancellationState::addSourceReference() noexcept {
302 state_.fetch_add(kSourceReferenceCountIncrement, std::memory_order_relaxed);
303 return CancellationStateSourcePtr{this};
304}
305
306inline void CancellationState::removeSourceReference() noexcept {
307 const auto oldState = state_.fetch_sub(
308 kSourceReferenceCountIncrement, std::memory_order_acq_rel);
309 DCHECK(
310 (oldState & kSourceReferenceCountMask) >= kSourceReferenceCountIncrement);
311 if (oldState <
312 (kSourceReferenceCountIncrement + kTokenReferenceCountIncrement)) {
313 delete this;
314 }
315}
316
317inline bool CancellationState::isCancellationRequested() const noexcept {
318 return isCancellationRequested(state_.load(std::memory_order_acquire));
319}
320
321inline bool CancellationState::canBeCancelled() const noexcept {
322 return canBeCancelled(state_.load(std::memory_order_acquire));
323}
324
325inline bool CancellationState::canBeCancelled(std::uint64_t state) noexcept {
326 // Can be cancelled if there is at least one CancellationSource ref-count
327 // or if cancellation has been requested.
328 return (state >= kSourceReferenceCountIncrement) ||
329 isCancellationRequested(state);
330}
331
332inline bool CancellationState::isCancellationRequested(
333 std::uint64_t state) noexcept {
334 return (state & kCancellationRequestedFlag) != 0;
335}
336
337inline bool CancellationState::isLocked(std::uint64_t state) noexcept {
338 return (state & kLockedFlag) != 0;
339}
340
341} // namespace detail
342
343} // namespace folly
344