1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include <atomic> |
19 | #include <chrono> |
20 | #include <cstdint> |
21 | #include <functional> |
22 | #include <memory> |
23 | #include <random> |
24 | #include <thread> |
25 | #include <vector> |
26 | |
27 | #include <gtest/gtest.h> |
28 | |
29 | #include "arrow/status.h" |
30 | #include "arrow/test-util.h" |
31 | #include "arrow/util/task-group.h" |
32 | #include "arrow/util/thread-pool.h" |
33 | |
34 | namespace arrow { |
35 | namespace internal { |
36 | |
37 | static void sleep_for(double seconds) { |
38 | std::this_thread::sleep_for( |
39 | std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9))); |
40 | } |
41 | |
42 | // Generate random sleep durations |
43 | static std::vector<double> RandomSleepDurations(int nsleeps, double min_seconds, |
44 | double max_seconds) { |
45 | std::vector<double> sleeps; |
46 | std::default_random_engine engine; |
47 | std::uniform_real_distribution<> sleep_dist(min_seconds, max_seconds); |
48 | for (int i = 0; i < nsleeps; ++i) { |
49 | sleeps.push_back(sleep_dist(engine)); |
50 | } |
51 | return sleeps; |
52 | } |
53 | |
54 | // Check TaskGroup behaviour with a bunch of all-successful tasks |
55 | void TestTaskGroupSuccess(std::shared_ptr<TaskGroup> task_group) { |
56 | const int NTASKS = 10; |
57 | auto sleeps = RandomSleepDurations(NTASKS, 1e-3, 4e-3); |
58 | |
59 | // Add NTASKS sleeps |
60 | std::atomic<int> count(0); |
61 | for (int i = 0; i < NTASKS; ++i) { |
62 | task_group->Append([&, i]() { |
63 | sleep_for(sleeps[i]); |
64 | count += i; |
65 | return Status::OK(); |
66 | }); |
67 | } |
68 | ASSERT_TRUE(task_group->ok()); |
69 | |
70 | ASSERT_OK(task_group->Finish()); |
71 | ASSERT_TRUE(task_group->ok()); |
72 | ASSERT_EQ(count.load(), NTASKS * (NTASKS - 1) / 2); |
73 | // Finish() is idempotent |
74 | ASSERT_OK(task_group->Finish()); |
75 | } |
76 | |
77 | // Check TaskGroup behaviour with some successful and some failing tasks |
78 | void TestTaskGroupErrors(std::shared_ptr<TaskGroup> task_group) { |
79 | const int NSUCCESSES = 2; |
80 | const int NERRORS = 20; |
81 | |
82 | std::atomic<int> count(0); |
83 | |
84 | for (int i = 0; i < NSUCCESSES; ++i) { |
85 | task_group->Append([&]() { |
86 | count++; |
87 | return Status::OK(); |
88 | }); |
89 | } |
90 | ASSERT_TRUE(task_group->ok()); |
91 | for (int i = 0; i < NERRORS; ++i) { |
92 | task_group->Append([&]() { |
93 | sleep_for(1e-2); |
94 | count++; |
95 | return Status::Invalid("some message" ); |
96 | }); |
97 | } |
98 | |
99 | // Task error is propagated |
100 | ASSERT_RAISES(Invalid, task_group->Finish()); |
101 | ASSERT_FALSE(task_group->ok()); |
102 | if (task_group->parallelism() == 1) { |
103 | // Serial: exactly two successes and an error |
104 | ASSERT_EQ(count.load(), 3); |
105 | } else { |
106 | // Parallel: at least two successes and an error |
107 | ASSERT_GE(count.load(), 3); |
108 | ASSERT_LE(count.load(), 2 * task_group->parallelism()); |
109 | } |
110 | // Finish() is idempotent |
111 | ASSERT_RAISES(Invalid, task_group->Finish()); |
112 | } |
113 | |
114 | // Check TaskGroup behaviour with a bunch of all-successful tasks and task groups |
115 | void TestTaskSubGroupsSuccess(std::shared_ptr<TaskGroup> task_group) { |
116 | const int NTASKS = 50; |
117 | const int NGROUPS = 7; |
118 | |
119 | auto sleeps = RandomSleepDurations(NTASKS, 1e-4, 1e-3); |
120 | std::vector<std::shared_ptr<TaskGroup>> groups = {task_group}; |
121 | |
122 | // Create some subgroups |
123 | for (int i = 0; i < NGROUPS - 1; ++i) { |
124 | groups.push_back(task_group->MakeSubGroup()); |
125 | } |
126 | |
127 | // Add NTASKS sleeps amonst all groups |
128 | std::atomic<int> count(0); |
129 | for (int i = 0; i < NTASKS; ++i) { |
130 | groups[i % NGROUPS]->Append([&, i]() { |
131 | sleep_for(sleeps[i]); |
132 | count += i; |
133 | return Status::OK(); |
134 | }); |
135 | } |
136 | ASSERT_TRUE(task_group->ok()); |
137 | |
138 | // Finish all subgroups first, then main group |
139 | for (int i = NGROUPS - 1; i >= 0; --i) { |
140 | ASSERT_OK(groups[i]->Finish()); |
141 | } |
142 | ASSERT_TRUE(task_group->ok()); |
143 | ASSERT_EQ(count.load(), NTASKS * (NTASKS - 1) / 2); |
144 | // Finish() is idempotent |
145 | ASSERT_OK(task_group->Finish()); |
146 | } |
147 | |
148 | // Check TaskGroup behaviour with both successful and failing tasks and task groups |
149 | void TestTaskSubGroupsErrors(std::shared_ptr<TaskGroup> task_group) { |
150 | const int NTASKS = 50; |
151 | const int NGROUPS = 7; |
152 | const int FAIL_EVERY = 17; |
153 | std::vector<std::shared_ptr<TaskGroup>> groups = {task_group}; |
154 | |
155 | // Create some subgroups |
156 | for (int i = 0; i < NGROUPS - 1; ++i) { |
157 | groups.push_back(task_group->MakeSubGroup()); |
158 | } |
159 | |
160 | // Add NTASKS sleeps amonst all groups |
161 | for (int i = 0; i < NTASKS; ++i) { |
162 | groups[i % NGROUPS]->Append([&, i]() { |
163 | sleep_for(1e-3); |
164 | // As NGROUPS > NTASKS / FAIL_EVERY, some subgroups are successful |
165 | if (i % FAIL_EVERY == 0) { |
166 | return Status::Invalid("some message" ); |
167 | } else { |
168 | return Status::OK(); |
169 | } |
170 | }); |
171 | } |
172 | |
173 | // Finish all subgroups first, then main group |
174 | int nsuccessful = 0; |
175 | for (int i = NGROUPS - 1; i > 0; --i) { |
176 | Status st = groups[i]->Finish(); |
177 | if (st.ok()) { |
178 | ++nsuccessful; |
179 | } else { |
180 | ASSERT_RAISES(Invalid, st); |
181 | } |
182 | } |
183 | ASSERT_RAISES(Invalid, task_group->Finish()); |
184 | ASSERT_FALSE(task_group->ok()); |
185 | // Finish() is idempotent |
186 | ASSERT_RAISES(Invalid, task_group->Finish()); |
187 | } |
188 | |
189 | // Check TaskGroup behaviour with tasks spawning other tasks |
190 | void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) { |
191 | const int N = 6; |
192 | |
193 | std::atomic<int> count(0); |
194 | // Make a task that recursively spawns itself |
195 | std::function<std::function<Status()>(int)> make_task = [&](int i) { |
196 | return [&, i]() { |
197 | count++; |
198 | if (i > 0) { |
199 | // Exercise parallelism by spawning two tasks at once and then sleeping |
200 | task_group->Append(make_task(i - 1)); |
201 | task_group->Append(make_task(i - 1)); |
202 | sleep_for(1e-3); |
203 | } |
204 | return Status::OK(); |
205 | }; |
206 | }; |
207 | |
208 | task_group->Append(make_task(N)); |
209 | |
210 | ASSERT_OK(task_group->Finish()); |
211 | ASSERT_TRUE(task_group->ok()); |
212 | ASSERT_EQ(count.load(), (1 << (N + 1)) - 1); |
213 | } |
214 | |
215 | TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); } |
216 | |
217 | TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); } |
218 | |
219 | TEST(SerialTaskGroup, TasksSpawnTasks) { TestTasksSpawnTasks(TaskGroup::MakeSerial()); } |
220 | |
221 | TEST(SerialTaskGroup, SubGroupsSuccess) { |
222 | TestTaskSubGroupsSuccess(TaskGroup::MakeSerial()); |
223 | } |
224 | |
225 | TEST(SerialTaskGroup, SubGroupsErrors) { |
226 | TestTaskSubGroupsErrors(TaskGroup::MakeSerial()); |
227 | } |
228 | |
229 | TEST(ThreadedTaskGroup, Success) { |
230 | auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool()); |
231 | TestTaskGroupSuccess(task_group); |
232 | } |
233 | |
234 | TEST(ThreadedTaskGroup, Errors) { |
235 | // Limit parallelism to ensure some tasks don't get started |
236 | // after the first failing ones |
237 | std::shared_ptr<ThreadPool> thread_pool; |
238 | ASSERT_OK(ThreadPool::Make(4, &thread_pool)); |
239 | |
240 | TestTaskGroupErrors(TaskGroup::MakeThreaded(thread_pool.get())); |
241 | } |
242 | |
243 | TEST(ThreadedTaskGroup, TasksSpawnTasks) { |
244 | auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool()); |
245 | TestTasksSpawnTasks(task_group); |
246 | } |
247 | |
248 | TEST(ThreadedTaskGroup, SubGroupsSuccess) { |
249 | std::shared_ptr<ThreadPool> thread_pool; |
250 | ASSERT_OK(ThreadPool::Make(4, &thread_pool)); |
251 | |
252 | TestTaskSubGroupsSuccess(TaskGroup::MakeThreaded(thread_pool.get())); |
253 | } |
254 | |
255 | TEST(ThreadedTaskGroup, SubGroupsErrors) { |
256 | std::shared_ptr<ThreadPool> thread_pool; |
257 | ASSERT_OK(ThreadPool::Make(4, &thread_pool)); |
258 | |
259 | TestTaskSubGroupsErrors(TaskGroup::MakeThreaded(thread_pool.get())); |
260 | } |
261 | |
262 | } // namespace internal |
263 | } // namespace arrow |
264 | |