1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef JIT_PRIMITIVE_CONF_HPP |
18 | #define JIT_PRIMITIVE_CONF_HPP |
19 | |
20 | #include <stdint.h> |
21 | |
22 | #include "common/primitive_attr.hpp" |
23 | |
24 | namespace mkldnn { |
25 | namespace impl { |
26 | namespace cpu { |
27 | |
28 | /* convolution */ |
29 | enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni}; |
30 | enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, |
31 | loop_ngcw, loop_nhwcg, loop_nwcg}; |
32 | enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, |
33 | loop_brl}; |
34 | enum conv_kernel_kind_t {embd_bcast, expl_bcast}; |
35 | |
36 | enum { |
37 | FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1, |
38 | FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3, |
39 | FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5, |
40 | FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7, |
41 | FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9, |
42 | FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips |
43 | loading weights-data from memory; this |
44 | needs to happen on the first Group/16 |
45 | iteration. */ |
46 | FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip |
47 | loading bias data from memory */ |
48 | FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution |
49 | pass */ |
50 | }; |
51 | |
52 | struct jit_conv_conf_t { |
53 | prop_kind_t prop_kind; |
54 | conv_version_t ver; |
55 | conv_loop_order_t loop_order; |
56 | |
57 | int simd_w; |
58 | int ndims; |
59 | int mb; |
60 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
61 | int id, ih, iw, od, oh, ow; |
62 | int f_pad, l_pad, t_pad; |
63 | int back_pad, r_pad, b_pad; |
64 | int kd, kh, kw; |
65 | int stride_d, stride_h, stride_w; |
66 | int dilate_d, dilate_h, dilate_w; |
67 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
68 | bool with_bias; |
69 | bool with_sum; |
70 | bool with_eltwise; |
71 | |
72 | post_ops_t::entry_t::eltwise_t eltwise; |
73 | |
74 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
75 | |
76 | int idp, ihp, iwp, ohp, owp; |
77 | int nb_ic, ic_block; |
78 | int nb_oc, oc_block; |
79 | int nb_ow, ow_block; |
80 | int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking |
81 | into account vector registers distribution */ |
82 | int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work |
83 | within threads */ |
84 | int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work |
85 | int nb_ic_L2; |
86 | int h_blocking; |
87 | int nb_oc_L2; |
88 | int ur_h, ur_w; |
89 | int ur_w_tail; |
90 | bool is_1stconv; |
91 | int nonblk_group_off; |
92 | /* fma avx512_core */ |
93 | conv_kernel_kind_t kernel_kind; |
94 | /* 4fma */ |
95 | int tr_iw; |
96 | int tr_src_num_guard_elems; |
97 | /* 1st conv: 4fma */ |
98 | int tr_ld; |
99 | int kh_step; |
100 | /* 4vnni */ |
101 | int typesize_in; |
102 | int typesize_out; |
103 | int typesize_bia; |
104 | int typesize_acc; |
105 | /* avx512_u8s8u8 */ |
106 | int ic_nb1, ic_nb2; |
107 | int oc_nb1; |
108 | int ur_ow_max, ur_ow, ur_ow_tail; |
109 | int ur_ow_nsteps; |
110 | data_type_t bia_dt; |
111 | data_type_t dst_dt; |
112 | /* avx512: max possible value is nregs(32) - aux_regs(4) */ |
113 | int src_offsets[28]; |
114 | int src_count; |
115 | bool expl_bcast; |
116 | bool large_spatial; |
117 | int is_oc_scale; |
118 | int max_regs_ur; // maximum accumulation registers |
119 | // dw conv |
120 | int nb_ch, ch_block, nb_ch_blocking; |
121 | bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; |
122 | int aligned_threads; |
123 | // large spatial |
124 | int oh_blk_size; |
125 | // s8s8 convolution |
126 | bool signed_input; |
127 | float wei_adj_scale; |
128 | }; |
129 | |
130 | struct jit_conv_conf_2x3_wino_t { |
131 | conv_version_t ver; |
132 | |
133 | int m; |
134 | int r; |
135 | int alpha; |
136 | int tile_h, tile_w; |
137 | |
138 | int mb; |
139 | int ngroups, ic, oc, oc_without_padding; |
140 | int ih, iw, oh, ow; |
141 | int l_pad, t_pad; |
142 | int r_pad, b_pad; |
143 | int kh, kw; |
144 | int stride_h, stride_w; |
145 | int dilate_h, dilate_w; |
146 | |
147 | int nb_ic, ic_block; |
148 | int nb_oc, oc_block; |
149 | |
150 | int w_block_size, h_block_size; |
151 | |
152 | data_type_t bia_dt; |
153 | data_type_t dst_dt; |
154 | |
155 | int is_oc_scale; |
156 | int typesize_in; |
157 | int typesize_out; |
158 | int typesize_bia; |
159 | int typesize_acc; |
160 | |
161 | format_tag_t src_tag, dst_tag; // temporary workaround |
162 | bool with_bias; |
163 | bool small_mb; |
164 | |
165 | int xb, yb; |
166 | int inp_stride; |
167 | int out_stride; |
168 | int wei_stride; |
169 | int bia_stride; |
170 | |
171 | int M, N, K; |
172 | int m_block, n_block, k_block; |
173 | int n2_block, n_chunks; |
174 | int k2_block, k_chunks; |
175 | |
176 | int mb_block, nb_mb; |
177 | |
178 | size_t size_wino_src, size_wino_wei, size_wino_dst; |
179 | |
180 | int nthr; |
181 | }; |
182 | |
183 | /* |
184 | Winograd sched policy: |
185 | |
186 | Computation Unit: |
187 | W: weights transform |
188 | S: src transform |
189 | D: dst transform |
190 | G: gemm |
191 | |
192 | Thread grouping by: |
193 | i: nb_ic |
194 | o: nb_oc |
195 | t: tile_block |
196 | e: element in tile |
197 | |
198 | Note: 'i' and 'o' are omited if |
199 | i. not comblined with t or |
200 | ii. with discrete transforms |
201 | |
202 | Current policies supported: |
203 | */ |
204 | enum winograd_sched_t { |
205 | WSCHED_INVALID = 0, |
206 | |
207 | /* Forward & backward-data */ |
208 | /* W_S_G_D implements discrete transforms */ |
209 | WSCHED_DATA_W_S_G_D, |
210 | /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ |
211 | WSCHED_DATA_W_SGD, |
212 | |
213 | /* Backward-weights */ |
214 | WSCHED_WEI_S_D_G_W, |
215 | WSCHED_WEI_SDGtWo, |
216 | WSCHED_WEI_S_D_Giot_W, |
217 | WSCHED_WEI_SDGt_W, |
218 | }; |
219 | |
220 | struct jit_conv_winograd_conf_t : public jit_conv_conf_t { |
221 | int itiles; |
222 | int jtiles; |
223 | int ntiles; |
224 | int ic_simd_block=16; |
225 | int tile_4fma_padding; |
226 | int tile_4fma; |
227 | int oc_simd_block=16; |
228 | int oc_reg_block; |
229 | int ic_reg_block; |
230 | int tile_block; |
231 | int tile_block_ur; |
232 | int nb_tile_block_ur; |
233 | |
234 | bool double_buffering; |
235 | bool with_relu_postsum; |
236 | int zmm_start; |
237 | int nb_reg; |
238 | |
239 | int dimK; |
240 | int dimK_4fma; |
241 | int dimK_reg_block; |
242 | int dimK_block; |
243 | int dimK_nb_block; |
244 | |
245 | int dimM; |
246 | int dimM_reg_block; |
247 | int dimM_simd_block; |
248 | int dimM_block; |
249 | int dimM_nb_block; |
250 | |
251 | int dimN; |
252 | int dimN_reg_block; |
253 | int dimN_bcast_ur; |
254 | int dimN_block; |
255 | int dimN_nb_block; |
256 | |
257 | winograd_sched_t sched_policy; |
258 | }; |
259 | |
260 | struct jit_conv_call_s { |
261 | const void *src; /* hack, non-const for backward_data */ |
262 | const void *dst; /* hack, non-const for forward */ |
263 | const void *filt; /* hack, non-const for backward_weights */ |
264 | const void *bias; /* hack, non-const for backward_bias */ |
265 | const void *src_prf; |
266 | const void *dst_prf; |
267 | const void *filt_prf; |
268 | const void *bias_prf; |
269 | const void *scales; |
270 | const void *acc_s32; |
271 | const void *compensation; |
272 | size_t kd_offset; |
273 | size_t kd_offset_prf; |
274 | size_t d_index; |
275 | size_t d_index_prf; |
276 | size_t d_worksize; |
277 | size_t d_worksize_prf; |
278 | size_t kd_padding; |
279 | size_t kd_padding_prf; |
280 | size_t kh_padding; |
281 | size_t kh_padding_prf; |
282 | size_t owb; |
283 | size_t owb_prf; |
284 | size_t kw_padding; |
285 | size_t channel; |
286 | size_t channel_prf; |
287 | size_t oc_blocks; |
288 | size_t ur_w; |
289 | size_t ur_str_w; |
290 | size_t ch_blocks; |
291 | size_t t_overflow; |
292 | size_t b_overflow; |
293 | int flags; |
294 | }; |
295 | |
296 | struct jit_deconv_call_s { |
297 | const void *src; /* hack, non-const for backward_data */ |
298 | const void *dst; /* hack, non-const for forward */ |
299 | const void *filt; /* hack, non-const for backward_weights */ |
300 | const void *bias; /* hack, non-const for backward_bias */ |
301 | const void *scales; |
302 | const void *compensation; |
303 | size_t t_overflow; |
304 | size_t b_overflow; |
305 | size_t kh_padding; |
306 | size_t oc_blocks; |
307 | }; |
308 | |
309 | struct jit_dw_conv_call_s { |
310 | const void *input; |
311 | const void *output; |
312 | const void *filter; |
313 | const void *bias; |
314 | size_t kh_count; |
315 | size_t oh_count; |
316 | size_t oh_index; |
317 | size_t filter_pad_off; |
318 | unsigned char |
319 | exec_flags; /* Flags passed by driver execution to inner kernel */ |
320 | }; |
321 | |
322 | struct jit_wino_transform_call_s { |
323 | size_t tile_block; |
324 | size_t tile_block_ur; |
325 | size_t nb_tile_block_ur; |
326 | size_t tile_count; |
327 | size_t tj; |
328 | size_t ti; |
329 | void *src; |
330 | void *dst; |
331 | void *Mw; |
332 | void *M; |
333 | void *T; |
334 | void *G; |
335 | void *bias; |
336 | }; |
337 | |
338 | struct jit_1x1_conv_conf_t { |
339 | prop_kind_t prop_kind; |
340 | conv_version_t ver; |
341 | |
342 | int mb; |
343 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
344 | int iw, ih, ow, oh; |
345 | int l_pad, t_pad; |
346 | int kh, kw; |
347 | int stride_h, stride_w; |
348 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
349 | bool with_bias; |
350 | bool with_sum; |
351 | bool with_eltwise; |
352 | |
353 | post_ops_t::entry_t::eltwise_t eltwise; |
354 | |
355 | int is, os; |
356 | int ic_block, oc_block; |
357 | |
358 | int ur, ur_tail; |
359 | |
360 | int reduce_dim, reduce_block, nb_reduce, |
361 | nb_reduce_blocking, nb_reduce_blocking_max; |
362 | int load_dim, load_block, nb_load, |
363 | nb_load_blocking, nb_load_blocking_max, nb_load_chunk; |
364 | int bcast_dim, bcast_block, nb_bcast, |
365 | nb_bcast_blocking, nb_bcast_blocking_max; |
366 | |
367 | int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step; |
368 | int load_loop_load_step, load_loop_iter_step; |
369 | int bcast_loop_output_step, bcast_loop_output_substep; |
370 | int bcast_loop_bcast_step, bcast_loop_bcast_substep; |
371 | int fma_step; |
372 | int load_grp_count; |
373 | conv_1x1_loop_order_t loop_order; |
374 | bool use_vmovntps; |
375 | /* avx512 core */ |
376 | bool expl_bcast; |
377 | /* 4vnni */ |
378 | int typesize_in; |
379 | int typesize_out; |
380 | int typesize_bia; |
381 | int typesize_acc; |
382 | /* 4fma */ |
383 | bool transpose_src; |
384 | int tr_is; |
385 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
386 | int is_oc_scale; |
387 | data_type_t bia_dt; |
388 | data_type_t dst_dt; |
389 | bool signed_input; |
390 | float wei_adj_scale; |
391 | }; |
392 | |
393 | struct jit_gemm_conv_conf_t { |
394 | prop_kind_t prop_kind; |
395 | |
396 | int mb; |
397 | int ngroups, ic, oc; |
398 | int iw, ih, id, ow, oh, od; |
399 | int l_pad, t_pad, f_pad; |
400 | int kh, kw, kd; |
401 | int stride_h, stride_w, stride_d; |
402 | int dilate_h, dilate_w, dilate_d; |
403 | bool with_bias; |
404 | |
405 | int is, os, ks; |
406 | int ic_block, oc_block; |
407 | |
408 | int nthr; |
409 | ptrdiff_t im2col_sz; |
410 | bool need_wei_reduction; |
411 | bool signed_input; |
412 | int oh_block; |
413 | int ow_block; |
414 | bool outer_threading; |
415 | }; |
416 | |
417 | struct jit_1x1_conv_call_s { |
418 | const void *bcast_data; |
419 | const void *load_data; |
420 | const void *output_data; |
421 | const void *bias_data; // used in forward and backward_weights only |
422 | const void *acc_s32; |
423 | const void *scales; |
424 | const void *compensation; |
425 | |
426 | size_t load_dim; |
427 | size_t bcast_dim; |
428 | size_t reduce_dim; |
429 | |
430 | size_t output_stride; // used in backward_weights only |
431 | |
432 | size_t first_last_flag; |
433 | }; |
434 | |
435 | /* pooling */ |
436 | struct jit_pool_conf_t { |
437 | int ndims; |
438 | int mb, c; |
439 | int id, ih, iw, od, oh, ow; |
440 | int stride_d, stride_h, stride_w; |
441 | int kd, kh, kw; |
442 | int f_pad, t_pad, l_pad; |
443 | alg_kind_t alg; |
444 | bool is_training; |
445 | bool pad_w_is_null; |
446 | bool is_backward; |
447 | bool simple_alg; |
448 | data_type_t ind_dt; |
449 | |
450 | int c_block, c_tail, nb_c; |
451 | int ur_c, ur_c_tail; |
452 | int ur_w; |
453 | int ur_w_tail; |
454 | size_t tail[4]; |
455 | data_type_t src_dt; |
456 | data_type_t dst_dt; |
457 | }; |
458 | |
459 | struct jit_pool_call_s { |
460 | const float *src; |
461 | const float *dst; |
462 | const void *indices; |
463 | const float *src_prf; |
464 | const float *dst_prf; |
465 | const void *indices_prf; |
466 | size_t oh; |
467 | size_t kd_padding; |
468 | size_t kh_padding; |
469 | size_t kh_padding_shift; |
470 | size_t kd_padding_shift; |
471 | size_t kw_padding; |
472 | const float* init_value; |
473 | float ker_area_h; |
474 | }; |
475 | |
476 | |
477 | } |
478 | } |
479 | } |
480 | |
481 | #endif |
482 | |