1 | //************************************ bs::framework - Copyright 2018 Marko Pintera **************************************// |
2 | //*********** Licensed under the MIT license. See LICENSE.md for full terms. This notice is not to be removed. ***********// |
3 | #include "Threading/BsTaskScheduler.h" |
4 | #include "Threading/BsThreadPool.h" |
5 | |
6 | namespace bs |
7 | { |
8 | Task::Task(const PrivatelyConstruct& dummy, const String& name, std::function<void()> taskWorker, |
9 | TaskPriority priority, SPtr<Task> dependency) |
10 | : mName(name), mPriority(priority), mTaskWorker(std::move(taskWorker)), mTaskDependency(std::move(dependency)) |
11 | { |
12 | |
13 | } |
14 | |
15 | SPtr<Task> Task::create(const String& name, std::function<void()> taskWorker, TaskPriority priority, |
16 | SPtr<Task> dependency) |
17 | { |
18 | return bs_shared_ptr_new<Task>(PrivatelyConstruct(), name, std::move(taskWorker), priority, std::move(dependency)); |
19 | } |
20 | |
21 | bool Task::isComplete() const |
22 | { |
23 | return mState == 2; |
24 | } |
25 | |
26 | bool Task::isCanceled() const |
27 | { |
28 | return mState == 3; |
29 | } |
30 | |
31 | bool Task::hasStarted() const |
32 | { |
33 | UINT32 state = mState; |
34 | |
35 | return state == 1 || state == 2; |
36 | } |
37 | |
38 | void Task::wait() |
39 | { |
40 | if(mParent != nullptr) |
41 | mParent->waitUntilComplete(this); |
42 | } |
43 | |
44 | void Task::cancel() |
45 | { |
46 | mState = 3; |
47 | } |
48 | |
49 | TaskGroup::TaskGroup(const PrivatelyConstruct& dummy, String name, std::function<void(UINT32)> taskWorker, |
50 | UINT32 count, TaskPriority priority, SPtr<Task> dependency) |
51 | : mName(std::move(name)), mCount(count), mPriority(priority), mTaskWorker(std::move(taskWorker)) |
52 | , mTaskDependency(std::move(dependency)) |
53 | { |
54 | |
55 | } |
56 | |
57 | SPtr<TaskGroup> TaskGroup::create(String name, std::function<void(UINT32)> taskWorker, UINT32 count, |
58 | TaskPriority priority, SPtr<Task> dependency) |
59 | { |
60 | return bs_shared_ptr_new<TaskGroup>(PrivatelyConstruct(), std::move(name), std::move(taskWorker), count, priority, |
61 | std::move(dependency)); |
62 | } |
63 | |
64 | bool TaskGroup::isComplete() const |
65 | { |
66 | return mNumRemainingTasks == 0; |
67 | } |
68 | |
69 | void TaskGroup::wait() |
70 | { |
71 | if(mParent != nullptr) |
72 | mParent->waitUntilComplete(this); |
73 | } |
74 | |
75 | TaskScheduler::TaskScheduler() |
76 | :mTaskQueue(&TaskScheduler::taskCompare) |
77 | { |
78 | mMaxActiveTasks = BS_THREAD_HARDWARE_CONCURRENCY; |
79 | |
80 | mTaskSchedulerThread = ThreadPool::instance().run("TaskScheduler" , std::bind(&TaskScheduler::runMain, this)); |
81 | } |
82 | |
83 | TaskScheduler::~TaskScheduler() |
84 | { |
85 | // Wait until all tasks complete |
86 | { |
87 | Lock activeTaskLock(mReadyMutex); |
88 | |
89 | while (mActiveTasks.size() > 0) |
90 | { |
91 | SPtr<Task> task = mActiveTasks[0]; |
92 | activeTaskLock.unlock(); |
93 | |
94 | task->wait(); |
95 | activeTaskLock.lock(); |
96 | } |
97 | } |
98 | |
99 | // Start shutdown of the main queue worker and wait until it exits |
100 | { |
101 | Lock lock(mReadyMutex); |
102 | |
103 | mShutdown = true; |
104 | } |
105 | |
106 | mTaskReadyCond.notify_one(); |
107 | |
108 | mTaskSchedulerThread.blockUntilComplete(); |
109 | } |
110 | |
111 | void TaskScheduler::addTask(SPtr<Task> task) |
112 | { |
113 | Lock lock(mReadyMutex); |
114 | |
115 | assert(task->mState != 1 && "Task is already executing, it cannot be executed again until it finishes." ); |
116 | |
117 | task->mParent = this; |
118 | task->mTaskId = mNextTaskId++; |
119 | task->mState.store(0); // Reset state in case the task is getting re-queued |
120 | |
121 | mCheckTasks = true; |
122 | mTaskQueue.insert(std::move(task)); |
123 | |
124 | // Wake main scheduler thread |
125 | mTaskReadyCond.notify_one(); |
126 | } |
127 | |
128 | void TaskScheduler::addTaskGroup(const SPtr<TaskGroup>& taskGroup) |
129 | { |
130 | Lock lock(mReadyMutex); |
131 | |
132 | for(UINT32 i = 0; i < taskGroup->mCount; i++) |
133 | { |
134 | const auto worker = [i, taskGroup] |
135 | { |
136 | taskGroup->mTaskWorker(i); |
137 | --taskGroup->mNumRemainingTasks; |
138 | }; |
139 | |
140 | SPtr<Task> task = Task::create(taskGroup->mName, worker, taskGroup->mPriority, taskGroup->mTaskDependency); |
141 | task->mParent = this; |
142 | task->mTaskId = mNextTaskId++; |
143 | task->mState.store(0); // Reset state in case the task is getting re-queued |
144 | |
145 | mCheckTasks = true; |
146 | mTaskQueue.insert(std::move(task)); |
147 | } |
148 | |
149 | taskGroup->mParent = this; |
150 | |
151 | // Wake main scheduler thread |
152 | mTaskReadyCond.notify_one(); |
153 | } |
154 | |
155 | void TaskScheduler::addWorker() |
156 | { |
157 | Lock lock(mReadyMutex); |
158 | |
159 | mMaxActiveTasks++; |
160 | |
161 | // A spot freed up, queue new tasks on main scheduler thread if they exist |
162 | mTaskReadyCond.notify_one(); |
163 | } |
164 | |
165 | void TaskScheduler::removeWorker() |
166 | { |
167 | Lock lock(mReadyMutex); |
168 | |
169 | if(mMaxActiveTasks > 0) |
170 | mMaxActiveTasks--; |
171 | } |
172 | |
173 | void TaskScheduler::runMain() |
174 | { |
175 | while(true) |
176 | { |
177 | Lock lock(mReadyMutex); |
178 | |
179 | while((!mCheckTasks || (UINT32)mActiveTasks.size() >= mMaxActiveTasks) && !mShutdown) |
180 | mTaskReadyCond.wait(lock); |
181 | |
182 | mCheckTasks = false; |
183 | |
184 | if(mShutdown) |
185 | break; |
186 | |
187 | for(auto iter = mTaskQueue.begin(); iter != mTaskQueue.end();) |
188 | { |
189 | if ((UINT32)mActiveTasks.size() >= mMaxActiveTasks) |
190 | break; |
191 | |
192 | SPtr<Task> curTask = *iter; |
193 | |
194 | if(curTask->isCanceled()) |
195 | { |
196 | iter = mTaskQueue.erase(iter); |
197 | continue; |
198 | } |
199 | |
200 | if(curTask->mTaskDependency != nullptr && !curTask->mTaskDependency->isComplete()) |
201 | { |
202 | ++iter; |
203 | continue; |
204 | } |
205 | |
206 | // Spin until a thread becomes available. This happens primarily because our mActiveTask count and |
207 | // ThreadPool's thread idle count aren't synced, so while the task manager thinks it's free to run new |
208 | // tasks, the ThreadPool might still have those threads as running, meaning their allocation will fail. |
209 | // So we just spin here for a bit, in that rare case. |
210 | if(ThreadPool::instance().getNumAvailable() == 0) |
211 | { |
212 | mCheckTasks = true; |
213 | break; |
214 | } |
215 | |
216 | iter = mTaskQueue.erase(iter); |
217 | |
218 | curTask->mState.store(1); |
219 | mActiveTasks.push_back(curTask); |
220 | |
221 | ThreadPool::instance().run(curTask->mName, std::bind(&TaskScheduler::runTask, this, curTask)); |
222 | } |
223 | } |
224 | } |
225 | |
226 | void TaskScheduler::runTask(SPtr<Task> task) |
227 | { |
228 | task->mTaskWorker(); |
229 | |
230 | { |
231 | Lock lock(mReadyMutex); |
232 | |
233 | auto findIter = std::find(mActiveTasks.begin(), mActiveTasks.end(), task); |
234 | if (findIter != mActiveTasks.end()) |
235 | mActiveTasks.erase(findIter); |
236 | } |
237 | |
238 | { |
239 | Lock lock(mCompleteMutex); |
240 | task->mState.store(2); |
241 | |
242 | mTaskCompleteCond.notify_all(); |
243 | } |
244 | |
245 | // Wake the main scheduler thread in case there are other tasks waiting or this task was someone's dependency |
246 | { |
247 | Lock lock(mReadyMutex); |
248 | |
249 | mCheckTasks = true; |
250 | mTaskReadyCond.notify_one(); |
251 | } |
252 | } |
253 | |
254 | void TaskScheduler::waitUntilComplete(const Task* task) |
255 | { |
256 | if(task->isCanceled()) |
257 | return; |
258 | |
259 | if(task->mTaskDependency) |
260 | task->mTaskDependency->wait(); |
261 | |
262 | // If we haven't started executing the task yet, just execute it right here |
263 | SPtr<Task> queuedTask; |
264 | { |
265 | Lock lock(mReadyMutex); |
266 | |
267 | if(!task->hasStarted()) |
268 | { |
269 | auto iterFind = std::find_if(mTaskQueue.begin(), mTaskQueue.end(), |
270 | [task](const SPtr<Task>& x) { return x.get() == task; }); |
271 | |
272 | assert(iterFind != mTaskQueue.end()); |
273 | |
274 | queuedTask = *iterFind; |
275 | mTaskQueue.erase(iterFind); |
276 | |
277 | queuedTask->mState.store(1); |
278 | } |
279 | } |
280 | |
281 | if(queuedTask) |
282 | { |
283 | runTask(queuedTask); |
284 | return; |
285 | } |
286 | |
287 | // Otherwise we wait until the task completes |
288 | { |
289 | Lock lock(mCompleteMutex); |
290 | |
291 | while(!task->isComplete()) |
292 | { |
293 | addWorker(); |
294 | mTaskCompleteCond.wait(lock); |
295 | removeWorker(); |
296 | } |
297 | } |
298 | } |
299 | |
300 | void TaskScheduler::waitUntilComplete(const TaskGroup* taskGroup) |
301 | { |
302 | Lock lock(mCompleteMutex); |
303 | |
304 | while (taskGroup->mNumRemainingTasks > 0) |
305 | { |
306 | addWorker(); |
307 | mTaskCompleteCond.wait(lock); |
308 | removeWorker(); |
309 | } |
310 | } |
311 | |
312 | bool TaskScheduler::taskCompare(const SPtr<Task>& lhs, const SPtr<Task>& rhs) |
313 | { |
314 | // If priority is the same, sort by the order the tasks were queued |
315 | if(lhs->mPriority == rhs->mPriority) |
316 | return lhs->mTaskId < rhs->mTaskId; |
317 | |
318 | // Otherwise the task with the higher priority always goes first |
319 | return lhs->mPriority > rhs->mPriority; |
320 | } |
321 | } |
322 | |