1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef UTILS_HPP
18#define UTILS_HPP
19
20#include <stddef.h>
21#include <stdio.h>
22#include <stdlib.h>
23#include <assert.h>
24#include <stdint.h>
25
26#if defined(__x86_64__) || defined(_M_X64)
27#define MKLDNN_X86_64
28#endif
29
30#define MSAN_ENABLED 0
31#if defined(__has_feature)
32#if __has_feature(memory_sanitizer)
33#undef MSAN_ENABLED
34#define MSAN_ENABLED 1
35#include <sanitizer/msan_interface.h>
36#endif
37#endif
38
39#include "c_types_map.hpp"
40#include "nstl.hpp"
41#include "z_magic.hpp"
42
43namespace mkldnn {
44namespace impl {
45
46// Sanity check for 64 bits
47static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only");
48
49#define CHECK(f) do { \
50 status_t status = f; \
51 if (status != status::success) \
52 return status; \
53} while (0)
54
55#define IMPLICATION(cause, effect) (!(cause) || !!(effect))
56
57namespace utils {
58
59/* a bunch of std:: analogues to be compliant with any msvs version
60 *
61 * Rationale: msvs c++ (and even some c) headers contain special pragma that
62 * injects msvs-version check into object files in order to abi-mismatches
63 * during the static linking. This makes sense if e.g. std:: objects are passed
64 * through between application and library, which is not the case for mkl-dnn
65 * (since there is no any c++-rt dependent stuff, ideally...). */
66
67/* SFINAE helper -- analogue to std::enable_if */
68template<bool expr, class T = void> struct enable_if {};
69template<class T> struct enable_if<true, T> { typedef T type; };
70
71/* analogue std::conditional */
72template <bool, typename, typename> struct conditional {};
73template <typename T, typename F> struct conditional<true, T, F>
74{ typedef T type; };
75template <typename T, typename F> struct conditional<false, T, F>
76{ typedef F type; };
77
78template <bool, typename, bool, typename, typename> struct conditional3 {};
79template <typename T, typename FT, typename FF>
80struct conditional3<true, T, false, FT, FF> { typedef T type; };
81template <typename T, typename FT, typename FF>
82struct conditional3<false, T, true, FT, FF> { typedef FT type; };
83template <typename T, typename FT, typename FF>
84struct conditional3<false, T, false, FT, FF> { typedef FF type; };
85
86template <bool, typename U, U, U> struct conditional_v {};
87template <typename U, U t, U f> struct conditional_v<true, U, t, f>
88{ static constexpr U value = t; };
89template <typename U, U t, U f> struct conditional_v<false, U, t, f>
90{ static constexpr U value = f; };
91
92template <typename T> struct remove_reference { typedef T type; };
93template <typename T> struct remove_reference<T&> { typedef T type; };
94template <typename T> struct remove_reference<T&&> { typedef T type; };
95
96template <typename T>
97inline T&& forward(typename utils::remove_reference<T>::type &t)
98{ return static_cast<T&&>(t); }
99template <typename T>
100inline T&& forward(typename utils::remove_reference<T>::type &&t)
101{ return static_cast<T&&>(t); }
102
103template <typename T>
104inline typename remove_reference<T>::type zero()
105{ auto zero = typename remove_reference<T>::type(); return zero; }
106
107template <typename T, typename P>
108inline bool everyone_is(T val, P item) { return val == item; }
109template <typename T, typename P, typename... Args>
110inline bool everyone_is(T val, P item, Args... item_others) {
111 return val == item && everyone_is(val, item_others...);
112}
113
114template <typename T, typename P>
115constexpr bool one_of(T val, P item) { return val == item; }
116template <typename T, typename P, typename... Args>
117constexpr bool one_of(T val, P item, Args... item_others) {
118 return val == item || one_of(val, item_others...);
119}
120
121template <typename... Args>
122inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); }
123
124template<typename T>
125inline void array_copy(T *dst, const T *src, size_t size) {
126 for (size_t i = 0; i < size; ++i) dst[i] = src[i];
127}
128template<typename T>
129inline bool array_cmp(const T *a1, const T *a2, size_t size) {
130 for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false;
131 return true;
132}
133template<typename T, typename U>
134inline void array_set(T *arr, const U& val, size_t size) {
135 for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val);
136}
137
138namespace product_impl {
139template<size_t> struct int2type{};
140
141template <typename T>
142constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; }
143
144template <typename T, size_t num>
145inline T product_impl(const T *arr, int2type<num>) {
146 return arr[0]*product_impl(arr+1, int2type<num-1>()); }
147}
148
149template <size_t num, typename T>
150inline T array_product(const T *arr) {
151 return product_impl::product_impl(arr, product_impl::int2type<num-1>());
152}
153
154template<typename T, typename R = T>
155inline R array_product(const T *arr, size_t size) {
156 R prod = 1;
157 for (size_t i = 0; i < size; ++i) prod *= arr[i];
158 return prod;
159}
160
161/** sorts an array of values using @p comparator. While sorting the array
162 * of value, the function permutes an array of @p keys accordingly.
163 *
164 * @note The arrays of @p keys can be omitted. In this case the function
165 * sorts the array of @vals only.
166 */
167template <typename T, typename U, typename F>
168inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) {
169 if (size == 0) return;
170
171 for (size_t i = 0; i < size - 1; ++i) {
172 bool swapped = false;
173
174 for (size_t j = 0; j < size - i - 1; j++) {
175 if (comparator(vals[j], vals[j + 1]) > 0) {
176 nstl::swap(vals[j], vals[j + 1]);
177 if (keys) nstl::swap(keys[j], keys[j + 1]);
178 swapped = true;
179 }
180 }
181
182 if (swapped == false) break;
183 }
184}
185
186template <typename T, typename U>
187inline typename remove_reference<T>::type div_up(const T a, const U b) {
188 assert(b);
189 return (a + b - 1) / b;
190}
191
192template <typename T, typename U>
193inline typename remove_reference<T>::type rnd_up(const T a, const U b) {
194 return div_up(a, b) * b;
195}
196
197template <typename T, typename U>
198inline typename remove_reference<T>::type rnd_dn(const T a, const U b) {
199 return (a / b) * b;
200}
201
202template <typename T> T *align_ptr(T *ptr, uintptr_t alignment)
203{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); }
204
205template <typename T, typename U, typename V>
206inline U this_block_size(const T offset, const U max, const V block_size) {
207 assert(offset < max);
208 // TODO (Roma): can't use nstl::max() due to circular dependency... we
209 // need to fix this
210 const T block_boundary = offset + block_size;
211 if (block_boundary > max)
212 return max - offset;
213 else
214 return block_size;
215}
216
217template<typename T>
218inline T nd_iterator_init(T start) { return start; }
219template<typename T, typename U, typename W, typename... Args>
220inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) {
221 start = nd_iterator_init(start, utils::forward<Args>(tuple)...);
222 x = start % X;
223 return start / X;
224}
225
226inline bool nd_iterator_step() { return true; }
227template<typename U, typename W, typename... Args>
228inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) {
229 if (nd_iterator_step(utils::forward<Args>(tuple)...) ) {
230 x = (x + 1) % X;
231 return x == 0;
232 }
233 return false;
234}
235
236template<typename U, typename W, typename Y>
237inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X)
238{
239 U max_jump = end - cur;
240 U dim_jump = X - x;
241 if (dim_jump <= max_jump) {
242 x = 0;
243 cur += dim_jump;
244 return true;
245 } else {
246 cur += max_jump;
247 x += max_jump;
248 return false;
249 }
250}
251template<typename U, typename W, typename Y, typename... Args>
252inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X,
253 Args &&... tuple)
254{
255 if (nd_iterator_jump(cur, end, utils::forward<Args>(tuple)...)) {
256 x = (x + 1) % X;
257 return x == 0;
258 }
259 return false;
260}
261
262template <typename T>
263inline T pick(size_t i, const T &x0) { return x0; }
264template <typename T, typename ...Args>
265inline T pick(size_t i, const T &x0, Args &&... args) {
266 return i == 0 ? x0 : pick(i - 1, utils::forward<Args>(args)...);
267}
268
269template <typename T>
270T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference,
271 const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) {
272 switch (prop_kind) {
273 case prop_kind::forward_inference: return val_fwd_inference;
274 case prop_kind::forward_training: return val_fwd_training;
275 case prop_kind::backward_data: return val_bwd_d;
276 case prop_kind::backward_weights: return val_bwd_w;
277 default: assert(!"unsupported prop_kind");
278 }
279 return T();
280}
281
282template <typename T>
283T pick_by_prop_kind(prop_kind_t prop_kind,
284 const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w)
285{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); }
286
287template <typename Telem, size_t Tdims>
288struct array_offset_calculator {
289 template <typename... Targs>
290 array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... }
291 {
292 _base_ptr = base;
293 }
294 template <typename... Targs>
295 inline Telem &operator()(Targs... Fargs)
296 {
297 return *(_base_ptr + _offset(1, Fargs...));
298 }
299
300private:
301 template <typename... Targs>
302 inline size_t _offset(size_t const dimension, size_t element)
303 {
304 return element;
305 }
306
307 template <typename... Targs>
308 inline size_t _offset(size_t const dimension, size_t theta, size_t element)
309 {
310 return element + (_dims[dimension] * theta);
311 }
312
313 template <typename... Targs>
314 inline size_t _offset(size_t const dimension, size_t theta, size_t element,
315 Targs... Fargs)
316 {
317 size_t t_prime = element + (_dims[dimension] * theta);
318 return _offset(dimension + 1, t_prime, Fargs...);
319 }
320
321 Telem *_base_ptr;
322 const int _dims[Tdims];
323};
324
325}
326
327int32_t fetch_and_add(int32_t *dst, int32_t val);
328inline void yield_thread() {}
329
330// Reads an environment variable 'name' and stores its string value in the
331// 'buffer' of 'buffer_size' bytes on success.
332//
333// - Returns the length of the environment variable string value (excluding
334// the terminating 0) if it is set and its contents (including the terminating
335// 0) can be stored in the 'buffer' without truncation.
336//
337// - Returns negated length of environment variable string value and writes
338// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to
339// store the value (including the terminating 0) without truncation.
340//
341// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment
342// variable is not set.
343//
344// - Returns INT_MIN if the 'name' is NULL.
345//
346// - Returns INT_MIN if the 'buffer_size' is negative.
347//
348// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than
349// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to
350// retrieve the length of the environment variable value string.
351//
352int getenv(const char *name, char *buffer, int buffer_size);
353// Reads an integer from the environment
354int getenv_int(const char *name, int default_value = 0);
355bool jit_dump_enabled();
356FILE *fopen(const char *filename, const char *mode);
357
358constexpr int msan_enabled = MSAN_ENABLED;
359inline void msan_unpoison(void *ptr, size_t size) {
360#if MSAN_ENABLED
361 __msan_unpoison(ptr, size);
362#endif
363}
364
365}
366}
367
368#endif
369
370// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
371