1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP |
18 | #define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP |
19 | |
20 | #include "c_types_map.hpp" |
21 | #include "memory_tracking.hpp" |
22 | |
23 | #include "jit_generator.hpp" |
24 | #include "jit_primitive_conf.hpp" |
25 | #include "jit_uni_eltwise.hpp" |
26 | |
27 | namespace mkldnn { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | template<typename Vmm> |
32 | struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { |
33 | |
34 | _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, |
35 | const primitive_attr_t &attr) |
36 | : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) |
37 | { |
38 | if (jcp.with_eltwise) |
39 | eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>( |
40 | this, jcp.eltwise); |
41 | |
42 | generate(); |
43 | jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); |
44 | } |
45 | |
46 | ~_jit_avx512_common_conv_fwd_kernel() { |
47 | delete eltwise_injector_; |
48 | } |
49 | |
50 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) |
51 | |
52 | jit_conv_conf_t jcp; |
53 | const primitive_attr_t &attr_; |
54 | void (*jit_ker_)(jit_conv_call_s *); |
55 | |
56 | private: |
57 | using reg64_t = const Xbyak::Reg64; |
58 | enum { |
59 | typesize = sizeof(float), |
60 | ker_reg_base_idx = 28, |
61 | }; |
62 | |
63 | reg64_t param = abi_param1; |
64 | reg64_t reg_inp = r8; |
65 | reg64_t reg_ker = r9; |
66 | reg64_t reg_out = r10; |
67 | |
68 | reg64_t reg_inp_prf = r11; |
69 | reg64_t reg_ker_prf = r12; |
70 | reg64_t reg_out_prf = r13; |
71 | reg64_t reg_owb = r12; |
72 | |
73 | reg64_t aux_reg_inp = r14; |
74 | reg64_t aux_reg_ker = r15; |
75 | |
76 | reg64_t aux_reg_inp_prf = rsi; |
77 | reg64_t aux_reg_ker_prf = rdx; |
78 | |
79 | reg64_t reg_channel = rsi; |
80 | reg64_t reg_bias = rdx; |
81 | |
82 | reg64_t aux_reg_ker_d = r9; |
83 | reg64_t aux_reg_inp_d = rbx; |
84 | reg64_t aux_reg_inp_d_prf = r13; |
85 | reg64_t aux_reg_ker_d_prf = abi_not_param1; |
86 | reg64_t reg_ki = r10; |
87 | |
88 | reg64_t reg_kj = rax; |
89 | reg64_t reg_relu_ns = rax; |
90 | reg64_t reg_oi = rbx; |
91 | reg64_t reg_kh = abi_not_param1; |
92 | |
93 | reg64_t reg_tmp = rbp; |
94 | |
95 | reg64_t reg_ic_loop = rdx; |
96 | reg64_t reg_inp_loop = rsi; |
97 | |
98 | reg64_t reg_init_flag = r13; |
99 | reg64_t reg_bias_ptr = param; |
100 | |
101 | reg64_t aux_reg_ic = r12; |
102 | reg64_t reg_binp = rax; |
103 | reg64_t reg_bout = r11; |
104 | reg64_t aux1_reg_inp = rbx; |
105 | reg64_t aux_reg_out = abi_not_param1; |
106 | |
107 | reg64_t reg_long_offt = r11; |
108 | reg64_t reg_out_long_offt = r14; |
109 | |
110 | inline Vmm vmm_ker(int i_ic) { |
111 | assert(i_ic < 4); |
112 | return Vmm(ker_reg_base_idx + i_ic); |
113 | } |
114 | |
115 | inline Vmm vmm_out(int i_ur, int i_oc) { |
116 | int idx = i_ur + i_oc * jcp.ur_w; |
117 | assert(idx < ker_reg_base_idx); |
118 | return Vmm(idx); |
119 | } |
120 | |
121 | inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { |
122 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
123 | assert(idx < 31); |
124 | return Vmm(idx); |
125 | } |
126 | |
127 | Xbyak::Reg64 imm_addr64 = r15; |
128 | Vmm vmm_wei = Vmm(31); |
129 | |
130 | jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_; |
131 | |
132 | inline void prepare_output(int ur_w); |
133 | inline void store_output(int ur_w); |
134 | inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); |
135 | inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); |
136 | inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r); |
137 | inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r); |
138 | inline void compute_loop(int ur_w, int pad_l, int pad_r); |
139 | |
140 | void generate(); |
141 | |
142 | inline size_t get_output_offset(int oi, int n_oc_block) { |
143 | return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh |
144 | * jcp.ow * jcp.od + oi) * jcp.oc_block; |
145 | } |
146 | |
147 | inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { |
148 | size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1; |
149 | size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id; |
150 | return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1) |
151 | + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str); |
152 | } |
153 | |
154 | inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) { |
155 | return jcp.typesize_in * jcp.oc_block |
156 | * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd |
157 | + (ic + ker_number) + ki * jcp.ic_block); |
158 | } |
159 | |
160 | inline int get_ow_start(int ki, int pad_l) { |
161 | return nstl::max(0, |
162 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
163 | } |
164 | |
165 | inline int get_ow_end(int ur_w, int ki, int pad_r) { |
166 | return ur_w - nstl::max(0, utils::div_up(pad_r |
167 | - (jcp.kw - 1 - ki) |
168 | * (jcp.dilate_w + 1), |
169 | jcp.stride_w)); |
170 | } |
171 | }; |
172 | |
173 | struct jit_avx512_common_conv_fwd_kernel { |
174 | |
175 | jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, |
176 | const primitive_attr_t &attr) : |
177 | jit_ker(nullptr), |
178 | zmm_kernel_(nullptr), |
179 | xmm_kernel_(nullptr) { |
180 | int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block; |
181 | switch (ch_block) { |
182 | case 16: |
183 | zmm_kernel_ = |
184 | new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>( |
185 | ajcp, attr); |
186 | jit_ker = zmm_kernel_->jit_ker_; |
187 | return; |
188 | case 4: |
189 | xmm_kernel_ = |
190 | new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>( |
191 | ajcp, attr); |
192 | jit_ker = xmm_kernel_->jit_ker_; |
193 | return; |
194 | default: |
195 | assert(!"invalid channel blocking" ); |
196 | } |
197 | } |
198 | |
199 | ~jit_avx512_common_conv_fwd_kernel() { |
200 | delete xmm_kernel_; |
201 | delete zmm_kernel_; |
202 | } |
203 | |
204 | enum { |
205 | typesize = sizeof(float) |
206 | }; |
207 | |
208 | static bool post_ops_ok(jit_conv_conf_t &jcp, |
209 | const primitive_attr_t &attr); |
210 | static status_t init_conf(jit_conv_conf_t &jcp, |
211 | const convolution_desc_t &cd, |
212 | memory_desc_t &src_pd, |
213 | memory_desc_t &weights_pd, |
214 | memory_desc_t &dst_pd, |
215 | memory_desc_t &bias_pd, |
216 | const primitive_attr_t &attr, |
217 | int nthreads); |
218 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
219 | const jit_conv_conf_t &jcp); |
220 | |
221 | void(*jit_ker)(jit_conv_call_s *); |
222 | _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_; |
223 | _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_; |
224 | }; |
225 | |
226 | struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator { |
227 | |
228 | jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) |
229 | { |
230 | generate(); |
231 | jit_ker = (void (*)(jit_conv_call_s *))getCode(); |
232 | } |
233 | |
234 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32) |
235 | |
236 | static status_t init_conf(jit_conv_conf_t &jcp, |
237 | const convolution_desc_t &cd, |
238 | const memory_desc_wrapper &diff_src_d, |
239 | const memory_desc_wrapper &weights_d, |
240 | const memory_desc_wrapper &diff_dst_d); |
241 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
242 | const jit_conv_conf_t &jcp); |
243 | |
244 | jit_conv_conf_t jcp; |
245 | void (*jit_ker)(jit_conv_call_s *); |
246 | |
247 | private: |
248 | using reg64_t = const Xbyak::Reg64; |
249 | enum { |
250 | typesize = sizeof(float), |
251 | ker_reg_base_idx = 28, |
252 | }; |
253 | |
254 | reg64_t param = abi_param1; |
255 | reg64_t reg_dst = r8; |
256 | reg64_t reg_ker = r9; |
257 | reg64_t reg_src = r10; |
258 | |
259 | reg64_t reg_dst_prf = r11; |
260 | reg64_t reg_ker_prf = r12; |
261 | reg64_t reg_src_prf = r13; |
262 | |
263 | reg64_t aux_reg_dst = r14; |
264 | reg64_t aux_reg_ker = r15; |
265 | |
266 | reg64_t aux_reg_dst_prf = rsi; |
267 | reg64_t aux_reg_ker_prf = rdx; |
268 | |
269 | reg64_t aux_reg_dst_d_prf = r13; |
270 | reg64_t aux_reg_dst_d = rbx; |
271 | reg64_t aux_reg_ker_d_prf = abi_not_param1; |
272 | reg64_t aux_reg_ker_d = r9; |
273 | reg64_t reg_ki = r10; |
274 | |
275 | reg64_t reg_kj = rax; |
276 | reg64_t reg_oi = rbx; |
277 | reg64_t reg_kh = abi_not_param1; |
278 | |
279 | reg64_t reg_channel = rsi; |
280 | |
281 | reg64_t reg_tmp = rbp; |
282 | reg64_t reg_long_offt = r14; |
283 | |
284 | inline Xbyak::Zmm zmm_ker(int i_ic) { |
285 | assert(i_ic < 4); |
286 | return Xbyak::Zmm(ker_reg_base_idx + i_ic); |
287 | } |
288 | inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { |
289 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
290 | assert(idx < 31); |
291 | return Xbyak::Zmm(idx); |
292 | } |
293 | inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) { |
294 | int idx = i_ur + i_oc * jcp.ur_w; |
295 | assert(idx < ker_reg_base_idx); |
296 | return Xbyak::Zmm(idx); |
297 | } |
298 | |
299 | Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); |
300 | |
301 | inline void prepare_output(int ur_w); |
302 | inline void store_output(int ur_w); |
303 | inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); |
304 | inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); |
305 | inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow); |
306 | inline void compute_loop(int ur_w, int l_overflow, int r_overflow); |
307 | void generate(); |
308 | |
309 | inline int get_iw_start(int ki, int l_overflow) |
310 | { |
311 | int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w |
312 | + l_overflow * jcp.stride_w |
313 | - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); |
314 | while (res < 0) |
315 | res += jcp.stride_w; |
316 | |
317 | return res; |
318 | } |
319 | |
320 | inline int get_iw_end(int ur_w, int ki, int r_overflow) |
321 | { |
322 | if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) |
323 | ur_w += nstl::min(0, jcp.r_pad); // remove negative padding |
324 | int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w |
325 | + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); |
326 | while (res < 0) |
327 | res += jcp.stride_w; |
328 | |
329 | return ur_w - res; |
330 | } |
331 | }; |
332 | |
333 | struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { |
334 | |
335 | jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) |
336 | : jcp(ajcp) |
337 | { |
338 | generate(); |
339 | jit_ker = (void (*)(jit_conv_call_s *))getCode(); |
340 | } |
341 | |
342 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) |
343 | |
344 | static status_t init_conf(jit_conv_conf_t &jcp, |
345 | const convolution_desc_t &cd, |
346 | memory_desc_t &src_md, |
347 | memory_desc_t &diff_weights_md, |
348 | memory_desc_t &diff_bias_md, |
349 | memory_desc_t &diff_dst_md); |
350 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
351 | const jit_conv_conf_t &jcp); |
352 | |
353 | jit_conv_conf_t jcp; |
354 | void (*jit_ker)(jit_conv_call_s *); |
355 | |
356 | private: |
357 | using reg64_t = const Xbyak::Reg64; |
358 | enum {typesize = sizeof(float)}; |
359 | static const int max_ur_w; |
360 | |
361 | reg64_t param = abi_param1; |
362 | reg64_t reg_input = rax; |
363 | reg64_t reg_kernel = rdx; |
364 | reg64_t reg_output = rsi; |
365 | reg64_t b_ic = abi_not_param1; |
366 | reg64_t kj = r8; |
367 | reg64_t reg_kh = r9; |
368 | reg64_t reg_ur_w_trips = r10; |
369 | reg64_t reg_oj = r15; |
370 | reg64_t reg_ih_count = rbx; |
371 | reg64_t reg_tmp = r14; |
372 | reg64_t reg_long_offt = r14; |
373 | |
374 | reg64_t ki = r11; |
375 | reg64_t reg_kd_count = r12; |
376 | reg64_t reg_oi = r12; |
377 | reg64_t reg_d_index = r13; |
378 | reg64_t reg_input_d = r15; |
379 | reg64_t reg_output_d = rbx; |
380 | reg64_t aux_reg_input = r12; |
381 | reg64_t aux_reg_kernel = r13; |
382 | reg64_t reg_bias = rbx; |
383 | |
384 | inline void bias_kernel(); |
385 | inline void maybe_zero_kernel(); |
386 | inline void compute_oh_step_unroll_ow_icblock(int ic_block_step, |
387 | int max_ur_w); |
388 | inline void od_step_comeback_pointers(); |
389 | inline void oh_step_comeback_pointers(); |
390 | inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); |
391 | inline void compute_ic_block_step(int ur_w, |
392 | int pad_l, int pad_r, int ic_block_step, |
393 | int input_offset, int kernel_offset, int output_offset, |
394 | bool input_wraparound = false); |
395 | inline void compute_ic_block_step_fma(int ur_w, |
396 | int pad_l, int pad_r, int ic_block_step, |
397 | int input_offset, int kernel_offset, int output_offset, |
398 | bool input_wraparound); |
399 | inline void compute_ic_block_step_4fma(int ur_w, |
400 | int pad_l, int pad_r, int ic_block_step, |
401 | int input_offset, int kernel_offset, int output_offset, |
402 | bool input_wraparound); |
403 | inline void compute_oh_step_common(int ic_block_step, int max_ur_w); |
404 | inline void compute_oh_step_disp(); |
405 | inline void compute_oh_loop_common(); |
406 | inline void compute_d_loop_common(); |
407 | |
408 | inline bool compute_full_spat_loop(); |
409 | inline bool flat_4ops_compute(); |
410 | |
411 | inline void compute_loop(); |
412 | |
413 | void generate(); |
414 | |
415 | static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, |
416 | int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); |
417 | }; |
418 | |
419 | } |
420 | } |
421 | } |
422 | |
423 | #endif |
424 | |