1// Copyright 2019 The SwiftShader Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// This file contains a number of synchronization primitives for concurrency.
16//
17// You may be tempted to change this code to unlock the mutex before calling
18// std::condition_variable::notify_[one,all]. Please read
19// https://issuetracker.google.com/issues/133135427 before making this sort of
20// change.
21
22#ifndef sw_Synchronization_hpp
23#define sw_Synchronization_hpp
24
25#include <assert.h>
26#include <chrono>
27#include <condition_variable>
28#include <mutex>
29#include <queue>
30
31namespace sw
32{
33
34// TaskEvents is an interface for notifying when tasks begin and end.
35// Tasks can be nested and/or overlapping.
36// TaskEvents is used for task queue synchronization.
37class TaskEvents
38{
39public:
40 // start() is called before a task begins.
41 virtual void start() = 0;
42 // finish() is called after a task ends. finish() must only be called after
43 // a corresponding call to start().
44 virtual void finish() = 0;
45 // complete() is a helper for calling start() followed by finish().
46 inline void complete() { start(); finish(); }
47
48protected:
49 virtual ~TaskEvents() = default;
50};
51
52// WaitGroup is a synchronization primitive that allows you to wait for
53// collection of asynchronous tasks to finish executing.
54// Call add() before each task begins, and then call done() when after each task
55// is finished.
56// At the same time, wait() can be used to block until all tasks have finished.
57// WaitGroup takes its name after Golang's sync.WaitGroup.
58class WaitGroup : public TaskEvents
59{
60public:
61 // add() begins a new task.
62 void add()
63 {
64 std::unique_lock<std::mutex> lock(mutex);
65 ++count_;
66 }
67
68 // done() is called when a task of the WaitGroup has been completed.
69 // Returns true if there are no more tasks currently running in the
70 // WaitGroup.
71 bool done()
72 {
73 std::unique_lock<std::mutex> lock(mutex);
74 assert(count_ > 0);
75 --count_;
76 if(count_ == 0)
77 {
78 condition.notify_all();
79 }
80 return count_ == 0;
81 }
82
83 // wait() blocks until all the tasks have been finished.
84 void wait()
85 {
86 std::unique_lock<std::mutex> lock(mutex);
87 condition.wait(lock, [this] { return count_ == 0; });
88 }
89
90 // wait() blocks until all the tasks have been finished or the timeout
91 // has been reached, returning true if all tasks have been completed, or
92 // false if the timeout has been reached.
93 template <class CLOCK, class DURATION>
94 bool wait(const std::chrono::time_point<CLOCK, DURATION>& timeout)
95 {
96 std::unique_lock<std::mutex> lock(mutex);
97 return condition.wait_until(lock, timeout, [this] { return count_ == 0; });
98 }
99
100 // count() returns the number of times add() has been called without a call
101 // to done().
102 // Note: No lock is held after count() returns, so the count may immediately
103 // change after returning.
104 int32_t count()
105 {
106 std::unique_lock<std::mutex> lock(mutex);
107 return count_;
108 }
109
110 // TaskEvents compliance
111 void start() override { add(); }
112 void finish() override { done(); }
113
114private:
115 int32_t count_ = 0; // guarded by mutex
116 std::mutex mutex;
117 std::condition_variable condition;
118};
119
120// Chan is a thread-safe FIFO queue of type T.
121// Chan takes its name after Golang's chan.
122template <typename T>
123class Chan
124{
125public:
126 Chan();
127
128 // take returns the next item in the chan, blocking until an item is
129 // available.
130 T take();
131
132 // tryTake returns a <T, bool> pair.
133 // If the chan is not empty, then the next item and true are returned.
134 // If the chan is empty, then a default-initialized T and false are returned.
135 std::pair<T, bool> tryTake();
136
137 // put places an item into the chan, blocking if the chan is bounded and
138 // full.
139 void put(const T &v);
140
141 // Returns the number of items in the chan.
142 // Note: that this may change as soon as the function returns, so should
143 // only be used for debugging.
144 size_t count();
145
146private:
147 std::queue<T> queue;
148 std::mutex mutex;
149 std::condition_variable added;
150};
151
152template <typename T>
153Chan<T>::Chan() {}
154
155template <typename T>
156T Chan<T>::take()
157{
158 std::unique_lock<std::mutex> lock(mutex);
159 // Wait for item to be added.
160 added.wait(lock, [this] { return queue.size() > 0; });
161 T out = queue.front();
162 queue.pop();
163 return out;
164}
165
166template <typename T>
167std::pair<T, bool> Chan<T>::tryTake()
168{
169 std::unique_lock<std::mutex> lock(mutex);
170 if (queue.size() == 0)
171 {
172 return std::make_pair(T{}, false);
173 }
174 T out = queue.front();
175 queue.pop();
176 return std::make_pair(out, true);
177}
178
179template <typename T>
180void Chan<T>::put(const T &item)
181{
182 std::unique_lock<std::mutex> lock(mutex);
183 queue.push(item);
184 added.notify_one();
185}
186
187template <typename T>
188size_t Chan<T>::count()
189{
190 std::unique_lock<std::mutex> lock(mutex);
191 return queue.size();
192}
193
194} // namespace sw
195
196#endif // sw_Synchronization_hpp
197