1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #ifndef _WIN32 |
19 | #include <sys/wait.h> |
20 | #include <unistd.h> |
21 | #endif |
22 | |
23 | #include <algorithm> |
24 | #include <chrono> |
25 | #include <cstdint> |
26 | #include <cstdio> |
27 | #include <cstdlib> |
28 | #include <functional> |
29 | #include <future> |
30 | #include <memory> |
31 | #include <string> |
32 | #include <thread> |
33 | #include <vector> |
34 | |
35 | #include <gtest/gtest.h> |
36 | |
37 | #include "arrow/status.h" |
38 | #include "arrow/test-util.h" |
39 | #include "arrow/util/io-util.h" |
40 | #include "arrow/util/macros.h" |
41 | #include "arrow/util/thread-pool.h" |
42 | |
43 | namespace arrow { |
44 | namespace internal { |
45 | |
46 | static void sleep_for(double seconds) { |
47 | std::this_thread::sleep_for( |
48 | std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9))); |
49 | } |
50 | |
51 | static void busy_wait(double seconds, std::function<bool()> predicate) { |
52 | const double period = 0.001; |
53 | for (int i = 0; !predicate() && i * period < seconds; ++i) { |
54 | sleep_for(period); |
55 | } |
56 | } |
57 | |
58 | template <typename T> |
59 | static void task_add(T x, T y, T* out) { |
60 | *out = x + y; |
61 | } |
62 | |
63 | template <typename T> |
64 | static void task_slow_add(double seconds, T x, T y, T* out) { |
65 | sleep_for(seconds); |
66 | *out = x + y; |
67 | } |
68 | |
69 | typedef std::function<void(int, int, int*)> AddTaskFunc; |
70 | |
71 | template <typename T> |
72 | static T add(T x, T y) { |
73 | return x + y; |
74 | } |
75 | |
76 | template <typename T> |
77 | static T slow_add(double seconds, T x, T y) { |
78 | sleep_for(seconds); |
79 | return x + y; |
80 | } |
81 | |
82 | template <typename T> |
83 | static T inplace_add(T& x, T y) { |
84 | return x += y; |
85 | } |
86 | |
87 | // A class to spawn "add" tasks to a pool and check the results when done |
88 | |
89 | class AddTester { |
90 | public: |
91 | explicit AddTester(int nadds) : nadds(nadds), xs(nadds), ys(nadds), outs(nadds, -1) { |
92 | int x = 0, y = 0; |
93 | std::generate(xs.begin(), xs.end(), [&] { |
94 | ++x; |
95 | return x; |
96 | }); |
97 | std::generate(ys.begin(), ys.end(), [&] { |
98 | y += 10; |
99 | return y; |
100 | }); |
101 | } |
102 | |
103 | AddTester(AddTester&&) = default; |
104 | |
105 | void SpawnTasks(ThreadPool* pool, AddTaskFunc add_func) { |
106 | for (int i = 0; i < nadds; ++i) { |
107 | ASSERT_OK(pool->Spawn([=] { add_func(xs[i], ys[i], &outs[i]); })); |
108 | } |
109 | } |
110 | |
111 | void CheckResults() { |
112 | for (int i = 0; i < nadds; ++i) { |
113 | ASSERT_EQ(outs[i], (i + 1) * 11); |
114 | } |
115 | } |
116 | |
117 | void CheckNotAllComputed() { |
118 | for (int i = 0; i < nadds; ++i) { |
119 | if (outs[i] == -1) { |
120 | return; |
121 | } |
122 | } |
123 | ASSERT_TRUE(0) << "all values were computed" ; |
124 | } |
125 | |
126 | private: |
127 | ARROW_DISALLOW_COPY_AND_ASSIGN(AddTester); |
128 | |
129 | int nadds; |
130 | std::vector<int> xs; |
131 | std::vector<int> ys; |
132 | std::vector<int> outs; |
133 | }; |
134 | |
135 | class TestThreadPool : public ::testing::Test { |
136 | public: |
137 | void TearDown() { |
138 | fflush(stdout); |
139 | fflush(stderr); |
140 | } |
141 | |
142 | std::shared_ptr<ThreadPool> MakeThreadPool() { return MakeThreadPool(4); } |
143 | |
144 | std::shared_ptr<ThreadPool> MakeThreadPool(int threads) { |
145 | std::shared_ptr<ThreadPool> pool; |
146 | Status st = ThreadPool::Make(threads, &pool); |
147 | return pool; |
148 | } |
149 | |
150 | void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func) { |
151 | AddTester add_tester(nadds); |
152 | add_tester.SpawnTasks(pool, add_func); |
153 | ASSERT_OK(pool->Shutdown()); |
154 | add_tester.CheckResults(); |
155 | } |
156 | |
157 | void SpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds, |
158 | AddTaskFunc add_func) { |
159 | // Same as SpawnAdds, but do the task spawning from multiple threads |
160 | std::vector<AddTester> add_testers; |
161 | std::vector<std::thread> threads; |
162 | for (int i = 0; i < nthreads; ++i) { |
163 | add_testers.emplace_back(nadds); |
164 | } |
165 | for (auto& add_tester : add_testers) { |
166 | threads.emplace_back([&] { add_tester.SpawnTasks(pool, add_func); }); |
167 | } |
168 | for (auto& thread : threads) { |
169 | thread.join(); |
170 | } |
171 | ASSERT_OK(pool->Shutdown()); |
172 | for (auto& add_tester : add_testers) { |
173 | add_tester.CheckResults(); |
174 | } |
175 | } |
176 | }; |
177 | |
178 | TEST_F(TestThreadPool, ConstructDestruct) { |
179 | // Stress shutdown-at-destruction logic |
180 | for (int threads : {1, 2, 3, 8, 32, 70}) { |
181 | auto pool = this->MakeThreadPool(threads); |
182 | } |
183 | } |
184 | |
185 | // Correctness and stress tests using Spawn() and Shutdown() |
186 | |
187 | TEST_F(TestThreadPool, Spawn) { |
188 | auto pool = this->MakeThreadPool(3); |
189 | SpawnAdds(pool.get(), 7, task_add<int>); |
190 | } |
191 | |
192 | TEST_F(TestThreadPool, StressSpawn) { |
193 | auto pool = this->MakeThreadPool(30); |
194 | SpawnAdds(pool.get(), 1000, task_add<int>); |
195 | } |
196 | |
197 | TEST_F(TestThreadPool, StressSpawnThreaded) { |
198 | auto pool = this->MakeThreadPool(30); |
199 | SpawnAddsThreaded(pool.get(), 20, 100, task_add<int>); |
200 | } |
201 | |
202 | TEST_F(TestThreadPool, SpawnSlow) { |
203 | // This checks that Shutdown() waits for all tasks to finish |
204 | auto pool = this->MakeThreadPool(2); |
205 | SpawnAdds(pool.get(), 7, [](int x, int y, int* out) { |
206 | return task_slow_add(0.02 /* seconds */, x, y, out); |
207 | }); |
208 | } |
209 | |
210 | TEST_F(TestThreadPool, StressSpawnSlow) { |
211 | auto pool = this->MakeThreadPool(30); |
212 | SpawnAdds(pool.get(), 1000, [](int x, int y, int* out) { |
213 | return task_slow_add(0.002 /* seconds */, x, y, out); |
214 | }); |
215 | } |
216 | |
217 | TEST_F(TestThreadPool, StressSpawnSlowThreaded) { |
218 | auto pool = this->MakeThreadPool(30); |
219 | SpawnAddsThreaded(pool.get(), 20, 100, [](int x, int y, int* out) { |
220 | return task_slow_add(0.002 /* seconds */, x, y, out); |
221 | }); |
222 | } |
223 | |
224 | TEST_F(TestThreadPool, QuickShutdown) { |
225 | AddTester add_tester(100); |
226 | { |
227 | auto pool = this->MakeThreadPool(3); |
228 | add_tester.SpawnTasks(pool.get(), [](int x, int y, int* out) { |
229 | return task_slow_add(0.02 /* seconds */, x, y, out); |
230 | }); |
231 | ASSERT_OK(pool->Shutdown(false /* wait */)); |
232 | add_tester.CheckNotAllComputed(); |
233 | } |
234 | add_tester.CheckNotAllComputed(); |
235 | } |
236 | |
237 | TEST_F(TestThreadPool, SetCapacity) { |
238 | auto pool = this->MakeThreadPool(3); |
239 | ASSERT_EQ(pool->GetCapacity(), 3); |
240 | ASSERT_EQ(pool->GetActualCapacity(), 3); |
241 | |
242 | ASSERT_OK(pool->SetCapacity(5)); |
243 | ASSERT_EQ(pool->GetCapacity(), 5); |
244 | ASSERT_EQ(pool->GetActualCapacity(), 5); |
245 | |
246 | ASSERT_OK(pool->SetCapacity(2)); |
247 | ASSERT_EQ(pool->GetCapacity(), 2); |
248 | // Wait for workers to wake up and secede |
249 | busy_wait(0.5, [&] { return pool->GetActualCapacity() == 2; }); |
250 | ASSERT_EQ(pool->GetActualCapacity(), 2); |
251 | |
252 | ASSERT_OK(pool->SetCapacity(5)); |
253 | ASSERT_EQ(pool->GetCapacity(), 5); |
254 | ASSERT_EQ(pool->GetActualCapacity(), 5); |
255 | |
256 | // Downsize while tasks are pending |
257 | for (int i = 0; i < 10; ++i) { |
258 | ASSERT_OK(pool->Spawn(std::bind(sleep_for, 0.01 /* seconds */))); |
259 | } |
260 | ASSERT_OK(pool->SetCapacity(2)); |
261 | ASSERT_EQ(pool->GetCapacity(), 2); |
262 | busy_wait(0.5, [&] { return pool->GetActualCapacity() == 2; }); |
263 | ASSERT_EQ(pool->GetActualCapacity(), 2); |
264 | |
265 | // Ensure nothing got stuck |
266 | ASSERT_OK(pool->Shutdown()); |
267 | } |
268 | |
269 | // Test Submit() functionality |
270 | |
271 | TEST_F(TestThreadPool, Submit) { |
272 | auto pool = this->MakeThreadPool(3); |
273 | { |
274 | auto fut = pool->Submit(add<int>, 4, 5); |
275 | ASSERT_EQ(fut.get(), 9); |
276 | } |
277 | { |
278 | auto fut = pool->Submit(add<std::string>, "foo" , "bar" ); |
279 | ASSERT_EQ(fut.get(), "foobar" ); |
280 | } |
281 | { |
282 | auto fut = pool->Submit(slow_add<int>, 0.01 /* seconds */, 4, 5); |
283 | ASSERT_EQ(fut.get(), 9); |
284 | } |
285 | { |
286 | // Reference passing |
287 | std::string s = "foo" ; |
288 | auto fut = pool->Submit(inplace_add<std::string>, std::ref(s), "bar" ); |
289 | ASSERT_EQ(fut.get(), "foobar" ); |
290 | ASSERT_EQ(s, "foobar" ); |
291 | } |
292 | { |
293 | // `void` return type |
294 | auto fut = pool->Submit(sleep_for, 0.001); |
295 | fut.get(); |
296 | } |
297 | } |
298 | |
299 | // Test fork safety on Unix |
300 | |
301 | #if !(defined(_WIN32) || defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER) || \ |
302 | defined(THREAD_SANITIZER)) |
303 | TEST_F(TestThreadPool, ForkSafety) { |
304 | pid_t child_pid; |
305 | int child_status; |
306 | |
307 | { |
308 | // Fork after task submission |
309 | auto pool = this->MakeThreadPool(3); |
310 | auto fut = pool->Submit(add<int>, 4, 5); |
311 | ASSERT_EQ(fut.get(), 9); |
312 | |
313 | child_pid = fork(); |
314 | if (child_pid == 0) { |
315 | // Child: thread pool should be usable |
316 | fut = pool->Submit(add<int>, 3, 4); |
317 | if (fut.get() != 7) { |
318 | std::exit(1); |
319 | } |
320 | // Shutting down shouldn't hang or fail |
321 | Status st = pool->Shutdown(); |
322 | std::exit(st.ok() ? 0 : 2); |
323 | } else { |
324 | // Parent |
325 | ASSERT_GT(child_pid, 0); |
326 | ASSERT_GT(waitpid(child_pid, &child_status, 0), 0); |
327 | ASSERT_TRUE(WIFEXITED(child_status)); |
328 | ASSERT_EQ(WEXITSTATUS(child_status), 0); |
329 | ASSERT_OK(pool->Shutdown()); |
330 | } |
331 | } |
332 | { |
333 | // Fork after shutdown |
334 | auto pool = this->MakeThreadPool(3); |
335 | ASSERT_OK(pool->Shutdown()); |
336 | |
337 | child_pid = fork(); |
338 | if (child_pid == 0) { |
339 | // Child |
340 | // Spawning a task should return with error (pool was shutdown) |
341 | Status st = pool->Spawn([] {}); |
342 | if (!st.IsInvalid()) { |
343 | std::exit(1); |
344 | } |
345 | // Trigger destructor |
346 | pool.reset(); |
347 | std::exit(0); |
348 | } else { |
349 | // Parent |
350 | ASSERT_GT(child_pid, 0); |
351 | ASSERT_GT(waitpid(child_pid, &child_status, 0), 0); |
352 | ASSERT_TRUE(WIFEXITED(child_status)); |
353 | ASSERT_EQ(WEXITSTATUS(child_status), 0); |
354 | } |
355 | } |
356 | } |
357 | #endif |
358 | |
359 | TEST(TestGlobalThreadPool, Capacity) { |
360 | // Sanity check |
361 | auto pool = GetCpuThreadPool(); |
362 | int capacity = pool->GetCapacity(); |
363 | ASSERT_GT(capacity, 0); |
364 | ASSERT_EQ(pool->GetActualCapacity(), capacity); |
365 | ASSERT_EQ(GetCpuThreadPoolCapacity(), capacity); |
366 | |
367 | // Exercise default capacity heuristic |
368 | ASSERT_OK(DelEnvVar("OMP_NUM_THREADS" )); |
369 | ASSERT_OK(DelEnvVar("OMP_THREAD_LIMIT" )); |
370 | int hw_capacity = std::thread::hardware_concurrency(); |
371 | ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity); |
372 | ASSERT_OK(SetEnvVar("OMP_NUM_THREADS" , "13" )); |
373 | ASSERT_EQ(ThreadPool::DefaultCapacity(), 13); |
374 | ASSERT_OK(SetEnvVar("OMP_NUM_THREADS" , "7,5,13" )); |
375 | ASSERT_EQ(ThreadPool::DefaultCapacity(), 7); |
376 | ASSERT_OK(DelEnvVar("OMP_NUM_THREADS" )); |
377 | |
378 | ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT" , "1" )); |
379 | ASSERT_EQ(ThreadPool::DefaultCapacity(), 1); |
380 | ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT" , "999" )); |
381 | if (hw_capacity <= 999) { |
382 | ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity); |
383 | } |
384 | ASSERT_OK(SetEnvVar("OMP_NUM_THREADS" , "6,5,13" )); |
385 | ASSERT_EQ(ThreadPool::DefaultCapacity(), 6); |
386 | ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT" , "2" )); |
387 | ASSERT_EQ(ThreadPool::DefaultCapacity(), 2); |
388 | |
389 | // Invalid env values |
390 | ASSERT_OK(SetEnvVar("OMP_NUM_THREADS" , "0" )); |
391 | ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT" , "0" )); |
392 | ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity); |
393 | ASSERT_OK(SetEnvVar("OMP_NUM_THREADS" , "zzz" )); |
394 | ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT" , "x" )); |
395 | ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity); |
396 | ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT" , "-1" )); |
397 | ASSERT_OK(SetEnvVar("OMP_NUM_THREADS" , "99999999999999999999999999" )); |
398 | ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity); |
399 | |
400 | ASSERT_OK(DelEnvVar("OMP_NUM_THREADS" )); |
401 | ASSERT_OK(DelEnvVar("OMP_THREAD_LIMIT" )); |
402 | } |
403 | |
404 | } // namespace internal |
405 | } // namespace arrow |
406 | |