1/*
2 * Copyright 2015-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#pragma once
17
18#include <folly/ThreadLocal.h>
19#include <folly/synchronization/AsymmetricMemoryBarrier.h>
20
21namespace folly {
22
23class TLRefCount {
24 public:
25 using Int = int64_t;
26
27 TLRefCount()
28 : localCount_([&]() { return new LocalRefCount(*this); }),
29 collectGuard_(this, [](void*) {}) {}
30
31 ~TLRefCount() noexcept {
32 assert(globalCount_.load() == 0);
33 assert(state_.load() == State::GLOBAL);
34 }
35
36 // This can't increment from 0.
37 Int operator++() noexcept {
38 auto& localCount = *localCount_;
39
40 if (++localCount) {
41 return 42;
42 }
43
44 if (state_.load() == State::GLOBAL_TRANSITION) {
45 std::lock_guard<std::mutex> lg(globalMutex_);
46 }
47
48 assert(state_.load() == State::GLOBAL);
49
50 auto value = globalCount_.load();
51 do {
52 if (value == 0) {
53 return 0;
54 }
55 } while (!globalCount_.compare_exchange_weak(value, value + 1));
56
57 return value + 1;
58 }
59
60 Int operator--() noexcept {
61 auto& localCount = *localCount_;
62
63 if (--localCount) {
64 return 42;
65 }
66
67 if (state_.load() == State::GLOBAL_TRANSITION) {
68 std::lock_guard<std::mutex> lg(globalMutex_);
69 }
70
71 assert(state_.load() == State::GLOBAL);
72
73 return globalCount_-- - 1;
74 }
75
76 Int operator*() const {
77 if (state_ != State::GLOBAL) {
78 return 42;
79 }
80 return globalCount_.load();
81 }
82
83 void useGlobal() noexcept {
84 std::array<TLRefCount*, 1> ptrs{{this}};
85 useGlobal(ptrs);
86 }
87
88 template <typename Container>
89 static void useGlobal(const Container& refCountPtrs) {
90#ifdef FOLLY_SANITIZE_THREAD
91 // TSAN has a limitation for the number of locks held concurrently, so it's
92 // safer to call useGlobal() serially.
93 if (refCountPtrs.size() > 1) {
94 for (auto refCountPtr : refCountPtrs) {
95 refCountPtr->useGlobal();
96 }
97 return;
98 }
99#endif
100
101 std::vector<std::unique_lock<std::mutex>> lgs_;
102 for (auto refCountPtr : refCountPtrs) {
103 lgs_.emplace_back(refCountPtr->globalMutex_);
104
105 refCountPtr->state_ = State::GLOBAL_TRANSITION;
106 }
107
108 asymmetricHeavyBarrier();
109
110 for (auto refCountPtr : refCountPtrs) {
111 std::weak_ptr<void> collectGuardWeak = refCountPtr->collectGuard_;
112
113 // Make sure we can't create new LocalRefCounts
114 refCountPtr->collectGuard_.reset();
115
116 while (!collectGuardWeak.expired()) {
117 auto accessor = refCountPtr->localCount_.accessAllThreads();
118 for (auto& count : accessor) {
119 count.collect();
120 }
121 }
122
123 refCountPtr->state_ = State::GLOBAL;
124 }
125 }
126
127 private:
128 using AtomicInt = std::atomic<Int>;
129
130 enum class State {
131 LOCAL,
132 GLOBAL_TRANSITION,
133 GLOBAL,
134 };
135
136 class LocalRefCount {
137 public:
138 explicit LocalRefCount(TLRefCount& refCount) : refCount_(refCount) {
139 std::lock_guard<std::mutex> lg(refCount.globalMutex_);
140
141 collectGuard_ = refCount.collectGuard_;
142 }
143
144 ~LocalRefCount() {
145 collect();
146 }
147
148 void collect() {
149 std::lock_guard<std::mutex> lg(collectMutex_);
150
151 if (!collectGuard_) {
152 return;
153 }
154
155 collectCount_ = count_.load();
156 refCount_.globalCount_.fetch_add(collectCount_);
157 collectGuard_.reset();
158 }
159
160 bool operator++() {
161 return update(1);
162 }
163
164 bool operator--() {
165 return update(-1);
166 }
167
168 private:
169 bool update(Int delta) {
170 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
171 return false;
172 }
173
174 // This is equivalent to atomic fetch_add. We know that this operation
175 // is always performed from a single thread. asymmetricLightBarrier()
176 // makes things faster than atomic fetch_add on platforms with native
177 // support.
178 auto count = count_.load(std::memory_order_relaxed) + delta;
179 count_.store(count, std::memory_order_relaxed);
180
181 asymmetricLightBarrier();
182
183 if (UNLIKELY(refCount_.state_.load() != State::LOCAL)) {
184 std::lock_guard<std::mutex> lg(collectMutex_);
185
186 if (collectGuard_) {
187 return true;
188 }
189 if (collectCount_ != count) {
190 return false;
191 }
192 }
193
194 return true;
195 }
196
197 AtomicInt count_{0};
198 TLRefCount& refCount_;
199
200 std::mutex collectMutex_;
201 Int collectCount_{0};
202 std::shared_ptr<void> collectGuard_;
203 };
204
205 std::atomic<State> state_{State::LOCAL};
206 folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
207 std::atomic<int64_t> globalCount_{1};
208 std::mutex globalMutex_;
209 std::shared_ptr<void> collectGuard_;
210};
211
212} // namespace folly
213