| 1 | /******************************************************************************* |
| 2 | * Copyright 2016-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP |
| 18 | #define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP |
| 19 | |
| 20 | #include "c_types_map.hpp" |
| 21 | #include "memory_tracking.hpp" |
| 22 | |
| 23 | #include "jit_generator.hpp" |
| 24 | #include "jit_primitive_conf.hpp" |
| 25 | #include "jit_uni_eltwise.hpp" |
| 26 | |
| 27 | namespace mkldnn { |
| 28 | namespace impl { |
| 29 | namespace cpu { |
| 30 | |
| 31 | template<typename Vmm> |
| 32 | struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { |
| 33 | |
| 34 | _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, |
| 35 | const primitive_attr_t &attr) |
| 36 | : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) |
| 37 | { |
| 38 | if (jcp.with_eltwise) |
| 39 | eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>( |
| 40 | this, jcp.eltwise); |
| 41 | |
| 42 | generate(); |
| 43 | jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); |
| 44 | } |
| 45 | |
| 46 | ~_jit_avx512_common_conv_fwd_kernel() { |
| 47 | delete eltwise_injector_; |
| 48 | } |
| 49 | |
| 50 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) |
| 51 | |
| 52 | jit_conv_conf_t jcp; |
| 53 | const primitive_attr_t &attr_; |
| 54 | void (*jit_ker_)(jit_conv_call_s *); |
| 55 | |
| 56 | private: |
| 57 | using reg64_t = const Xbyak::Reg64; |
| 58 | enum { |
| 59 | typesize = sizeof(float), |
| 60 | ker_reg_base_idx = 28, |
| 61 | }; |
| 62 | |
| 63 | reg64_t param = abi_param1; |
| 64 | reg64_t reg_inp = r8; |
| 65 | reg64_t reg_ker = r9; |
| 66 | reg64_t reg_out = r10; |
| 67 | |
| 68 | reg64_t reg_inp_prf = r11; |
| 69 | reg64_t reg_ker_prf = r12; |
| 70 | reg64_t reg_out_prf = r13; |
| 71 | reg64_t reg_owb = r12; |
| 72 | |
| 73 | reg64_t aux_reg_inp = r14; |
| 74 | reg64_t aux_reg_ker = r15; |
| 75 | |
| 76 | reg64_t aux_reg_inp_prf = rsi; |
| 77 | reg64_t aux_reg_ker_prf = rdx; |
| 78 | |
| 79 | reg64_t reg_channel = rsi; |
| 80 | reg64_t reg_bias = rdx; |
| 81 | |
| 82 | reg64_t aux_reg_ker_d = r9; |
| 83 | reg64_t aux_reg_inp_d = rbx; |
| 84 | reg64_t aux_reg_inp_d_prf = r13; |
| 85 | reg64_t aux_reg_ker_d_prf = abi_not_param1; |
| 86 | reg64_t reg_ki = r10; |
| 87 | |
| 88 | reg64_t reg_kj = rax; |
| 89 | reg64_t reg_relu_ns = rax; |
| 90 | reg64_t reg_oi = rbx; |
| 91 | reg64_t reg_kh = abi_not_param1; |
| 92 | |
| 93 | reg64_t reg_tmp = rbp; |
| 94 | |
| 95 | reg64_t reg_ic_loop = rdx; |
| 96 | reg64_t reg_inp_loop = rsi; |
| 97 | |
| 98 | reg64_t reg_init_flag = r13; |
| 99 | reg64_t reg_bias_ptr = param; |
| 100 | |
| 101 | reg64_t aux_reg_ic = r12; |
| 102 | reg64_t reg_binp = rax; |
| 103 | reg64_t reg_bout = r11; |
| 104 | reg64_t aux1_reg_inp = rbx; |
| 105 | reg64_t aux_reg_out = abi_not_param1; |
| 106 | |
| 107 | reg64_t reg_long_offt = r11; |
| 108 | reg64_t reg_out_long_offt = r14; |
| 109 | |
| 110 | inline Vmm vmm_ker(int i_ic) { |
| 111 | assert(i_ic < 4); |
| 112 | return Vmm(ker_reg_base_idx + i_ic); |
| 113 | } |
| 114 | |
| 115 | inline Vmm vmm_out(int i_ur, int i_oc) { |
| 116 | int idx = i_ur + i_oc * jcp.ur_w; |
| 117 | assert(idx < ker_reg_base_idx); |
| 118 | return Vmm(idx); |
| 119 | } |
| 120 | |
| 121 | inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { |
| 122 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
| 123 | assert(idx < 31); |
| 124 | return Vmm(idx); |
| 125 | } |
| 126 | |
| 127 | Xbyak::Reg64 imm_addr64 = r15; |
| 128 | Vmm vmm_wei = Vmm(31); |
| 129 | |
| 130 | jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_; |
| 131 | |
| 132 | inline void prepare_output(int ur_w); |
| 133 | inline void store_output(int ur_w); |
| 134 | inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); |
| 135 | inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); |
| 136 | inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r); |
| 137 | inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r); |
| 138 | inline void compute_loop(int ur_w, int pad_l, int pad_r); |
| 139 | |
| 140 | void generate(); |
| 141 | |
| 142 | inline size_t get_output_offset(int oi, int n_oc_block) { |
| 143 | return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh |
| 144 | * jcp.ow * jcp.od + oi) * jcp.oc_block; |
| 145 | } |
| 146 | |
| 147 | inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { |
| 148 | size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1; |
| 149 | size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id; |
| 150 | return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1) |
| 151 | + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str); |
| 152 | } |
| 153 | |
| 154 | inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) { |
| 155 | return jcp.typesize_in * jcp.oc_block |
| 156 | * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd |
| 157 | + (ic + ker_number) + ki * jcp.ic_block); |
| 158 | } |
| 159 | |
| 160 | inline int get_ow_start(int ki, int pad_l) { |
| 161 | return nstl::max(0, |
| 162 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
| 163 | } |
| 164 | |
| 165 | inline int get_ow_end(int ur_w, int ki, int pad_r) { |
| 166 | return ur_w - nstl::max(0, utils::div_up(pad_r |
| 167 | - (jcp.kw - 1 - ki) |
| 168 | * (jcp.dilate_w + 1), |
| 169 | jcp.stride_w)); |
| 170 | } |
| 171 | }; |
| 172 | |
| 173 | struct jit_avx512_common_conv_fwd_kernel { |
| 174 | |
| 175 | jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, |
| 176 | const primitive_attr_t &attr) : |
| 177 | jit_ker(nullptr), |
| 178 | zmm_kernel_(nullptr), |
| 179 | xmm_kernel_(nullptr) { |
| 180 | int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block; |
| 181 | switch (ch_block) { |
| 182 | case 16: |
| 183 | zmm_kernel_ = |
| 184 | new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>( |
| 185 | ajcp, attr); |
| 186 | jit_ker = zmm_kernel_->jit_ker_; |
| 187 | return; |
| 188 | case 4: |
| 189 | xmm_kernel_ = |
| 190 | new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>( |
| 191 | ajcp, attr); |
| 192 | jit_ker = xmm_kernel_->jit_ker_; |
| 193 | return; |
| 194 | default: |
| 195 | assert(!"invalid channel blocking" ); |
| 196 | } |
| 197 | } |
| 198 | |
| 199 | ~jit_avx512_common_conv_fwd_kernel() { |
| 200 | delete xmm_kernel_; |
| 201 | delete zmm_kernel_; |
| 202 | } |
| 203 | |
| 204 | enum { |
| 205 | typesize = sizeof(float) |
| 206 | }; |
| 207 | |
| 208 | static bool post_ops_ok(jit_conv_conf_t &jcp, |
| 209 | const primitive_attr_t &attr); |
| 210 | static status_t init_conf(jit_conv_conf_t &jcp, |
| 211 | const convolution_desc_t &cd, |
| 212 | memory_desc_t &src_pd, |
| 213 | memory_desc_t &weights_pd, |
| 214 | memory_desc_t &dst_pd, |
| 215 | memory_desc_t &bias_pd, |
| 216 | const primitive_attr_t &attr, |
| 217 | int nthreads); |
| 218 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
| 219 | const jit_conv_conf_t &jcp); |
| 220 | |
| 221 | void(*jit_ker)(jit_conv_call_s *); |
| 222 | _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_; |
| 223 | _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_; |
| 224 | }; |
| 225 | |
| 226 | struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator { |
| 227 | |
| 228 | jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) |
| 229 | { |
| 230 | generate(); |
| 231 | jit_ker = (void (*)(jit_conv_call_s *))getCode(); |
| 232 | } |
| 233 | |
| 234 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32) |
| 235 | |
| 236 | static status_t init_conf(jit_conv_conf_t &jcp, |
| 237 | const convolution_desc_t &cd, |
| 238 | const memory_desc_wrapper &diff_src_d, |
| 239 | const memory_desc_wrapper &weights_d, |
| 240 | const memory_desc_wrapper &diff_dst_d); |
| 241 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
| 242 | const jit_conv_conf_t &jcp); |
| 243 | |
| 244 | jit_conv_conf_t jcp; |
| 245 | void (*jit_ker)(jit_conv_call_s *); |
| 246 | |
| 247 | private: |
| 248 | using reg64_t = const Xbyak::Reg64; |
| 249 | enum { |
| 250 | typesize = sizeof(float), |
| 251 | ker_reg_base_idx = 28, |
| 252 | }; |
| 253 | |
| 254 | reg64_t param = abi_param1; |
| 255 | reg64_t reg_dst = r8; |
| 256 | reg64_t reg_ker = r9; |
| 257 | reg64_t reg_src = r10; |
| 258 | |
| 259 | reg64_t reg_dst_prf = r11; |
| 260 | reg64_t reg_ker_prf = r12; |
| 261 | reg64_t reg_src_prf = r13; |
| 262 | |
| 263 | reg64_t aux_reg_dst = r14; |
| 264 | reg64_t aux_reg_ker = r15; |
| 265 | |
| 266 | reg64_t aux_reg_dst_prf = rsi; |
| 267 | reg64_t aux_reg_ker_prf = rdx; |
| 268 | |
| 269 | reg64_t aux_reg_dst_d_prf = r13; |
| 270 | reg64_t aux_reg_dst_d = rbx; |
| 271 | reg64_t aux_reg_ker_d_prf = abi_not_param1; |
| 272 | reg64_t aux_reg_ker_d = r9; |
| 273 | reg64_t reg_ki = r10; |
| 274 | |
| 275 | reg64_t reg_kj = rax; |
| 276 | reg64_t reg_oi = rbx; |
| 277 | reg64_t reg_kh = abi_not_param1; |
| 278 | |
| 279 | reg64_t reg_channel = rsi; |
| 280 | |
| 281 | reg64_t reg_tmp = rbp; |
| 282 | reg64_t reg_long_offt = r14; |
| 283 | |
| 284 | inline Xbyak::Zmm zmm_ker(int i_ic) { |
| 285 | assert(i_ic < 4); |
| 286 | return Xbyak::Zmm(ker_reg_base_idx + i_ic); |
| 287 | } |
| 288 | inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { |
| 289 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
| 290 | assert(idx < 31); |
| 291 | return Xbyak::Zmm(idx); |
| 292 | } |
| 293 | inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) { |
| 294 | int idx = i_ur + i_oc * jcp.ur_w; |
| 295 | assert(idx < ker_reg_base_idx); |
| 296 | return Xbyak::Zmm(idx); |
| 297 | } |
| 298 | |
| 299 | Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); |
| 300 | |
| 301 | inline void prepare_output(int ur_w); |
| 302 | inline void store_output(int ur_w); |
| 303 | inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); |
| 304 | inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); |
| 305 | inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow); |
| 306 | inline void compute_loop(int ur_w, int l_overflow, int r_overflow); |
| 307 | void generate(); |
| 308 | |
| 309 | inline int get_iw_start(int ki, int l_overflow) |
| 310 | { |
| 311 | int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w |
| 312 | + l_overflow * jcp.stride_w |
| 313 | - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); |
| 314 | while (res < 0) |
| 315 | res += jcp.stride_w; |
| 316 | |
| 317 | return res; |
| 318 | } |
| 319 | |
| 320 | inline int get_iw_end(int ur_w, int ki, int r_overflow) |
| 321 | { |
| 322 | if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) |
| 323 | ur_w += nstl::min(0, jcp.r_pad); // remove negative padding |
| 324 | int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w |
| 325 | + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); |
| 326 | while (res < 0) |
| 327 | res += jcp.stride_w; |
| 328 | |
| 329 | return ur_w - res; |
| 330 | } |
| 331 | }; |
| 332 | |
| 333 | struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { |
| 334 | |
| 335 | jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) |
| 336 | : jcp(ajcp) |
| 337 | { |
| 338 | generate(); |
| 339 | jit_ker = (void (*)(jit_conv_call_s *))getCode(); |
| 340 | } |
| 341 | |
| 342 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) |
| 343 | |
| 344 | static status_t init_conf(jit_conv_conf_t &jcp, |
| 345 | const convolution_desc_t &cd, |
| 346 | memory_desc_t &src_md, |
| 347 | memory_desc_t &diff_weights_md, |
| 348 | memory_desc_t &diff_bias_md, |
| 349 | memory_desc_t &diff_dst_md); |
| 350 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
| 351 | const jit_conv_conf_t &jcp); |
| 352 | |
| 353 | jit_conv_conf_t jcp; |
| 354 | void (*jit_ker)(jit_conv_call_s *); |
| 355 | |
| 356 | private: |
| 357 | using reg64_t = const Xbyak::Reg64; |
| 358 | enum {typesize = sizeof(float)}; |
| 359 | static const int max_ur_w; |
| 360 | |
| 361 | reg64_t param = abi_param1; |
| 362 | reg64_t reg_input = rax; |
| 363 | reg64_t reg_kernel = rdx; |
| 364 | reg64_t reg_output = rsi; |
| 365 | reg64_t b_ic = abi_not_param1; |
| 366 | reg64_t kj = r8; |
| 367 | reg64_t reg_kh = r9; |
| 368 | reg64_t reg_ur_w_trips = r10; |
| 369 | reg64_t reg_oj = r15; |
| 370 | reg64_t reg_ih_count = rbx; |
| 371 | reg64_t reg_tmp = r14; |
| 372 | reg64_t reg_long_offt = r14; |
| 373 | |
| 374 | reg64_t ki = r11; |
| 375 | reg64_t reg_kd_count = r12; |
| 376 | reg64_t reg_oi = r12; |
| 377 | reg64_t reg_d_index = r13; |
| 378 | reg64_t reg_input_d = r15; |
| 379 | reg64_t reg_output_d = rbx; |
| 380 | reg64_t aux_reg_input = r12; |
| 381 | reg64_t aux_reg_kernel = r13; |
| 382 | reg64_t reg_bias = rbx; |
| 383 | |
| 384 | inline void bias_kernel(); |
| 385 | inline void maybe_zero_kernel(); |
| 386 | inline void compute_oh_step_unroll_ow_icblock(int ic_block_step, |
| 387 | int max_ur_w); |
| 388 | inline void od_step_comeback_pointers(); |
| 389 | inline void oh_step_comeback_pointers(); |
| 390 | inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); |
| 391 | inline void compute_ic_block_step(int ur_w, |
| 392 | int pad_l, int pad_r, int ic_block_step, |
| 393 | int input_offset, int kernel_offset, int output_offset, |
| 394 | bool input_wraparound = false); |
| 395 | inline void compute_ic_block_step_fma(int ur_w, |
| 396 | int pad_l, int pad_r, int ic_block_step, |
| 397 | int input_offset, int kernel_offset, int output_offset, |
| 398 | bool input_wraparound); |
| 399 | inline void compute_ic_block_step_4fma(int ur_w, |
| 400 | int pad_l, int pad_r, int ic_block_step, |
| 401 | int input_offset, int kernel_offset, int output_offset, |
| 402 | bool input_wraparound); |
| 403 | inline void compute_oh_step_common(int ic_block_step, int max_ur_w); |
| 404 | inline void compute_oh_step_disp(); |
| 405 | inline void compute_oh_loop_common(); |
| 406 | inline void compute_d_loop_common(); |
| 407 | |
| 408 | inline bool compute_full_spat_loop(); |
| 409 | inline bool flat_4ops_compute(); |
| 410 | |
| 411 | inline void compute_loop(); |
| 412 | |
| 413 | void generate(); |
| 414 | |
| 415 | static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, |
| 416 | int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); |
| 417 | }; |
| 418 | |
| 419 | } |
| 420 | } |
| 421 | } |
| 422 | |
| 423 | #endif |
| 424 | |