1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef MATH_UTILS_HPP
18#define MATH_UTILS_HPP
19
20#include <stdint.h>
21#include <math.h>
22
23#include "utils.hpp"
24#include "nstl.hpp"
25#include "mkldnn_traits.hpp"
26
27#if defined(MKLDNN_X86_64)
28#include "immintrin.h"
29#endif
30
31namespace mkldnn {
32namespace impl {
33namespace math {
34
35/** rounds @p f to an integer according to the mxcsr register */
36inline int mxcsr_round(float f) {
37#if defined(MKLDNN_X86_64)
38 return _mm_cvtss_si32(_mm_load_ss(&f));
39#else
40 return (int)nearbyintf(f); // optimism
41#endif
42}
43
44template <typename data_t, typename acc_t>
45inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
46 typename utils::remove_reference<data_t>::type>::type
47saturate(const acc_t &x) {
48 return (typename utils::remove_reference<data_t>::type)x;
49}
50
51template <typename data_t, typename acc_t>
52inline typename utils::enable_if<nstl::is_integral<data_t>::value,
53 typename utils::remove_reference<data_t>::type>::type
54saturate(const acc_t &x) {
55 acc_t v = x;
56 if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
57 v = (acc_t)nstl::numeric_limits<data_t>::lowest();
58 if (v > (acc_t)nstl::numeric_limits<data_t>::max())
59 v = (acc_t)nstl::numeric_limits<data_t>::max();
60 return (typename utils::remove_reference<data_t>::type)v;
61}
62
63template <typename data_t>
64double saturate(const double &x) {
65 double v = x;
66 if (v < (double)nstl::numeric_limits<data_t>::lowest())
67 v = (double)nstl::numeric_limits<data_t>::lowest();
68 if (v > (double)nstl::numeric_limits<data_t>::max())
69 v = (double)nstl::numeric_limits<data_t>::max();
70 return v;
71}
72
73template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
74 return x <= 127u ? x : 127;
75}
76
77template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
78 return x >= 0 ? x : 0;
79}
80
81template <typename out_t>
82typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
83out_round(float v) { return (out_t)mxcsr_round(v); }
84
85template <typename out_t>
86typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
87out_round(double v) { return (out_t)mxcsr_round((float)v); }
88
89template <typename out_t>
90typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
91out_round(float v) { return v; }
92
93inline int gcd(int a, int b) {
94 a = impl::nstl::abs(a);
95 b = impl::nstl::abs(b);
96 if (a < b) { int x = a; a = b; b = x; }
97
98 if (b == 0) return a;
99
100 int r;
101 while ((r = a % b) != 0) { a = b; b = r; }
102
103 return b;
104}
105
106template <typename T>
107inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
108
109/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
110inline int ilog2q(size_t v) {
111 if (v == 0)
112 return -1;
113
114 int p = 0;
115# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
116 CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
117# undef CP
118 return p;
119}
120
121template <typename T, typename U = typename utils::remove_reference<T>::type>
122inline U one_m_square(T x) {
123 return (U)(1 - x) * (1 + x);
124}
125
126template <typename T, typename U = typename utils::remove_reference<T>::type>
127inline U x_m_square(T x) {
128 return (U)(1 - x) * x;
129}
130
131/* activation */
132template <typename T, typename A,
133 typename U = typename utils::remove_reference<T>::type>
134inline U relu_fwd(T s, A alpha) {
135 return s > 0 ? s : (U)(s * alpha);
136}
137template <typename T, typename A,
138 typename U = typename utils::remove_reference<T>::type>
139inline U relu_bwd(T dd, T s, A alpha) {
140 return s > 0 ? dd : (U)(dd * alpha);
141}
142
143template <typename T, typename U = typename utils::remove_reference<T>::type>
144inline U tanh_fwd(T s) {
145 const float e = tanhf((float) s);
146 return (U)e;
147}
148
149template <typename T, typename U = typename utils::remove_reference<T>::type>
150inline U tanh_bwd(T dd, T s) {
151 const float e = tanh_fwd<float>((float) s);
152 return (U)(dd * (1 - e) * (1 + e));
153}
154
155template <typename T, typename A,
156 typename U = typename utils::remove_reference<T>::type>
157inline U elu_fwd(T s, A alpha) {
158 return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
159}
160template <typename T, typename A,
161 typename U = typename utils::remove_reference<T>::type>
162 inline U elu_bwd(T dd, T s, A alpha) {
163 return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
164}
165
166template <typename T, typename U = typename utils::remove_reference<T>::type>
167inline U square_fwd(T s) {
168 return s * s;
169}
170
171template <typename T, typename U = typename utils::remove_reference<T>::type>
172inline U square_bwd(T dd, T s) {
173 return dd * 2 * s;
174}
175
176template <typename T, typename U = typename utils::remove_reference<T>::type>
177inline U abs_fwd(T s) {
178 return s > 0 ? s : -s;
179}
180
181template <typename T, typename U = typename utils::remove_reference<T>::type>
182inline U abs_bwd(T dd, T s) {
183 return s > 0 ? dd : s < 0 ? -dd : 0;
184}
185
186template <typename T, typename U = typename utils::remove_reference<T>::type>
187inline U sqrt_fwd(T s) {
188 return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
189}
190
191template <typename T, typename U = typename utils::remove_reference<T>::type>
192inline U sqrt_bwd(T dd, T s) {
193 return s > 0
194 ? (U)(dd / (2 * ::sqrtf((float)(s))))
195 : 0;
196}
197
198template <typename T, typename A,
199 typename U = typename utils::remove_reference<T>::type>
200inline U linear_fwd(T s, A alpha, A beta) {
201 return (U)(alpha * s + beta);
202}
203
204template <typename T, typename A,
205 typename U = typename utils::remove_reference<T>::type>
206inline U linear_bwd(T dd, T s, A alpha, A beta) {
207 (void) s;
208 (void) beta;
209 return (U)(dd * alpha);
210}
211
212template <typename T, typename A,
213 typename U = typename utils::remove_reference<T>::type>
214inline U bounded_relu_fwd(T s, A alpha) {
215 s = s > 0 ? s : 0;
216 return s > alpha ? (U)(alpha) : s;
217}
218
219template <typename T, typename A,
220 typename U = typename utils::remove_reference<T>::type>
221inline U bounded_relu_bwd(T dd, T s, A alpha) {
222 return dd * (0 < s && s < alpha ? 1 : 0);
223}
224
225template <typename T, typename U = typename utils::remove_reference<T>::type>
226inline U soft_relu_fwd(T s) {
227 float max_logf = 8.872284e+01; //::logf(FLT_MAX)
228 return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
229}
230
231template <typename T, typename U = typename utils::remove_reference<T>::type>
232inline U soft_relu_bwd(T dd, T s) {
233 return (U)(dd / (1 + ::expf((float)(-s))));
234}
235
236template <typename T, typename U = typename utils::remove_reference<T>::type>
237inline U logistic_fwd(T s) {
238 U v = (U)(::expf((float) -s));
239 return 1 / (1 + v);
240}
241
242template <typename T, typename U = typename utils::remove_reference<T>::type>
243inline U logistic_bwd(T dd, T s) {
244 U v = logistic_fwd<T, U>(s);
245 return dd * v * (1 - v);
246}
247
248inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
249 using namespace alg_kind;
250 using namespace utils;
251 const bool preserves_zero = true
252 && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
253 && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
254 return preserves_zero;
255}
256
257inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
258{
259 if (!bias)
260 return 0.0f;
261
262#define CASE(dt) \
263 case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
264
265 switch (data_type) {
266 CASE(data_type::s8);
267 CASE(data_type::u8);
268 CASE(data_type::s32);
269 CASE(data_type::f32);
270 default: assert(!"unimplemented");
271 }
272 return 0; // never happens (should probably be a NaN)
273#undef CASE
274}
275
276}
277}
278}
279
280#endif
281