1 | /* |
2 | * Copyright 2014-present Facebook, Inc. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef FOLLY_GEN_PARALLELMAP_H_ |
18 | #error This file may only be included from folly/gen/ParallelMap.h |
19 | #endif |
20 | |
21 | #include <atomic> |
22 | #include <cassert> |
23 | #include <thread> |
24 | #include <type_traits> |
25 | #include <utility> |
26 | #include <vector> |
27 | |
28 | #include <folly/MPMCPipeline.h> |
29 | #include <folly/experimental/EventCount.h> |
30 | #include <folly/functional/Invoke.h> |
31 | |
32 | namespace folly { |
33 | namespace gen { |
34 | namespace detail { |
35 | |
36 | /** |
37 | * PMap - Map in parallel (using threads). For producing a sequence of |
38 | * values by passing each value from a source collection through a |
39 | * predicate while running the predicate in parallel in different |
40 | * threads. |
41 | * |
42 | * This type is usually used through the 'pmap' helper function: |
43 | * |
44 | * auto squares = seq(1, 10) | pmap(fibonacci, 4) | sum; |
45 | */ |
46 | template <class Predicate> |
47 | class PMap : public Operator<PMap<Predicate>> { |
48 | Predicate pred_; |
49 | size_t nThreads_; |
50 | |
51 | public: |
52 | PMap() = default; |
53 | |
54 | PMap(Predicate pred, size_t nThreads) |
55 | : pred_(std::move(pred)), nThreads_(nThreads) {} |
56 | |
57 | template < |
58 | class Value, |
59 | class Source, |
60 | class Input = typename std::decay<Value>::type, |
61 | class Output = |
62 | typename std::decay<invoke_result_t<Predicate, Value>>::type> |
63 | class Generator |
64 | : public GenImpl<Output, Generator<Value, Source, Input, Output>> { |
65 | Source source_; |
66 | Predicate pred_; |
67 | const size_t nThreads_; |
68 | |
69 | class ExecutionPipeline { |
70 | std::vector<std::thread> workers_; |
71 | std::atomic<bool> done_{false}; |
72 | const Predicate& pred_; |
73 | MPMCPipeline<Input, Output> pipeline_; |
74 | EventCount wake_; |
75 | |
76 | public: |
77 | ExecutionPipeline(const Predicate& pred, size_t nThreads) |
78 | : pred_(pred), pipeline_(nThreads, nThreads) { |
79 | workers_.reserve(nThreads); |
80 | for (size_t i = 0; i < nThreads; i++) { |
81 | workers_.push_back(std::thread([this] { this->predApplier(); })); |
82 | } |
83 | } |
84 | |
85 | ~ExecutionPipeline() { |
86 | assert(pipeline_.sizeGuess() == 0); |
87 | assert(done_.load()); |
88 | for (auto& w : workers_) { |
89 | w.join(); |
90 | } |
91 | } |
92 | |
93 | void stop() { |
94 | // prevent workers from consuming more than we produce. |
95 | done_.store(true, std::memory_order_release); |
96 | wake_.notifyAll(); |
97 | } |
98 | |
99 | bool write(Value&& value) { |
100 | bool wrote = pipeline_.write(std::forward<Value>(value)); |
101 | if (wrote) { |
102 | wake_.notify(); |
103 | } |
104 | return wrote; |
105 | } |
106 | |
107 | void blockingWrite(Value&& value) { |
108 | pipeline_.blockingWrite(std::forward<Value>(value)); |
109 | wake_.notify(); |
110 | } |
111 | |
112 | bool read(Output& out) { |
113 | return pipeline_.read(out); |
114 | } |
115 | |
116 | void blockingRead(Output& out) { |
117 | pipeline_.blockingRead(out); |
118 | } |
119 | |
120 | private: |
121 | void predApplier() { |
122 | // Each thread takes a value from the pipeline_, runs the |
123 | // predicate and enqueues the result. The pipeline preserves |
124 | // ordering. NOTE: don't use blockingReadStage<0> to read from |
125 | // the pipeline_ as there may not be any: end-of-data is signaled |
126 | // separately using done_/wake_. |
127 | Input in; |
128 | for (;;) { |
129 | auto key = wake_.prepareWait(); |
130 | |
131 | typename MPMCPipeline<Input, Output>::template Ticket<0> ticket; |
132 | if (pipeline_.template readStage<0>(ticket, in)) { |
133 | wake_.cancelWait(); |
134 | Output out = pred_(std::move(in)); |
135 | pipeline_.template blockingWriteStage<0>(ticket, std::move(out)); |
136 | continue; |
137 | } |
138 | |
139 | if (done_.load(std::memory_order_acquire)) { |
140 | wake_.cancelWait(); |
141 | break; |
142 | } |
143 | |
144 | // Not done_, but no items in the queue. |
145 | wake_.wait(key); |
146 | } |
147 | } |
148 | }; |
149 | |
150 | public: |
151 | Generator(Source source, const Predicate& pred, size_t nThreads) |
152 | : source_(std::move(source)), |
153 | pred_(pred), |
154 | nThreads_(nThreads ? nThreads : sysconf(_SC_NPROCESSORS_ONLN)) {} |
155 | |
156 | template <class Body> |
157 | void foreach(Body&& body) const { |
158 | ExecutionPipeline pipeline(pred_, nThreads_); |
159 | |
160 | size_t wrote = 0; |
161 | size_t read = 0; |
162 | source_.foreach([&](Value value) { |
163 | if (pipeline.write(std::forward<Value>(value))) { |
164 | // input queue not yet full, saturate it before we process |
165 | // anything downstream |
166 | ++wrote; |
167 | return; |
168 | } |
169 | |
170 | // input queue full; drain ready items from the queue |
171 | Output out; |
172 | while (pipeline.read(out)) { |
173 | ++read; |
174 | body(std::move(out)); |
175 | } |
176 | |
177 | // write the value we were going to write before we made room. |
178 | pipeline.blockingWrite(std::forward<Value>(value)); |
179 | ++wrote; |
180 | }); |
181 | |
182 | pipeline.stop(); |
183 | |
184 | // flush the output queue |
185 | while (read < wrote) { |
186 | Output out; |
187 | pipeline.blockingRead(out); |
188 | ++read; |
189 | body(std::move(out)); |
190 | } |
191 | } |
192 | |
193 | template <class Handler> |
194 | bool apply(Handler&& handler) const { |
195 | ExecutionPipeline pipeline(pred_, nThreads_); |
196 | |
197 | size_t wrote = 0; |
198 | size_t read = 0; |
199 | bool more = true; |
200 | source_.apply([&](Value value) { |
201 | if (pipeline.write(std::forward<Value>(value))) { |
202 | // input queue not yet full, saturate it before we process |
203 | // anything downstream |
204 | ++wrote; |
205 | return true; |
206 | } |
207 | |
208 | // input queue full; drain ready items from the queue |
209 | Output out; |
210 | while (pipeline.read(out)) { |
211 | ++read; |
212 | if (!handler(std::move(out))) { |
213 | more = false; |
214 | return false; |
215 | } |
216 | } |
217 | |
218 | // write the value we were going to write before we made room. |
219 | pipeline.blockingWrite(std::forward<Value>(value)); |
220 | ++wrote; |
221 | return true; |
222 | }); |
223 | |
224 | pipeline.stop(); |
225 | |
226 | // flush the output queue |
227 | while (read < wrote) { |
228 | Output out; |
229 | pipeline.blockingRead(out); |
230 | ++read; |
231 | if (more) { |
232 | more = more && handler(std::move(out)); |
233 | } |
234 | } |
235 | return more; |
236 | } |
237 | |
238 | static constexpr bool infinite = Source::infinite; |
239 | }; |
240 | |
241 | template <class Source, class Value, class Gen = Generator<Value, Source>> |
242 | Gen compose(GenImpl<Value, Source>&& source) const { |
243 | return Gen(std::move(source.self()), pred_, nThreads_); |
244 | } |
245 | |
246 | template <class Source, class Value, class Gen = Generator<Value, Source>> |
247 | Gen compose(const GenImpl<Value, Source>& source) const { |
248 | return Gen(source.self(), pred_, nThreads_); |
249 | } |
250 | }; |
251 | } // namespace detail |
252 | } // namespace gen |
253 | } // namespace folly |
254 | |