1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
18#define MKLDNN_THREAD_PARALLEL_ND_HPP
19
20/* This header must be included by mkldnn_thread.hpp only */
21
22/* Functions:
23 * - parallel(nthr, f) - executes f in parallel using at most
24 * nthr threads. If nthr equals 0
25 * mkldnn_get_max_threads() threads is
26 * used
27 * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
28 * created threads
29 * - parallel_nd(dims..., f) - creates a parallel section and then
30 * calls for_nd
31 * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
32 * calls for_nd (mostly for convenience)
33 */
34
35namespace mkldnn {
36namespace impl {
37
38/* general parallelization */
39template <typename F>
40void parallel(int nthr, F f) {
41 if (nthr == 0) nthr = mkldnn_get_max_threads();
42#if MKLDNN_THR == MKLDNN_THR_SEQ
43 assert(nthr == 1);
44 f(0, 1);
45#elif MKLDNN_THR == MKLDNN_THR_OMP
46 if (nthr == 1) { f(0, 1); return; }
47# pragma omp parallel num_threads(nthr)
48 f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
49#elif MKLDNN_THR == MKLDNN_THR_TBB
50 if (nthr == 1) { f(0, 1); return; }
51 tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
52#endif
53}
54
55/* for_nd section */
56
57template <typename T0, typename F>
58void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
59 T0 start{0}, end{0};
60 balance211(D0, nthr, ithr, start, end);
61 for (T0 d0 = start; d0 < end; ++d0) f(d0);
62}
63
64template <typename T0, typename T1, typename F>
65void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
66 const size_t work_amount = (size_t)D0 * D1;
67 if (work_amount == 0) return;
68 size_t start{0}, end{0};
69 balance211(work_amount, nthr, ithr, start, end);
70
71 T0 d0{0}; T1 d1{0};
72 utils::nd_iterator_init(start, d0, D0, d1, D1);
73 for (size_t iwork = start; iwork < end; ++iwork) {
74 f(d0, d1);
75 utils::nd_iterator_step(d0, D0, d1, D1);
76 }
77}
78
79template <typename T0, typename T1, typename T2, typename F>
80void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
81 const T2 &D2, F f) {
82 const size_t work_amount = (size_t)D0 * D1 * D2;
83 if (work_amount == 0) return;
84 size_t start{0}, end{0};
85 balance211(work_amount, nthr, ithr, start, end);
86
87 T0 d0{0}; T1 d1{0}; T2 d2{0};
88 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
89 for (size_t iwork = start; iwork < end; ++iwork) {
90 f(d0, d1, d2);
91 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
92 }
93}
94
95template <typename T0, typename T1, typename T2, typename T3, typename F>
96void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
97 const T2 &D2, const T3 &D3, F f) {
98 const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
99 if (work_amount == 0) return;
100 size_t start{0}, end{0};
101 balance211(work_amount, nthr, ithr, start, end);
102
103 T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
104 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
105 for (size_t iwork = start; iwork < end; ++iwork) {
106 f(d0, d1, d2, d3);
107 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
108 }
109}
110
111template <typename T0, typename T1, typename T2, typename T3, typename T4,
112 typename F>
113void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
114 const T2 &D2, const T3 &D3, const T4 &D4, F f) {
115 const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
116 if (work_amount == 0) return;
117 size_t start{0}, end{0};
118 balance211(work_amount, nthr, ithr, start, end);
119
120 T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
121 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
122 for (size_t iwork = start; iwork < end; ++iwork) {
123 f(d0, d1, d2, d3, d4);
124 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
125 }
126}
127
128template <typename T0, typename T1, typename T2, typename T3, typename T4,
129 typename T5, typename F>
130void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
131 const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
132 const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
133 if (work_amount == 0) return;
134 size_t start{0}, end{0};
135 balance211(work_amount, nthr, ithr, start, end);
136
137 T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
138 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
139 d5, D5);
140 for (size_t iwork = start; iwork < end; ++iwork) {
141 f(d0, d1, d2, d3, d4, d5);
142 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
143 }
144}
145
146// Skip a lambda function in the parameter pack.
147template <typename T>
148constexpr size_t get_work_amount(const T &v) { return 1; }
149template <typename T, typename ...Args>
150constexpr size_t get_work_amount(const T &v, Args &&...args)
151{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
152
153/* parallel_nd and parallel_nd_in_omp section */
154
155#if MKLDNN_THR != MKLDNN_THR_TBB
156template <typename ...Args>
157void parallel_nd(Args &&...args) {
158#if MKLDNN_THR == MKLDNN_THR_SEQ
159 for_nd(0, 1, utils::forward<Args>(args)...);
160#elif MKLDNN_THR == MKLDNN_THR_OMP
161 const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
162# pragma omp parallel if (do_parallel)
163 {
164 const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
165 const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
166 for_nd(ithr, nthr, utils::forward<Args>(args)...);
167 }
168#endif
169}
170#else // MKLDNN_THR != MKLDNN_THR_TBB
171
172// gcc 4.8 has a bug with passing parameter pack to lambdas.
173// So have to explicitly instantiate all the cases.
174
175template <typename T0, typename F>
176void parallel_nd(const T0 &D0, F f) {
177 const size_t work_amount = (size_t)D0;
178 if (work_amount == 0) return;
179 tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
180 for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
181 f(T0(iwork));
182 }
183 }, tbb::static_partitioner());
184}
185
186template <typename T0, typename T1, typename F>
187void parallel_nd(const T0 &D0, const T1 &D1, F f) {
188 const size_t work_amount = (size_t)D0 * D1;
189 if (work_amount == 0) return;
190 tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
191 T0 d0{0}; T1 d1{0};
192 utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
193 for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
194 f(d0, d1);
195 utils::nd_iterator_step(d0, D0, d1, D1);
196 }
197 }, tbb::static_partitioner());
198}
199
200template <typename T0, typename T1, typename T2, typename F>
201void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
202 const size_t work_amount = (size_t)D0 * D1 * D2;
203 if (work_amount == 0) return;
204 tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
205 T0 d0{0}; T1 d1{0}; T2 d2{0};
206 utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
207 for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
208 f(d0, d1, d2);
209 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
210 }
211 }, tbb::static_partitioner());
212}
213
214template <typename T0, typename T1, typename T2, typename T3, typename F>
215void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
216 const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
217 if (work_amount == 0) return;
218 tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
219 T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
220 utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
221 for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
222 f(d0, d1, d2, d3);
223 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
224 }
225 }, tbb::static_partitioner());
226}
227
228template <typename T0, typename T1, typename T2, typename T3, typename T4,
229 typename F>
230void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
231 const T4 &D4, F f) {
232 const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
233 if (work_amount == 0) return;
234 tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
235 T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
236 utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
237 for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
238 f(d0, d1, d2, d3, d4);
239 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
240 }
241 }, tbb::static_partitioner());
242}
243
244template <typename T0, typename T1, typename T2, typename T3, typename T4,
245 typename T5, typename F>
246void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
247 const T4 &D4, const T5 &D5, F f) {
248 const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
249 if (work_amount == 0) return;
250 tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
251 T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
252 utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
253 d5, D5);
254 for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
255 f(d0, d1, d2, d3, d4, d5);
256 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
257 }
258 }, tbb::static_partitioner());
259}
260#endif
261
262template <typename ...Args>
263void parallel_nd_in_omp(Args &&...args) {
264#if MKLDNN_THR == MKLDNN_THR_SEQ
265 for_nd(0, 1, utils::forward<Args>(args)...);
266#elif MKLDNN_THR == MKLDNN_THR_OMP
267 for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
268 utils::forward<Args>(args)...);
269#elif MKLDNN_THR == MKLDNN_THR_TBB
270 assert(!"unsupported parallel_nd_in_omp()");
271#endif
272}
273
274} // namespace impl
275} // namespace mkldnn
276
277#endif
278