1 | // Copyright 2019 The SwiftShader Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // This file contains a number of synchronization primitives for concurrency. |
16 | // |
17 | // You may be tempted to change this code to unlock the mutex before calling |
18 | // std::condition_variable::notify_[one,all]. Please read |
19 | // https://issuetracker.google.com/issues/133135427 before making this sort of |
20 | // change. |
21 | |
22 | #ifndef sw_Synchronization_hpp |
23 | #define sw_Synchronization_hpp |
24 | |
25 | #include <assert.h> |
26 | #include <chrono> |
27 | #include <condition_variable> |
28 | #include <mutex> |
29 | #include <queue> |
30 | |
31 | namespace sw |
32 | { |
33 | |
34 | // TaskEvents is an interface for notifying when tasks begin and end. |
35 | // Tasks can be nested and/or overlapping. |
36 | // TaskEvents is used for task queue synchronization. |
37 | class TaskEvents |
38 | { |
39 | public: |
40 | // start() is called before a task begins. |
41 | virtual void start() = 0; |
42 | // finish() is called after a task ends. finish() must only be called after |
43 | // a corresponding call to start(). |
44 | virtual void finish() = 0; |
45 | // complete() is a helper for calling start() followed by finish(). |
46 | inline void complete() { start(); finish(); } |
47 | |
48 | protected: |
49 | virtual ~TaskEvents() = default; |
50 | }; |
51 | |
52 | // WaitGroup is a synchronization primitive that allows you to wait for |
53 | // collection of asynchronous tasks to finish executing. |
54 | // Call add() before each task begins, and then call done() when after each task |
55 | // is finished. |
56 | // At the same time, wait() can be used to block until all tasks have finished. |
57 | // WaitGroup takes its name after Golang's sync.WaitGroup. |
58 | class WaitGroup : public TaskEvents |
59 | { |
60 | public: |
61 | // add() begins a new task. |
62 | void add() |
63 | { |
64 | std::unique_lock<std::mutex> lock(mutex); |
65 | ++count_; |
66 | } |
67 | |
68 | // done() is called when a task of the WaitGroup has been completed. |
69 | // Returns true if there are no more tasks currently running in the |
70 | // WaitGroup. |
71 | bool done() |
72 | { |
73 | std::unique_lock<std::mutex> lock(mutex); |
74 | assert(count_ > 0); |
75 | --count_; |
76 | if(count_ == 0) |
77 | { |
78 | condition.notify_all(); |
79 | } |
80 | return count_ == 0; |
81 | } |
82 | |
83 | // wait() blocks until all the tasks have been finished. |
84 | void wait() |
85 | { |
86 | std::unique_lock<std::mutex> lock(mutex); |
87 | condition.wait(lock, [this] { return count_ == 0; }); |
88 | } |
89 | |
90 | // wait() blocks until all the tasks have been finished or the timeout |
91 | // has been reached, returning true if all tasks have been completed, or |
92 | // false if the timeout has been reached. |
93 | template <class CLOCK, class DURATION> |
94 | bool wait(const std::chrono::time_point<CLOCK, DURATION>& timeout) |
95 | { |
96 | std::unique_lock<std::mutex> lock(mutex); |
97 | return condition.wait_until(lock, timeout, [this] { return count_ == 0; }); |
98 | } |
99 | |
100 | // count() returns the number of times add() has been called without a call |
101 | // to done(). |
102 | // Note: No lock is held after count() returns, so the count may immediately |
103 | // change after returning. |
104 | int32_t count() |
105 | { |
106 | std::unique_lock<std::mutex> lock(mutex); |
107 | return count_; |
108 | } |
109 | |
110 | // TaskEvents compliance |
111 | void start() override { add(); } |
112 | void finish() override { done(); } |
113 | |
114 | private: |
115 | int32_t count_ = 0; // guarded by mutex |
116 | std::mutex mutex; |
117 | std::condition_variable condition; |
118 | }; |
119 | |
120 | // Chan is a thread-safe FIFO queue of type T. |
121 | // Chan takes its name after Golang's chan. |
122 | template <typename T> |
123 | class Chan |
124 | { |
125 | public: |
126 | Chan(); |
127 | |
128 | // take returns the next item in the chan, blocking until an item is |
129 | // available. |
130 | T take(); |
131 | |
132 | // tryTake returns a <T, bool> pair. |
133 | // If the chan is not empty, then the next item and true are returned. |
134 | // If the chan is empty, then a default-initialized T and false are returned. |
135 | std::pair<T, bool> tryTake(); |
136 | |
137 | // put places an item into the chan, blocking if the chan is bounded and |
138 | // full. |
139 | void put(const T &v); |
140 | |
141 | // Returns the number of items in the chan. |
142 | // Note: that this may change as soon as the function returns, so should |
143 | // only be used for debugging. |
144 | size_t count(); |
145 | |
146 | private: |
147 | std::queue<T> queue; |
148 | std::mutex mutex; |
149 | std::condition_variable added; |
150 | }; |
151 | |
152 | template <typename T> |
153 | Chan<T>::Chan() {} |
154 | |
155 | template <typename T> |
156 | T Chan<T>::take() |
157 | { |
158 | std::unique_lock<std::mutex> lock(mutex); |
159 | // Wait for item to be added. |
160 | added.wait(lock, [this] { return queue.size() > 0; }); |
161 | T out = queue.front(); |
162 | queue.pop(); |
163 | return out; |
164 | } |
165 | |
166 | template <typename T> |
167 | std::pair<T, bool> Chan<T>::tryTake() |
168 | { |
169 | std::unique_lock<std::mutex> lock(mutex); |
170 | if (queue.size() == 0) |
171 | { |
172 | return std::make_pair(T{}, false); |
173 | } |
174 | T out = queue.front(); |
175 | queue.pop(); |
176 | return std::make_pair(out, true); |
177 | } |
178 | |
179 | template <typename T> |
180 | void Chan<T>::put(const T &item) |
181 | { |
182 | std::unique_lock<std::mutex> lock(mutex); |
183 | queue.push(item); |
184 | added.notify_one(); |
185 | } |
186 | |
187 | template <typename T> |
188 | size_t Chan<T>::count() |
189 | { |
190 | std::unique_lock<std::mutex> lock(mutex); |
191 | return queue.size(); |
192 | } |
193 | |
194 | } // namespace sw |
195 | |
196 | #endif // sw_Synchronization_hpp |
197 | |