1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/util/thread-pool.h" |
19 | |
20 | #include <algorithm> |
21 | #include <condition_variable> |
22 | #include <deque> |
23 | #include <mutex> |
24 | #include <string> |
25 | #include <thread> |
26 | #include <vector> |
27 | |
28 | #include "arrow/util/io-util.h" |
29 | #include "arrow/util/logging.h" |
30 | |
31 | namespace arrow { |
32 | namespace internal { |
33 | |
34 | struct ThreadPool::State { |
35 | State() : desired_capacity_(0), please_shutdown_(false), quick_shutdown_(false) {} |
36 | |
37 | // NOTE: in case locking becomes too expensive, we can investigate lock-free FIFOs |
38 | // such as https://github.com/cameron314/concurrentqueue |
39 | |
40 | std::mutex mutex_; |
41 | std::condition_variable cv_; |
42 | std::condition_variable cv_shutdown_; |
43 | |
44 | std::list<std::thread> workers_; |
45 | // Trashcan for finished threads |
46 | std::vector<std::thread> finished_workers_; |
47 | std::deque<std::function<void()>> pending_tasks_; |
48 | |
49 | // Desired number of threads |
50 | int desired_capacity_; |
51 | // Are we shutting down? |
52 | bool please_shutdown_; |
53 | bool quick_shutdown_; |
54 | }; |
55 | |
56 | ThreadPool::ThreadPool() |
57 | : sp_state_(std::make_shared<ThreadPool::State>()), |
58 | state_(sp_state_.get()), |
59 | shutdown_on_destroy_(true) { |
60 | #ifndef _WIN32 |
61 | pid_ = getpid(); |
62 | #endif |
63 | } |
64 | |
65 | ThreadPool::~ThreadPool() { |
66 | if (shutdown_on_destroy_) { |
67 | ARROW_UNUSED(Shutdown(false /* wait */)); |
68 | } |
69 | } |
70 | |
71 | void ThreadPool::ProtectAgainstFork() { |
72 | #ifndef _WIN32 |
73 | pid_t current_pid = getpid(); |
74 | if (pid_ != current_pid) { |
75 | // Reinitialize internal state in child process after fork() |
76 | // Ideally we would use pthread_at_fork(), but that doesn't allow |
77 | // storing an argument, hence we'd need to maintain a list of all |
78 | // existing ThreadPools. |
79 | int capacity = state_->desired_capacity_; |
80 | |
81 | auto new_state = std::make_shared<ThreadPool::State>(); |
82 | new_state->please_shutdown_ = state_->please_shutdown_; |
83 | new_state->quick_shutdown_ = state_->quick_shutdown_; |
84 | |
85 | pid_ = current_pid; |
86 | sp_state_ = new_state; |
87 | state_ = sp_state_.get(); |
88 | |
89 | // Launch worker threads anew |
90 | if (!state_->please_shutdown_) { |
91 | ARROW_UNUSED(SetCapacity(capacity)); |
92 | } |
93 | } |
94 | #endif |
95 | } |
96 | |
97 | Status ThreadPool::SetCapacity(int threads) { |
98 | ProtectAgainstFork(); |
99 | std::unique_lock<std::mutex> lock(state_->mutex_); |
100 | if (state_->please_shutdown_) { |
101 | return Status::Invalid("operation forbidden during or after shutdown" ); |
102 | } |
103 | if (threads <= 0) { |
104 | return Status::Invalid("ThreadPool capacity must be > 0" ); |
105 | } |
106 | CollectFinishedWorkersUnlocked(); |
107 | |
108 | state_->desired_capacity_ = threads; |
109 | int diff = static_cast<int>(threads - state_->workers_.size()); |
110 | if (diff > 0) { |
111 | LaunchWorkersUnlocked(diff); |
112 | } else if (diff < 0) { |
113 | // Wake threads to ask them to stop |
114 | state_->cv_.notify_all(); |
115 | } |
116 | return Status::OK(); |
117 | } |
118 | |
119 | int ThreadPool::GetCapacity() { |
120 | ProtectAgainstFork(); |
121 | std::unique_lock<std::mutex> lock(state_->mutex_); |
122 | return state_->desired_capacity_; |
123 | } |
124 | |
125 | int ThreadPool::GetActualCapacity() { |
126 | ProtectAgainstFork(); |
127 | std::unique_lock<std::mutex> lock(state_->mutex_); |
128 | return static_cast<int>(state_->workers_.size()); |
129 | } |
130 | |
131 | Status ThreadPool::Shutdown(bool wait) { |
132 | ProtectAgainstFork(); |
133 | std::unique_lock<std::mutex> lock(state_->mutex_); |
134 | |
135 | if (state_->please_shutdown_) { |
136 | return Status::Invalid("Shutdown() already called" ); |
137 | } |
138 | state_->please_shutdown_ = true; |
139 | state_->quick_shutdown_ = !wait; |
140 | state_->cv_.notify_all(); |
141 | state_->cv_shutdown_.wait(lock, [this] { return state_->workers_.empty(); }); |
142 | if (!state_->quick_shutdown_) { |
143 | DCHECK_EQ(state_->pending_tasks_.size(), 0); |
144 | } else { |
145 | state_->pending_tasks_.clear(); |
146 | } |
147 | CollectFinishedWorkersUnlocked(); |
148 | return Status::OK(); |
149 | } |
150 | |
151 | void ThreadPool::CollectFinishedWorkersUnlocked() { |
152 | for (auto& thread : state_->finished_workers_) { |
153 | // Make sure OS thread has exited |
154 | thread.join(); |
155 | } |
156 | state_->finished_workers_.clear(); |
157 | } |
158 | |
159 | void ThreadPool::LaunchWorkersUnlocked(int threads) { |
160 | std::shared_ptr<State> state = sp_state_; |
161 | |
162 | for (int i = 0; i < threads; i++) { |
163 | state_->workers_.emplace_back(); |
164 | auto it = --(state_->workers_.end()); |
165 | *it = std::thread([state, it] { WorkerLoop(state, it); }); |
166 | } |
167 | } |
168 | |
169 | void ThreadPool::WorkerLoop(std::shared_ptr<State> state, |
170 | std::list<std::thread>::iterator it) { |
171 | std::unique_lock<std::mutex> lock(state->mutex_); |
172 | |
173 | // Since we hold the lock, `it` now points to the correct thread object |
174 | // (LaunchWorkersUnlocked has exited) |
175 | DCHECK_EQ(std::this_thread::get_id(), it->get_id()); |
176 | |
177 | // If too many threads, we should secede from the pool |
178 | const auto should_secede = [&]() -> bool { |
179 | return state->workers_.size() > static_cast<size_t>(state->desired_capacity_); |
180 | }; |
181 | |
182 | while (true) { |
183 | // By the time this thread is started, some tasks may have been pushed |
184 | // or shutdown could even have been requested. So we only wait on the |
185 | // condition variable at the end of the loop. |
186 | |
187 | // Execute pending tasks if any |
188 | while (!state->pending_tasks_.empty() && !state->quick_shutdown_) { |
189 | // We check this opportunistically at each loop iteration since |
190 | // it releases the lock below. |
191 | if (should_secede()) { |
192 | break; |
193 | } |
194 | { |
195 | std::function<void()> task = std::move(state->pending_tasks_.front()); |
196 | state->pending_tasks_.pop_front(); |
197 | lock.unlock(); |
198 | task(); |
199 | } |
200 | lock.lock(); |
201 | } |
202 | // Now either the queue is empty *or* a quick shutdown was requested |
203 | if (state->please_shutdown_ || should_secede()) { |
204 | break; |
205 | } |
206 | // Wait for next wakeup |
207 | state->cv_.wait(lock); |
208 | } |
209 | |
210 | // We're done. Move our thread object to the trashcan of finished |
211 | // workers. This has two motivations: |
212 | // 1) the thread object doesn't get destroyed before this function finishes |
213 | // (but we could call thread::detach() instead) |
214 | // 2) we can explicitly join() the trashcan threads to make sure all OS threads |
215 | // are exited before the ThreadPool is destroyed. Otherwise subtle |
216 | // timing conditions can lead to false positives with Valgrind. |
217 | DCHECK_EQ(std::this_thread::get_id(), it->get_id()); |
218 | state->finished_workers_.push_back(std::move(*it)); |
219 | state->workers_.erase(it); |
220 | if (state->please_shutdown_) { |
221 | // Notify the function waiting in Shutdown(). |
222 | state->cv_shutdown_.notify_one(); |
223 | } |
224 | } |
225 | |
226 | Status ThreadPool::SpawnReal(std::function<void()> task) { |
227 | { |
228 | ProtectAgainstFork(); |
229 | std::lock_guard<std::mutex> lock(state_->mutex_); |
230 | if (state_->please_shutdown_) { |
231 | return Status::Invalid("operation forbidden during or after shutdown" ); |
232 | } |
233 | CollectFinishedWorkersUnlocked(); |
234 | state_->pending_tasks_.push_back(std::move(task)); |
235 | } |
236 | state_->cv_.notify_one(); |
237 | return Status::OK(); |
238 | } |
239 | |
240 | Status ThreadPool::Make(int threads, std::shared_ptr<ThreadPool>* out) { |
241 | auto pool = std::shared_ptr<ThreadPool>(new ThreadPool()); |
242 | RETURN_NOT_OK(pool->SetCapacity(threads)); |
243 | *out = std::move(pool); |
244 | return Status::OK(); |
245 | } |
246 | |
247 | // ---------------------------------------------------------------------- |
248 | // Global thread pool |
249 | |
250 | static int ParseOMPEnvVar(const char* name) { |
251 | // OMP_NUM_THREADS is a comma-separated list of positive integers. |
252 | // We are only interested in the first (top-level) number. |
253 | std::string str; |
254 | if (!GetEnvVar(name, &str).ok()) { |
255 | return 0; |
256 | } |
257 | auto first_comma = str.find_first_of(','); |
258 | if (first_comma != std::string::npos) { |
259 | str = str.substr(0, first_comma); |
260 | } |
261 | try { |
262 | return std::max(0, std::stoi(str)); |
263 | } catch (...) { |
264 | return 0; |
265 | } |
266 | } |
267 | |
268 | int ThreadPool::DefaultCapacity() { |
269 | int capacity, limit; |
270 | capacity = ParseOMPEnvVar("OMP_NUM_THREADS" ); |
271 | if (capacity == 0) { |
272 | capacity = std::thread::hardware_concurrency(); |
273 | } |
274 | limit = ParseOMPEnvVar("OMP_THREAD_LIMIT" ); |
275 | if (limit > 0) { |
276 | capacity = std::min(limit, capacity); |
277 | } |
278 | if (capacity == 0) { |
279 | ARROW_LOG(WARNING) << "Failed to determine the number of available threads, " |
280 | "using a hardcoded arbitrary value" ; |
281 | capacity = 4; |
282 | } |
283 | return capacity; |
284 | } |
285 | |
286 | // Helper for the singleton pattern |
287 | std::shared_ptr<ThreadPool> ThreadPool::MakeCpuThreadPool() { |
288 | std::shared_ptr<ThreadPool> pool; |
289 | DCHECK_OK(ThreadPool::Make(ThreadPool::DefaultCapacity(), &pool)); |
290 | // On Windows, the global ThreadPool destructor may be called after |
291 | // non-main threads have been killed by the OS, and hang in a condition |
292 | // variable. |
293 | // On Unix, we want to avoid leak reports by Valgrind. |
294 | #ifdef _WIN32 |
295 | pool->shutdown_on_destroy_ = false; |
296 | #endif |
297 | return pool; |
298 | } |
299 | |
300 | ThreadPool* GetCpuThreadPool() { |
301 | static std::shared_ptr<ThreadPool> singleton = ThreadPool::MakeCpuThreadPool(); |
302 | return singleton.get(); |
303 | } |
304 | |
305 | } // namespace internal |
306 | |
307 | int GetCpuThreadPoolCapacity() { return internal::GetCpuThreadPool()->GetCapacity(); } |
308 | |
309 | Status SetCpuThreadPoolCapacity(int threads) { |
310 | return internal::GetCpuThreadPool()->SetCapacity(threads); |
311 | } |
312 | |
313 | } // namespace arrow |
314 | |