1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#ifndef _WIN32
19#include <sys/wait.h>
20#include <unistd.h>
21#endif
22
23#include <algorithm>
24#include <chrono>
25#include <cstdint>
26#include <cstdio>
27#include <cstdlib>
28#include <functional>
29#include <future>
30#include <memory>
31#include <string>
32#include <thread>
33#include <vector>
34
35#include <gtest/gtest.h>
36
37#include "arrow/status.h"
38#include "arrow/test-util.h"
39#include "arrow/util/io-util.h"
40#include "arrow/util/macros.h"
41#include "arrow/util/thread-pool.h"
42
43namespace arrow {
44namespace internal {
45
46static void sleep_for(double seconds) {
47 std::this_thread::sleep_for(
48 std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9)));
49}
50
51static void busy_wait(double seconds, std::function<bool()> predicate) {
52 const double period = 0.001;
53 for (int i = 0; !predicate() && i * period < seconds; ++i) {
54 sleep_for(period);
55 }
56}
57
58template <typename T>
59static void task_add(T x, T y, T* out) {
60 *out = x + y;
61}
62
63template <typename T>
64static void task_slow_add(double seconds, T x, T y, T* out) {
65 sleep_for(seconds);
66 *out = x + y;
67}
68
69typedef std::function<void(int, int, int*)> AddTaskFunc;
70
71template <typename T>
72static T add(T x, T y) {
73 return x + y;
74}
75
76template <typename T>
77static T slow_add(double seconds, T x, T y) {
78 sleep_for(seconds);
79 return x + y;
80}
81
82template <typename T>
83static T inplace_add(T& x, T y) {
84 return x += y;
85}
86
87// A class to spawn "add" tasks to a pool and check the results when done
88
89class AddTester {
90 public:
91 explicit AddTester(int nadds) : nadds(nadds), xs(nadds), ys(nadds), outs(nadds, -1) {
92 int x = 0, y = 0;
93 std::generate(xs.begin(), xs.end(), [&] {
94 ++x;
95 return x;
96 });
97 std::generate(ys.begin(), ys.end(), [&] {
98 y += 10;
99 return y;
100 });
101 }
102
103 AddTester(AddTester&&) = default;
104
105 void SpawnTasks(ThreadPool* pool, AddTaskFunc add_func) {
106 for (int i = 0; i < nadds; ++i) {
107 ASSERT_OK(pool->Spawn([=] { add_func(xs[i], ys[i], &outs[i]); }));
108 }
109 }
110
111 void CheckResults() {
112 for (int i = 0; i < nadds; ++i) {
113 ASSERT_EQ(outs[i], (i + 1) * 11);
114 }
115 }
116
117 void CheckNotAllComputed() {
118 for (int i = 0; i < nadds; ++i) {
119 if (outs[i] == -1) {
120 return;
121 }
122 }
123 ASSERT_TRUE(0) << "all values were computed";
124 }
125
126 private:
127 ARROW_DISALLOW_COPY_AND_ASSIGN(AddTester);
128
129 int nadds;
130 std::vector<int> xs;
131 std::vector<int> ys;
132 std::vector<int> outs;
133};
134
135class TestThreadPool : public ::testing::Test {
136 public:
137 void TearDown() {
138 fflush(stdout);
139 fflush(stderr);
140 }
141
142 std::shared_ptr<ThreadPool> MakeThreadPool() { return MakeThreadPool(4); }
143
144 std::shared_ptr<ThreadPool> MakeThreadPool(int threads) {
145 std::shared_ptr<ThreadPool> pool;
146 Status st = ThreadPool::Make(threads, &pool);
147 return pool;
148 }
149
150 void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func) {
151 AddTester add_tester(nadds);
152 add_tester.SpawnTasks(pool, add_func);
153 ASSERT_OK(pool->Shutdown());
154 add_tester.CheckResults();
155 }
156
157 void SpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds,
158 AddTaskFunc add_func) {
159 // Same as SpawnAdds, but do the task spawning from multiple threads
160 std::vector<AddTester> add_testers;
161 std::vector<std::thread> threads;
162 for (int i = 0; i < nthreads; ++i) {
163 add_testers.emplace_back(nadds);
164 }
165 for (auto& add_tester : add_testers) {
166 threads.emplace_back([&] { add_tester.SpawnTasks(pool, add_func); });
167 }
168 for (auto& thread : threads) {
169 thread.join();
170 }
171 ASSERT_OK(pool->Shutdown());
172 for (auto& add_tester : add_testers) {
173 add_tester.CheckResults();
174 }
175 }
176};
177
178TEST_F(TestThreadPool, ConstructDestruct) {
179 // Stress shutdown-at-destruction logic
180 for (int threads : {1, 2, 3, 8, 32, 70}) {
181 auto pool = this->MakeThreadPool(threads);
182 }
183}
184
185// Correctness and stress tests using Spawn() and Shutdown()
186
187TEST_F(TestThreadPool, Spawn) {
188 auto pool = this->MakeThreadPool(3);
189 SpawnAdds(pool.get(), 7, task_add<int>);
190}
191
192TEST_F(TestThreadPool, StressSpawn) {
193 auto pool = this->MakeThreadPool(30);
194 SpawnAdds(pool.get(), 1000, task_add<int>);
195}
196
197TEST_F(TestThreadPool, StressSpawnThreaded) {
198 auto pool = this->MakeThreadPool(30);
199 SpawnAddsThreaded(pool.get(), 20, 100, task_add<int>);
200}
201
202TEST_F(TestThreadPool, SpawnSlow) {
203 // This checks that Shutdown() waits for all tasks to finish
204 auto pool = this->MakeThreadPool(2);
205 SpawnAdds(pool.get(), 7, [](int x, int y, int* out) {
206 return task_slow_add(0.02 /* seconds */, x, y, out);
207 });
208}
209
210TEST_F(TestThreadPool, StressSpawnSlow) {
211 auto pool = this->MakeThreadPool(30);
212 SpawnAdds(pool.get(), 1000, [](int x, int y, int* out) {
213 return task_slow_add(0.002 /* seconds */, x, y, out);
214 });
215}
216
217TEST_F(TestThreadPool, StressSpawnSlowThreaded) {
218 auto pool = this->MakeThreadPool(30);
219 SpawnAddsThreaded(pool.get(), 20, 100, [](int x, int y, int* out) {
220 return task_slow_add(0.002 /* seconds */, x, y, out);
221 });
222}
223
224TEST_F(TestThreadPool, QuickShutdown) {
225 AddTester add_tester(100);
226 {
227 auto pool = this->MakeThreadPool(3);
228 add_tester.SpawnTasks(pool.get(), [](int x, int y, int* out) {
229 return task_slow_add(0.02 /* seconds */, x, y, out);
230 });
231 ASSERT_OK(pool->Shutdown(false /* wait */));
232 add_tester.CheckNotAllComputed();
233 }
234 add_tester.CheckNotAllComputed();
235}
236
237TEST_F(TestThreadPool, SetCapacity) {
238 auto pool = this->MakeThreadPool(3);
239 ASSERT_EQ(pool->GetCapacity(), 3);
240 ASSERT_EQ(pool->GetActualCapacity(), 3);
241
242 ASSERT_OK(pool->SetCapacity(5));
243 ASSERT_EQ(pool->GetCapacity(), 5);
244 ASSERT_EQ(pool->GetActualCapacity(), 5);
245
246 ASSERT_OK(pool->SetCapacity(2));
247 ASSERT_EQ(pool->GetCapacity(), 2);
248 // Wait for workers to wake up and secede
249 busy_wait(0.5, [&] { return pool->GetActualCapacity() == 2; });
250 ASSERT_EQ(pool->GetActualCapacity(), 2);
251
252 ASSERT_OK(pool->SetCapacity(5));
253 ASSERT_EQ(pool->GetCapacity(), 5);
254 ASSERT_EQ(pool->GetActualCapacity(), 5);
255
256 // Downsize while tasks are pending
257 for (int i = 0; i < 10; ++i) {
258 ASSERT_OK(pool->Spawn(std::bind(sleep_for, 0.01 /* seconds */)));
259 }
260 ASSERT_OK(pool->SetCapacity(2));
261 ASSERT_EQ(pool->GetCapacity(), 2);
262 busy_wait(0.5, [&] { return pool->GetActualCapacity() == 2; });
263 ASSERT_EQ(pool->GetActualCapacity(), 2);
264
265 // Ensure nothing got stuck
266 ASSERT_OK(pool->Shutdown());
267}
268
269// Test Submit() functionality
270
271TEST_F(TestThreadPool, Submit) {
272 auto pool = this->MakeThreadPool(3);
273 {
274 auto fut = pool->Submit(add<int>, 4, 5);
275 ASSERT_EQ(fut.get(), 9);
276 }
277 {
278 auto fut = pool->Submit(add<std::string>, "foo", "bar");
279 ASSERT_EQ(fut.get(), "foobar");
280 }
281 {
282 auto fut = pool->Submit(slow_add<int>, 0.01 /* seconds */, 4, 5);
283 ASSERT_EQ(fut.get(), 9);
284 }
285 {
286 // Reference passing
287 std::string s = "foo";
288 auto fut = pool->Submit(inplace_add<std::string>, std::ref(s), "bar");
289 ASSERT_EQ(fut.get(), "foobar");
290 ASSERT_EQ(s, "foobar");
291 }
292 {
293 // `void` return type
294 auto fut = pool->Submit(sleep_for, 0.001);
295 fut.get();
296 }
297}
298
299// Test fork safety on Unix
300
301#if !(defined(_WIN32) || defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER) || \
302 defined(THREAD_SANITIZER))
303TEST_F(TestThreadPool, ForkSafety) {
304 pid_t child_pid;
305 int child_status;
306
307 {
308 // Fork after task submission
309 auto pool = this->MakeThreadPool(3);
310 auto fut = pool->Submit(add<int>, 4, 5);
311 ASSERT_EQ(fut.get(), 9);
312
313 child_pid = fork();
314 if (child_pid == 0) {
315 // Child: thread pool should be usable
316 fut = pool->Submit(add<int>, 3, 4);
317 if (fut.get() != 7) {
318 std::exit(1);
319 }
320 // Shutting down shouldn't hang or fail
321 Status st = pool->Shutdown();
322 std::exit(st.ok() ? 0 : 2);
323 } else {
324 // Parent
325 ASSERT_GT(child_pid, 0);
326 ASSERT_GT(waitpid(child_pid, &child_status, 0), 0);
327 ASSERT_TRUE(WIFEXITED(child_status));
328 ASSERT_EQ(WEXITSTATUS(child_status), 0);
329 ASSERT_OK(pool->Shutdown());
330 }
331 }
332 {
333 // Fork after shutdown
334 auto pool = this->MakeThreadPool(3);
335 ASSERT_OK(pool->Shutdown());
336
337 child_pid = fork();
338 if (child_pid == 0) {
339 // Child
340 // Spawning a task should return with error (pool was shutdown)
341 Status st = pool->Spawn([] {});
342 if (!st.IsInvalid()) {
343 std::exit(1);
344 }
345 // Trigger destructor
346 pool.reset();
347 std::exit(0);
348 } else {
349 // Parent
350 ASSERT_GT(child_pid, 0);
351 ASSERT_GT(waitpid(child_pid, &child_status, 0), 0);
352 ASSERT_TRUE(WIFEXITED(child_status));
353 ASSERT_EQ(WEXITSTATUS(child_status), 0);
354 }
355 }
356}
357#endif
358
359TEST(TestGlobalThreadPool, Capacity) {
360 // Sanity check
361 auto pool = GetCpuThreadPool();
362 int capacity = pool->GetCapacity();
363 ASSERT_GT(capacity, 0);
364 ASSERT_EQ(pool->GetActualCapacity(), capacity);
365 ASSERT_EQ(GetCpuThreadPoolCapacity(), capacity);
366
367 // Exercise default capacity heuristic
368 ASSERT_OK(DelEnvVar("OMP_NUM_THREADS"));
369 ASSERT_OK(DelEnvVar("OMP_THREAD_LIMIT"));
370 int hw_capacity = std::thread::hardware_concurrency();
371 ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
372 ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "13"));
373 ASSERT_EQ(ThreadPool::DefaultCapacity(), 13);
374 ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "7,5,13"));
375 ASSERT_EQ(ThreadPool::DefaultCapacity(), 7);
376 ASSERT_OK(DelEnvVar("OMP_NUM_THREADS"));
377
378 ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "1"));
379 ASSERT_EQ(ThreadPool::DefaultCapacity(), 1);
380 ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "999"));
381 if (hw_capacity <= 999) {
382 ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
383 }
384 ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "6,5,13"));
385 ASSERT_EQ(ThreadPool::DefaultCapacity(), 6);
386 ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "2"));
387 ASSERT_EQ(ThreadPool::DefaultCapacity(), 2);
388
389 // Invalid env values
390 ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "0"));
391 ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "0"));
392 ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
393 ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "zzz"));
394 ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "x"));
395 ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
396 ASSERT_OK(SetEnvVar("OMP_THREAD_LIMIT", "-1"));
397 ASSERT_OK(SetEnvVar("OMP_NUM_THREADS", "99999999999999999999999999"));
398 ASSERT_EQ(ThreadPool::DefaultCapacity(), hw_capacity);
399
400 ASSERT_OK(DelEnvVar("OMP_NUM_THREADS"));
401 ASSERT_OK(DelEnvVar("OMP_THREAD_LIMIT"));
402}
403
404} // namespace internal
405} // namespace arrow
406