1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
18#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
19
20#include "c_types_map.hpp"
21#include "memory_tracking.hpp"
22
23#include "jit_generator.hpp"
24#include "jit_primitive_conf.hpp"
25#include "jit_uni_eltwise.hpp"
26
27namespace mkldnn {
28namespace impl {
29namespace cpu {
30
31template<typename Vmm>
32struct _jit_avx512_common_conv_fwd_kernel : public jit_generator {
33
34 _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
35 const primitive_attr_t &attr)
36 : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
37 {
38 if (jcp.with_eltwise)
39 eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
40 this, jcp.eltwise);
41
42 generate();
43 jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
44 }
45
46 ~_jit_avx512_common_conv_fwd_kernel() {
47 delete eltwise_injector_;
48 }
49
50 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel)
51
52 jit_conv_conf_t jcp;
53 const primitive_attr_t &attr_;
54 void (*jit_ker_)(jit_conv_call_s *);
55
56private:
57 using reg64_t = const Xbyak::Reg64;
58 enum {
59 typesize = sizeof(float),
60 ker_reg_base_idx = 28,
61 };
62
63 reg64_t param = abi_param1;
64 reg64_t reg_inp = r8;
65 reg64_t reg_ker = r9;
66 reg64_t reg_out = r10;
67
68 reg64_t reg_inp_prf = r11;
69 reg64_t reg_ker_prf = r12;
70 reg64_t reg_out_prf = r13;
71 reg64_t reg_owb = r12;
72
73 reg64_t aux_reg_inp = r14;
74 reg64_t aux_reg_ker = r15;
75
76 reg64_t aux_reg_inp_prf = rsi;
77 reg64_t aux_reg_ker_prf = rdx;
78
79 reg64_t reg_channel = rsi;
80 reg64_t reg_bias = rdx;
81
82 reg64_t aux_reg_ker_d = r9;
83 reg64_t aux_reg_inp_d = rbx;
84 reg64_t aux_reg_inp_d_prf = r13;
85 reg64_t aux_reg_ker_d_prf = abi_not_param1;
86 reg64_t reg_ki = r10;
87
88 reg64_t reg_kj = rax;
89 reg64_t reg_relu_ns = rax;
90 reg64_t reg_oi = rbx;
91 reg64_t reg_kh = abi_not_param1;
92
93 reg64_t reg_tmp = rbp;
94
95 reg64_t reg_ic_loop = rdx;
96 reg64_t reg_inp_loop = rsi;
97
98 reg64_t reg_init_flag = r13;
99 reg64_t reg_bias_ptr = param;
100
101 reg64_t aux_reg_ic = r12;
102 reg64_t reg_binp = rax;
103 reg64_t reg_bout = r11;
104 reg64_t aux1_reg_inp = rbx;
105 reg64_t aux_reg_out = abi_not_param1;
106
107 reg64_t reg_long_offt = r11;
108 reg64_t reg_out_long_offt = r14;
109
110 inline Vmm vmm_ker(int i_ic) {
111 assert(i_ic < 4);
112 return Vmm(ker_reg_base_idx + i_ic);
113 }
114
115 inline Vmm vmm_out(int i_ur, int i_oc) {
116 int idx = i_ur + i_oc * jcp.ur_w;
117 assert(idx < ker_reg_base_idx);
118 return Vmm(idx);
119 }
120
121 inline Vmm vmm_inp(int i_ic, int nb_x_blocking) {
122 int idx = i_ic + nb_x_blocking * jcp.ur_w;
123 assert(idx < 31);
124 return Vmm(idx);
125 }
126
127 Xbyak::Reg64 imm_addr64 = r15;
128 Vmm vmm_wei = Vmm(31);
129
130 jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
131
132 inline void prepare_output(int ur_w);
133 inline void store_output(int ur_w);
134 inline void compute_loop_fma(int ur_w, int pad_l, int pad_r);
135 inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r);
136 inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r);
137 inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r);
138 inline void compute_loop(int ur_w, int pad_l, int pad_r);
139
140 void generate();
141
142 inline size_t get_output_offset(int oi, int n_oc_block) {
143 return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh
144 * jcp.ow * jcp.od + oi) * jcp.oc_block;
145 }
146
147 inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) {
148 size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1;
149 size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id;
150 return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1)
151 + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str);
152 }
153
154 inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) {
155 return jcp.typesize_in * jcp.oc_block
156 * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd
157 + (ic + ker_number) + ki * jcp.ic_block);
158 }
159
160 inline int get_ow_start(int ki, int pad_l) {
161 return nstl::max(0,
162 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
163 }
164
165 inline int get_ow_end(int ur_w, int ki, int pad_r) {
166 return ur_w - nstl::max(0, utils::div_up(pad_r
167 - (jcp.kw - 1 - ki)
168 * (jcp.dilate_w + 1),
169 jcp.stride_w));
170 }
171};
172
173struct jit_avx512_common_conv_fwd_kernel {
174
175 jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
176 const primitive_attr_t &attr) :
177 jit_ker(nullptr),
178 zmm_kernel_(nullptr),
179 xmm_kernel_(nullptr) {
180 int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block;
181 switch (ch_block) {
182 case 16:
183 zmm_kernel_ =
184 new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>(
185 ajcp, attr);
186 jit_ker = zmm_kernel_->jit_ker_;
187 return;
188 case 4:
189 xmm_kernel_ =
190 new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>(
191 ajcp, attr);
192 jit_ker = xmm_kernel_->jit_ker_;
193 return;
194 default:
195 assert(!"invalid channel blocking");
196 }
197 }
198
199 ~jit_avx512_common_conv_fwd_kernel() {
200 delete xmm_kernel_;
201 delete zmm_kernel_;
202 }
203
204 enum {
205 typesize = sizeof(float)
206 };
207
208 static bool post_ops_ok(jit_conv_conf_t &jcp,
209 const primitive_attr_t &attr);
210 static status_t init_conf(jit_conv_conf_t &jcp,
211 const convolution_desc_t &cd,
212 memory_desc_t &src_pd,
213 memory_desc_t &weights_pd,
214 memory_desc_t &dst_pd,
215 memory_desc_t &bias_pd,
216 const primitive_attr_t &attr,
217 int nthreads);
218 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
219 const jit_conv_conf_t &jcp);
220
221 void(*jit_ker)(jit_conv_call_s *);
222 _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
223 _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
224};
225
226struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator {
227
228 jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
229 {
230 generate();
231 jit_ker = (void (*)(jit_conv_call_s *))getCode();
232 }
233
234 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32)
235
236 static status_t init_conf(jit_conv_conf_t &jcp,
237 const convolution_desc_t &cd,
238 const memory_desc_wrapper &diff_src_d,
239 const memory_desc_wrapper &weights_d,
240 const memory_desc_wrapper &diff_dst_d);
241 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
242 const jit_conv_conf_t &jcp);
243
244 jit_conv_conf_t jcp;
245 void (*jit_ker)(jit_conv_call_s *);
246
247private:
248 using reg64_t = const Xbyak::Reg64;
249 enum {
250 typesize = sizeof(float),
251 ker_reg_base_idx = 28,
252 };
253
254 reg64_t param = abi_param1;
255 reg64_t reg_dst = r8;
256 reg64_t reg_ker = r9;
257 reg64_t reg_src = r10;
258
259 reg64_t reg_dst_prf = r11;
260 reg64_t reg_ker_prf = r12;
261 reg64_t reg_src_prf = r13;
262
263 reg64_t aux_reg_dst = r14;
264 reg64_t aux_reg_ker = r15;
265
266 reg64_t aux_reg_dst_prf = rsi;
267 reg64_t aux_reg_ker_prf = rdx;
268
269 reg64_t aux_reg_dst_d_prf = r13;
270 reg64_t aux_reg_dst_d = rbx;
271 reg64_t aux_reg_ker_d_prf = abi_not_param1;
272 reg64_t aux_reg_ker_d = r9;
273 reg64_t reg_ki = r10;
274
275 reg64_t reg_kj = rax;
276 reg64_t reg_oi = rbx;
277 reg64_t reg_kh = abi_not_param1;
278
279 reg64_t reg_channel = rsi;
280
281 reg64_t reg_tmp = rbp;
282 reg64_t reg_long_offt = r14;
283
284 inline Xbyak::Zmm zmm_ker(int i_ic) {
285 assert(i_ic < 4);
286 return Xbyak::Zmm(ker_reg_base_idx + i_ic);
287 }
288 inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
289 int idx = i_ic + nb_x_blocking * jcp.ur_w;
290 assert(idx < 31);
291 return Xbyak::Zmm(idx);
292 }
293 inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
294 int idx = i_ur + i_oc * jcp.ur_w;
295 assert(idx < ker_reg_base_idx);
296 return Xbyak::Zmm(idx);
297 }
298
299 Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
300
301 inline void prepare_output(int ur_w);
302 inline void store_output(int ur_w);
303 inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow);
304 inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow);
305 inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow);
306 inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
307 void generate();
308
309 inline int get_iw_start(int ki, int l_overflow)
310 {
311 int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
312 + l_overflow * jcp.stride_w
313 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
314 while (res < 0)
315 res += jcp.stride_w;
316
317 return res;
318 }
319
320 inline int get_iw_end(int ur_w, int ki, int r_overflow)
321 {
322 if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
323 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
324 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
325 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
326 while (res < 0)
327 res += jcp.stride_w;
328
329 return ur_w - res;
330 }
331};
332
333struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator {
334
335 jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp)
336 : jcp(ajcp)
337 {
338 generate();
339 jit_ker = (void (*)(jit_conv_call_s *))getCode();
340 }
341
342 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32)
343
344 static status_t init_conf(jit_conv_conf_t &jcp,
345 const convolution_desc_t &cd,
346 memory_desc_t &src_md,
347 memory_desc_t &diff_weights_md,
348 memory_desc_t &diff_bias_md,
349 memory_desc_t &diff_dst_md);
350 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
351 const jit_conv_conf_t &jcp);
352
353 jit_conv_conf_t jcp;
354 void (*jit_ker)(jit_conv_call_s *);
355
356private:
357 using reg64_t = const Xbyak::Reg64;
358 enum {typesize = sizeof(float)};
359 static const int max_ur_w;
360
361 reg64_t param = abi_param1;
362 reg64_t reg_input = rax;
363 reg64_t reg_kernel = rdx;
364 reg64_t reg_output = rsi;
365 reg64_t b_ic = abi_not_param1;
366 reg64_t kj = r8;
367 reg64_t reg_kh = r9;
368 reg64_t reg_ur_w_trips = r10;
369 reg64_t reg_oj = r15;
370 reg64_t reg_ih_count = rbx;
371 reg64_t reg_tmp = r14;
372 reg64_t reg_long_offt = r14;
373
374 reg64_t ki = r11;
375 reg64_t reg_kd_count = r12;
376 reg64_t reg_oi = r12;
377 reg64_t reg_d_index = r13;
378 reg64_t reg_input_d = r15;
379 reg64_t reg_output_d = rbx;
380 reg64_t aux_reg_input = r12;
381 reg64_t aux_reg_kernel = r13;
382 reg64_t reg_bias = rbx;
383
384 inline void bias_kernel();
385 inline void maybe_zero_kernel();
386 inline void compute_oh_step_unroll_ow_icblock(int ic_block_step,
387 int max_ur_w);
388 inline void od_step_comeback_pointers();
389 inline void oh_step_comeback_pointers();
390 inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
391 inline void compute_ic_block_step(int ur_w,
392 int pad_l, int pad_r, int ic_block_step,
393 int input_offset, int kernel_offset, int output_offset,
394 bool input_wraparound = false);
395 inline void compute_ic_block_step_fma(int ur_w,
396 int pad_l, int pad_r, int ic_block_step,
397 int input_offset, int kernel_offset, int output_offset,
398 bool input_wraparound);
399 inline void compute_ic_block_step_4fma(int ur_w,
400 int pad_l, int pad_r, int ic_block_step,
401 int input_offset, int kernel_offset, int output_offset,
402 bool input_wraparound);
403 inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
404 inline void compute_oh_step_disp();
405 inline void compute_oh_loop_common();
406 inline void compute_d_loop_common();
407
408 inline bool compute_full_spat_loop();
409 inline bool flat_4ops_compute();
410
411 inline void compute_loop();
412
413 void generate();
414
415 static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
416 int &nthr_g, int &nthr_oc_b, int &nthr_ic_b);
417};
418
419}
420}
421}
422
423#endif
424