1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef MATH_UTILS_HPP |
18 | #define MATH_UTILS_HPP |
19 | |
20 | #include <stdint.h> |
21 | #include <math.h> |
22 | |
23 | #include "utils.hpp" |
24 | #include "nstl.hpp" |
25 | #include "mkldnn_traits.hpp" |
26 | |
27 | #if defined(MKLDNN_X86_64) |
28 | #include "immintrin.h" |
29 | #endif |
30 | |
31 | namespace mkldnn { |
32 | namespace impl { |
33 | namespace math { |
34 | |
35 | /** rounds @p f to an integer according to the mxcsr register */ |
36 | inline int mxcsr_round(float f) { |
37 | #if defined(MKLDNN_X86_64) |
38 | return _mm_cvtss_si32(_mm_load_ss(&f)); |
39 | #else |
40 | return (int)nearbyintf(f); // optimism |
41 | #endif |
42 | } |
43 | |
44 | template <typename data_t, typename acc_t> |
45 | inline typename utils::enable_if<!nstl::is_integral<data_t>::value, |
46 | typename utils::remove_reference<data_t>::type>::type |
47 | saturate(const acc_t &x) { |
48 | return (typename utils::remove_reference<data_t>::type)x; |
49 | } |
50 | |
51 | template <typename data_t, typename acc_t> |
52 | inline typename utils::enable_if<nstl::is_integral<data_t>::value, |
53 | typename utils::remove_reference<data_t>::type>::type |
54 | saturate(const acc_t &x) { |
55 | acc_t v = x; |
56 | if (v < (acc_t)nstl::numeric_limits<data_t>::lowest()) |
57 | v = (acc_t)nstl::numeric_limits<data_t>::lowest(); |
58 | if (v > (acc_t)nstl::numeric_limits<data_t>::max()) |
59 | v = (acc_t)nstl::numeric_limits<data_t>::max(); |
60 | return (typename utils::remove_reference<data_t>::type)v; |
61 | } |
62 | |
63 | template <typename data_t> |
64 | double saturate(const double &x) { |
65 | double v = x; |
66 | if (v < (double)nstl::numeric_limits<data_t>::lowest()) |
67 | v = (double)nstl::numeric_limits<data_t>::lowest(); |
68 | if (v > (double)nstl::numeric_limits<data_t>::max()) |
69 | v = (double)nstl::numeric_limits<data_t>::max(); |
70 | return v; |
71 | } |
72 | |
73 | template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) { |
74 | return x <= 127u ? x : 127; |
75 | } |
76 | |
77 | template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) { |
78 | return x >= 0 ? x : 0; |
79 | } |
80 | |
81 | template <typename out_t> |
82 | typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type |
83 | out_round(float v) { return (out_t)mxcsr_round(v); } |
84 | |
85 | template <typename out_t> |
86 | typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type |
87 | out_round(double v) { return (out_t)mxcsr_round((float)v); } |
88 | |
89 | template <typename out_t> |
90 | typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type |
91 | out_round(float v) { return v; } |
92 | |
93 | inline int gcd(int a, int b) { |
94 | a = impl::nstl::abs(a); |
95 | b = impl::nstl::abs(b); |
96 | if (a < b) { int x = a; a = b; b = x; } |
97 | |
98 | if (b == 0) return a; |
99 | |
100 | int r; |
101 | while ((r = a % b) != 0) { a = b; b = r; } |
102 | |
103 | return b; |
104 | } |
105 | |
106 | template <typename T> |
107 | inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } |
108 | |
109 | /** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ |
110 | inline int ilog2q(size_t v) { |
111 | if (v == 0) |
112 | return -1; |
113 | |
114 | int p = 0; |
115 | # define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) |
116 | CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); |
117 | # undef CP |
118 | return p; |
119 | } |
120 | |
121 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
122 | inline U one_m_square(T x) { |
123 | return (U)(1 - x) * (1 + x); |
124 | } |
125 | |
126 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
127 | inline U x_m_square(T x) { |
128 | return (U)(1 - x) * x; |
129 | } |
130 | |
131 | /* activation */ |
132 | template <typename T, typename A, |
133 | typename U = typename utils::remove_reference<T>::type> |
134 | inline U relu_fwd(T s, A alpha) { |
135 | return s > 0 ? s : (U)(s * alpha); |
136 | } |
137 | template <typename T, typename A, |
138 | typename U = typename utils::remove_reference<T>::type> |
139 | inline U relu_bwd(T dd, T s, A alpha) { |
140 | return s > 0 ? dd : (U)(dd * alpha); |
141 | } |
142 | |
143 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
144 | inline U tanh_fwd(T s) { |
145 | const float e = tanhf((float) s); |
146 | return (U)e; |
147 | } |
148 | |
149 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
150 | inline U tanh_bwd(T dd, T s) { |
151 | const float e = tanh_fwd<float>((float) s); |
152 | return (U)(dd * (1 - e) * (1 + e)); |
153 | } |
154 | |
155 | template <typename T, typename A, |
156 | typename U = typename utils::remove_reference<T>::type> |
157 | inline U elu_fwd(T s, A alpha) { |
158 | return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); |
159 | } |
160 | template <typename T, typename A, |
161 | typename U = typename utils::remove_reference<T>::type> |
162 | inline U elu_bwd(T dd, T s, A alpha) { |
163 | return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); |
164 | } |
165 | |
166 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
167 | inline U square_fwd(T s) { |
168 | return s * s; |
169 | } |
170 | |
171 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
172 | inline U square_bwd(T dd, T s) { |
173 | return dd * 2 * s; |
174 | } |
175 | |
176 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
177 | inline U abs_fwd(T s) { |
178 | return s > 0 ? s : -s; |
179 | } |
180 | |
181 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
182 | inline U abs_bwd(T dd, T s) { |
183 | return s > 0 ? dd : s < 0 ? -dd : 0; |
184 | } |
185 | |
186 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
187 | inline U sqrt_fwd(T s) { |
188 | return s > 0 ? (U)(::sqrtf((float)(s))) : 0; |
189 | } |
190 | |
191 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
192 | inline U sqrt_bwd(T dd, T s) { |
193 | return s > 0 |
194 | ? (U)(dd / (2 * ::sqrtf((float)(s)))) |
195 | : 0; |
196 | } |
197 | |
198 | template <typename T, typename A, |
199 | typename U = typename utils::remove_reference<T>::type> |
200 | inline U linear_fwd(T s, A alpha, A beta) { |
201 | return (U)(alpha * s + beta); |
202 | } |
203 | |
204 | template <typename T, typename A, |
205 | typename U = typename utils::remove_reference<T>::type> |
206 | inline U linear_bwd(T dd, T s, A alpha, A beta) { |
207 | (void) s; |
208 | (void) beta; |
209 | return (U)(dd * alpha); |
210 | } |
211 | |
212 | template <typename T, typename A, |
213 | typename U = typename utils::remove_reference<T>::type> |
214 | inline U bounded_relu_fwd(T s, A alpha) { |
215 | s = s > 0 ? s : 0; |
216 | return s > alpha ? (U)(alpha) : s; |
217 | } |
218 | |
219 | template <typename T, typename A, |
220 | typename U = typename utils::remove_reference<T>::type> |
221 | inline U bounded_relu_bwd(T dd, T s, A alpha) { |
222 | return dd * (0 < s && s < alpha ? 1 : 0); |
223 | } |
224 | |
225 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
226 | inline U soft_relu_fwd(T s) { |
227 | float max_logf = 8.872284e+01; //::logf(FLT_MAX) |
228 | return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; |
229 | } |
230 | |
231 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
232 | inline U soft_relu_bwd(T dd, T s) { |
233 | return (U)(dd / (1 + ::expf((float)(-s)))); |
234 | } |
235 | |
236 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
237 | inline U logistic_fwd(T s) { |
238 | U v = (U)(::expf((float) -s)); |
239 | return 1 / (1 + v); |
240 | } |
241 | |
242 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
243 | inline U logistic_bwd(T dd, T s) { |
244 | U v = logistic_fwd<T, U>(s); |
245 | return dd * v * (1 - v); |
246 | } |
247 | |
248 | inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { |
249 | using namespace alg_kind; |
250 | using namespace utils; |
251 | const bool preserves_zero = true |
252 | && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) |
253 | && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); |
254 | return preserves_zero; |
255 | } |
256 | |
257 | inline float get_bias(const char *bias, size_t offset, data_type_t data_type) |
258 | { |
259 | if (!bias) |
260 | return 0.0f; |
261 | |
262 | #define CASE(dt) \ |
263 | case dt: return (float)((const prec_traits<dt>::type *)bias)[offset] |
264 | |
265 | switch (data_type) { |
266 | CASE(data_type::s8); |
267 | CASE(data_type::u8); |
268 | CASE(data_type::s32); |
269 | CASE(data_type::f32); |
270 | default: assert(!"unimplemented" ); |
271 | } |
272 | return 0; // never happens (should probably be a NaN) |
273 | #undef CASE |
274 | } |
275 | |
276 | } |
277 | } |
278 | } |
279 | |
280 | #endif |
281 | |