1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef MKLDNN_THREAD_HPP
18#define MKLDNN_THREAD_HPP
19
20#include "utils.hpp"
21#include "z_magic.hpp"
22
23#define MKLDNN_THR_SEQ 0
24#define MKLDNN_THR_OMP 1
25#define MKLDNN_THR_TBB 2
26
27/* Ideally this condition below should never happen (if the library is built
28 * using regular cmake). For the 3rd-party projects that build the library
29 * from the sources on their own try to guess the right threading... */
30#if !defined(MKLDNN_THR)
31# define MKLDNN_THR MKLDNN_THR_TBB
32#endif
33
34#if MKLDNN_THR == MKLDNN_THR_SEQ
35#define MKLDNN_THR_SYNC 1
36inline int mkldnn_get_max_threads() { return 1; }
37inline int mkldnn_get_num_threads() { return 1; }
38inline int mkldnn_get_thread_num() { return 0; }
39inline int mkldnn_in_parallel() { return 0; }
40inline void mkldnn_thr_barrier() {}
41
42#define PRAGMA_OMP(...)
43
44#elif MKLDNN_THR == MKLDNN_THR_OMP
45#include <omp.h>
46#define MKLDNN_THR_SYNC 1
47
48inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
49inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
50inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
51inline int mkldnn_in_parallel() { return omp_in_parallel(); }
52inline void mkldnn_thr_barrier() {
53# pragma omp barrier
54}
55
56#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
57
58#elif MKLDNN_THR == MKLDNN_THR_TBB
59#include "tbb/task_arena.h"
60#include "tbb/parallel_for.h"
61#define MKLDNN_THR_SYNC 0
62
63inline int mkldnn_get_max_threads()
64{ return tbb::this_task_arena::max_concurrency(); }
65inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
66inline int mkldnn_get_thread_num()
67{ return tbb::this_task_arena::current_thread_index(); }
68inline int mkldnn_in_parallel() { return 0; }
69inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
70
71#define PRAGMA_OMP(...)
72
73#endif
74
75/* MSVC still supports omp 2.0 only */
76#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
77# define collapse(x)
78# define PRAGMA_OMP_SIMD(...)
79#else
80# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
81#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
82
83namespace mkldnn {
84namespace impl {
85
86inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
87
88template <typename T, typename U>
89inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
90 T n_min = 1;
91 T &n_my = n_end;
92 if (team <= 1 || n == 0) {
93 n_start = 0;
94 n_my = n;
95 } else if (n_min == 1) {
96 // team = T1 + T2
97 // n = T1*n1 + T2*n2 (n1 - n2 = 1)
98 T n1 = utils::div_up(n, (T)team);
99 T n2 = n1 - 1;
100 T T1 = n - n2 * (T)team;
101 n_my = (T)tid < T1 ? n1 : n2;
102 n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
103 }
104
105 n_end += n_start;
106}
107
108} // namespace impl
109} // namespace mkldnn
110
111#include "mkldnn_thread_parallel_nd.hpp"
112
113#endif
114
115// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
116