1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef JIT_PRIMITIVE_CONF_HPP
18#define JIT_PRIMITIVE_CONF_HPP
19
20#include <stdint.h>
21
22#include "common/primitive_attr.hpp"
23
24namespace mkldnn {
25namespace impl {
26namespace cpu {
27
28/* convolution */
29enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni};
30enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn,
31 loop_ngcw, loop_nhwcg, loop_nwcg};
32enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr,
33 loop_brl};
34enum conv_kernel_kind_t {embd_bcast, expl_bcast};
35
36enum {
37 FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1,
38 FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3,
39 FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5,
40 FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7,
41 FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9,
42 FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
43 loading weights-data from memory; this
44 needs to happen on the first Group/16
45 iteration. */
46 FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
47 loading bias data from memory */
48 FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
49 pass */
50};
51
52struct jit_conv_conf_t {
53 prop_kind_t prop_kind;
54 conv_version_t ver;
55 conv_loop_order_t loop_order;
56
57 int simd_w;
58 int ndims;
59 int mb;
60 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
61 int id, ih, iw, od, oh, ow;
62 int f_pad, l_pad, t_pad;
63 int back_pad, r_pad, b_pad;
64 int kd, kh, kw;
65 int stride_d, stride_h, stride_w;
66 int dilate_d, dilate_h, dilate_w;
67 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
68 bool with_bias;
69 bool with_sum;
70 bool with_eltwise;
71
72 post_ops_t::entry_t::eltwise_t eltwise;
73
74 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
75
76 int idp, ihp, iwp, ohp, owp;
77 int nb_ic, ic_block;
78 int nb_oc, oc_block;
79 int nb_ow, ow_block;
80 int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking
81 into account vector registers distribution */
82 int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work
83 within threads */
84 int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
85 int nb_ic_L2;
86 int h_blocking;
87 int nb_oc_L2;
88 int ur_h, ur_w;
89 int ur_w_tail;
90 bool is_1stconv;
91 int nonblk_group_off;
92 /* fma avx512_core */
93 conv_kernel_kind_t kernel_kind;
94 /* 4fma */
95 int tr_iw;
96 int tr_src_num_guard_elems;
97 /* 1st conv: 4fma */
98 int tr_ld;
99 int kh_step;
100 /* 4vnni */
101 int typesize_in;
102 int typesize_out;
103 int typesize_bia;
104 int typesize_acc;
105 /* avx512_u8s8u8 */
106 int ic_nb1, ic_nb2;
107 int oc_nb1;
108 int ur_ow_max, ur_ow, ur_ow_tail;
109 int ur_ow_nsteps;
110 data_type_t bia_dt;
111 data_type_t dst_dt;
112 /* avx512: max possible value is nregs(32) - aux_regs(4) */
113 int src_offsets[28];
114 int src_count;
115 bool expl_bcast;
116 bool large_spatial;
117 int is_oc_scale;
118 int max_regs_ur; // maximum accumulation registers
119 // dw conv
120 int nb_ch, ch_block, nb_ch_blocking;
121 bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
122 int aligned_threads;
123 // large spatial
124 int oh_blk_size;
125 // s8s8 convolution
126 bool signed_input;
127 float wei_adj_scale;
128};
129
130struct jit_conv_conf_2x3_wino_t {
131 conv_version_t ver;
132
133 int m;
134 int r;
135 int alpha;
136 int tile_h, tile_w;
137
138 int mb;
139 int ngroups, ic, oc, oc_without_padding;
140 int ih, iw, oh, ow;
141 int l_pad, t_pad;
142 int r_pad, b_pad;
143 int kh, kw;
144 int stride_h, stride_w;
145 int dilate_h, dilate_w;
146
147 int nb_ic, ic_block;
148 int nb_oc, oc_block;
149
150 int w_block_size, h_block_size;
151
152 data_type_t bia_dt;
153 data_type_t dst_dt;
154
155 int is_oc_scale;
156 int typesize_in;
157 int typesize_out;
158 int typesize_bia;
159 int typesize_acc;
160
161 format_tag_t src_tag, dst_tag; // temporary workaround
162 bool with_bias;
163 bool small_mb;
164
165 int xb, yb;
166 int inp_stride;
167 int out_stride;
168 int wei_stride;
169 int bia_stride;
170
171 int M, N, K;
172 int m_block, n_block, k_block;
173 int n2_block, n_chunks;
174 int k2_block, k_chunks;
175
176 int mb_block, nb_mb;
177
178 size_t size_wino_src, size_wino_wei, size_wino_dst;
179
180 int nthr;
181};
182
183/*
184 Winograd sched policy:
185
186 Computation Unit:
187 W: weights transform
188 S: src transform
189 D: dst transform
190 G: gemm
191
192 Thread grouping by:
193 i: nb_ic
194 o: nb_oc
195 t: tile_block
196 e: element in tile
197
198 Note: 'i' and 'o' are omited if
199 i. not comblined with t or
200 ii. with discrete transforms
201
202 Current policies supported:
203*/
204enum winograd_sched_t {
205 WSCHED_INVALID = 0,
206
207 /* Forward & backward-data */
208 /* W_S_G_D implements discrete transforms */
209 WSCHED_DATA_W_S_G_D,
210 /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
211 WSCHED_DATA_W_SGD,
212
213 /* Backward-weights */
214 WSCHED_WEI_S_D_G_W,
215 WSCHED_WEI_SDGtWo,
216 WSCHED_WEI_S_D_Giot_W,
217 WSCHED_WEI_SDGt_W,
218};
219
220struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
221 int itiles;
222 int jtiles;
223 int ntiles;
224 int ic_simd_block=16;
225 int tile_4fma_padding;
226 int tile_4fma;
227 int oc_simd_block=16;
228 int oc_reg_block;
229 int ic_reg_block;
230 int tile_block;
231 int tile_block_ur;
232 int nb_tile_block_ur;
233
234 bool double_buffering;
235 bool with_relu_postsum;
236 int zmm_start;
237 int nb_reg;
238
239 int dimK;
240 int dimK_4fma;
241 int dimK_reg_block;
242 int dimK_block;
243 int dimK_nb_block;
244
245 int dimM;
246 int dimM_reg_block;
247 int dimM_simd_block;
248 int dimM_block;
249 int dimM_nb_block;
250
251 int dimN;
252 int dimN_reg_block;
253 int dimN_bcast_ur;
254 int dimN_block;
255 int dimN_nb_block;
256
257 winograd_sched_t sched_policy;
258};
259
260struct jit_conv_call_s {
261 const void *src; /* hack, non-const for backward_data */
262 const void *dst; /* hack, non-const for forward */
263 const void *filt; /* hack, non-const for backward_weights */
264 const void *bias; /* hack, non-const for backward_bias */
265 const void *src_prf;
266 const void *dst_prf;
267 const void *filt_prf;
268 const void *bias_prf;
269 const void *scales;
270 const void *acc_s32;
271 const void *compensation;
272 size_t kd_offset;
273 size_t kd_offset_prf;
274 size_t d_index;
275 size_t d_index_prf;
276 size_t d_worksize;
277 size_t d_worksize_prf;
278 size_t kd_padding;
279 size_t kd_padding_prf;
280 size_t kh_padding;
281 size_t kh_padding_prf;
282 size_t owb;
283 size_t owb_prf;
284 size_t kw_padding;
285 size_t channel;
286 size_t channel_prf;
287 size_t oc_blocks;
288 size_t ur_w;
289 size_t ur_str_w;
290 size_t ch_blocks;
291 size_t t_overflow;
292 size_t b_overflow;
293 int flags;
294};
295
296struct jit_deconv_call_s {
297 const void *src; /* hack, non-const for backward_data */
298 const void *dst; /* hack, non-const for forward */
299 const void *filt; /* hack, non-const for backward_weights */
300 const void *bias; /* hack, non-const for backward_bias */
301 const void *scales;
302 const void *compensation;
303 size_t t_overflow;
304 size_t b_overflow;
305 size_t kh_padding;
306 size_t oc_blocks;
307};
308
309struct jit_dw_conv_call_s {
310 const void *input;
311 const void *output;
312 const void *filter;
313 const void *bias;
314 size_t kh_count;
315 size_t oh_count;
316 size_t oh_index;
317 size_t filter_pad_off;
318 unsigned char
319 exec_flags; /* Flags passed by driver execution to inner kernel */
320};
321
322struct jit_wino_transform_call_s {
323 size_t tile_block;
324 size_t tile_block_ur;
325 size_t nb_tile_block_ur;
326 size_t tile_count;
327 size_t tj;
328 size_t ti;
329 void *src;
330 void *dst;
331 void *Mw;
332 void *M;
333 void *T;
334 void *G;
335 void *bias;
336};
337
338struct jit_1x1_conv_conf_t {
339 prop_kind_t prop_kind;
340 conv_version_t ver;
341
342 int mb;
343 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
344 int iw, ih, ow, oh;
345 int l_pad, t_pad;
346 int kh, kw;
347 int stride_h, stride_w;
348 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
349 bool with_bias;
350 bool with_sum;
351 bool with_eltwise;
352
353 post_ops_t::entry_t::eltwise_t eltwise;
354
355 int is, os;
356 int ic_block, oc_block;
357
358 int ur, ur_tail;
359
360 int reduce_dim, reduce_block, nb_reduce,
361 nb_reduce_blocking, nb_reduce_blocking_max;
362 int load_dim, load_block, nb_load,
363 nb_load_blocking, nb_load_blocking_max, nb_load_chunk;
364 int bcast_dim, bcast_block, nb_bcast,
365 nb_bcast_blocking, nb_bcast_blocking_max;
366
367 int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step;
368 int load_loop_load_step, load_loop_iter_step;
369 int bcast_loop_output_step, bcast_loop_output_substep;
370 int bcast_loop_bcast_step, bcast_loop_bcast_substep;
371 int fma_step;
372 int load_grp_count;
373 conv_1x1_loop_order_t loop_order;
374 bool use_vmovntps;
375 /* avx512 core */
376 bool expl_bcast;
377 /* 4vnni */
378 int typesize_in;
379 int typesize_out;
380 int typesize_bia;
381 int typesize_acc;
382 /* 4fma */
383 bool transpose_src;
384 int tr_is;
385 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
386 int is_oc_scale;
387 data_type_t bia_dt;
388 data_type_t dst_dt;
389 bool signed_input;
390 float wei_adj_scale;
391};
392
393struct jit_gemm_conv_conf_t {
394 prop_kind_t prop_kind;
395
396 int mb;
397 int ngroups, ic, oc;
398 int iw, ih, id, ow, oh, od;
399 int l_pad, t_pad, f_pad;
400 int kh, kw, kd;
401 int stride_h, stride_w, stride_d;
402 int dilate_h, dilate_w, dilate_d;
403 bool with_bias;
404
405 int is, os, ks;
406 int ic_block, oc_block;
407
408 int nthr;
409 ptrdiff_t im2col_sz;
410 bool need_wei_reduction;
411 bool signed_input;
412 int oh_block;
413 int ow_block;
414 bool outer_threading;
415};
416
417struct jit_1x1_conv_call_s {
418 const void *bcast_data;
419 const void *load_data;
420 const void *output_data;
421 const void *bias_data; // used in forward and backward_weights only
422 const void *acc_s32;
423 const void *scales;
424 const void *compensation;
425
426 size_t load_dim;
427 size_t bcast_dim;
428 size_t reduce_dim;
429
430 size_t output_stride; // used in backward_weights only
431
432 size_t first_last_flag;
433};
434
435/* pooling */
436struct jit_pool_conf_t {
437 int ndims;
438 int mb, c;
439 int id, ih, iw, od, oh, ow;
440 int stride_d, stride_h, stride_w;
441 int kd, kh, kw;
442 int f_pad, t_pad, l_pad;
443 alg_kind_t alg;
444 bool is_training;
445 bool pad_w_is_null;
446 bool is_backward;
447 bool simple_alg;
448 data_type_t ind_dt;
449
450 int c_block, c_tail, nb_c;
451 int ur_c, ur_c_tail;
452 int ur_w;
453 int ur_w_tail;
454 size_t tail[4];
455 data_type_t src_dt;
456 data_type_t dst_dt;
457};
458
459struct jit_pool_call_s {
460 const float *src;
461 const float *dst;
462 const void *indices;
463 const float *src_prf;
464 const float *dst_prf;
465 const void *indices_prf;
466 size_t oh;
467 size_t kd_padding;
468 size_t kh_padding;
469 size_t kh_padding_shift;
470 size_t kd_padding_shift;
471 size_t kw_padding;
472 const float* init_value;
473 float ker_area_h;
474};
475
476
477}
478}
479}
480
481#endif
482