1/*
2 * Copyright 2017-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#pragma once
17#include <folly/DefaultKeepAliveExecutor.h>
18#include <folly/Memory.h>
19#include <folly/SharedMutex.h>
20#include <folly/executors/GlobalThreadPoolList.h>
21#include <folly/executors/task_queue/LifoSemMPMCQueue.h>
22#include <folly/executors/thread_factory/NamedThreadFactory.h>
23#include <folly/io/async/Request.h>
24#include <folly/portability/GFlags.h>
25#include <folly/synchronization/Baton.h>
26
27#include <algorithm>
28#include <mutex>
29#include <queue>
30
31#include <glog/logging.h>
32
33namespace folly {
34
35/* Base class for implementing threadpool based executors.
36 *
37 * Dynamic thread behavior:
38 *
39 * ThreadPoolExecutors may vary their actual running number of threads
40 * between minThreads_ and maxThreads_, tracked by activeThreads_.
41 * The actual implementation of joining an idle thread is left to the
42 * ThreadPoolExecutors' subclass (typically by LifoSem try_take_for
43 * timing out). Idle threads should be removed from threadList_, and
44 * threadsToJoin incremented, and activeThreads_ decremented.
45 *
46 * On task add(), if an executor can garantee there is an active
47 * thread that will handle the task, then nothing needs to be done.
48 * If not, then ensureActiveThreads() should be called to possibly
49 * start another pool thread, up to maxThreads_.
50 *
51 * ensureJoined() is called on add(), such that we can join idle
52 * threads that were destroyed (which can't be joined from
53 * themselves).
54 */
55class ThreadPoolExecutor : public DefaultKeepAliveExecutor {
56 public:
57 explicit ThreadPoolExecutor(
58 size_t maxThreads,
59 size_t minThreads,
60 std::shared_ptr<ThreadFactory> threadFactory,
61 bool isWaitForAll = false);
62
63 ~ThreadPoolExecutor() override;
64
65 void add(Func func) override = 0;
66 virtual void
67 add(Func func, std::chrono::milliseconds expiration, Func expireCallback) = 0;
68
69 void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) {
70 CHECK(numThreads() == 0);
71 threadFactory_ = std::move(threadFactory);
72 }
73
74 std::shared_ptr<ThreadFactory> getThreadFactory() {
75 return threadFactory_;
76 }
77
78 size_t numThreads();
79 void setNumThreads(size_t numThreads);
80
81 // Return actual number of active threads -- this could be different from
82 // numThreads() due to ThreadPoolExecutor's dynamic behavior.
83 size_t numActiveThreads();
84
85 /*
86 * stop() is best effort - there is no guarantee that unexecuted tasks won't
87 * be executed before it returns. Specifically, IOThreadPoolExecutor's stop()
88 * behaves like join().
89 */
90 void stop();
91 void join();
92
93 /**
94 * Execute f against all ThreadPoolExecutors, primarily for retrieving and
95 * exporting stats.
96 */
97 static void withAll(FunctionRef<void(ThreadPoolExecutor&)> f);
98
99 struct PoolStats {
100 PoolStats()
101 : threadCount(0),
102 idleThreadCount(0),
103 activeThreadCount(0),
104 pendingTaskCount(0),
105 totalTaskCount(0),
106 maxIdleTime(0) {}
107 size_t threadCount, idleThreadCount, activeThreadCount;
108 uint64_t pendingTaskCount, totalTaskCount;
109 std::chrono::nanoseconds maxIdleTime;
110 };
111
112 PoolStats getPoolStats();
113 size_t getPendingTaskCount();
114 std::string getName();
115
116 struct TaskStats {
117 TaskStats() : expired(false), waitTime(0), runTime(0) {}
118 bool expired;
119 std::chrono::nanoseconds waitTime;
120 std::chrono::nanoseconds runTime;
121 };
122
123 using TaskStatsCallback = std::function<void(TaskStats)>;
124 void subscribeToTaskStats(TaskStatsCallback cb);
125
126 /**
127 * Base class for threads created with ThreadPoolExecutor.
128 * Some subclasses have methods that operate on these
129 * handles.
130 */
131 class ThreadHandle {
132 public:
133 virtual ~ThreadHandle() = default;
134 };
135
136 /**
137 * Observer interface for thread start/stop.
138 * Provides hooks so actions can be taken when
139 * threads are created
140 */
141 class Observer {
142 public:
143 virtual void threadStarted(ThreadHandle*) = 0;
144 virtual void threadStopped(ThreadHandle*) = 0;
145 virtual void threadPreviouslyStarted(ThreadHandle* h) {
146 threadStarted(h);
147 }
148 virtual void threadNotYetStopped(ThreadHandle* h) {
149 threadStopped(h);
150 }
151 virtual ~Observer() = default;
152 };
153
154 void addObserver(std::shared_ptr<Observer>);
155 void removeObserver(std::shared_ptr<Observer>);
156
157 void setThreadDeathTimeout(std::chrono::milliseconds timeout) {
158 threadTimeout_ = timeout;
159 }
160
161 protected:
162 // Prerequisite: threadListLock_ writelocked
163 void addThreads(size_t n);
164 // Prerequisite: threadListLock_ writelocked
165 void removeThreads(size_t n, bool isJoin);
166
167 struct TaskStatsCallbackRegistry;
168
169 struct alignas(hardware_destructive_interference_size) Thread
170 : public ThreadHandle {
171 explicit Thread(ThreadPoolExecutor* pool)
172 : id(nextId++),
173 handle(),
174 idle(true),
175 lastActiveTime(std::chrono::steady_clock::now()),
176 taskStatsCallbacks(pool->taskStatsCallbacks_) {}
177
178 ~Thread() override = default;
179
180 static std::atomic<uint64_t> nextId;
181 uint64_t id;
182 std::thread handle;
183 bool idle;
184 std::chrono::steady_clock::time_point lastActiveTime;
185 folly::Baton<> startupBaton;
186 std::shared_ptr<TaskStatsCallbackRegistry> taskStatsCallbacks;
187 };
188
189 typedef std::shared_ptr<Thread> ThreadPtr;
190
191 struct Task {
192 explicit Task(
193 Func&& func,
194 std::chrono::milliseconds expiration,
195 Func&& expireCallback);
196 Func func_;
197 TaskStats stats_;
198 std::chrono::steady_clock::time_point enqueueTime_;
199 std::chrono::milliseconds expiration_;
200 Func expireCallback_;
201 std::shared_ptr<folly::RequestContext> context_;
202 };
203
204 static void runTask(const ThreadPtr& thread, Task&& task);
205
206 // The function that will be bound to pool threads. It must call
207 // thread->startupBaton.post() when it's ready to consume work.
208 virtual void threadRun(ThreadPtr thread) = 0;
209
210 // Stop n threads and put their ThreadPtrs in the stoppedThreads_ queue
211 // and remove them from threadList_, either synchronize or asynchronize
212 // Prerequisite: threadListLock_ writelocked
213 virtual void stopThreads(size_t n) = 0;
214
215 // Join n stopped threads and remove them from waitingForJoinThreads_ queue.
216 // Should not hold a lock because joining thread operation may invoke some
217 // cleanup operations on the thread, and those cleanup operations may
218 // require a lock on ThreadPoolExecutor.
219 void joinStoppedThreads(size_t n);
220
221 // Create a suitable Thread struct
222 virtual ThreadPtr makeThread() {
223 return std::make_shared<Thread>(this);
224 }
225
226 // Prerequisite: threadListLock_ readlocked or writelocked
227 virtual size_t getPendingTaskCountImpl() = 0;
228
229 class ThreadList {
230 public:
231 void add(const ThreadPtr& state) {
232 auto it = std::lower_bound(
233 vec_.begin(),
234 vec_.end(),
235 state,
236 // compare method is a static method of class
237 // and therefore cannot be inlined by compiler
238 // as a template predicate of the STL algorithm
239 // but wrapped up with the lambda function (lambda will be inlined)
240 // compiler can inline compare method as well
241 [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
242 return compare(ts1, ts2);
243 });
244 vec_.insert(it, state);
245 }
246
247 void remove(const ThreadPtr& state) {
248 auto itPair = std::equal_range(
249 vec_.begin(),
250 vec_.end(),
251 state,
252 // the same as above
253 [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
254 return compare(ts1, ts2);
255 });
256 CHECK(itPair.first != vec_.end());
257 CHECK(std::next(itPair.first) == itPair.second);
258 vec_.erase(itPair.first);
259 }
260
261 const std::vector<ThreadPtr>& get() const {
262 return vec_;
263 }
264
265 private:
266 static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
267 return ts1->id < ts2->id;
268 }
269
270 std::vector<ThreadPtr> vec_;
271 };
272
273 class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
274 public:
275 BlockingQueueAddResult add(ThreadPtr item) override;
276 ThreadPtr take() override;
277 size_t size() override;
278 folly::Optional<ThreadPtr> try_take_for(
279 std::chrono::milliseconds /*timeout */) override;
280
281 private:
282 folly::LifoSem sem_;
283 std::mutex mutex_;
284 std::queue<ThreadPtr> queue_;
285 };
286
287 std::shared_ptr<ThreadFactory> threadFactory_;
288 const bool isWaitForAll_; // whether to wait till event base loop exits
289
290 ThreadList threadList_;
291 SharedMutex threadListLock_;
292 StoppedThreadQueue stoppedThreads_;
293 std::atomic<bool> isJoin_{false}; // whether the current downsizing is a join
294
295 struct TaskStatsCallbackRegistry {
296 folly::ThreadLocal<bool> inCallback;
297 folly::Synchronized<std::vector<TaskStatsCallback>> callbackList;
298 };
299 std::shared_ptr<TaskStatsCallbackRegistry> taskStatsCallbacks_;
300 std::vector<std::shared_ptr<Observer>> observers_;
301 folly::ThreadPoolListHook threadPoolHook_;
302
303 // Dynamic thread sizing functions and variables
304 void ensureActiveThreads();
305 void ensureJoined();
306 bool minActive();
307 bool tryTimeoutThread();
308
309 // These are only modified while holding threadListLock_, but
310 // are read without holding the lock.
311 std::atomic<size_t> maxThreads_{0};
312 std::atomic<size_t> minThreads_{0};
313 std::atomic<size_t> activeThreads_{0};
314
315 std::atomic<size_t> threadsToJoin_{0};
316 std::chrono::milliseconds threadTimeout_{0};
317
318 void joinKeepAliveOnce() {
319 if (!std::exchange(keepAliveJoined_, true)) {
320 joinKeepAlive();
321 }
322 }
323
324 bool keepAliveJoined_{false};
325};
326
327} // namespace folly
328