1//************************************ bs::framework - Copyright 2018 Marko Pintera **************************************//
2//*********** Licensed under the MIT license. See LICENSE.md for full terms. This notice is not to be removed. ***********//
3#include "Threading/BsTaskScheduler.h"
4#include "Threading/BsThreadPool.h"
5
6namespace bs
7{
8 Task::Task(const PrivatelyConstruct& dummy, const String& name, std::function<void()> taskWorker,
9 TaskPriority priority, SPtr<Task> dependency)
10 : mName(name), mPriority(priority), mTaskWorker(std::move(taskWorker)), mTaskDependency(std::move(dependency))
11 {
12
13 }
14
15 SPtr<Task> Task::create(const String& name, std::function<void()> taskWorker, TaskPriority priority,
16 SPtr<Task> dependency)
17 {
18 return bs_shared_ptr_new<Task>(PrivatelyConstruct(), name, std::move(taskWorker), priority, std::move(dependency));
19 }
20
21 bool Task::isComplete() const
22 {
23 return mState == 2;
24 }
25
26 bool Task::isCanceled() const
27 {
28 return mState == 3;
29 }
30
31 bool Task::hasStarted() const
32 {
33 UINT32 state = mState;
34
35 return state == 1 || state == 2;
36 }
37
38 void Task::wait()
39 {
40 if(mParent != nullptr)
41 mParent->waitUntilComplete(this);
42 }
43
44 void Task::cancel()
45 {
46 mState = 3;
47 }
48
49 TaskGroup::TaskGroup(const PrivatelyConstruct& dummy, String name, std::function<void(UINT32)> taskWorker,
50 UINT32 count, TaskPriority priority, SPtr<Task> dependency)
51 : mName(std::move(name)), mCount(count), mPriority(priority), mTaskWorker(std::move(taskWorker))
52 , mTaskDependency(std::move(dependency))
53 {
54
55 }
56
57 SPtr<TaskGroup> TaskGroup::create(String name, std::function<void(UINT32)> taskWorker, UINT32 count,
58 TaskPriority priority, SPtr<Task> dependency)
59 {
60 return bs_shared_ptr_new<TaskGroup>(PrivatelyConstruct(), std::move(name), std::move(taskWorker), count, priority,
61 std::move(dependency));
62 }
63
64 bool TaskGroup::isComplete() const
65 {
66 return mNumRemainingTasks == 0;
67 }
68
69 void TaskGroup::wait()
70 {
71 if(mParent != nullptr)
72 mParent->waitUntilComplete(this);
73 }
74
75 TaskScheduler::TaskScheduler()
76 :mTaskQueue(&TaskScheduler::taskCompare)
77 {
78 mMaxActiveTasks = BS_THREAD_HARDWARE_CONCURRENCY;
79
80 mTaskSchedulerThread = ThreadPool::instance().run("TaskScheduler", std::bind(&TaskScheduler::runMain, this));
81 }
82
83 TaskScheduler::~TaskScheduler()
84 {
85 // Wait until all tasks complete
86 {
87 Lock activeTaskLock(mReadyMutex);
88
89 while (mActiveTasks.size() > 0)
90 {
91 SPtr<Task> task = mActiveTasks[0];
92 activeTaskLock.unlock();
93
94 task->wait();
95 activeTaskLock.lock();
96 }
97 }
98
99 // Start shutdown of the main queue worker and wait until it exits
100 {
101 Lock lock(mReadyMutex);
102
103 mShutdown = true;
104 }
105
106 mTaskReadyCond.notify_one();
107
108 mTaskSchedulerThread.blockUntilComplete();
109 }
110
111 void TaskScheduler::addTask(SPtr<Task> task)
112 {
113 Lock lock(mReadyMutex);
114
115 assert(task->mState != 1 && "Task is already executing, it cannot be executed again until it finishes.");
116
117 task->mParent = this;
118 task->mTaskId = mNextTaskId++;
119 task->mState.store(0); // Reset state in case the task is getting re-queued
120
121 mCheckTasks = true;
122 mTaskQueue.insert(std::move(task));
123
124 // Wake main scheduler thread
125 mTaskReadyCond.notify_one();
126 }
127
128 void TaskScheduler::addTaskGroup(const SPtr<TaskGroup>& taskGroup)
129 {
130 Lock lock(mReadyMutex);
131
132 for(UINT32 i = 0; i < taskGroup->mCount; i++)
133 {
134 const auto worker = [i, taskGroup]
135 {
136 taskGroup->mTaskWorker(i);
137 --taskGroup->mNumRemainingTasks;
138 };
139
140 SPtr<Task> task = Task::create(taskGroup->mName, worker, taskGroup->mPriority, taskGroup->mTaskDependency);
141 task->mParent = this;
142 task->mTaskId = mNextTaskId++;
143 task->mState.store(0); // Reset state in case the task is getting re-queued
144
145 mCheckTasks = true;
146 mTaskQueue.insert(std::move(task));
147 }
148
149 taskGroup->mParent = this;
150
151 // Wake main scheduler thread
152 mTaskReadyCond.notify_one();
153 }
154
155 void TaskScheduler::addWorker()
156 {
157 Lock lock(mReadyMutex);
158
159 mMaxActiveTasks++;
160
161 // A spot freed up, queue new tasks on main scheduler thread if they exist
162 mTaskReadyCond.notify_one();
163 }
164
165 void TaskScheduler::removeWorker()
166 {
167 Lock lock(mReadyMutex);
168
169 if(mMaxActiveTasks > 0)
170 mMaxActiveTasks--;
171 }
172
173 void TaskScheduler::runMain()
174 {
175 while(true)
176 {
177 Lock lock(mReadyMutex);
178
179 while((!mCheckTasks || (UINT32)mActiveTasks.size() >= mMaxActiveTasks) && !mShutdown)
180 mTaskReadyCond.wait(lock);
181
182 mCheckTasks = false;
183
184 if(mShutdown)
185 break;
186
187 for(auto iter = mTaskQueue.begin(); iter != mTaskQueue.end();)
188 {
189 if ((UINT32)mActiveTasks.size() >= mMaxActiveTasks)
190 break;
191
192 SPtr<Task> curTask = *iter;
193
194 if(curTask->isCanceled())
195 {
196 iter = mTaskQueue.erase(iter);
197 continue;
198 }
199
200 if(curTask->mTaskDependency != nullptr && !curTask->mTaskDependency->isComplete())
201 {
202 ++iter;
203 continue;
204 }
205
206 // Spin until a thread becomes available. This happens primarily because our mActiveTask count and
207 // ThreadPool's thread idle count aren't synced, so while the task manager thinks it's free to run new
208 // tasks, the ThreadPool might still have those threads as running, meaning their allocation will fail.
209 // So we just spin here for a bit, in that rare case.
210 if(ThreadPool::instance().getNumAvailable() == 0)
211 {
212 mCheckTasks = true;
213 break;
214 }
215
216 iter = mTaskQueue.erase(iter);
217
218 curTask->mState.store(1);
219 mActiveTasks.push_back(curTask);
220
221 ThreadPool::instance().run(curTask->mName, std::bind(&TaskScheduler::runTask, this, curTask));
222 }
223 }
224 }
225
226 void TaskScheduler::runTask(SPtr<Task> task)
227 {
228 task->mTaskWorker();
229
230 {
231 Lock lock(mReadyMutex);
232
233 auto findIter = std::find(mActiveTasks.begin(), mActiveTasks.end(), task);
234 if (findIter != mActiveTasks.end())
235 mActiveTasks.erase(findIter);
236 }
237
238 {
239 Lock lock(mCompleteMutex);
240 task->mState.store(2);
241
242 mTaskCompleteCond.notify_all();
243 }
244
245 // Wake the main scheduler thread in case there are other tasks waiting or this task was someone's dependency
246 {
247 Lock lock(mReadyMutex);
248
249 mCheckTasks = true;
250 mTaskReadyCond.notify_one();
251 }
252 }
253
254 void TaskScheduler::waitUntilComplete(const Task* task)
255 {
256 if(task->isCanceled())
257 return;
258
259 if(task->mTaskDependency)
260 task->mTaskDependency->wait();
261
262 // If we haven't started executing the task yet, just execute it right here
263 SPtr<Task> queuedTask;
264 {
265 Lock lock(mReadyMutex);
266
267 if(!task->hasStarted())
268 {
269 auto iterFind = std::find_if(mTaskQueue.begin(), mTaskQueue.end(),
270 [task](const SPtr<Task>& x) { return x.get() == task; });
271
272 assert(iterFind != mTaskQueue.end());
273
274 queuedTask = *iterFind;
275 mTaskQueue.erase(iterFind);
276
277 queuedTask->mState.store(1);
278 }
279 }
280
281 if(queuedTask)
282 {
283 runTask(queuedTask);
284 return;
285 }
286
287 // Otherwise we wait until the task completes
288 {
289 Lock lock(mCompleteMutex);
290
291 while(!task->isComplete())
292 {
293 addWorker();
294 mTaskCompleteCond.wait(lock);
295 removeWorker();
296 }
297 }
298 }
299
300 void TaskScheduler::waitUntilComplete(const TaskGroup* taskGroup)
301 {
302 Lock lock(mCompleteMutex);
303
304 while (taskGroup->mNumRemainingTasks > 0)
305 {
306 addWorker();
307 mTaskCompleteCond.wait(lock);
308 removeWorker();
309 }
310 }
311
312 bool TaskScheduler::taskCompare(const SPtr<Task>& lhs, const SPtr<Task>& rhs)
313 {
314 // If priority is the same, sort by the order the tasks were queued
315 if(lhs->mPriority == rhs->mPriority)
316 return lhs->mTaskId < rhs->mTaskId;
317
318 // Otherwise the task with the higher priority always goes first
319 return lhs->mPriority > rhs->mPriority;
320 }
321}
322