| 1 | /******************************************************************************* |
| 2 | * Copyright 2016-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #ifndef JIT_PRIMITIVE_CONF_HPP |
| 18 | #define JIT_PRIMITIVE_CONF_HPP |
| 19 | |
| 20 | #include <stdint.h> |
| 21 | |
| 22 | #include "common/primitive_attr.hpp" |
| 23 | |
| 24 | namespace mkldnn { |
| 25 | namespace impl { |
| 26 | namespace cpu { |
| 27 | |
| 28 | /* convolution */ |
| 29 | enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni}; |
| 30 | enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, |
| 31 | loop_ngcw, loop_nhwcg, loop_nwcg}; |
| 32 | enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, |
| 33 | loop_brl}; |
| 34 | enum conv_kernel_kind_t {embd_bcast, expl_bcast}; |
| 35 | |
| 36 | enum { |
| 37 | FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1, |
| 38 | FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3, |
| 39 | FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5, |
| 40 | FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7, |
| 41 | FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9, |
| 42 | FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips |
| 43 | loading weights-data from memory; this |
| 44 | needs to happen on the first Group/16 |
| 45 | iteration. */ |
| 46 | FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip |
| 47 | loading bias data from memory */ |
| 48 | FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution |
| 49 | pass */ |
| 50 | }; |
| 51 | |
| 52 | struct jit_conv_conf_t { |
| 53 | prop_kind_t prop_kind; |
| 54 | conv_version_t ver; |
| 55 | conv_loop_order_t loop_order; |
| 56 | |
| 57 | int simd_w; |
| 58 | int ndims; |
| 59 | int mb; |
| 60 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
| 61 | int id, ih, iw, od, oh, ow; |
| 62 | int f_pad, l_pad, t_pad; |
| 63 | int back_pad, r_pad, b_pad; |
| 64 | int kd, kh, kw; |
| 65 | int stride_d, stride_h, stride_w; |
| 66 | int dilate_d, dilate_h, dilate_w; |
| 67 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
| 68 | bool with_bias; |
| 69 | bool with_sum; |
| 70 | bool with_eltwise; |
| 71 | |
| 72 | post_ops_t::entry_t::eltwise_t eltwise; |
| 73 | |
| 74 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
| 75 | |
| 76 | int idp, ihp, iwp, ohp, owp; |
| 77 | int nb_ic, ic_block; |
| 78 | int nb_oc, oc_block; |
| 79 | int nb_ow, ow_block; |
| 80 | int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking |
| 81 | into account vector registers distribution */ |
| 82 | int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work |
| 83 | within threads */ |
| 84 | int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work |
| 85 | int nb_ic_L2; |
| 86 | int h_blocking; |
| 87 | int nb_oc_L2; |
| 88 | int ur_h, ur_w; |
| 89 | int ur_w_tail; |
| 90 | bool is_1stconv; |
| 91 | int nonblk_group_off; |
| 92 | /* fma avx512_core */ |
| 93 | conv_kernel_kind_t kernel_kind; |
| 94 | /* 4fma */ |
| 95 | int tr_iw; |
| 96 | int tr_src_num_guard_elems; |
| 97 | /* 1st conv: 4fma */ |
| 98 | int tr_ld; |
| 99 | int kh_step; |
| 100 | /* 4vnni */ |
| 101 | int typesize_in; |
| 102 | int typesize_out; |
| 103 | int typesize_bia; |
| 104 | int typesize_acc; |
| 105 | /* avx512_u8s8u8 */ |
| 106 | int ic_nb1, ic_nb2; |
| 107 | int oc_nb1; |
| 108 | int ur_ow_max, ur_ow, ur_ow_tail; |
| 109 | int ur_ow_nsteps; |
| 110 | data_type_t bia_dt; |
| 111 | data_type_t dst_dt; |
| 112 | /* avx512: max possible value is nregs(32) - aux_regs(4) */ |
| 113 | int src_offsets[28]; |
| 114 | int src_count; |
| 115 | bool expl_bcast; |
| 116 | bool large_spatial; |
| 117 | int is_oc_scale; |
| 118 | int max_regs_ur; // maximum accumulation registers |
| 119 | // dw conv |
| 120 | int nb_ch, ch_block, nb_ch_blocking; |
| 121 | bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; |
| 122 | int aligned_threads; |
| 123 | // large spatial |
| 124 | int oh_blk_size; |
| 125 | // s8s8 convolution |
| 126 | bool signed_input; |
| 127 | float wei_adj_scale; |
| 128 | }; |
| 129 | |
| 130 | struct jit_conv_conf_2x3_wino_t { |
| 131 | conv_version_t ver; |
| 132 | |
| 133 | int m; |
| 134 | int r; |
| 135 | int alpha; |
| 136 | int tile_h, tile_w; |
| 137 | |
| 138 | int mb; |
| 139 | int ngroups, ic, oc, oc_without_padding; |
| 140 | int ih, iw, oh, ow; |
| 141 | int l_pad, t_pad; |
| 142 | int r_pad, b_pad; |
| 143 | int kh, kw; |
| 144 | int stride_h, stride_w; |
| 145 | int dilate_h, dilate_w; |
| 146 | |
| 147 | int nb_ic, ic_block; |
| 148 | int nb_oc, oc_block; |
| 149 | |
| 150 | int w_block_size, h_block_size; |
| 151 | |
| 152 | data_type_t bia_dt; |
| 153 | data_type_t dst_dt; |
| 154 | |
| 155 | int is_oc_scale; |
| 156 | int typesize_in; |
| 157 | int typesize_out; |
| 158 | int typesize_bia; |
| 159 | int typesize_acc; |
| 160 | |
| 161 | format_tag_t src_tag, dst_tag; // temporary workaround |
| 162 | bool with_bias; |
| 163 | bool small_mb; |
| 164 | |
| 165 | int xb, yb; |
| 166 | int inp_stride; |
| 167 | int out_stride; |
| 168 | int wei_stride; |
| 169 | int bia_stride; |
| 170 | |
| 171 | int M, N, K; |
| 172 | int m_block, n_block, k_block; |
| 173 | int n2_block, n_chunks; |
| 174 | int k2_block, k_chunks; |
| 175 | |
| 176 | int mb_block, nb_mb; |
| 177 | |
| 178 | size_t size_wino_src, size_wino_wei, size_wino_dst; |
| 179 | |
| 180 | int nthr; |
| 181 | }; |
| 182 | |
| 183 | /* |
| 184 | Winograd sched policy: |
| 185 | |
| 186 | Computation Unit: |
| 187 | W: weights transform |
| 188 | S: src transform |
| 189 | D: dst transform |
| 190 | G: gemm |
| 191 | |
| 192 | Thread grouping by: |
| 193 | i: nb_ic |
| 194 | o: nb_oc |
| 195 | t: tile_block |
| 196 | e: element in tile |
| 197 | |
| 198 | Note: 'i' and 'o' are omited if |
| 199 | i. not comblined with t or |
| 200 | ii. with discrete transforms |
| 201 | |
| 202 | Current policies supported: |
| 203 | */ |
| 204 | enum winograd_sched_t { |
| 205 | WSCHED_INVALID = 0, |
| 206 | |
| 207 | /* Forward & backward-data */ |
| 208 | /* W_S_G_D implements discrete transforms */ |
| 209 | WSCHED_DATA_W_S_G_D, |
| 210 | /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ |
| 211 | WSCHED_DATA_W_SGD, |
| 212 | |
| 213 | /* Backward-weights */ |
| 214 | WSCHED_WEI_S_D_G_W, |
| 215 | WSCHED_WEI_SDGtWo, |
| 216 | WSCHED_WEI_S_D_Giot_W, |
| 217 | WSCHED_WEI_SDGt_W, |
| 218 | }; |
| 219 | |
| 220 | struct jit_conv_winograd_conf_t : public jit_conv_conf_t { |
| 221 | int itiles; |
| 222 | int jtiles; |
| 223 | int ntiles; |
| 224 | int ic_simd_block=16; |
| 225 | int tile_4fma_padding; |
| 226 | int tile_4fma; |
| 227 | int oc_simd_block=16; |
| 228 | int oc_reg_block; |
| 229 | int ic_reg_block; |
| 230 | int tile_block; |
| 231 | int tile_block_ur; |
| 232 | int nb_tile_block_ur; |
| 233 | |
| 234 | bool double_buffering; |
| 235 | bool with_relu_postsum; |
| 236 | int zmm_start; |
| 237 | int nb_reg; |
| 238 | |
| 239 | int dimK; |
| 240 | int dimK_4fma; |
| 241 | int dimK_reg_block; |
| 242 | int dimK_block; |
| 243 | int dimK_nb_block; |
| 244 | |
| 245 | int dimM; |
| 246 | int dimM_reg_block; |
| 247 | int dimM_simd_block; |
| 248 | int dimM_block; |
| 249 | int dimM_nb_block; |
| 250 | |
| 251 | int dimN; |
| 252 | int dimN_reg_block; |
| 253 | int dimN_bcast_ur; |
| 254 | int dimN_block; |
| 255 | int dimN_nb_block; |
| 256 | |
| 257 | winograd_sched_t sched_policy; |
| 258 | }; |
| 259 | |
| 260 | struct jit_conv_call_s { |
| 261 | const void *src; /* hack, non-const for backward_data */ |
| 262 | const void *dst; /* hack, non-const for forward */ |
| 263 | const void *filt; /* hack, non-const for backward_weights */ |
| 264 | const void *bias; /* hack, non-const for backward_bias */ |
| 265 | const void *src_prf; |
| 266 | const void *dst_prf; |
| 267 | const void *filt_prf; |
| 268 | const void *bias_prf; |
| 269 | const void *scales; |
| 270 | const void *acc_s32; |
| 271 | const void *compensation; |
| 272 | size_t kd_offset; |
| 273 | size_t kd_offset_prf; |
| 274 | size_t d_index; |
| 275 | size_t d_index_prf; |
| 276 | size_t d_worksize; |
| 277 | size_t d_worksize_prf; |
| 278 | size_t kd_padding; |
| 279 | size_t kd_padding_prf; |
| 280 | size_t kh_padding; |
| 281 | size_t kh_padding_prf; |
| 282 | size_t owb; |
| 283 | size_t owb_prf; |
| 284 | size_t kw_padding; |
| 285 | size_t channel; |
| 286 | size_t channel_prf; |
| 287 | size_t oc_blocks; |
| 288 | size_t ur_w; |
| 289 | size_t ur_str_w; |
| 290 | size_t ch_blocks; |
| 291 | size_t t_overflow; |
| 292 | size_t b_overflow; |
| 293 | int flags; |
| 294 | }; |
| 295 | |
| 296 | struct jit_deconv_call_s { |
| 297 | const void *src; /* hack, non-const for backward_data */ |
| 298 | const void *dst; /* hack, non-const for forward */ |
| 299 | const void *filt; /* hack, non-const for backward_weights */ |
| 300 | const void *bias; /* hack, non-const for backward_bias */ |
| 301 | const void *scales; |
| 302 | const void *compensation; |
| 303 | size_t t_overflow; |
| 304 | size_t b_overflow; |
| 305 | size_t kh_padding; |
| 306 | size_t oc_blocks; |
| 307 | }; |
| 308 | |
| 309 | struct jit_dw_conv_call_s { |
| 310 | const void *input; |
| 311 | const void *output; |
| 312 | const void *filter; |
| 313 | const void *bias; |
| 314 | size_t kh_count; |
| 315 | size_t oh_count; |
| 316 | size_t oh_index; |
| 317 | size_t filter_pad_off; |
| 318 | unsigned char |
| 319 | exec_flags; /* Flags passed by driver execution to inner kernel */ |
| 320 | }; |
| 321 | |
| 322 | struct jit_wino_transform_call_s { |
| 323 | size_t tile_block; |
| 324 | size_t tile_block_ur; |
| 325 | size_t nb_tile_block_ur; |
| 326 | size_t tile_count; |
| 327 | size_t tj; |
| 328 | size_t ti; |
| 329 | void *src; |
| 330 | void *dst; |
| 331 | void *Mw; |
| 332 | void *M; |
| 333 | void *T; |
| 334 | void *G; |
| 335 | void *bias; |
| 336 | }; |
| 337 | |
| 338 | struct jit_1x1_conv_conf_t { |
| 339 | prop_kind_t prop_kind; |
| 340 | conv_version_t ver; |
| 341 | |
| 342 | int mb; |
| 343 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
| 344 | int iw, ih, ow, oh; |
| 345 | int l_pad, t_pad; |
| 346 | int kh, kw; |
| 347 | int stride_h, stride_w; |
| 348 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
| 349 | bool with_bias; |
| 350 | bool with_sum; |
| 351 | bool with_eltwise; |
| 352 | |
| 353 | post_ops_t::entry_t::eltwise_t eltwise; |
| 354 | |
| 355 | int is, os; |
| 356 | int ic_block, oc_block; |
| 357 | |
| 358 | int ur, ur_tail; |
| 359 | |
| 360 | int reduce_dim, reduce_block, nb_reduce, |
| 361 | nb_reduce_blocking, nb_reduce_blocking_max; |
| 362 | int load_dim, load_block, nb_load, |
| 363 | nb_load_blocking, nb_load_blocking_max, nb_load_chunk; |
| 364 | int bcast_dim, bcast_block, nb_bcast, |
| 365 | nb_bcast_blocking, nb_bcast_blocking_max; |
| 366 | |
| 367 | int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step; |
| 368 | int load_loop_load_step, load_loop_iter_step; |
| 369 | int bcast_loop_output_step, bcast_loop_output_substep; |
| 370 | int bcast_loop_bcast_step, bcast_loop_bcast_substep; |
| 371 | int fma_step; |
| 372 | int load_grp_count; |
| 373 | conv_1x1_loop_order_t loop_order; |
| 374 | bool use_vmovntps; |
| 375 | /* avx512 core */ |
| 376 | bool expl_bcast; |
| 377 | /* 4vnni */ |
| 378 | int typesize_in; |
| 379 | int typesize_out; |
| 380 | int typesize_bia; |
| 381 | int typesize_acc; |
| 382 | /* 4fma */ |
| 383 | bool transpose_src; |
| 384 | int tr_is; |
| 385 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
| 386 | int is_oc_scale; |
| 387 | data_type_t bia_dt; |
| 388 | data_type_t dst_dt; |
| 389 | bool signed_input; |
| 390 | float wei_adj_scale; |
| 391 | }; |
| 392 | |
| 393 | struct jit_gemm_conv_conf_t { |
| 394 | prop_kind_t prop_kind; |
| 395 | |
| 396 | int mb; |
| 397 | int ngroups, ic, oc; |
| 398 | int iw, ih, id, ow, oh, od; |
| 399 | int l_pad, t_pad, f_pad; |
| 400 | int kh, kw, kd; |
| 401 | int stride_h, stride_w, stride_d; |
| 402 | int dilate_h, dilate_w, dilate_d; |
| 403 | bool with_bias; |
| 404 | |
| 405 | int is, os, ks; |
| 406 | int ic_block, oc_block; |
| 407 | |
| 408 | int nthr; |
| 409 | ptrdiff_t im2col_sz; |
| 410 | bool need_wei_reduction; |
| 411 | bool signed_input; |
| 412 | int oh_block; |
| 413 | int ow_block; |
| 414 | bool outer_threading; |
| 415 | }; |
| 416 | |
| 417 | struct jit_1x1_conv_call_s { |
| 418 | const void *bcast_data; |
| 419 | const void *load_data; |
| 420 | const void *output_data; |
| 421 | const void *bias_data; // used in forward and backward_weights only |
| 422 | const void *acc_s32; |
| 423 | const void *scales; |
| 424 | const void *compensation; |
| 425 | |
| 426 | size_t load_dim; |
| 427 | size_t bcast_dim; |
| 428 | size_t reduce_dim; |
| 429 | |
| 430 | size_t output_stride; // used in backward_weights only |
| 431 | |
| 432 | size_t first_last_flag; |
| 433 | }; |
| 434 | |
| 435 | /* pooling */ |
| 436 | struct jit_pool_conf_t { |
| 437 | int ndims; |
| 438 | int mb, c; |
| 439 | int id, ih, iw, od, oh, ow; |
| 440 | int stride_d, stride_h, stride_w; |
| 441 | int kd, kh, kw; |
| 442 | int f_pad, t_pad, l_pad; |
| 443 | alg_kind_t alg; |
| 444 | bool is_training; |
| 445 | bool pad_w_is_null; |
| 446 | bool is_backward; |
| 447 | bool simple_alg; |
| 448 | data_type_t ind_dt; |
| 449 | |
| 450 | int c_block, c_tail, nb_c; |
| 451 | int ur_c, ur_c_tail; |
| 452 | int ur_w; |
| 453 | int ur_w_tail; |
| 454 | size_t tail[4]; |
| 455 | data_type_t src_dt; |
| 456 | data_type_t dst_dt; |
| 457 | }; |
| 458 | |
| 459 | struct jit_pool_call_s { |
| 460 | const float *src; |
| 461 | const float *dst; |
| 462 | const void *indices; |
| 463 | const float *src_prf; |
| 464 | const float *dst_prf; |
| 465 | const void *indices_prf; |
| 466 | size_t oh; |
| 467 | size_t kd_padding; |
| 468 | size_t kh_padding; |
| 469 | size_t kh_padding_shift; |
| 470 | size_t kd_padding_shift; |
| 471 | size_t kw_padding; |
| 472 | const float* init_value; |
| 473 | float ker_area_h; |
| 474 | }; |
| 475 | |
| 476 | |
| 477 | } |
| 478 | } |
| 479 | } |
| 480 | |
| 481 | #endif |
| 482 | |