| 1 | /******************************************************************************* |
| 2 | * Copyright 2017-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #ifndef MATH_UTILS_HPP |
| 18 | #define MATH_UTILS_HPP |
| 19 | |
| 20 | #include <stdint.h> |
| 21 | #include <math.h> |
| 22 | |
| 23 | #include "utils.hpp" |
| 24 | #include "nstl.hpp" |
| 25 | #include "mkldnn_traits.hpp" |
| 26 | |
| 27 | #if defined(MKLDNN_X86_64) |
| 28 | #include "immintrin.h" |
| 29 | #endif |
| 30 | |
| 31 | namespace mkldnn { |
| 32 | namespace impl { |
| 33 | namespace math { |
| 34 | |
| 35 | /** rounds @p f to an integer according to the mxcsr register */ |
| 36 | inline int mxcsr_round(float f) { |
| 37 | #if defined(MKLDNN_X86_64) |
| 38 | return _mm_cvtss_si32(_mm_load_ss(&f)); |
| 39 | #else |
| 40 | return (int)nearbyintf(f); // optimism |
| 41 | #endif |
| 42 | } |
| 43 | |
| 44 | template <typename data_t, typename acc_t> |
| 45 | inline typename utils::enable_if<!nstl::is_integral<data_t>::value, |
| 46 | typename utils::remove_reference<data_t>::type>::type |
| 47 | saturate(const acc_t &x) { |
| 48 | return (typename utils::remove_reference<data_t>::type)x; |
| 49 | } |
| 50 | |
| 51 | template <typename data_t, typename acc_t> |
| 52 | inline typename utils::enable_if<nstl::is_integral<data_t>::value, |
| 53 | typename utils::remove_reference<data_t>::type>::type |
| 54 | saturate(const acc_t &x) { |
| 55 | acc_t v = x; |
| 56 | if (v < (acc_t)nstl::numeric_limits<data_t>::lowest()) |
| 57 | v = (acc_t)nstl::numeric_limits<data_t>::lowest(); |
| 58 | if (v > (acc_t)nstl::numeric_limits<data_t>::max()) |
| 59 | v = (acc_t)nstl::numeric_limits<data_t>::max(); |
| 60 | return (typename utils::remove_reference<data_t>::type)v; |
| 61 | } |
| 62 | |
| 63 | template <typename data_t> |
| 64 | double saturate(const double &x) { |
| 65 | double v = x; |
| 66 | if (v < (double)nstl::numeric_limits<data_t>::lowest()) |
| 67 | v = (double)nstl::numeric_limits<data_t>::lowest(); |
| 68 | if (v > (double)nstl::numeric_limits<data_t>::max()) |
| 69 | v = (double)nstl::numeric_limits<data_t>::max(); |
| 70 | return v; |
| 71 | } |
| 72 | |
| 73 | template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) { |
| 74 | return x <= 127u ? x : 127; |
| 75 | } |
| 76 | |
| 77 | template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) { |
| 78 | return x >= 0 ? x : 0; |
| 79 | } |
| 80 | |
| 81 | template <typename out_t> |
| 82 | typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type |
| 83 | out_round(float v) { return (out_t)mxcsr_round(v); } |
| 84 | |
| 85 | template <typename out_t> |
| 86 | typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type |
| 87 | out_round(double v) { return (out_t)mxcsr_round((float)v); } |
| 88 | |
| 89 | template <typename out_t> |
| 90 | typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type |
| 91 | out_round(float v) { return v; } |
| 92 | |
| 93 | inline int gcd(int a, int b) { |
| 94 | a = impl::nstl::abs(a); |
| 95 | b = impl::nstl::abs(b); |
| 96 | if (a < b) { int x = a; a = b; b = x; } |
| 97 | |
| 98 | if (b == 0) return a; |
| 99 | |
| 100 | int r; |
| 101 | while ((r = a % b) != 0) { a = b; b = r; } |
| 102 | |
| 103 | return b; |
| 104 | } |
| 105 | |
| 106 | template <typename T> |
| 107 | inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } |
| 108 | |
| 109 | /** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ |
| 110 | inline int ilog2q(size_t v) { |
| 111 | if (v == 0) |
| 112 | return -1; |
| 113 | |
| 114 | int p = 0; |
| 115 | # define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) |
| 116 | CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); |
| 117 | # undef CP |
| 118 | return p; |
| 119 | } |
| 120 | |
| 121 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 122 | inline U one_m_square(T x) { |
| 123 | return (U)(1 - x) * (1 + x); |
| 124 | } |
| 125 | |
| 126 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 127 | inline U x_m_square(T x) { |
| 128 | return (U)(1 - x) * x; |
| 129 | } |
| 130 | |
| 131 | /* activation */ |
| 132 | template <typename T, typename A, |
| 133 | typename U = typename utils::remove_reference<T>::type> |
| 134 | inline U relu_fwd(T s, A alpha) { |
| 135 | return s > 0 ? s : (U)(s * alpha); |
| 136 | } |
| 137 | template <typename T, typename A, |
| 138 | typename U = typename utils::remove_reference<T>::type> |
| 139 | inline U relu_bwd(T dd, T s, A alpha) { |
| 140 | return s > 0 ? dd : (U)(dd * alpha); |
| 141 | } |
| 142 | |
| 143 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 144 | inline U tanh_fwd(T s) { |
| 145 | const float e = tanhf((float) s); |
| 146 | return (U)e; |
| 147 | } |
| 148 | |
| 149 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 150 | inline U tanh_bwd(T dd, T s) { |
| 151 | const float e = tanh_fwd<float>((float) s); |
| 152 | return (U)(dd * (1 - e) * (1 + e)); |
| 153 | } |
| 154 | |
| 155 | template <typename T, typename A, |
| 156 | typename U = typename utils::remove_reference<T>::type> |
| 157 | inline U elu_fwd(T s, A alpha) { |
| 158 | return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); |
| 159 | } |
| 160 | template <typename T, typename A, |
| 161 | typename U = typename utils::remove_reference<T>::type> |
| 162 | inline U elu_bwd(T dd, T s, A alpha) { |
| 163 | return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); |
| 164 | } |
| 165 | |
| 166 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 167 | inline U square_fwd(T s) { |
| 168 | return s * s; |
| 169 | } |
| 170 | |
| 171 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 172 | inline U square_bwd(T dd, T s) { |
| 173 | return dd * 2 * s; |
| 174 | } |
| 175 | |
| 176 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 177 | inline U abs_fwd(T s) { |
| 178 | return s > 0 ? s : -s; |
| 179 | } |
| 180 | |
| 181 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 182 | inline U abs_bwd(T dd, T s) { |
| 183 | return s > 0 ? dd : s < 0 ? -dd : 0; |
| 184 | } |
| 185 | |
| 186 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 187 | inline U sqrt_fwd(T s) { |
| 188 | return s > 0 ? (U)(::sqrtf((float)(s))) : 0; |
| 189 | } |
| 190 | |
| 191 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 192 | inline U sqrt_bwd(T dd, T s) { |
| 193 | return s > 0 |
| 194 | ? (U)(dd / (2 * ::sqrtf((float)(s)))) |
| 195 | : 0; |
| 196 | } |
| 197 | |
| 198 | template <typename T, typename A, |
| 199 | typename U = typename utils::remove_reference<T>::type> |
| 200 | inline U linear_fwd(T s, A alpha, A beta) { |
| 201 | return (U)(alpha * s + beta); |
| 202 | } |
| 203 | |
| 204 | template <typename T, typename A, |
| 205 | typename U = typename utils::remove_reference<T>::type> |
| 206 | inline U linear_bwd(T dd, T s, A alpha, A beta) { |
| 207 | (void) s; |
| 208 | (void) beta; |
| 209 | return (U)(dd * alpha); |
| 210 | } |
| 211 | |
| 212 | template <typename T, typename A, |
| 213 | typename U = typename utils::remove_reference<T>::type> |
| 214 | inline U bounded_relu_fwd(T s, A alpha) { |
| 215 | s = s > 0 ? s : 0; |
| 216 | return s > alpha ? (U)(alpha) : s; |
| 217 | } |
| 218 | |
| 219 | template <typename T, typename A, |
| 220 | typename U = typename utils::remove_reference<T>::type> |
| 221 | inline U bounded_relu_bwd(T dd, T s, A alpha) { |
| 222 | return dd * (0 < s && s < alpha ? 1 : 0); |
| 223 | } |
| 224 | |
| 225 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 226 | inline U soft_relu_fwd(T s) { |
| 227 | float max_logf = 8.872284e+01; //::logf(FLT_MAX) |
| 228 | return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; |
| 229 | } |
| 230 | |
| 231 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 232 | inline U soft_relu_bwd(T dd, T s) { |
| 233 | return (U)(dd / (1 + ::expf((float)(-s)))); |
| 234 | } |
| 235 | |
| 236 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 237 | inline U logistic_fwd(T s) { |
| 238 | U v = (U)(::expf((float) -s)); |
| 239 | return 1 / (1 + v); |
| 240 | } |
| 241 | |
| 242 | template <typename T, typename U = typename utils::remove_reference<T>::type> |
| 243 | inline U logistic_bwd(T dd, T s) { |
| 244 | U v = logistic_fwd<T, U>(s); |
| 245 | return dd * v * (1 - v); |
| 246 | } |
| 247 | |
| 248 | inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { |
| 249 | using namespace alg_kind; |
| 250 | using namespace utils; |
| 251 | const bool preserves_zero = true |
| 252 | && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) |
| 253 | && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); |
| 254 | return preserves_zero; |
| 255 | } |
| 256 | |
| 257 | inline float get_bias(const char *bias, size_t offset, data_type_t data_type) |
| 258 | { |
| 259 | if (!bias) |
| 260 | return 0.0f; |
| 261 | |
| 262 | #define CASE(dt) \ |
| 263 | case dt: return (float)((const prec_traits<dt>::type *)bias)[offset] |
| 264 | |
| 265 | switch (data_type) { |
| 266 | CASE(data_type::s8); |
| 267 | CASE(data_type::u8); |
| 268 | CASE(data_type::s32); |
| 269 | CASE(data_type::f32); |
| 270 | default: assert(!"unimplemented" ); |
| 271 | } |
| 272 | return 0; // never happens (should probably be a NaN) |
| 273 | #undef CASE |
| 274 | } |
| 275 | |
| 276 | } |
| 277 | } |
| 278 | } |
| 279 | |
| 280 | #endif |
| 281 | |