1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "c_types_map.hpp"
18#include "nstl.hpp"
19#include "type_helpers.hpp"
20#include "utils.hpp"
21
22#include "cpu_barrier.hpp"
23
24#include "jit_avx512_common_conv_kernel.hpp"
25
26#define GET_OFF(field) offsetof(jit_conv_call_s, field)
27#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024)
28
29namespace mkldnn {
30namespace impl {
31namespace cpu {
32
33using namespace mkldnn::impl::format_tag;
34using namespace mkldnn::impl::memory_tracking::names;
35using namespace mkldnn::impl::utils;
36using namespace Xbyak;
37
38namespace {
39
40constexpr auto small_spatial = 14;
41unsigned int L1_cache_size = get_cache_size(1, true);
42
43inline void pick_loop_order(jit_conv_conf_t &jcp) {
44 using namespace prop_kind;
45 assert(one_of(jcp.prop_kind,
46 forward_training, forward_inference, backward_data));
47 auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
48 auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
49
50 // ow-threading is currently implemented for forward only
51 // TODO: single code for fwd and bwd after ow-thr for bwd
52 // meaningless switch was removed
53 if (jcp.prop_kind == backward_data) {
54 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
55 ? loop_cgn : loop_gnc;
56 } else {
57 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
58 ? loop_cwgn : loop_gncw;
59 }
60}
61
62inline bool is_1stconv(const jit_conv_conf_t &jcp) {
63 if (mayiuse(avx512_core))
64 return (jcp.ic < 16 && jcp.ngroups == 1);
65 else
66 return one_of(jcp.ic, 1, 3);
67}
68
69inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
70 return (jcp.nb_ow > 1);
71}
72
73inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) {
74 return (jcp.ver == ver_4fma && is_ow_threading_on(jcp));
75}
76
77}
78
79template<typename Vmm>
80void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w)
81{
82 for (int k = 0; k < jcp.nb_oc_blocking; k++)
83 for (int j = 0; j < ur_w; j++) {
84 Vmm vmm = vmm_out(j, k);
85 vpxord(vmm, vmm, vmm);
86 if (!is_owb_prefetching(jcp)) {
87 size_t aux_output_offset = get_output_offset(j, k);
88 mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
89 aux_output_offset, reg_out_long_offt));
90 }
91 }
92}
93
94template<typename Vmm>
95void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w)
96{
97 Label no_update_label, store_label, eltwise_label;
98
99 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
100 if (jcp.with_bias) {
101 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
102 }
103
104 if (!jcp.with_sum) {
105 cmp(reg_channel, 0);
106 je(no_update_label, T_NEAR);
107 }
108
109 for (int k = 0; k < jcp.nb_oc_blocking; k++)
110 for (int j = 0; j < ur_w; j++) {
111 Vmm vmm = vmm_out(j, k);
112 size_t aux_output_offset = get_output_offset(j, k);
113 vaddps(vmm,
114 make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
115 }
116
117 if (!jcp.with_sum) {
118 jmp(eltwise_label, T_NEAR);
119 } else {
120 cmp(reg_channel, 0);
121 jne(eltwise_label, T_NEAR);
122 }
123
124 L(no_update_label);
125 if (jcp.with_bias) {
126 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
127 int bias_offset = jcp.typesize_out * k * jcp.oc_block;
128 for (int j = 0; j < ur_w; j++) {
129 Vmm vmm = vmm_out(j, k);
130 vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset));
131 }
132 mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
133 }
134 }
135
136 L(eltwise_label);
137 if (jcp.with_eltwise) {
138 cmp(reg_channel, jcp.nb_ic - 1);
139 jl(store_label, T_NEAR);
140
141 if (ur_w == jcp.ur_w) {
142 eltwise_injector_->compute_vector_range(0,
143 jcp.nb_oc_blocking * jcp.ur_w);
144 } else {
145 for (int k = 0; k < jcp.nb_oc_blocking; k++)
146 eltwise_injector_->compute_vector_range(k * jcp.ur_w,
147 k * jcp.ur_w + ur_w);
148 }
149 }
150
151 L(store_label);
152 for (int k = 0; k < jcp.nb_oc_blocking; k++)
153 for (int j = 0; j < ur_w; j++) {
154 Vmm vmm = vmm_out(j, k);
155 size_t aux_output_offset = (size_t)typesize *
156 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
157 vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
158 reg_out_long_offt), vmm);
159 if (!is_owb_prefetching(jcp))
160 mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
161 aux_output_offset, reg_out_long_offt));
162 }
163}
164
165template<typename Vmm>
166void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w,
167 int pad_l, int pad_r)
168{
169}
170
171template<>
172void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w,
173 int pad_l, int pad_r)
174{
175 assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0);
176
177 int iw = jcp.iw;
178 int ih = jcp.ih;
179 int kw = jcp.kw;
180 int stride_w = jcp.stride_w;
181 int ic_block = jcp.ic_block;
182 int oc_block = jcp.oc_block;
183
184 Label kh_label, kd_label;
185
186 if (one_of(jcp.ndims, 3, 4)) {
187 mov(aux_reg_inp, reg_inp);
188 mov(aux_reg_ker, reg_ker);
189 mov(aux_reg_inp_prf, reg_inp_prf);
190 }
191
192 size_t max_input_offset = (size_t)jcp.typesize_in
193 * ((size_t)(kw + ur_w * stride_w - pad_l)
194 + (size_t)ic_block * iw * ih * jcp.id);
195 assert(reg_inp_prf == reg_long_offt);
196 if (max_input_offset > INT_MAX) push(reg_inp_prf);
197
198 if (jcp.ndims == 5) {
199 push(reg_out_prf);
200 push(reg_out);
201
202 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
203 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
204 mov(aux_reg_inp_d, reg_inp);
205 mov(aux_reg_inp_d_prf, reg_inp_prf);
206
207 L(kd_label);
208 }
209 mov(reg_kj, reg_kh);
210 if (jcp.ndims == 5) {
211 mov(aux_reg_inp, aux_reg_inp_d);
212 mov(aux_reg_ker, aux_reg_ker_d);
213 mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
214 }
215
216 L(kh_label);
217 for (int ki = 0; ki < kw; ki += 4) {
218 for (int ic = 0; ic < ic_block; ic++) {
219 for (int i = 0; i < 4; i++) {
220 int aux_ker_offset
221 = jcp.typesize_in
222 * ((ki + i) * oc_block
223 + ic * kw * jcp.kh * jcp.kd * oc_block);
224 if (ki + i < kw)
225 vmovups(vmm_ker(i),
226 EVEX_compress_addr(aux_reg_ker, aux_ker_offset));
227 else
228 vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i));
229 }
230
231 int j_start = get_ow_start(ki, pad_l);
232 int j_end = get_ow_end(ur_w, ki, pad_r);
233
234 for (int j = j_start, prf_count=0; j < j_end; j++) {
235 size_t aux_input_offset = (size_t)jcp.typesize_in
236 * ((size_t)(ki + j * stride_w
237 - pad_l) + (size_t)ic * iw * ih * jcp.id);
238 v4fmaddps(vmm_out(j, 0), vmm_ker(0),
239 EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
240 reg_long_offt));
241 if (ki + prf_count < kw && prf_count < 4
242 && ((ki < 2 && j % 4) || j % 2)) {
243 int aux_ker_offset = jcp.typesize_in
244 * ((ki + prf_count) * oc_block
245 + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block);
246 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
247 aux_ker_offset));
248 prf_count++;
249 }
250 if (ki == 0
251 && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
252 mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf,
253 aux_input_offset, reg_long_offt));
254 }
255 if (ki == 1
256 && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
257 mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
258 aux_input_offset+jcp.typesize_in * iw, reg_long_offt));
259 }
260 }
261 }
262 }
263 add(aux_reg_ker, jcp.typesize_in * kw * oc_block);
264 add(aux_reg_inp, jcp.typesize_in * iw);
265 add(aux_reg_inp_prf, jcp.typesize_in * iw);
266
267 dec(reg_kj);
268 cmp(reg_kj, 0);
269 jg(kh_label, T_NEAR);
270
271 if (jcp.ndims == 5) {
272 add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw);
273 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block);
274 add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw);
275
276 dec(reg_ki);
277 cmp(reg_ki, 0);
278 jg(kd_label, T_NEAR);
279
280 pop(reg_out);
281 pop(reg_out_prf);
282 }
283
284 if (max_input_offset > INT_MAX) pop(reg_inp_prf);
285}
286
287template<typename Vmm>
288void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w,
289 int pad_l, int pad_r)
290{
291}
292
293template<>
294void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w,
295 int pad_l, int pad_r)
296{
297 int stride_w = jcp.stride_w;
298 int ic_block = jcp.ic_block;
299 int oc_block = jcp.oc_block;
300 Label kh_label, last_iter_label, loop_end_label, kd_label;
301 int ker_load_number = 4;
302 int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block;
303 int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
304
305 bool check_last_kh = (jcp.kh > 3);
306 bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28);
307
308 int oi_ipref_t0 = get_ow_start(0, pad_l);
309 int ow_end_ipref = get_ow_end(ur_w, 0, pad_r);
310
311 assert(jcp.oc % jcp.nb_oc_blocking == 0);
312
313 auto kernel_offset = [=](int ocb, int ic, int ki) {
314 int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki;
315 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
316 int ic_offset = ic * jcp.oc_block;
317 return typesize * (blk_offset + ic_offset);
318 };
319 auto kernel_loads = [=](int ki, int ic, int kk) {
320 for (int ii = 0; ii < ker_load_number; ii++) {
321 int aux_kernel_offset = kernel_offset(kk, ic + ii, ki);
322 vmovups(vmm_ker(ii),
323 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
324 }
325 };
326 auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
327 if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
328 && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) {
329 int aux_inp_offset
330 = typesize
331 * ((oi_ipref_t0 * stride_w - pad_l) * ic_block
332 + (jcp.dilate_h + 1) * jcp.iw * ic_block);
333 prefetcht0(EVEX_compress_addr(aux_reg_inp,
334 aux_inp_offset));
335 oi_ipref_t0++;
336 }
337 };
338
339 if (one_of(jcp.ndims, 3, 4)) {
340 mov(aux_reg_inp, reg_inp);
341 mov(aux_reg_ker, reg_ker);
342 mov(aux_reg_ker_prf, reg_ker_prf);
343 mov(aux_reg_inp_prf, reg_inp_prf);
344 }
345
346 if (jcp.ndims == 5) {
347 push(reg_out_prf);
348 push(reg_out);
349
350 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
351 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
352 mov(aux_reg_inp_d, reg_inp);
353 mov(aux_reg_inp_d_prf, reg_inp_prf);
354 mov(aux_reg_ker_d_prf, reg_ker_prf);
355 L(kd_label);
356 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
357 } else {
358 mov(reg_kj, reg_kh);
359 }
360 if (jcp.ndims == 5) {
361 mov(aux_reg_inp, aux_reg_inp_d);
362 mov(aux_reg_ker, aux_reg_ker_d);
363 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
364 mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
365 }
366
367 align(16);
368 L(kh_label);
369 int kw = jcp.kw;
370 if (check_last_kh) {
371 for (int ki = 0; ki < kw; ki++)
372 for (int ic = 0; ic < ic_block; ic += 4)
373 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
374 bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1
375 && ki == kw - 1 && (ic + 4) == ic_block);
376
377 if (last_kernel_loads) {
378 cmp(reg_kj, 1);
379 je(last_iter_label, T_NEAR);
380 }
381
382 kernel_loads(ki, ic, kk);
383 for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
384 prf_count_t0 = 0;
385 oi < get_ow_end(ur_w, ki, pad_r); oi++) {
386 int aux_input_offset = typesize
387 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
388 - pad_l) * ic_block
389 + ic);
390 v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
391 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
392
393 if (oi % 2) {
394 if (prf_count_t0 < 4) {
395 int aux_kernel_prf;
396 if (last_kernel_loads)
397 aux_kernel_prf= kernel_offset(0,
398 prf_count_t0 + ic + 4
399 - ic_block, 0) + typesize * kw
400 * oc_block * ic_block;
401 else
402 aux_kernel_prf = kernel_offset(kk, ic + 4
403 + prf_count_t0, ki);
404 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
405 aux_kernel_prf));
406 prf_count_t0++;
407 } else if (prf_count_t1 < 4) {
408 mic_prefetcht1(EVEX_compress_addr(
409 aux_reg_ker_prf, kernel_offset(kk, ic
410 + prf_count_t1, ki)));
411 prf_count_t1++;
412 }
413 } else
414 prefetch_inp_next_kh(ki, 2, prf_count_t0,
415 prf_count_t1);
416 }
417
418 if (last_kernel_loads) {
419 jmp(loop_end_label, T_NEAR);
420
421 L(last_iter_label);
422
423 kernel_loads(ki, ic, kk);
424 for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
425 prf_count_t0 = 0;
426 oi < get_ow_end(ur_w, ki, pad_r); oi++) {
427 int aux_input_offset = typesize
428 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
429 - pad_l) * ic_block
430 + ic);
431 v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
432 EVEX_compress_addr(aux_reg_inp,
433 aux_input_offset));
434 if (oi % 2) {
435 if (prf_count_t0 < 4) {
436 mic_prefetcht0(EVEX_compress_addr(
437 aux_reg_ker_prf, kernel_offset(0,
438 prf_count_t0, 0)));
439 prf_count_t0++;
440 } else if (prf_count_t1 < 4) {
441 mic_prefetcht1(EVEX_compress_addr(
442 aux_reg_ker_prf, kernel_offset(kk,
443 ic + prf_count_t1, ki)));
444 prf_count_t1++;
445 }
446 }
447 }
448 L(loop_end_label);
449 }
450 }
451 } else {
452 for (int ki = 0; ki < kw; ki++)
453 for (int ic = 0; ic < ic_block; ic += 4)
454 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
455 kernel_loads(ki, ic, kk);
456 for (int oi = get_ow_start(ki, pad_l),
457 prf_count_t1 = 0, prf_count_t0 = 0;
458 oi < get_ow_end(ur_w, ki, pad_r); oi++) {
459 int aux_input_offset = typesize
460 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
461 - pad_l) * ic_block + ic);
462 v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
463 EVEX_compress_addr(aux_reg_inp,
464 aux_input_offset));
465
466 if (!is_owb_prefetching(jcp)) {
467 if ((oi % 2) && (prf_count_t1 < 4)) {
468 mic_prefetcht1(EVEX_compress_addr(
469 aux_reg_ker_prf, kernel_offset(kk,
470 ic + prf_count_t1, ki)));
471 prf_count_t1++;
472 }
473 } else {
474 if (!(ki == 0 && ic == 0)
475 && !(ki == kw-1 && ic == 0) &&
476 (oi % 2) && (prf_count_t1 < 4)
477 ) {
478 mic_prefetcht0(EVEX_compress_addr(
479 aux_reg_ker, kernel_offset(kk,
480 ic + 4 + prf_count_t0, ki)));
481 prf_count_t0++;
482 }
483 }
484 if (!is_owb_prefetching(jcp)) {
485 if (pref_current_inp) {
486 if (ki == 0 && ic == 0 && kk == 0)
487 mic_prefetcht0(EVEX_compress_addr(
488 aux_reg_inp,
489 aux_input_offset + shift_input_ptr));
490 } else {
491 if (ki == 1 && ic == 0 && kk == 0)
492 mic_prefetcht1(EVEX_compress_addr(
493 aux_reg_inp_prf, aux_input_offset));
494 }
495 } else {
496 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
497 int inp_shift
498 = jcp.typesize_in * ur_w * stride_w * inp_mult;
499 bool kk_pref_slot = kk ? oi % 2 : !(oi % 2);
500 if (ki == 0 && ic == 0 && kk_pref_slot)
501 mic_prefetcht1(EVEX_compress_addr(
502 aux_reg_inp,
503 aux_input_offset + inp_shift));
504
505 if (ki == kw - 1 && ic == 0 && kk_pref_slot)
506 mic_prefetcht0(EVEX_compress_addr(
507 aux_reg_inp,
508 aux_input_offset + inp_shift));
509 }
510 }
511 }
512 }
513
514 add(aux_reg_ker, shift_kernel_ptr);
515 add(aux_reg_inp, shift_input_ptr);
516 add(aux_reg_ker_prf, shift_kernel_ptr);
517 add(aux_reg_inp_prf, shift_input_ptr);
518
519 dec(reg_kj);
520 cmp(reg_kj, 0);
521 jg(kh_label, T_NEAR);
522
523 if (jcp.ndims == 5) {
524 add(aux_reg_inp_d,
525 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
526 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
527 * jcp.ic_block);
528 add(aux_reg_inp_d_prf,
529 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
530 add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
531 * jcp.ic_block);
532
533 dec(reg_ki);
534 cmp(reg_ki, 0);
535 jg(kd_label, T_NEAR);
536
537 pop(reg_out);
538 pop(reg_out_prf);
539 }
540}
541
542template<typename Vmm>
543void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w,
544 int pad_l, int pad_r)
545{
546 bool prf_ker = true;
547 bool prf_inp = true;
548 int ih = jcp.ih;
549 int stride_w = jcp.stride_w;
550 int id = jcp.id;
551 int iw = jcp.iw;
552 int kw = jcp.kw;
553 int ic_block = jcp.ic_block;
554 int oc_block = jcp.oc_block;
555 int nb_oc_block = jcp.nb_oc_blocking;
556 Label kh_label, kd_label;
557
558 int ker_pipeline_depth = 4;
559 assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
560 assert(oc_block >= ker_pipeline_depth);
561
562 int num_ker_loads = ic_block * nb_oc_block * kw;
563 int num_ker_prfs = prf_ker ? num_ker_loads : 0;
564 int num_inp_prfs = prf_inp ?
565 ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) :
566 0;
567 if (jcp.is_1stconv && prf_inp) {
568 num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block;
569 }
570 int num_prfs = num_ker_prfs + num_inp_prfs;
571 int num_fmas = num_ker_loads * ur_w;
572 int prf_inst_spacing
573 = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1;
574 int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
575 int inp_mul = !jcp.is_1stconv ? ic_block : 1;
576
577 if (one_of(jcp.ndims, 3, 4)) {
578 mov(aux_reg_inp, reg_inp);
579 mov(aux_reg_ker, reg_ker);
580 mov(aux_reg_inp_prf, reg_inp_prf);
581 mov(aux_reg_ker_prf, reg_ker_prf);
582 }
583
584 size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id;
585 assert(reg_inp_prf == reg_long_offt);
586 if (max_input_offset > INT_MAX) push(reg_inp_prf);
587
588
589 if (jcp.ndims == 5) {
590 push(reg_out_prf);
591 push(reg_out);
592
593 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
594 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
595 mov(aux_reg_inp_d, reg_inp);
596 mov(aux_reg_inp_d_prf, reg_inp_prf);
597 mov(aux_reg_ker_d_prf, reg_ker_prf);
598
599 L(kd_label);
600 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
601 } else {
602 mov(reg_kj, reg_kh);
603 }
604
605 if (jcp.ndims == 5) {
606 mov(aux_reg_inp, aux_reg_inp_d);
607 mov(aux_reg_ker, aux_reg_ker_d);
608 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
609 mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
610 }
611
612 align(16);
613 L(kh_label);
614 {
615 int step = 0;
616 int ker_prfs = 0;
617 for (int ki = 0; ki < kw; ki++) {
618 for (int ic = 0; ic < ic_block; ic++) {
619 int aux_kernel_offset = 0;
620 if (step == 0) {
621 for (int i = 0; i < ker_pipeline_depth; i++) {
622 aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
623 vmovups(vmm_ker(i), EVEX_compress_addr(
624 aux_reg_ker, aux_kernel_offset));
625 }
626 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
627 int load_offset = ker_pipeline_depth - 1;
628 int ker_load_reg_idx
629 = (step + load_offset) % ker_pipeline_depth;
630 aux_kernel_offset
631 = get_kernel_offset(ki, ic, 0, load_offset);
632 vmovups(vmm_ker(ker_load_reg_idx),
633 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
634 }
635
636 bool ker_prf_inserted = false;
637 Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
638 int j_start = get_ow_start(ki, pad_l);
639 int j_end = get_ow_end(ur_w, ki, pad_r);
640 for (int j = j_start; j < j_end; j++) {
641 size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
642 auto addr = EVEX_compress_addr_safe(aux_reg_inp,
643 aux_input_offset, reg_long_offt, true);
644 vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
645 int fma_idx = step * ur_w + j;
646 int prf_slot_idx = fma_idx / prf_inst_spacing;
647 if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
648 if (prf_ker && !ker_prf_inserted
649 && ker_prfs < num_ker_prfs) {
650 int ker_prf_offset
651 = jcp.typesize_in * ker_prfs * jcp.oc_block;
652 mic_prefetcht2(EVEX_compress_addr(
653 aux_reg_ker_prf, ker_prf_offset));
654 ker_prf_inserted = true;
655 ker_prfs++;
656 } else if (prf_inp) {
657 int inp_prf_idx = prf_slot_idx - ker_prfs;
658 if (inp_prf_idx < num_inp_prfs) {
659 size_t inp_prf_stride = nstl::max(kw, stride_w);
660 size_t inp_prf_offset;
661 if (!jcp.is_1stconv) {
662 inp_prf_offset
663 = ic_block * jcp.typesize_in
664 * ((inp_prf_idx / kw)
665 * inp_prf_stride
666 + (inp_prf_idx % kw));
667 } else {
668 size_t ic_prf_stride =
669 (size_t)jcp.typesize_in * iw * ih * id;
670 size_t iw_prf_stride
671 = jcp.typesize_in * jcp.simd_w;
672 inp_prf_offset = ((inp_prf_idx / ic_block)
673 * iw_prf_stride
674 + (inp_prf_idx % ic_block)
675 * ic_prf_stride);
676 }
677 mic_prefetcht0(EVEX_compress_addr_safe(
678 aux_reg_inp_prf, inp_prf_offset,
679 reg_long_offt));
680 }
681 }
682 }
683 }
684 step++;
685 }
686 }
687 add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block);
688 if (prf_ker)
689 add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block);
690 add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
691 if (prf_inp)
692 add(aux_reg_inp_prf,
693 jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
694 dec(reg_kj);
695 cmp(reg_kj, 0);
696 jg(kh_label, T_NEAR);
697 }
698
699
700 if (jcp.ndims == 5) {
701 add(aux_reg_inp_d,
702 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
703 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
704 * jcp.ic_block);
705 add(aux_reg_inp_d_prf,
706 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
707 add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
708 * jcp.ic_block);
709
710 dec(reg_ki);
711 cmp(reg_ki, 0);
712 jg(kd_label, T_NEAR);
713
714 pop(reg_out);
715 pop(reg_out_prf);
716 }
717 if (max_input_offset > INT_MAX) pop(reg_inp_prf);
718}
719
720template<typename Vmm>
721void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w,
722 int pad_l, int pad_r)
723{
724 int kw = jcp.kw;
725 int stride_w = jcp.stride_w;
726 int ic_block = jcp.ic_block;
727 int oc_block = jcp.oc_block;
728 int nb_oc_block = jcp.nb_oc_blocking;
729 Label kh_label, kd_label;
730 int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block
731 * jcp.ic_block;
732 int inp_mul = !jcp.is_1stconv ? ic_block : 1;
733 int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
734 * inp_mul;
735
736
737 auto input_offset = [=](int oi, int ic, int ki) {
738 return (size_t)jcp.typesize_in
739 * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
740 * inp_mul + (size_t)ic
741 * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
742 };
743
744 if (one_of(jcp.ndims, 3, 4)) {
745 mov(aux_reg_inp, reg_inp);
746 mov(aux_reg_ker, reg_ker);
747 }
748
749 if (jcp.ndims == 5) {
750 push(reg_out);
751
752 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
753 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
754 mov(aux_reg_inp_d, reg_inp);
755
756 L(kd_label);
757 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
758 } else {
759 mov(reg_kj, reg_kh);
760 }
761
762 if (jcp.ndims == 5) {
763 mov(aux_reg_inp, aux_reg_inp_d);
764 mov(aux_reg_ker, aux_reg_ker_d);
765 }
766
767 L(kh_label);
768 {
769 for (int ki = 0; ki < kw; ki++) {
770 int jj_start = get_ow_start(ki, pad_l);
771 int jj_end = get_ow_end(ur_w, ki, pad_r);
772 for (int ic = 0; ic < ic_block; ic++) {
773 if (jcp.kernel_kind == expl_bcast) {
774 for (int jj = jj_start; jj < jj_end; jj++) {
775 size_t aux_input_offset = input_offset(jj, ic, ki);
776 vbroadcastss(vmm_inp(jj, nb_oc_block),
777 EVEX_compress_addr_safe(aux_reg_inp,
778 aux_input_offset, reg_long_offt));
779 }
780 }
781 for (int ii = 0; ii < nb_oc_block; ii++) {
782 int aux_kernel_offset = jcp.typesize_in
783 * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block
784 * oc_block + ki * ic_block * oc_block + ic * oc_block);
785 if (jj_end - jj_start > 0)
786 vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker,
787 aux_kernel_offset));
788 for (int jj = jj_start; jj < jj_end; jj++)
789 if (jcp.kernel_kind == expl_bcast)
790 vfmadd231ps(vmm_out(jj, ii),
791 vmm_inp(jj, nb_oc_block), vmm_wei);
792 else {
793 size_t aux_input_offset = input_offset(jj, ic, ki);
794 vfmadd231ps(vmm_out(jj, ii), vmm_wei,
795 EVEX_compress_addr_safe(aux_reg_inp,
796 aux_input_offset, reg_long_offt, true));
797 }
798 }
799 }
800 }
801 add(aux_reg_ker, shift_kernel_ptr);
802 add(aux_reg_inp, shift_input_ptr);
803 dec(reg_kj);
804 cmp(reg_kj, 0);
805 jg(kh_label, T_NEAR);
806 }
807
808 if (jcp.ndims == 5) {
809 add(aux_reg_inp_d,
810 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
811 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
812 * jcp.ic_block);
813
814 dec(reg_ki);
815 cmp(reg_ki, 0);
816 jg(kd_label, T_NEAR);
817
818 pop(reg_out);
819 }
820}
821
822template<typename Vmm>
823void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w,
824 int pad_l, int pad_r)
825{
826 if (jcp.ndims == 5) push(reg_oi);
827
828 prepare_output(ur_w);
829
830 Label skip_compute_loop;
831 if (jcp.ndims == 5) {
832 if ((jcp.dilate_d >= jcp.id)
833 || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
834 mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
835 cmp(reg_kj, 0);
836 je(skip_compute_loop, T_NEAR);
837 }
838 }
839 if ((jcp.dilate_h >= jcp.ih)
840 || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
841 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
842 cmp(reg_kj, 0);
843 je(skip_compute_loop, T_NEAR);
844 }
845
846 if (jcp.ver == ver_4fma)
847 if(jcp.is_1stconv)
848 compute_loop_4fma_1st(ur_w, pad_l, pad_r);
849 else
850 compute_loop_4fma(ur_w, pad_l, pad_r);
851 else if (jcp.ver == ver_fma)
852 if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast)
853 || mayiuse(avx512_mic))
854 compute_loop_fma(ur_w, pad_l, pad_r);
855 else
856 if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1)
857 compute_loop_fma(ur_w, pad_l, pad_r);
858 else
859 compute_loop_fma_core(ur_w, pad_l, pad_r);
860 else
861 assert(!"unknown convolution version");
862
863 L(skip_compute_loop);
864 store_output(ur_w);
865 if (jcp.ndims == 5) pop(reg_oi);
866}
867
868template<typename Vmm>
869void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate()
870{
871 int iw = jcp.iw;
872 int ow = jcp.ow;
873 int ow_block = jcp.ow_block;
874 int nb_ow = jcp.nb_ow;
875 int kw = jcp.kw;
876 int l_pad = jcp.l_pad;
877 int ur_w = jcp.ur_w;
878 int ur_w_tail = jcp.ur_w_tail;
879 int dilate_w = jcp.dilate_w + 1;
880 int stride_w = jcp.stride_w;
881
882 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
883 int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
884 int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult;
885 int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
886 int out_shift = jcp.typesize_out * ur_w * jcp.oc_block;
887
888 preamble();
889 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
890 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
891 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
892 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
893 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
894
895 int r_pad = nstl::max(
896 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1));
897 int n_oi = ow / ur_w;
898 int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w
899 - (iw + l_pad - 1);
900
901 if (!is_ow_threading_on(jcp)) {
902 // ow is being processed as a whole - with left and right paddings
903 if (r_pad1 > 0) n_oi--;
904
905 if (ow == ur_w) {
906 mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]);
907 mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]);
908 compute_loop(ur_w, l_pad, r_pad);
909 } else {
910 mov(reg_inp_prf, reg_inp);
911 mov(reg_out_prf, reg_out);
912 if (n_oi == 0) {
913 add(reg_inp_prf, inp_shift_pad);
914 add(reg_out_prf, out_shift);
915 compute_loop(ur_w, l_pad, r_pad1);
916 add(reg_inp, inp_shift_pad);
917 add(reg_out, out_shift);
918 if (ur_w_tail != 0) {
919 add(reg_inp_prf, inp_shift);
920 add(reg_out_prf, out_shift);
921 compute_loop(ur_w_tail, 0, r_pad);
922 }
923 } else {
924 xor_(reg_oi, reg_oi);
925 if (l_pad > 0) {
926 add(reg_inp_prf, inp_shift_pad);
927 add(reg_out_prf, out_shift);
928 compute_loop(ur_w, l_pad, 0);
929 add(reg_inp, inp_shift_pad);
930 add(reg_out, out_shift);
931 inc(reg_oi);
932 }
933 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
934 Label ow_loop_label;
935 L(ow_loop_label);
936 {
937 add(reg_inp_prf, inp_shift);
938 add(reg_out_prf, out_shift);
939 compute_loop(ur_w, 0, 0);
940 add(reg_inp, inp_shift);
941 add(reg_out, out_shift);
942 inc(reg_oi);
943 cmp(reg_oi, n_oi);
944 jl(ow_loop_label, T_NEAR);
945 }
946 }
947 if (r_pad1 > 0) {
948 add(reg_inp_prf, inp_shift);
949 add(reg_out_prf, out_shift);
950 compute_loop(ur_w, 0, r_pad1);
951 add(reg_inp, inp_shift);
952 add(reg_out, out_shift);
953 }
954 if (ur_w_tail != 0) {
955 add(reg_inp_prf, inp_shift);
956 add(reg_out_prf, out_shift);
957 compute_loop(ur_w_tail, 0, r_pad);
958 }
959 }
960 }
961 } else {
962 // ow block is only processed.
963 // Number of block is passed as parameter owb,
964 // and padding processing depends on this number.
965
966 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
967 Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
968
969 assert(ow_block % ur_w == 0);
970 int n_oi_not_last_ow_block = ow_block / ur_w;
971 // to simplify code (and general regs usage),
972 // size of ow block must be >= 2 * ur_w
973 assert(n_oi_not_last_ow_block > 1);
974 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
975 int n_oi_first_ow_block = n_oi_not_last_ow_block;
976
977 int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w;
978
979 // prepare right padding
980 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
981 bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2;
982 bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
983
984 if (last_ow_block_padded) n_oi_last_ow_block--;
985 else if (first_ow_block_padded) n_oi_first_ow_block--;
986 else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
987
988 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
989 cmp(reg_owb, 0); // is that the first ow-block ?
990 jg(middle_ow_blocks_label, T_NEAR);
991
992 // the first ow block, compute left padding
993
994 mov(reg_oi, n_oi_first_ow_block);
995 mov(reg_inp_prf, reg_inp);
996 mov(reg_out_prf, reg_out);
997
998 if (l_pad > 0) {
999 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1000 add(reg_inp_prf, inp_shift_pad);
1001 add(reg_out_prf, out_shift);
1002 compute_loop(ur_w, l_pad, 0);
1003 add(reg_inp, inp_shift_pad);
1004 add(reg_out, out_shift);
1005 dec(reg_oi);
1006 }
1007 jmp(oi_loop_label, T_NEAR);
1008
1009 // middle or last ow block entry
1010
1011 L(middle_ow_blocks_label);
1012
1013 if (l_pad > 0) {
1014 // just to consider left padding, not compute
1015 add(reg_inp, inp_shift_pad_second_block);
1016 add(reg_inp_prf, inp_shift_pad_second_block);
1017 }
1018
1019 // set number of iteration for oi-loop
1020 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
1021 mov(reg_oi, n_oi_last_ow_block);
1022 je(oi_loop_label, T_NEAR);
1023 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1024 mov(reg_oi, n_oi_next_last_ow_block);
1025 je(oi_loop_label, T_NEAR);
1026 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
1027
1028 // oi loop w/o padding
1029 L(oi_loop_label);
1030 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1031 L(oi_loop_start_label);
1032 cmp(reg_oi, 0);
1033 jle(oi_loop_end_label, T_NEAR);
1034
1035 add(reg_inp_prf, inp_shift);
1036 add(reg_out_prf, out_shift);
1037 compute_loop(ur_w, 0, 0);
1038 add(reg_inp, inp_shift);
1039 add(reg_out, out_shift);
1040 dec(reg_oi);
1041 jmp(oi_loop_start_label, T_NEAR);
1042 L(oi_loop_end_label);
1043
1044 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1045
1046 cmp(reg_owb, 0); // first ow-block ?
1047 if (first_ow_block_padded) {
1048 je(last_oi_label, T_NEAR);
1049 } else {
1050 je(end_label, T_NEAR);
1051 }
1052 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1053 jl(end_label, T_NEAR);
1054 if (next_last_ow_block_padded) {
1055 je(last_oi_label, T_NEAR);
1056 } else {
1057 je(end_label, T_NEAR);
1058 }
1059 // that is last block
1060 if (!last_ow_block_padded) {
1061 jmp(tail_label, T_NEAR);
1062 }
1063
1064 // last oi block with right padding
1065 L(last_oi_label);
1066 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1067 add(reg_inp_prf, inp_shift);
1068 add(reg_out_prf, out_shift);
1069 compute_loop(ur_w, 0, r_pad1);
1070 add(reg_inp, inp_shift);
1071 add(reg_out, out_shift);
1072
1073 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1074 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
1075 jl(end_label, T_NEAR);
1076
1077 L(tail_label);
1078 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1079 if (ur_w_tail != 0) {
1080 add(reg_inp_prf, inp_shift);
1081 add(reg_out_prf, out_shift);
1082 compute_loop(ur_w_tail, 0, r_pad);
1083 }
1084 L(end_label);
1085 }
1086 postamble();
1087
1088 if (jcp.with_eltwise)
1089 eltwise_injector_->prepare_table();
1090}
1091
1092bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
1093 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1094 const auto &p = attr.post_ops_;
1095
1096 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
1097 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1098
1099 switch (p.len_) {
1100 case 0: return true; // no post_ops
1101 case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
1102 case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
1103 default: return false;
1104 }
1105
1106 return false;
1107}
1108
1109status_t jit_avx512_common_conv_fwd_kernel::init_conf(
1110 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1111 memory_desc_t &src_md, memory_desc_t &weights_md,
1112 memory_desc_t &dst_md, memory_desc_t &bias_md,
1113 const primitive_attr_t &attr, int nthreads)
1114{
1115 using namespace prop_kind;
1116
1117 if (!mayiuse(avx512_common))
1118 return status::unimplemented;
1119
1120 const memory_desc_wrapper src_d(&src_md);
1121 const memory_desc_wrapper weights_d(&weights_md);
1122 const memory_desc_wrapper dst_d(&dst_md);
1123 const memory_desc_wrapper bias_d(&bias_md);
1124
1125 const int regs = 28;
1126 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1127 int ndims = src_d.ndims();
1128
1129 jcp = zero<decltype(jcp)>();
1130 jcp.ndims = ndims;
1131 jcp.prop_kind = cd.prop_kind;
1132 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1133 jcp.mb = src_d.dims()[0];
1134 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1135 jcp.oc_without_padding = jcp.oc;
1136 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1137 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1138 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1139 jcp.iw = src_d.dims()[ndims-1];
1140 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
1141 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2];
1142 jcp.ow = dst_d.dims()[ndims-1];
1143 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1144 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
1145 jcp.kw = weights_d.dims()[with_groups + ndims-1];
1146 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1147 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1148 jcp.l_pad = cd.padding[0][ndims-3];
1149 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1150 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1151 jcp.stride_w = cd.strides[ndims-3];
1152
1153 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1154 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1155 jcp.dilate_w = cd.dilates[ndims-3];
1156
1157 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
1158 - (jcp.ih + jcp.t_pad - 1);
1159 jcp.back_pad = (jcp.od - 1) * jcp.stride_d
1160 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
1161
1162 jcp.is_1stconv = is_1stconv(jcp);
1163
1164 bool ok_to_pad_channels = true
1165 && jcp.ngroups == 1
1166 && src_d.data_type() == data_type::f32;
1167
1168 const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
1169 jcp.simd_w = full_simd_w;
1170 bool ok_to_try_xmm = true
1171 && mayiuse(avx512_core)
1172 && src_d.data_type() == data_type::f32
1173 && !jcp.is_1stconv
1174 && !ok_to_pad_channels
1175 && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0)
1176 && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0);
1177 if (ok_to_try_xmm)
1178 jcp.simd_w = 4;
1179
1180 jcp.oc_block = jcp.simd_w;
1181 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
1182 jcp.aligned_threads = 0;
1183
1184 if (ok_to_pad_channels) {
1185 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1186 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1187 }
1188 bool args_ok = true
1189 && jcp.oc % jcp.oc_block == 0
1190 && jcp.ic % jcp.ic_block == 0;
1191 if (!args_ok)
1192 return status::unimplemented;
1193
1194 if (!post_ops_ok(jcp, attr))
1195 return status::unimplemented;
1196
1197 const auto &p = attr.post_ops_;
1198 jcp.with_sum = p.find(primitive_kind::sum) != -1;
1199 const int eltwise_ind = p.find(primitive_kind::eltwise);
1200 jcp.with_eltwise = eltwise_ind != -1;
1201 if (jcp.with_eltwise) {
1202 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
1203 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
1204 }
1205
1206 auto src_tag = jcp.is_1stconv
1207 ? pick(ndims - 3, ncw, nchw, ncdhw)
1208 : ((jcp.simd_w == 4)
1209 ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
1210 : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c));
1211 auto dst_tag = (jcp.simd_w == 4)
1212 ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
1213 : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1214 auto wei_tag = with_groups
1215 ? ((jcp.simd_w == 4)
1216 ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o)
1217 : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o))
1218 : ((jcp.simd_w == 4)
1219 ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o)
1220 : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o));
1221
1222 if (src_d.format_kind() == format_kind::any) {
1223 CHECK(memory_desc_init_by_tag(src_md, src_tag));
1224 jcp.src_tag = src_tag;
1225 } else {
1226 jcp.src_tag = src_d.matches_one_of_tag(src_tag);
1227 }
1228 if (jcp.src_tag != src_tag)
1229 return status::unimplemented;
1230
1231 if (dst_d.format_kind() == format_kind::any) {
1232 CHECK(memory_desc_init_by_tag(dst_md, dst_tag));
1233 jcp.dst_tag = dst_tag;
1234 } else {
1235 jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag);
1236 }
1237 if (jcp.dst_tag != dst_tag)
1238 return status::unimplemented;
1239
1240 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1241 if (jcp.with_bias) {
1242 if (bias_d.format_kind() == format_kind::any)
1243 CHECK(memory_desc_init_by_tag(bias_md, x));
1244 }
1245
1246 if (mayiuse(avx512_common) &&
1247 src_d.data_type() == data_type::f32
1248 && weights_d.data_type() == data_type::f32
1249 && dst_d.data_type() == data_type::f32) {
1250 jcp.ver = ver_fma;
1251 jcp.typesize_in = sizeof(float);
1252 jcp.typesize_out = sizeof(float);
1253 if (mayiuse(avx512_mic_4ops))
1254 jcp.ver = ver_4fma;
1255
1256 if (jcp.is_1stconv) {
1257 // TODO: fix & remove constraints below
1258 bool not_for_4fma
1259 = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad),
1260 nstl::max(jcp.kw, jcp.kh) < 7);
1261 bool is_dilated
1262 = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w);
1263 if (one_of(true, not_for_4fma, is_dilated))
1264 jcp.ver = ver_fma;
1265 if (jcp.ver == ver_4fma) {
1266 wei_tag = with_groups
1267 ? ((jcp.simd_w == 4)
1268 ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o)
1269 : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o))
1270 : ((jcp.simd_w == 4)
1271 ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o)
1272 : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o));
1273 } else {
1274 wei_tag = with_groups
1275 ? ((jcp.simd_w == 4)
1276 ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
1277 : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
1278 : ((jcp.simd_w == 4)
1279 ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o)
1280 : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o));
1281 }
1282 }
1283 } else {
1284 return status::unimplemented;
1285 }
1286
1287 if (weights_d.format_kind() == format_kind::any) {
1288 CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
1289 jcp.wei_tag = wei_tag;
1290 } else {
1291 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
1292 }
1293 if (jcp.wei_tag != wei_tag)
1294 return status::unimplemented;
1295
1296 if (jcp.is_1stconv) {
1297 jcp.ur_w = nstl::min(jcp.ow, regs);
1298 } else {
1299 // avx512_core guard - just to avoid possible regression for other archs
1300 if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
1301 jcp.ur_w = nstl::min(jcp.ow, regs);
1302 } else {
1303 for (int ur_w = regs; ur_w > 0; --ur_w) {
1304 if (jcp.ow % ur_w == 0) {
1305 jcp.ur_w = ur_w;
1306 break;
1307 }
1308 }
1309 }
1310 if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) {
1311 jcp.ur_w = nstl::min(jcp.ow, regs);
1312 }
1313 }
1314 // TODO (Tanya): currently applied to Segnet convolutions only.
1315 // Need to try for other topologies
1316 if (jcp.ow > 150 && jcp.ur_w < regs/2)
1317 jcp.ur_w = regs;
1318
1319 int n_oi = (jcp.ow / jcp.ur_w);
1320 int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w
1321 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
1322 if (jcp.l_pad > 0 && r_pad > 0)
1323 n_oi--;
1324
1325 bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0
1326 && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1));
1327 if (large_code_size) {
1328 const int max_code_size = 24 * 1024;
1329 const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw;
1330 int mult = 1;
1331 if (jcp.l_pad > 0) mult += 1;
1332 if (r_pad > 0) mult += 1;
1333 for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
1334 if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) {
1335 jcp.ur_w = ur_w;
1336 break;
1337 }
1338 }
1339 }
1340
1341 /* Grouped channel offset to support 'non-blocked data' format for
1342 * convolution sizes with '(input_channel / ngroups) < simd' */
1343 jcp.nonblk_group_off
1344 = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ?
1345 jcp.ic :
1346 1;
1347
1348 jcp.nb_ic = jcp.ic / jcp.ic_block;
1349 jcp.nb_oc = jcp.oc / jcp.oc_block;
1350 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1351
1352 auto is_ow_threading_applicable = [=]() {
1353 return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4)
1354 && IMPLICATION(mayiuse(avx512_mic),
1355 jcp.ver == ver_4fma
1356 && IMPLICATION(jcp.mb != 1,
1357 jcp.ih == 1 && jcp.kh == 1)));
1358 };
1359
1360 if (jcp.ver == ver_4fma && !jcp.is_1stconv) {
1361 if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8
1362 && jcp.oh <= 8 && jcp.ow == jcp.oh)
1363 || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) {
1364 if (jcp.nb_oc % 2 == 0) {
1365 jcp.nb_oc_blocking = 2;
1366 jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
1367 }
1368 } else {
1369 for (int i = jcp.nb_oc; i > 0; i--)
1370 if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
1371 jcp.nb_oc_blocking = i;
1372 break;
1373 }
1374 }
1375 if (jcp.ver == ver_4fma && is_ow_threading_applicable()) {
1376 if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow
1377 && jcp.ow != 2 * jcp.ur_w) {
1378 jcp.nb_oc_blocking = 2;
1379 jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
1380 }
1381 }
1382 }
1383
1384 jcp.ow_block = jcp.ow;
1385
1386 auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) {
1387 int nb_ow = div_up(jcp.ow, ow_block);
1388 int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
1389 int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
1390 float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
1391 float thr_eff = disbalance * (float)work_amount
1392 / rnd_up(work_amount, nthreads);
1393 return thr_eff;
1394 };
1395
1396 auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) {
1397 int res_ow_block = jcp.ow;
1398 eff = get_thr_eff(nb_oc_blocking, res_ow_block);
1399 if (!is_ow_threading_applicable())
1400 return res_ow_block;
1401
1402 int L2_part = (get_cache_size(2) * 7 / 8) / typesize;
1403 if (jcp.ver == ver_4fma)
1404 L2_part /= 2;
1405 int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
1406 int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
1407 int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
1408 * jcp.kw * jcp.kh;
1409 int nurw_cache = (L2_part - 2 * size_wei_chunk)
1410 / (2 * size_dst_chunk + 2 * size_src_chunk);
1411 // current design of generate() requires ow_block >= 2 * ur_w
1412 int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
1413
1414 int ow_block_thr = ow_block_cache;
1415 eff = get_thr_eff(nb_oc_blocking, ow_block_thr);
1416
1417 int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
1418 int start_nb_ow = div_up(jcp.ow, ow_block_thr);
1419 for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
1420 int ow_block
1421 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
1422 float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f;
1423 if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
1424 break;
1425 if (div_up(jcp.ow, ow_block) != nb_ow)
1426 continue;
1427 float thr_eff = get_thr_eff(nb_oc_blocking, ow_block);
1428 float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f;
1429 if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
1430 ow_block_thr = ow_block;
1431 eff = thr_eff;
1432 }
1433 eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f;
1434 if (eff > eff_threshold)
1435 break;
1436 }
1437 res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
1438 eff = get_thr_eff(nb_oc_blocking, res_ow_block);
1439 return res_ow_block;
1440 };
1441
1442
1443 if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
1444 int try_nb_oc_blocking = 2;
1445 unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
1446 * jcp.ic_block * jcp.kh * jcp.kd;
1447 unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block
1448 * try_nb_oc_blocking;
1449 unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
1450 * jcp.oc_block * try_nb_oc_blocking * jcp.kd;
1451 unsigned int ker_total_size = ker_inp_size + ker_out_size
1452 + ker_wei_size;
1453
1454 bool embd_bcast_condition = true
1455 && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size)
1456 && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192)
1457 && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
1458
1459 if (jcp.mb == 1) {
1460 unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
1461 * div_up(jcp.iw, jcp.stride_w) * jcp.ic;
1462 unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
1463
1464 // Estimate whether we need to limit the number of threads
1465 // and calculate this number. Includes some heuristic.
1466 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1467 int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh;
1468 int job_size_min = work_amount / nthreads;
1469 int job_size_max = div_up(work_amount, nthreads);
1470 int ch_max = rnd_up(jcp.oh, job_size_max);
1471 int ch_min = (job_size_min == 0)
1472 ? jcp.oh
1473 : rnd_up(jcp.oh, job_size_min);
1474 bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2
1475 && (jcp.oh != 8 || ch_max / jcp.oh > 1);
1476 bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2
1477 && (jcp.oh != 8 || ch_min / jcp.oh > 1);
1478 bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1)
1479 || nthreads > oc_chunks;
1480 if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1
1481 && wei_size / inp_size > 24
1482 && (not_aligned_max || not_aligned_min)
1483 && eligible_case) {
1484 // Try to find nthreads > mkldnn_get_max_threads() / 2 such
1485 // that oc_chunks is a multiple of nthreads, or nthreads is a
1486 // multiple of oc_chunks. Otherwise, keep default value.
1487 // TODO: implement a task-based alternative without throttling.
1488 jcp.aligned_threads = nthreads;
1489 for (int i = nthreads; i > nthreads / 2; i--) {
1490 if (oc_chunks % i == 0 || i % oc_chunks == 0) {
1491 jcp.aligned_threads = i;
1492 break;
1493 }
1494 }
1495 }
1496 }
1497
1498 if (jcp.kw > 3
1499 || (jcp.stride_w == 1 && jcp.stride_h == 1
1500 && embd_bcast_condition)
1501 || ((jcp.stride_w != 1 || jcp.stride_h != 1)
1502 && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
1503 && embd_bcast_condition)))
1504 || (jcp.mb == 1
1505 && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
1506 || (jcp.ow <= 147 && jcp.oc <= 96)))) {
1507 jcp.kernel_kind = embd_bcast;
1508 jcp.ur_w = nstl::min(jcp.ow, regs);
1509 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1510 if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
1511 && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
1512 && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
1513 && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
1514 jcp.nb_oc_blocking = try_nb_oc_blocking;
1515 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1516 }
1517 } else {
1518 jcp.kernel_kind = expl_bcast;
1519 jcp.nb_ic_blocking = 1;
1520 if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) {
1521 float best_thr_eff = 0.f;
1522 int best_nb_oc_blocking = 1;
1523 for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) {
1524 if (jcp.nb_oc % i == 0) {
1525 float thr_eff;
1526 int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
1527 get_ow_block(i, ur_w, thr_eff);
1528 if (thr_eff > 1.05f * best_thr_eff) {
1529 best_nb_oc_blocking = i;
1530 best_thr_eff = thr_eff;
1531 }
1532 }
1533 }
1534 jcp.nb_oc_blocking = best_nb_oc_blocking;
1535 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1536 }
1537 }
1538 }
1539
1540 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1541
1542 args_ok = true
1543 && jcp.l_pad <= jcp.ur_w
1544 && jcp.ic <= src_d.padded_dims()[1]
1545 && jcp.oc <= dst_d.padded_dims()[1]
1546 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
1547 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
1548 if (!args_ok)
1549 return status::unimplemented;
1550
1551 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
1552 + (jcp.kw - 1) * (jcp.dilate_w + 1)
1553 - (jcp.iw + jcp.l_pad - 1));
1554 if (r_pad_no_tail > jcp.ur_w)
1555 return status::unimplemented;
1556
1557 pick_loop_order(jcp);
1558
1559 jcp.nb_ic_L2 = jcp.nb_ic;
1560
1561 float thr_eff;
1562 jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff);
1563 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1564
1565 const int L2_size = get_cache_size(2, true) / sizeof(float);
1566 // Source and output data needs to fit in L2,
1567 // leaving some space for weights and prefetching.
1568 int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
1569 - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
1570 / (jcp.stride_h * jcp.iw + jcp.ow));
1571 jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
1572
1573 if (jcp.ver == ver_4fma) {
1574 if (!is_ow_threading_on(jcp)) {
1575 for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic;
1576 divf++) {
1577 size_t l2_src
1578 = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id;
1579 size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking
1580 * jcp.oh * jcp.od;
1581 size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block
1582 * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd;
1583 if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
1584 if (jcp.kh == 3 && jcp.oh == 7) {
1585 jcp.nb_ic_L2 = 1;
1586 break;
1587 }
1588 temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf
1589 : jcp.nb_ic_L2);
1590 } else {
1591 jcp.nb_ic_L2 = temp_nb;
1592 break;
1593 }
1594 }
1595 } else if (jcp.ic > 64) {
1596 jcp.nb_ic_L2 = 2; /* according to performance data*/
1597 }
1598 }
1599
1600 return status::success;
1601}
1602
1603void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
1604 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1605 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
1606 scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
1607}
1608
1609void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
1610{
1611 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1612 for (int j = 0; j < ur_w; j++) {
1613 Zmm zmm = zmm_out(j, k);
1614 vpxord(zmm, zmm, zmm);
1615 size_t aux_src_offset
1616 = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j)
1617 * jcp.ic_block;
1618 mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
1619 reg_long_offt));
1620 }
1621 }
1622}
1623
1624void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w)
1625{
1626 Label no_update_label;
1627
1628 mov(reg_channel, ptr[param + GET_OFF(channel)]);
1629 cmp(reg_channel, 0);
1630 je(no_update_label, T_NEAR);
1631 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1632 for (int j = 0; j < ur_w; j++) {
1633 Zmm zmm = zmm_out(j, k);
1634 size_t aux_src_offset = (size_t)typesize
1635 * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
1636 vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset,
1637 reg_long_offt));
1638 }
1639 }
1640
1641 L(no_update_label);
1642 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1643 for (int j = 0; j < ur_w; j++) {
1644 Zmm zmm = zmm_out(j, k);
1645 size_t aux_src_offset = (size_t)typesize
1646 * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
1647 vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset,
1648 reg_long_offt), zmm);
1649 mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
1650 reg_long_offt));
1651 }
1652 }
1653}
1654
1655void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
1656 int ur_w, int l_overflow, int r_overflow)
1657{
1658 int ow = jcp.ow;
1659 int kw = jcp.kw;
1660 int ic_block = jcp.ic_block;
1661 int oc_block = jcp.oc_block;
1662 Label kh_label, last_iter_label, loop_end_label, kd_label;
1663 int ker_load_number = 4;
1664 int shift_ker_ptr = typesize * kw * oc_block * ic_block;
1665 int shift_dst_ptr = typesize * ow * oc_block;
1666 int ii_dpref_t0 = get_iw_start(0, l_overflow);
1667 int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow);
1668
1669 bool check_last_kh = (jcp.kh > 3);
1670 auto kernel_offset = [=](int icb, int oc, int ki) {
1671 int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
1672 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
1673 int oc_offset = oc * jcp.oc_block;
1674 return typesize * (blk_offset + oc_offset);
1675 };
1676 auto kernel_loads = [=](int ki, int oc, int kk) {
1677 for (int ii = 0; ii < ker_load_number; ii++) {
1678 int aux_kernel_offset = kernel_offset(kk, oc + ii, ki);
1679 vmovups(zmm_ker(ii),
1680 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1681 }
1682 };
1683 auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
1684 if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
1685 && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) {
1686 int aux_dst_offset = typesize * ((ii_dpref_t0
1687 + jcp.l_pad) * oc_block + jcp.ow * oc_block);
1688 prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1689 ii_dpref_t0++;
1690 }
1691 };
1692
1693 if (one_of(jcp.ndims, 3, 4)) {
1694 mov(aux_reg_dst, reg_dst);
1695 mov(aux_reg_ker, reg_ker);
1696 mov(aux_reg_dst_prf, reg_dst_prf);
1697 mov(aux_reg_ker_prf, reg_ker_prf);
1698 }
1699
1700 if (jcp.ndims == 5) {
1701 push(reg_src_prf);
1702 push(reg_src);
1703
1704 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1705 mov(aux_reg_dst_d, reg_dst);
1706 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
1707 mov(aux_reg_dst_d_prf, reg_dst_prf);
1708 mov(aux_reg_ker_d_prf, reg_ker_prf);
1709
1710 L(kd_label);
1711 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
1712 } else {
1713 mov(reg_kj, reg_kh);
1714 }
1715
1716 if (jcp.ndims == 5) {
1717 mov(aux_reg_dst, aux_reg_dst_d);
1718 mov(aux_reg_ker, aux_reg_ker_d);
1719 mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
1720 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
1721 }
1722
1723 align(16);
1724 L(kh_label);
1725 if (check_last_kh) {
1726 for (int ki = 0; ki < kw; ki++)
1727 for (int oc = 0; oc < oc_block; oc += 4)
1728 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
1729 bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1
1730 && ki == kw - 1 && (oc + 4) == oc_block);
1731
1732 if (last_kernel_loads) {
1733 cmp(reg_kj, 1);
1734 je(last_iter_label, T_NEAR);
1735 }
1736
1737 kernel_loads(ki, oc, kk);
1738 for (int ii = get_iw_start(ki, l_overflow),
1739 prf_count_t0 = 0, prf_count_t1 = 0;
1740 ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
1741 int aux_dst_offset = typesize
1742 * ((ii + jcp.l_pad - ki) * oc_block + oc);
1743 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
1744 EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1745
1746 if (ii % 2) {
1747 if (prf_count_t0 < 4) {
1748 int aux_kernel_prf;
1749 if (last_kernel_loads)
1750 aux_kernel_prf= kernel_offset(0, prf_count_t0
1751 + oc + 4 - oc_block, 0) + typesize * kw
1752 * oc_block * ic_block;
1753 else
1754 aux_kernel_prf = kernel_offset(kk, oc + 4
1755 + prf_count_t0, ki);
1756 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
1757 aux_kernel_prf));
1758 prf_count_t0++;
1759 } else if (prf_count_t1 < 4) {
1760 mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
1761 kernel_offset(kk, oc + prf_count_t1, ki)));
1762 prf_count_t1++;
1763 }
1764 } else
1765 prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1);
1766 }
1767 if (last_kernel_loads) {
1768 jmp(loop_end_label, T_NEAR);
1769
1770 L(last_iter_label);
1771
1772 kernel_loads(ki, oc, kk);
1773 for (int ii = get_iw_start(ki, l_overflow),
1774 prf_count_t0 = 0, prf_count_t1 = 0;
1775 ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
1776 int aux_dst_offset = typesize
1777 * ((ii + jcp.l_pad - ki) * oc_block + oc);
1778 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
1779 EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1780 if (ii % 2) {
1781 if (prf_count_t0 < 4) {
1782 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf,
1783 kernel_offset(0, prf_count_t0, 0)));
1784 prf_count_t0++;
1785 } else if (prf_count_t1 < 4) {
1786 mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
1787 kernel_offset(kk, oc + prf_count_t1, ki)));
1788 prf_count_t1++;
1789 }
1790 }
1791 }
1792 L(loop_end_label);
1793 }
1794 }
1795 } else {
1796 for (int ki = 0; ki < kw; ki++)
1797 for (int oc = 0; oc < oc_block; oc += 4)
1798 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
1799 kernel_loads(ki, oc, kk);
1800
1801 for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0;
1802 ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
1803 int aux_dst_offset = typesize
1804 * ((ii + jcp.l_pad - ki) * oc_block + oc);
1805 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
1806 EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1807 if ((ii % 2) && (prf_count_t1 < 4)) {
1808 mic_prefetcht1(EVEX_compress_addr(
1809 aux_reg_ker_prf, kernel_offset(kk,
1810 oc + prf_count_t1, ki)));
1811 prf_count_t1++;
1812 }
1813 if ( ki == 1 && oc == 0 && kk == 0)
1814 mic_prefetcht1(EVEX_compress_addr(
1815 aux_reg_dst_prf, aux_dst_offset));
1816 }
1817 }
1818 }
1819
1820 add(aux_reg_ker, shift_ker_ptr);
1821 sub(aux_reg_dst, shift_dst_ptr);
1822 add(aux_reg_ker_prf, shift_ker_ptr);
1823 sub(aux_reg_dst_prf, shift_dst_ptr);
1824
1825 dec(reg_kj);
1826 cmp(reg_kj, 0);
1827 jg(kh_label, T_NEAR);
1828
1829 if (jcp.ndims == 5) {
1830 sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block);
1831 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
1832 sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block);
1833 add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block);
1834
1835 dec(reg_ki);
1836 cmp(reg_ki, 0);
1837 jg(kd_label, T_NEAR);
1838
1839 pop(reg_src);
1840 pop(reg_src_prf);
1841 }
1842}
1843
1844void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
1845 int ur_w, int l_overflow, int r_overflow)
1846{
1847 Label kh_label, kd_label;
1848 int kw = jcp.kw;
1849 int ow = jcp.ow;
1850
1851 int ic_block = jcp.ic_block;
1852 int oc_block = jcp.oc_block;
1853 int l_pad = jcp.l_pad;
1854 int dilate_w = jcp.dilate_w + 1;
1855 int stride_w = jcp.stride_w;
1856 int stride_h = jcp.stride_h;
1857
1858 int ker_pipeline_depth = 4;
1859 assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
1860 assert(oc_block >= ker_pipeline_depth);
1861
1862 int num_ker_loads = oc_block * kw;
1863 int num_inp_prfs = ur_w * nstl::min(kw, stride_w)
1864 + nstl::max(0, kw - stride_w);
1865 int num_prfs = num_ker_loads + num_inp_prfs;
1866 int num_fmas = num_ker_loads * ur_w / stride_w;
1867 int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs);
1868 int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
1869
1870 if (one_of(jcp.ndims, 3, 4)) {
1871 mov(aux_reg_dst, reg_dst);
1872 mov(aux_reg_ker, reg_ker);
1873
1874 mov(aux_reg_dst_prf, reg_dst_prf);
1875 mov(aux_reg_ker_prf, reg_ker_prf);
1876 }
1877
1878 if (jcp.ndims == 5) {
1879 push(reg_src_prf);
1880 push(reg_src);
1881
1882 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1883 mov(aux_reg_dst_d, reg_dst);
1884 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
1885 mov(aux_reg_dst_d_prf, reg_dst_prf);
1886 mov(aux_reg_ker_d_prf, reg_ker_prf);
1887
1888 L(kd_label);
1889 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
1890 } else {
1891 mov(reg_kj, reg_kh);
1892 }
1893
1894 if (jcp.ndims == 5) {
1895 mov(aux_reg_dst, aux_reg_dst_d);
1896 mov(aux_reg_ker, aux_reg_ker_d);
1897 mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
1898 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
1899 }
1900
1901 L(kh_label); {
1902 int step = 0;
1903 int ker_prfs = 0;
1904 for (int ki = 0; ki < kw; ki++) {
1905 for (int oc = 0; oc < oc_block; oc++) {
1906 if (step == 0) {
1907 for (int i = 0; i < ker_pipeline_depth; i++) {
1908 int aux_kernel_offset = typesize * ((oc + i) * oc_block
1909 + ki * ic_block * oc_block);
1910 vmovups(zmm_ker(i), EVEX_compress_addr(
1911 aux_reg_ker, aux_kernel_offset));
1912 }
1913 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
1914 int load_offset = ker_pipeline_depth - 1;
1915 int ker_load_reg_idx
1916 = (step + load_offset) % ker_pipeline_depth;
1917 int aux_kernel_offset = typesize * ((oc + load_offset)
1918 * oc_block + ki * ic_block * oc_block);
1919 vmovups(zmm_ker(ker_load_reg_idx),
1920 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1921 }
1922
1923 bool ker_prf_inserted = false;
1924 auto zmm_kernel = zmm_ker(step % ker_pipeline_depth);
1925
1926 int jj_start = get_iw_start(ki, l_overflow);
1927 int jj_end = get_iw_end(ur_w, ki, r_overflow);
1928 assert(stride_w != 1
1929 || jj_start == nstl::max(0,
1930 l_overflow - (kw - 1 - ki) * dilate_w));
1931 assert(stride_w != 1
1932 || jj_end == ur_w - nstl::max(0,
1933 r_overflow - ki * dilate_w));
1934
1935 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1936 assert((jj + l_pad - ki * dilate_w) % stride_w == 0);
1937 int aux_dst_offset = typesize *
1938 (((jj + l_pad - ki * dilate_w)
1939 / stride_w) * jcp.oc_block + oc);
1940 vfmadd231ps(zmm_out(jj, 0), zmm_kernel,
1941 EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true));
1942
1943 int fma_idx = (step * ur_w + jj) / stride_w;
1944 int prf_slot_idx = fma_idx / prf_inst_spacing;
1945 if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
1946 if (!ker_prf_inserted && ker_prfs < num_ker_loads) {
1947 int ker_prf_offset = typesize
1948 * ker_prfs * jcp.oc_block;
1949 mic_prefetcht1(EVEX_compress_addr(
1950 aux_reg_ker_prf, ker_prf_offset));
1951 ker_prf_inserted = true;
1952 ker_prfs++;
1953 } else {
1954 int inp_prf_idx = prf_slot_idx - ker_prfs;
1955 if (inp_prf_idx < num_inp_prfs) {
1956 int inp_prf_offset
1957 = ic_block * typesize
1958 * ((inp_prf_idx / kw) * kw
1959 + (inp_prf_idx % kw));
1960 mic_prefetcht0(EVEX_compress_addr(
1961 aux_reg_dst_prf, inp_prf_offset));
1962 }
1963 }
1964 }
1965 }
1966 step++;
1967 }
1968 }
1969
1970 add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block);
1971 sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block);
1972 add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block);
1973 sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block);
1974
1975 dec(reg_kj);
1976 cmp(reg_kj, 0);
1977 jg(kh_label, T_NEAR);
1978 }
1979 if (jcp.ndims == 5) {
1980 sub(aux_reg_dst_d,
1981 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
1982 add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh
1983 * oc_block * ic_block);
1984 sub(aux_reg_dst_d_prf,
1985 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
1986 add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh
1987 * oc_block * ic_block);
1988
1989 dec(reg_ki);
1990 cmp(reg_ki, 0);
1991 jg(kd_label, T_NEAR);
1992 }
1993
1994 if (jcp.ndims == 5)
1995 {
1996 pop(reg_src);
1997 pop(reg_src_prf);
1998 }
1999}
2000
2001void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
2002 int ur_w, int l_overflow, int r_overflow)
2003{
2004 int kw = jcp.kw;
2005 int ow = jcp.ow;
2006 int dilate_w = jcp.dilate_w + 1;
2007 int stride_w = jcp.stride_w;
2008 int ic_block = jcp.ic_block;
2009 int oc_block = jcp.oc_block;
2010 int nb_ic_block = jcp.nb_ic_blocking;
2011 Label kh_label, kd_label;
2012
2013 int shift_ker_ptr = typesize * kw * oc_block * ic_block;
2014 int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block;
2015
2016 auto output_offset = [=](int oi, int oc, int ki) {
2017 return typesize *
2018 (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc);
2019 };
2020 auto kernel_offset = [=](int icb, int oc, int ki) {
2021 int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
2022 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
2023 int oc_offset = oc * jcp.oc_block;
2024 return typesize * (blk_offset + oc_offset);
2025 };
2026
2027 if (one_of(jcp.ndims, 3, 4)) {
2028 mov(aux_reg_dst, reg_dst);
2029 mov(aux_reg_ker, reg_ker);
2030 }
2031
2032 if (jcp.ndims == 5) {
2033 push(reg_src_prf);
2034 push(reg_src);
2035
2036 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
2037 mov(aux_reg_dst_d, reg_dst);
2038 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
2039
2040 L(kd_label);
2041 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2042 } else {
2043 mov(reg_kj, reg_kh);
2044 }
2045
2046 if (jcp.ndims == 5) {
2047 mov(aux_reg_dst, aux_reg_dst_d);
2048 mov(aux_reg_ker, aux_reg_ker_d);
2049 }
2050
2051 L(kh_label);
2052 {
2053 for (int ki = 0; ki < kw; ki++) {
2054 int jj_start = get_iw_start(ki, l_overflow);
2055 int jj_end = get_iw_end(ur_w, ki, r_overflow);
2056 for (int oc = 0; oc < oc_block; oc++) {
2057 if (jcp.kernel_kind == expl_bcast) {
2058 for (int jj = jj_start; jj < jj_end; jj++) {
2059 int aux_output_offset = output_offset(jj, oc, ki);
2060 vbroadcastss(zmm_inp(jj, nb_ic_block),
2061 ptr[aux_reg_dst + aux_output_offset]);
2062 }
2063 }
2064 for (int ii = 0; ii < nb_ic_block; ii++) {
2065 int aux_kernel_offset = kernel_offset(ii, oc, ki);
2066 if (jj_end - jj_start > 0)
2067 vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
2068 aux_kernel_offset));
2069 for (int jj = jj_start; jj < jj_end; jj += stride_w)
2070 if (jcp.kernel_kind == expl_bcast)
2071 vfmadd231ps(zmm_out(jj, ii),
2072 zmm_inp(jj, nb_ic_block), zmm_wei);
2073 else
2074 vfmadd231ps(zmm_out(jj, ii), zmm_wei,
2075 EVEX_compress_addr(aux_reg_dst,
2076 output_offset(jj, oc, ki), true));
2077 }
2078 }
2079 }
2080 add(aux_reg_ker, shift_ker_ptr);
2081 sub(aux_reg_dst, shift_dst_ptr);
2082 dec(reg_kj);
2083 cmp(reg_kj, 0);
2084 jg(kh_label, T_NEAR);
2085 }
2086
2087 if (jcp.ndims == 5) {
2088 sub(aux_reg_dst_d,
2089 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
2090 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
2091
2092 dec(reg_ki);
2093 cmp(reg_ki, 0);
2094 jg(kd_label, T_NEAR);
2095
2096 pop(reg_src);
2097 pop(reg_src_prf);
2098 }
2099}
2100
2101inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
2102 int ur_w, int l_overflow, int r_overflow)
2103{
2104 if (jcp.ndims == 5) push(reg_oi);
2105
2106 prepare_output(ur_w);
2107
2108 Label skip_compute_loop;
2109 if (jcp.ndims == 5) {
2110 mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
2111 cmp(reg_kj, 0);
2112 je(skip_compute_loop, T_NEAR);
2113 }
2114 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2115 cmp(reg_kj, 0);
2116 je(skip_compute_loop, T_NEAR);
2117
2118 if (jcp.ver == ver_4fma)
2119 compute_loop_4fma(ur_w, l_overflow, r_overflow);
2120 else if (jcp.ver == ver_fma)
2121 if (mayiuse(avx512_mic))
2122 compute_loop_fma(ur_w, l_overflow, r_overflow);
2123 else
2124 if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1)
2125 compute_loop_fma(ur_w, l_overflow, r_overflow);
2126 else
2127 compute_loop_fma_core(ur_w, l_overflow, r_overflow);
2128 else
2129 assert("!unknown convolution version");
2130
2131 L(skip_compute_loop);
2132 store_output(ur_w);
2133 if (jcp.ndims == 5) pop(reg_oi);
2134}
2135
2136void jit_avx512_common_conv_bwd_data_kernel_f32::generate()
2137{
2138 int iw = jcp.iw;
2139 int kw = jcp.kw;
2140 int ur_w = jcp.ur_w;
2141 int ic_block = jcp.ic_block;
2142 int oc_block = jcp.oc_block;
2143 int ur_w_tail = jcp.ur_w_tail;
2144 int dilate_w = jcp.dilate_w + 1;
2145 int stride_w = jcp.stride_w;
2146
2147 int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block;
2148 int src_shift = jcp.typesize_out * ur_w * oc_block;
2149
2150 preamble();
2151
2152 mov(reg_src, ptr[param + GET_OFF(src)]);
2153 mov(reg_dst, ptr[param + GET_OFF(dst)]);
2154 mov(reg_ker, ptr[param + GET_OFF(filt)]);
2155
2156 mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
2157 mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]);
2158 mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]);
2159 mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]);
2160
2161 int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
2162 int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
2163 - nstl::max(0, jcp.r_pad)) / stride_w);
2164 int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
2165 - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
2166
2167 int n_oi = iw / ur_w;
2168 if (r_overflow1 > 0) n_oi--;
2169
2170 if (ur_w == iw) {
2171 compute_loop(ur_w, l_overflow, r_overflow);
2172 } else if (n_oi == 0) {
2173 compute_loop(ur_w, l_overflow, r_overflow1);
2174 add(reg_src, src_shift);
2175 add(reg_dst, dst_shift);
2176 add(reg_src_prf, src_shift);
2177 add(reg_dst_prf, dst_shift);
2178 if (ur_w_tail != 0)
2179 compute_loop(ur_w_tail, 0, r_overflow);
2180 } else {
2181 xor_(reg_oi, reg_oi);
2182 if (l_overflow > 0) {
2183 compute_loop(ur_w, l_overflow, 0);
2184 add(reg_src, src_shift);
2185 add(reg_dst, dst_shift);
2186 add(reg_src_prf, src_shift);
2187 add(reg_dst_prf, dst_shift);
2188
2189 inc(reg_oi);
2190 }
2191 if ((l_overflow <= 0 && n_oi > 0)
2192 || (l_overflow > 0 && n_oi > 1)) {
2193 Label ow_loop_label;
2194 L(ow_loop_label); {
2195 compute_loop(ur_w, 0, 0);
2196 add(reg_src, src_shift);
2197 add(reg_dst, dst_shift);
2198 add(reg_src_prf, src_shift);
2199 add(reg_dst_prf, dst_shift);
2200
2201 inc(reg_oi);
2202 cmp(reg_oi, n_oi);
2203 jl(ow_loop_label, T_NEAR);
2204 }
2205 }
2206 if (r_overflow1 > 0) {
2207 compute_loop(ur_w, 0, r_overflow1);
2208 add(reg_src, src_shift);
2209 add(reg_dst, dst_shift);
2210 add(reg_src_prf, src_shift);
2211 add(reg_dst_prf, dst_shift);
2212 }
2213 if (ur_w_tail != 0) {
2214 compute_loop(ur_w_tail, 0, r_overflow);
2215 }
2216 }
2217
2218 postamble();
2219}
2220
2221status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
2222 jit_conv_conf_t &jcp,
2223 const convolution_desc_t &cd,
2224 const memory_desc_wrapper &diff_src_d,
2225 const memory_desc_wrapper &weights_d,
2226 const memory_desc_wrapper &diff_dst_d)
2227{
2228 if (!mayiuse(avx512_common)) return status::unimplemented;
2229
2230 jcp = zero<decltype(jcp)>();
2231
2232 jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
2233 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
2234 int ndims = diff_src_d.ndims();
2235
2236 jcp.ndims = ndims;
2237 jcp.prop_kind = cd.prop_kind;
2238
2239 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2240 jcp.mb = diff_src_d.dims()[0];
2241
2242 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2243 jcp.oc_without_padding = jcp.oc;
2244 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
2245
2246 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
2247 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
2248 jcp.iw = diff_src_d.dims()[ndims-1];
2249 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
2250 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
2251 jcp.ow = diff_dst_d.dims()[ndims-1];
2252
2253 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
2254 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
2255 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2256
2257 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
2258 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
2259 jcp.l_pad = cd.padding[0][ndims-3];
2260
2261 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
2262 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
2263 jcp.stride_w = cd.strides[ndims-3];
2264
2265 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
2266 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
2267 jcp.dilate_w = cd.dilates[ndims-3];
2268 if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
2269 || (jcp.dilate_d != 0 && jcp.stride_d != 1)
2270 || (jcp.dilate_h != 0 && jcp.stride_h != 1))
2271 return status::unimplemented;
2272
2273 jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
2274 - (jcp.iw + jcp.l_pad - 1);
2275 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
2276 - (jcp.ih + jcp.t_pad - 1);
2277 jcp.back_pad = (jcp.od - 1) * jcp.stride_d
2278 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
2279
2280 jcp.aligned_threads = 0;
2281
2282 jcp.is_1stconv = false;
2283
2284 jcp.oc_block = jcp.simd_w;
2285 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
2286
2287 bool ok_to_pad_channels = true
2288 && jcp.ngroups == 1
2289 && diff_src_d.data_type() == data_type::f32;
2290
2291 if (ok_to_pad_channels) {
2292 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2293 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2294 }
2295
2296 auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
2297 auto wei_tag = with_groups
2298 ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
2299 : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i);
2300 jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag);
2301 jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
2302
2303 bool args_ok = true
2304 && jcp.oc % jcp.oc_block == 0
2305 && jcp.ic % jcp.ic_block == 0
2306 && jcp.src_tag == dat_tag
2307 && jcp.dst_tag == dat_tag;
2308 if (!args_ok)
2309 return status::unimplemented;
2310
2311 jcp.nb_ic = jcp.ic / jcp.ic_block;
2312 jcp.nb_oc = jcp.oc / jcp.oc_block;
2313
2314 jcp.ur_w = jcp.stride_w;
2315
2316 int regs = 28;
2317 if (jcp.iw <= regs)
2318 jcp.ur_w = jcp.iw;
2319 else {
2320 for (int ur_w = regs; ur_w > 0; --ur_w)
2321 if (ur_w % jcp.stride_w == 0) {
2322 jcp.ur_w = ur_w;
2323 break;
2324 }
2325 }
2326 int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2327 - jcp.l_pad) / jcp.stride_w);
2328 int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2329 - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
2330 int n_oi = jcp.iw / jcp.ur_w;
2331 if (r_overflow1 > 0) n_oi--;
2332
2333 if (mayiuse(avx512_common)
2334 && diff_dst_d.data_type() == data_type::f32
2335 && weights_d.data_type() == data_type::f32
2336 && diff_src_d.data_type() == data_type::f32) {
2337 jcp.ver = ver_fma;
2338 jcp.typesize_in = sizeof(float);
2339 jcp.typesize_out = sizeof(float);
2340 if (mayiuse(avx512_mic_4ops)
2341 && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) {
2342 jcp.ver = ver_4fma;
2343 }
2344 } else {
2345 return status::unimplemented;
2346 }
2347
2348 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
2349 if (jcp.wei_tag != wei_tag)
2350 return status::unimplemented;
2351
2352 if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
2353 && jcp.ver != ver_fma)
2354 return status::unimplemented;
2355
2356 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2357 if (jcp.ver == ver_4fma) {
2358 if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) {
2359 jcp.nb_ic_blocking = 2;
2360 } else {
2361 for (int i = jcp.nb_ic; i > 0; i--)
2362 if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) {
2363 jcp.nb_ic_blocking = i;
2364 break;
2365 }
2366 }
2367 }
2368
2369 jcp.loop_order = loop_gnc;
2370
2371 bool large_code_size = (jcp.ur_w != jcp.ow)
2372 && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1))
2373 && (r_overflow1 > 0) && (l_overflow > 0);
2374 if (large_code_size) {
2375 const int max_code_size = 24 * 1024;
2376 const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw;
2377 int mult = 1;
2378 if (l_overflow > 0) mult += 1;
2379 if (r_overflow1 > 0) mult += 1;
2380 for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
2381 if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2
2382 < max_code_size) {
2383 if (ur_w % jcp.stride_w == 0) {
2384 jcp.ur_w = ur_w;
2385 break;
2386 }
2387 }
2388 }
2389 }
2390
2391 if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
2392 int try_nb_ic_blocking = 2;
2393 unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block
2394 * try_nb_ic_blocking * jcp.kh;
2395 unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block;
2396 unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
2397 * jcp.oc_block * try_nb_ic_blocking;
2398 unsigned int ker_total_size = ker_inp_size + ker_out_size
2399 + ker_wei_size;
2400 if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8)
2401 || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13))
2402 || ker_total_size > L1_cache_size )))
2403 || jcp.stride_h > 1 || jcp.stride_d > 1) {
2404 jcp.kernel_kind = embd_bcast;
2405 jcp.ur_w = nstl::min(jcp.iw, regs);
2406 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2407 if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size
2408 && jcp.ow > 8)) && jcp.stride_h == 1)
2409 if (jcp.nb_ic % try_nb_ic_blocking == 0) {
2410 jcp.nb_ic_blocking = try_nb_ic_blocking;
2411 jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2412 if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2413 }
2414 } else {
2415 jcp.kernel_kind = expl_bcast;
2416 jcp.nb_oc_blocking = 1;
2417 jcp.nb_ic_blocking = 4;
2418 if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
2419 if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
2420 for (int i = jcp.nb_ic_blocking; i > 0; i--)
2421 if (jcp.nb_ic % i == 0) {
2422 jcp.nb_ic_blocking = i;
2423 break;
2424 }
2425 jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2426 if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2427 }
2428 }
2429 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
2430
2431 if (l_overflow * jcp.stride_w > jcp.ur_w)
2432 return status::unimplemented;
2433 int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2434 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
2435 if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
2436 return status::unimplemented;
2437 if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
2438 return status::unimplemented;
2439
2440 pick_loop_order(jcp);
2441
2442 jcp.nb_oc_L2 = jcp.nb_oc;
2443 if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) {
2444 for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc;
2445 divf++) {
2446 size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih
2447 * jcp.id;
2448 size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od;
2449 size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh
2450 * jcp.kd * jcp.nb_ic_blocking * temp_nb;
2451 if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
2452 if (jcp.kh == 3 && jcp.ih == 7) {
2453 jcp.nb_oc_L2 = 1;
2454 break;
2455 }
2456 temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf
2457 : jcp.nb_oc_L2);
2458 } else {
2459 jcp.nb_oc_L2 = temp_nb;
2460 break;
2461 }
2462 }
2463 }
2464
2465 args_ok = true
2466 && jcp.ic <= diff_src_d.padded_dims()[1]
2467 && jcp.oc <= diff_dst_d.padded_dims()[1]
2468 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
2469 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
2470 if (!args_ok) return status::unimplemented;
2471
2472 return status::success;
2473}
2474
2475void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
2476 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
2477 UNUSED(scratchpad);
2478 UNUSED(jcp);
2479}
2480
2481const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
2482
2483void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
2484{
2485 Label kd_comeback_label;
2486
2487 /* 'depth' loop count bound by 'kd_work_size' */
2488 mov(kj, reg_kd_count);
2489 L(kd_comeback_label); {
2490 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
2491 int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
2492 sub(reg_input,
2493 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult);
2494 sub(reg_kernel,
2495 jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block);
2496 dec(kj);
2497 cmp(kj, 0);
2498 jg(kd_comeback_label, T_NEAR);
2499 }
2500}
2501
2502void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
2503{
2504 Label kh_comeback_label, kd_comeback_label;
2505 mov(kj, reg_kh);
2506 L(kh_comeback_label); {
2507 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
2508 int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
2509 sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
2510 sub(reg_kernel,
2511 jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
2512 dec(kj);
2513 cmp(kj, 0);
2514 jg(kh_comeback_label, T_NEAR);
2515 }
2516}
2517
2518void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma(
2519 int ur_w, int pad_l, int pad_r,
2520 int ic_block_step, int input_offset, int kernel_offset,
2521 int output_offset, bool input_wraparound)
2522{
2523
2524 int kw = jcp.kw;
2525 int ic_block = jcp.ic_block;
2526 int oc_block = jcp.oc_block;
2527 for (int i_kw = 0; i_kw < kw; i_kw++)
2528 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2529 vmovups(Zmm(i_kw * ic_block_step + i_ic),
2530 EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block
2531 + i_ic) * jcp.oc_block + kernel_offset));
2532
2533 for (int i_ur = 0; i_ur < ur_w; i_ur++) {
2534 if (i_ur == 0) {
2535 vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4),
2536 EVEX_compress_addr(reg_output, typesize * (i_ur + 0)
2537 * oc_block + output_offset));
2538 if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4),
2539 EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block
2540 + output_offset));
2541 if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4),
2542 EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block
2543 + output_offset));
2544 if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
2545 EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
2546 + output_offset));
2547 } else if (i_ur + 3 < ur_w)
2548 vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
2549 EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
2550 + output_offset));
2551
2552 for (int i_kw = 0; i_kw < kw; i_kw++) {
2553 int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1);
2554 if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w +
2555 (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue;
2556 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2557 const size_t i_offset = (size_t)input_offset
2558 + (size_t)typesize * (jcp.ver == ver_4fma
2559 ? (i_iw - pad_l + i_ic * jcp.tr_iw)
2560 : (jcp.is_1stconv
2561 ? (i_iw - pad_l) + (size_t)i_ic
2562 * ((size_t)jcp.ih*jcp.iw*jcp.id)
2563 : (i_iw - pad_l) * ic_block + i_ic));
2564 vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic),
2565 Zmm(kw * ic_block_step + i_ur % 4),
2566 EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt,
2567 true));
2568 }
2569 }
2570 }
2571
2572 for (int i_kw = 0; i_kw < kw; i_kw++)
2573 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2574 vmovups(EVEX_compress_addr(reg_kernel, typesize
2575 * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset),
2576 Zmm(i_kw * ic_block_step + i_ic));
2577}
2578
2579void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma(
2580 int ur_w, int pad_l, int pad_r,
2581 int ic_block_step, int input_offset, int kernel_offset,
2582 int output_offset, bool input_wraparound)
2583{
2584 // TODO: add prefetches to fma version as well
2585
2586 assert(jcp.ver == ver_4fma);
2587
2588 int kw = jcp.kw;
2589 int ic_block = jcp.ic_block;
2590 int oc_block = jcp.oc_block;
2591
2592 auto zmm_ker = [=](int i_kw, int i_ic) {
2593 return Zmm(i_kw * ic_block_step + i_ic);
2594 };
2595
2596 auto ker_addr = [=](int i_kw, int i_ic) {
2597 size_t local_offset
2598 = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
2599 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
2600 };
2601
2602 auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) {
2603 int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1);
2604 int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
2605 return EVEX_compress_addr(reg_input,
2606 local_offset + input_offset + extra_offset);
2607 };
2608
2609 auto zmm_out = [=](int i_iw) {
2610 // TODO: move reg calc to global member funcs
2611 const int out_zmm_base_idx = 28;
2612 return Zmm(out_zmm_base_idx + i_iw % 4);
2613 };
2614
2615 auto out_addr = [=](int i_ur) {
2616 return EVEX_compress_addr(reg_output,
2617 jcp.typesize_in * i_ur * oc_block + output_offset);
2618 };
2619
2620 auto pf_callback = [=](int i_ur, int i_kw, int i_ic) {
2621 assert(i_ur % 4 == 0);
2622 if (i_ur == 0)
2623 prefetcht1(ker_addr(i_kw, i_ic));
2624 if (i_ur + 4 >= ur_w)
2625 prefetcht0(ker_addr(i_kw, i_ic));
2626
2627 const ptrdiff_t next_input_block_offset
2628 = jcp.typesize_in * ic_block_step * jcp.tr_iw;
2629 if (i_ur % 16 == 4 && i_kw == 0) {
2630 if (i_ur + 16 < ur_w)
2631 prefetcht0(inp_addr(i_ur + 16, i_ic));
2632 else
2633 prefetcht0(inp_addr(0, i_ic, next_input_block_offset));
2634 }
2635 if (i_ur % 16 == 4 && i_kw == 1) {
2636 if (input_wraparound)
2637 prefetcht1(inp_addr(i_ur, i_ic, -input_offset));
2638 else
2639 prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset));
2640 }
2641 };
2642
2643 for (int i_kw = 0; i_kw < kw; i_kw++)
2644 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2645 auto zmm = zmm_ker(i_kw, i_ic);
2646 vpxord(zmm, zmm, zmm);
2647 }
2648
2649 for (int i_ur = 0; i_ur < ur_w; i_ur += 4) {
2650
2651 for (int i = 0; i < 4; i++) {
2652 auto zmm = zmm_out(i_ur + i);
2653 if (i_ur + i < ur_w)
2654 vmovups(zmm, out_addr(i_ur + i));
2655 else
2656 vpxord(zmm, zmm, zmm);
2657 prefetcht0(out_addr(i_ur + i + 4));
2658 }
2659
2660 for (int i_kw = 0; i_kw < kw; i_kw++)
2661 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2662 int i_iw = i_ur + i_kw;
2663 v4fmaddps(zmm_ker(i_kw, i_ic),
2664 zmm_out(i_ur), inp_addr(i_iw, i_ic));
2665 pf_callback(i_ur, i_kw, i_ic);
2666 }
2667 }
2668
2669 for (int i_kw = 0; i_kw < kw; i_kw++)
2670 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2671 auto addr = ker_addr(i_kw, i_ic);
2672 auto zmm = zmm_ker(i_kw, i_ic);
2673 vaddps(zmm, zmm, addr);
2674 vmovups(addr, zmm);
2675 }
2676}
2677
2678void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step(
2679 int ur_w, int pad_l, int pad_r,
2680 int ic_block_step, int input_offset, int kernel_offset,
2681 int output_offset, bool input_wraparound)
2682{
2683 if (jcp.ver == ver_4fma)
2684 compute_ic_block_step_4fma(ur_w, pad_l, pad_r,
2685 ic_block_step, input_offset, kernel_offset, output_offset,
2686 input_wraparound);
2687 else if (jcp.ver == ver_fma)
2688 compute_ic_block_step_fma(ur_w, pad_l, pad_r,
2689 ic_block_step, input_offset, kernel_offset, output_offset,
2690 input_wraparound);
2691 else
2692 assert(!"unknown convolution version");
2693}
2694
2695void jit_avx512_common_conv_bwd_weights_kernel_f32
2696 ::compute_oh_step_unroll_ow_icblock(
2697 int ic_block_step, int max_ur_w)
2698{
2699 UNUSED(max_ur_w);
2700
2701 Label kh_label, kd_label;
2702
2703 int ic_block = jcp.ic_block;
2704 int oc_block = jcp.oc_block;
2705 int inp_mul = !jcp.is_1stconv ? ic_block : 1;
2706 int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
2707 int ow = jcp.ow;
2708
2709 int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
2710 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
2711 int l_pad = jcp.l_pad;
2712
2713 if (jcp.ndims == 5) {
2714 L(kd_label);
2715 mov(reg_input, aux_reg_input);
2716 mov(reg_kernel, aux_reg_kernel);
2717 }
2718
2719 mov(kj, reg_kh);
2720 L(kh_label);
2721 {
2722 for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
2723 const int input_offset = jcp.typesize_in
2724 * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic);
2725 compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
2726 input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
2727 i_b_ic + ic_block_step >= jcp.ic_block);
2728 }
2729 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
2730 add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
2731 dec(kj);
2732 cmp(kj, 0);
2733 jg(kh_label, T_NEAR);
2734 }
2735
2736 if (jcp.ndims == 5) {
2737 add(aux_reg_input,
2738 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul);
2739 add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
2740 * oc_block);
2741 dec(ki);
2742 cmp(ki, 0);
2743 jg(kd_label, T_NEAR);
2744 }
2745}
2746
2747void jit_avx512_common_conv_bwd_weights_kernel_f32
2748 ::compute_oh_step_unroll_ow(
2749 int ic_block_step, int max_ur_w)
2750{
2751 Label kh_label, ic_block_label, kd_label;
2752
2753 UNUSED(max_ur_w);
2754
2755 int ic_block = jcp.ic_block;
2756 int oc_block = jcp.oc_block;
2757
2758 int ow = jcp.ow;
2759
2760 int r_pad = nstl::max(0,
2761 (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
2762 - (jcp.iw + jcp.l_pad - 1));
2763 int l_pad = jcp.l_pad;
2764
2765 if (jcp.ndims == 5) {
2766 L(kd_label);
2767 mov(reg_input, aux_reg_input);
2768 mov(reg_kernel, aux_reg_kernel);
2769 }
2770
2771 mov(kj, reg_kh);
2772 L(kh_label);
2773 {
2774 xor_(b_ic, b_ic);
2775 L(ic_block_label); {
2776 compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
2777 0, 0, 0);
2778 size_t inp_icblk_stride = jcp.is_1stconv
2779 ? (size_t)jcp.ih * jcp.iw * jcp.id
2780 : (jcp.ver == ver_4fma ? jcp.tr_iw : 1);
2781 size_t input_offset
2782 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
2783 safe_add(reg_input, input_offset, reg_long_offt);
2784 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
2785 add(b_ic, ic_block_step);
2786 cmp(b_ic, jcp.ic_block);
2787 jl(ic_block_label, T_NEAR);
2788 }
2789
2790 if (jcp.is_1stconv) {
2791 size_t input_offset
2792 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
2793 safe_sub(reg_input, input_offset, reg_long_offt);
2794 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
2795 } else if (jcp.ver != ver_4fma) {
2796 add(reg_input, jcp.typesize_in
2797 * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
2798 }
2799 add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
2800 dec(kj);
2801 cmp(kj, 0);
2802 jg(kh_label, T_NEAR);
2803 }
2804 if (jcp.ndims == 5) {
2805 add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
2806 * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
2807 add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
2808 * oc_block);
2809 dec(ki);
2810 cmp(ki, 0);
2811 jg(kd_label, T_NEAR);
2812 }
2813}
2814
2815void jit_avx512_common_conv_bwd_weights_kernel_f32
2816 ::compute_oh_step_common(
2817 int ic_block_step, int max_ur_w)
2818{
2819 Label kh_label, ic_block_label, ow_block_label, kd_label;
2820
2821 int ic_block = jcp.ic_block;
2822 int oc_block = jcp.oc_block;
2823
2824 int ow = jcp.ow;
2825 int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
2826 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
2827 int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad;
2828
2829 int ur_w = nstl::min(ow, max_ur_w);
2830 int ur_w_trips = ow / ur_w;
2831 int ur_w_tail = ow % ur_w;
2832 if ((ur_w_tail == 0 && r_pad != 0)
2833 || r_pad >= ur_w_tail) {
2834 if (ur_w_trips > 1) {
2835 ur_w_tail += ur_w;
2836 ur_w_trips--;
2837 } else {
2838 ur_w_tail += (ur_w - ur_w / 2);
2839 ur_w = ur_w / 2;
2840 }
2841 }
2842
2843 int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block;
2844 int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult;
2845 int output_comeback = ur_w_trips * ur_w * oc_block;
2846
2847 if (jcp.ndims == 5) {
2848 L(kd_label);
2849 mov(reg_input, aux_reg_input);
2850 mov(reg_kernel, aux_reg_kernel);
2851 }
2852
2853 mov(kj, reg_kh);
2854 L(kh_label); {
2855 xor_(b_ic, b_ic);
2856 L(ic_block_label); {
2857 if (l_pad != 0) {
2858 ur_w_trips--;
2859 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
2860 add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad)
2861 * inp_mult);
2862 add(reg_output, jcp.typesize_in * ur_w * oc_block);
2863 }
2864
2865 if (ur_w_trips > 0) {
2866 xor_(reg_ur_w_trips, reg_ur_w_trips);
2867 L(ow_block_label); {
2868 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
2869 add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w
2870 * inp_mult);
2871 add(reg_output, jcp.typesize_in * ur_w * oc_block);
2872
2873 inc(reg_ur_w_trips);
2874 cmp(reg_ur_w_trips, ur_w_trips);
2875 jl(ow_block_label, T_NEAR);
2876 }
2877 }
2878
2879 if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad,
2880 ic_block_step, 0, 0, 0);
2881
2882 sub(reg_input, jcp.typesize_in * input_comeback);
2883 sub(reg_output, jcp.typesize_in * output_comeback);
2884 int inp_icblk_stride = jcp.is_1stconv
2885 ? jcp.ih * jcp.iw * jcp.id
2886 : (jcp.ver == ver_4fma ? jcp.tr_iw : 1);
2887 size_t input_offset
2888 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
2889 safe_add(reg_input, input_offset, reg_long_offt);
2890 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
2891
2892 add(b_ic, ic_block_step);
2893 cmp(b_ic, jcp.ic_block);
2894 jl(ic_block_label, T_NEAR);
2895 }
2896 if (jcp.is_1stconv) {
2897 size_t input_offset
2898 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
2899 safe_sub(reg_input, input_offset, reg_long_offt);
2900 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
2901 } else if (jcp.ver != ver_4fma) {
2902 add(reg_input, jcp.typesize_in
2903 * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
2904 }
2905 add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
2906 dec(kj);
2907 cmp(kj, 0);
2908 jg(kh_label, T_NEAR);
2909 }
2910 if (jcp.ndims == 5) {
2911 add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
2912 * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
2913 add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
2914 * oc_block);
2915 dec(ki);
2916 cmp(ki, 0);
2917 jg(kd_label, T_NEAR);
2918 }
2919}
2920
2921void jit_avx512_common_conv_bwd_weights_kernel_f32
2922 ::compute_oh_step_disp()
2923{
2924 int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2);
2925 if (jcp.is_1stconv) {
2926 bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0);
2927 ic_block_step
2928 = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1;
2929 }
2930
2931 bool too_large_to_unroll
2932 = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
2933 && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
2934
2935 int ow = jcp.ow;
2936 if (jcp.ndims == 5) {
2937 /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of
2938 * 'movs' must be guaranteed. */
2939 mov(ki, reg_kd_count);
2940 push(reg_kd_count);
2941 mov(aux_reg_input, reg_input);
2942 mov(aux_reg_kernel, reg_kernel);
2943 }
2944
2945 if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll)
2946 compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
2947 else if (ow <= max_ur_w)
2948 compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
2949 else
2950 compute_oh_step_common(ic_block_step, max_ur_w);
2951
2952 if (jcp.ndims == 5) {
2953 mov(reg_input, aux_reg_input);
2954 mov(reg_kernel, aux_reg_kernel);
2955 pop(reg_kd_count);
2956 od_step_comeback_pointers();
2957 } else {
2958 oh_step_comeback_pointers();
2959 }
2960}
2961
2962void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel()
2963{
2964 Label skip_zeroing, zeroing_loop;
2965
2966 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
2967 cmp(reg_tmp, 0);
2968 jz(skip_zeroing, T_NEAR);
2969
2970 Zmm zero = Zmm(0);
2971 vpxord(zero, zero, zero);
2972 xor_(reg_tmp, reg_tmp);
2973 L(zeroing_loop); {
2974 assert(jcp.oc_block * jcp.typesize_out
2975 == cpu_isa_traits<avx512_common>::vlen);
2976 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
2977 vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block
2978 * jcp.typesize_out], zero);
2979 add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
2980 cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd
2981 * jcp.typesize_out);
2982 jnz(zeroing_loop);
2983 }
2984
2985 L(skip_zeroing);
2986}
2987
2988void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel()
2989{
2990 Label skip_bias, bias_loop, skip_load_bias;
2991
2992 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
2993 test(reg_tmp,reg_tmp);
2994 jne(skip_bias, T_NEAR);
2995
2996 mov(reg_bias, ptr[param + GET_OFF(bias)]);
2997 mov(reg_output, ptr[param + GET_OFF(dst)]);
2998 vpxord(Zmm(1), Zmm(1), Zmm(1));
2999
3000 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3001 cmp(reg_tmp, 0);
3002 jne(skip_load_bias, T_NEAR);
3003 vmovups(Zmm(1), ptr[reg_bias]);
3004
3005 L(skip_load_bias);
3006
3007 mov(reg_oi, ptr[param + GET_OFF(d_worksize)]);
3008 sub(reg_oi, ptr[param + GET_OFF(d_index)]);
3009 mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out);
3010 imul(reg_oi, reg_tmp);
3011
3012 xor_(reg_tmp, reg_tmp);
3013 L(bias_loop); {
3014 vmovups(Zmm(0), ptr[reg_output + reg_tmp]);
3015 vaddps(Zmm(1), Zmm(1), Zmm(0));
3016 add(reg_tmp, jcp.oc_block * jcp.typesize_out);
3017 cmp(reg_tmp, reg_oi);
3018 jl(bias_loop);
3019 }
3020 vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1));
3021
3022 L(skip_bias);
3023}
3024
3025void jit_avx512_common_conv_bwd_weights_kernel_f32
3026 ::compute_oh_loop_common()
3027{
3028 int b_pad = jcp.b_pad;
3029 int t_pad = jcp.t_pad;
3030 bool is_dilated = jcp.dilate_h != 0;
3031 int dilate_h = jcp.dilate_h + 1;
3032 int stride_h = jcp.stride_h;
3033 const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
3034 int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
3035 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
3036 oh_bpad_label, oh_bpad_label_end, od_label, od_label_end,
3037 oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end;
3038
3039 int ow = jcp.ow;
3040
3041 mov(reg_kh, jcp.kh);
3042 xor_(reg_ih_count, reg_ih_count);
3043 xor_(reg_oj, reg_oj);
3044 /* Compute 'top' edge */
3045 if (t_pad > 0) {
3046 const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
3047 const int overflow
3048 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
3049 const int underflow = div_up(t_pad, dilate_h);
3050 const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
3051 mov(reg_kh, initial_inp_ker_overlap);
3052 add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
3053 * jcp.oc_block);
3054 // generate loop to process kernel while it remains within t_pad + ih
3055 if (kh_range < t_pad + jcp.ih) {
3056 if (is_dilated) {
3057 const int tail = t_pad % dilate_h;
3058 const int shift = tail == 0 ? 0 : dilate_h - tail;
3059 mov(reg_tmp, shift);
3060 if (tail != 0)
3061 add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
3062 }
3063 L(oh_tpad_label); {
3064 compute_oh_step_disp();
3065 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3066 if (is_dilated) {
3067 inc(reg_tmp);
3068 cmp(reg_tmp, dilate_h);
3069 jl(oh_dilate_label_shift, T_NEAR);
3070 // unshift input as new kernel element enters
3071 sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
3072 xor_(reg_tmp, reg_tmp);
3073 }
3074 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3075 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
3076 * jcp.ic_block * jcp.oc_block);
3077 add(reg_kh, stride_h);
3078 if (is_dilated) {
3079 jmp(oh_dilate_label_noshift, T_NEAR);
3080 L(oh_dilate_label_shift);
3081 // shift input as old kernel element progresses
3082 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3083 L(oh_dilate_label_noshift);
3084 }
3085 inc(reg_oj);
3086 add(reg_ih_count, stride_h);
3087
3088 // final number of kernel elements that overlap with input
3089 const int final_inp_ker_overlap
3090 = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
3091 cmp(reg_kh, final_inp_ker_overlap);
3092 jl(oh_tpad_label, T_NEAR);
3093 }
3094 }
3095 // need second loop to process kernel if it is larger than the input
3096 // (does not apply to dilations as they must have unit stride)
3097 if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
3098 t_pad % stride_h)) {
3099 assert(!is_dilated);
3100 mov(reg_kh, jcp.ih);
3101 L(oh_tpad_tail_label); {
3102 compute_oh_step_disp();
3103 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3104 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
3105 * jcp.ic_block * jcp.oc_block);
3106
3107 inc(reg_oj);
3108 add(reg_ih_count, stride_h);
3109
3110 cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h));
3111 jl(oh_tpad_tail_label, T_NEAR);
3112 }
3113 }
3114 // correct any excess shifts to kernel and input
3115 // (does not apply to dilations as they must have unit stride,
3116 // kernel must fit inside input, and padding is smaller than input)
3117 if (t_pad <= jcp.oh * stride_h) {
3118 // kernel has moved beyond padding (adjust for stride effects)
3119 if (t_pad % stride_h != 0) {
3120 assert(!is_dilated);
3121 int inp_corr = stride_h - t_pad % stride_h;
3122 add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
3123 * jcp.ic_block * jcp.oc_block);
3124 add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
3125 }
3126 } else {
3127 // kernel still overlaps padding (complete reset)
3128 assert(!is_dilated);
3129 sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
3130 * jcp.kw * jcp.ic_block * jcp.oc_block);
3131 }
3132 }
3133
3134 cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
3135 jge(oh_label_end, T_NEAR);
3136 cmp(reg_oj, jcp.oh);
3137 jge(oh_label, T_NEAR);
3138
3139 /* Compute middle block(s) */
3140 mov(reg_kh, jcp.kh);
3141 L(oh_label); {
3142 compute_oh_step_disp();
3143 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3144 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3145
3146 inc(reg_oj);
3147 add(reg_ih_count, stride_h);
3148
3149 cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
3150 jge(oh_label_end, T_NEAR);
3151
3152 cmp(reg_oj, jcp.oh);
3153 jl(oh_label, T_NEAR);
3154 }
3155 L(oh_label_end);
3156
3157 /* Compute bottom edge */
3158 if (b_pad > 0) {
3159 cmp(reg_oj, jcp.oh);
3160 jge(oh_bpad_label_end, T_NEAR);
3161
3162 if (is_dilated) {
3163 mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
3164 mov(reg_tmp, 0);
3165 } else {
3166 mov(reg_kh, jcp.ihp - b_pad);
3167 sub(reg_kh, reg_ih_count);
3168 }
3169 L(oh_bpad_label);
3170 {
3171 compute_oh_step_disp();
3172 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3173 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3174 if (is_dilated) {
3175 inc(reg_tmp);
3176 cmp(reg_tmp, dilate_h);
3177 jl(oh_dilate_label_end, T_NEAR);
3178 xor_(reg_tmp, reg_tmp);
3179 }
3180 sub(reg_kh, stride_h);
3181 cmp(reg_kh, 0);
3182 jle(oh_bpad_label_end, T_NEAR);
3183 if (is_dilated)
3184 L(oh_dilate_label_end);
3185
3186 inc(reg_oj);
3187 cmp(reg_oj, jcp.oh);
3188 jl(oh_bpad_label, T_NEAR);
3189 }
3190 L(oh_bpad_label_end);
3191 }
3192}
3193
3194void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() {
3195 int ic_block = jcp.ic_block;
3196 int oc_block = jcp.oc_block;
3197 const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
3198 int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
3199 int ow = jcp.ow;
3200 const int input_backpad_overlap
3201 = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
3202
3203 const size_t filter_shift
3204 = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block;
3205 const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult;
3206 const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block;
3207
3208 Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
3209 backpad_end_label, backpad_label;
3210
3211 if (jcp.with_bias) bias_kernel();
3212
3213 /* initially offset 'kd' by f_pad */
3214 add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
3215
3216 mov(reg_input_d, ptr[param + GET_OFF(src)]);
3217 mov(reg_output_d, ptr[param + GET_OFF(dst)]);
3218 mov(reg_d_index, ptr[param + GET_OFF(d_index)]);
3219 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
3220
3221 cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
3222 jge(loop_end_label, T_NEAR);
3223
3224 L(d_loop_label);
3225
3226 mov(reg_input, reg_input_d);
3227 mov(reg_output, reg_output_d);
3228
3229 push(reg_input_d);
3230 push(reg_output_d);
3231 push(reg_d_index);
3232
3233 compute_oh_loop_common();
3234
3235 pop(reg_d_index);
3236 pop(reg_output_d);
3237 pop(reg_input_d);
3238
3239 /* Compute 'front' edge */
3240 if (jcp.f_pad > 0) {
3241
3242 /* Check if within fpad region */
3243 cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
3244 jge(fpad_end_label, T_NEAR);
3245
3246 /* Fpad steps */
3247 sub(reg_kernel, filter_shift * jcp.stride_d);
3248 add(reg_kd_count, jcp.stride_d);
3249
3250 /* Final number of kernel elements that overlap with input */
3251 const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id);
3252 cmp(reg_kd_count, inp_ker_overlap);
3253 jl(common_block_label, T_NEAR);
3254
3255 /* Correct any excess shifts to kernel and input */
3256 if (jcp.f_pad <= jcp.od * jcp.stride_d) {
3257 /* Filter has moved beyond padding (adjust for stride effects) */
3258 if (jcp.f_pad % jcp.stride_d != 0) {
3259 int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
3260 add(reg_kernel, filter_shift * inp_corr);
3261 add(reg_input_d, input_shift * inp_corr);
3262 }
3263 } else {
3264 /* Filter still overlaps padding (complete reset) */
3265 sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
3266 }
3267
3268 /* Apply correction */
3269 mov(reg_kd_count, jcp.kd);
3270 jmp(common_block_label);
3271
3272 L(fpad_end_label);
3273 }
3274
3275 /* Compute bottom edge */
3276 if (jcp.back_pad > 0) {
3277
3278 /* Check if within back_pad region */
3279 cmp(reg_d_index, input_backpad_overlap - 1);
3280 jl(backpad_end_label, T_NEAR);
3281 jg(backpad_label, T_NEAR);
3282
3283 /* Execute overlap correction between the filter and the initial
3284 * back_pad region. */
3285 mov(reg_kd_count,
3286 jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d);
3287 jmp(backpad_end_label, T_NEAR);
3288
3289 L(backpad_label);
3290 sub(reg_kd_count, jcp.stride_d);
3291 cmp(reg_kd_count, 0);
3292 jle(loop_end_label, T_NEAR);
3293
3294 L(backpad_end_label);
3295 }
3296
3297 /* Compute middle block */
3298 add(reg_input_d, input_shift * jcp.stride_d);
3299
3300 /* Execute common block and loop */
3301 L(common_block_label);
3302 add(reg_output_d, output_shift);
3303 inc(reg_d_index);
3304 cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
3305 jl(d_loop_label, T_NEAR);
3306
3307 L(loop_end_label);
3308}
3309
3310bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() {
3311 // FIXME: use register mapping from the class declaration
3312 bool ok = jcp.ver == ver_4fma
3313 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
3314 && everyone_is(1, jcp.stride_h, jcp.stride_w);
3315 if (!ok) return false;
3316 if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2)
3317 return false;
3318
3319 // General code layout:
3320 //
3321 // Blocking over OH -- top level
3322 // (Reduces L2 pressure; not very useful right now)
3323 // Loop over all KHxKW kernel -- emit_kh_kw_loop()
3324 // Loop over OH block -- emit_h_loop()
3325 // Loop over OW blocks -- emit_fma_block()
3326 // (Supports both fully unrolled and partially unrolled versions to
3327 // reduce code size)
3328 // Loop over OW block -- emit_fma_step()
3329
3330 int max_working_set_size = 128 * 1024;
3331 int pad_ow = jcp.ow;
3332
3333 int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
3334 int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in;
3335 int row_size = inp_row_size + out_row_size;
3336
3337 int h_block_size = jcp.oh;
3338 int working_set_size = row_size * h_block_size;
3339
3340 if (working_set_size > max_working_set_size) {
3341 int opt_working_set_size = 48 * 1024;
3342 assert(opt_working_set_size < max_working_set_size);
3343
3344 while (working_set_size > opt_working_set_size) {
3345 for (int i = 2; i <= h_block_size; i++)
3346 if (i == h_block_size)
3347 h_block_size = h_block_size / 2;
3348 else if (h_block_size % i == 0) {
3349 h_block_size = h_block_size / i;
3350 break;
3351 }
3352 working_set_size = row_size * h_block_size;
3353
3354 if (h_block_size == 1 && working_set_size > opt_working_set_size)
3355 return false;
3356 }
3357 }
3358
3359 // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below)
3360 if (h_block_size < nstl::max(1, jcp.t_pad)
3361 || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size
3362 : jcp.oh % h_block_size))
3363 return false;
3364
3365 // check that we can use simple arithmetic for prefetch address
3366 // calculations
3367 // TODO: we need some traits for this check (Roma)
3368 int cache_line_size = 64;
3369 assert(jcp.ic_block * typesize == 64);
3370 assert(jcp.oc_block * typesize == 64);
3371
3372 int num_inp_l2_pfs = jcp.tr_iw * h_block_size;
3373 int avg_h_loop_len = h_block_size;
3374 int num_inp_l2_pfs_per_fma_block
3375 = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
3376 int num_out_l2_pfs = pad_ow * h_block_size;
3377 int num_out_l2_pfs_per_fma_block
3378 = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
3379
3380 Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors
3381 Reg64 reg_kh = rax;
3382 Reg64 reg_kw = rbx;
3383 Reg64 reg_tmp = abi_not_param1;
3384 Reg32 reg_tmp_w = reg_tmp.cvt32();
3385 Reg64 reg_ohs = rdx;
3386 Reg64 reg_ihs = rsi;
3387 Reg64 reg_h = r8;
3388 Reg64 reg_i = r9;
3389 Reg64 reg_j = r10;
3390
3391 Reg64 reg_inp = r13;
3392 Reg64 reg_out = r14;
3393 Reg64 reg_ker = r15;
3394
3395 Reg64 reg_inp_pf_l1 = rbp;
3396
3397 Reg64 reg_inp_pf_l2 = r11;
3398 Reg64 reg_out_pf_l2 = r12;
3399
3400 Xmm reg_inp_pf_save = xmm17;
3401 Xmm reg_out_pf_save = xmm18;
3402
3403 Reg64 reg_inp_save = abi_param1;
3404 Reg64 reg_out_save = reg_tmp;
3405
3406 auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); };
3407 auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
3408 auto inp_addr = [&](int oi, int ic1) {
3409 return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in];
3410 };
3411 auto out_addr = [&](int oi, int oj = 0) {
3412 assert(jcp.ver == ver_4fma);
3413 return ptr[reg_out
3414 + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in];
3415 };
3416 auto ker_addr = [&](int ic1) {
3417 return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out];
3418 };
3419
3420 auto emit_block = [&](int h_block_size,
3421 bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row)
3422 {
3423 // TODO: add an fma version (Roma)
3424 auto pad_ow = jcp.ow;
3425
3426 int ow4u = rnd_up(pad_ow, 4);
3427 int def_step_size = 16;
3428
3429 bool has_w_tail = (pad_ow % def_step_size != 0
3430 || pad_ow % 4 != 0);
3431 bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
3432
3433 auto emit_step = [&](int ur_ow,
3434 int num_inp_l1_pfs_per_fma_step,
3435 int num_inp_l2_pfs_per_fma_step,
3436 int num_out_l2_pfs_per_fma_step, bool is_w_tail)
3437 {
3438 bool block_wraparound = is_w_tail && is_last_row;
3439
3440 assert(ur_ow % 4 == 0);
3441 int tail_size = ow4u % ur_ow;
3442 int this_ur_ow
3443 = (is_w_tail && tail_size) ? tail_size : ur_ow;
3444 int ow_last_chunk4 = pad_ow % 4;
3445 int ow_zero_tail4 = ow_last_chunk4
3446 ? 4 - ow_last_chunk4 : 0;
3447
3448 auto emit_out_pf = [&](int oi) {
3449#if 1
3450 if (oi + def_step_size < ur_ow || !block_wraparound)
3451 mic_prefetcht0(ptr[reg_out
3452 + ((def_step_size + oi)
3453 * jcp.oc_block * jcp.typesize_in)]);
3454 else {
3455 assert(block_wraparound);
3456 assert(oi + def_step_size >= ur_ow);
3457 mic_prefetcht0(ptr[reg_out_save
3458 + ((oi + def_step_size - ur_ow)
3459 * jcp.oc_block * jcp.typesize_in)]);
3460 }
3461#else
3462 // XXX: This is an alternative prefetching strategy that
3463 // always prefetches the next row. Keeping it here for
3464 // future experiments (Roma)
3465 if (!block_wraparound)
3466 mic_prefetcht0(ptr[reg_out
3467 + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]);
3468 else
3469 mic_prefetcht0(ptr[reg_out + reg_ohs
3470 - ((h_block_size - 1) * jcp.ow
3471 - oi) * jcp.oc_block * jcp.typesize_in]);
3472#endif
3473 if (oi < num_out_l2_pfs_per_fma_step)
3474 mic_prefetcht1(ptr[reg_out_pf_l2
3475 + oi * jcp.oc_block * jcp.typesize_in]);
3476 };
3477
3478 auto emit_inp_pf = [&](int oi4, int ic1) {
3479 int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block;
3480 int num_pf_slots = jcp.ic_block * ur_ow / 4;
3481
3482 int num_pfs = num_inp_l1_pfs_per_fma_step
3483 + num_inp_l2_pfs_per_fma_step;
3484 int pf_freq = nstl::max(1, num_pf_slots / num_pfs);
3485
3486 if (pf_slot_idx % pf_freq)
3487 return;
3488
3489 int pf_idx = pf_slot_idx / pf_freq;
3490
3491 if (pf_idx < num_inp_l2_pfs_per_fma_step)
3492 mic_prefetcht1(ptr[reg_inp_pf_l2
3493 + pf_idx * jcp.ic_block * jcp.typesize_in]);
3494 else {
3495 pf_idx -= num_inp_l2_pfs_per_fma_step;
3496 // prefetch the 'tail' of the cache line because most of
3497 // the accesses are not aligned
3498 mic_prefetcht0(ptr[reg_inp_pf_l1
3499 + pf_idx * jcp.ic_block * jcp.typesize_in
3500 + cache_line_size - jcp.typesize_in]);
3501 }
3502 };
3503
3504 auto numloads = 4;
3505
3506 int steps = this_ur_ow;
3507 for (int oi4 = 0; oi4 < steps; oi4 += numloads) {
3508 for (int oi1 = 0; oi1 < numloads; oi1++) {
3509 int oi = oi4 + oi1;
3510 if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) {
3511 vmovups(zmm_out(oi), out_addr(oi));
3512 emit_out_pf(oi);
3513 } else {
3514 auto zmm = zmm_out(oi);
3515 vpxord(zmm, zmm, zmm);
3516 }
3517 }
3518
3519 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3520 if (jcp.ver == ver_4fma) {
3521 v4fmaddps(zmm_ker(ic1),
3522 zmm_out(oi4), inp_addr(oi4, ic1));
3523 } else {
3524 assert(!"unknown convolution version");
3525 }
3526 emit_inp_pf(oi4, ic1);
3527 }
3528 }
3529 };
3530
3531 // Input is transposed and padded but we only access about jcp.iw
3532 // elements so use that to compute the # of cache lines in each 'row'
3533 int num_inp_l1_pfs
3534 = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block;
3535
3536 if (full_w_unroll) {
3537 emit_step(ow4u, num_inp_l1_pfs,
3538 num_inp_l2_pfs_per_fma_block,
3539 num_out_l2_pfs_per_fma_block, true);
3540 add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size);
3541 add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size);
3542 } else {
3543 Label w_loop;
3544 int num_w_iters = pad_ow / def_step_size;
3545 int num_w_iters_full = num_w_iters + has_w_tail;
3546 int num_inp_l1_pfs_per_fma_step
3547 = div_up(num_inp_l1_pfs, num_w_iters_full);
3548 int num_inp_l2_pfs_per_fma_step
3549 = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full);
3550 int num_out_l2_pfs_per_fma_step
3551 = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full);
3552 mov(reg_i, num_w_iters);
3553 L(w_loop); {
3554 emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
3555 num_inp_l2_pfs_per_fma_step,
3556 num_out_l2_pfs_per_fma_step, false);
3557 add(reg_inp, def_step_size * jcp.typesize_in);
3558 add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in);
3559 add(reg_inp_pf_l1,
3560 num_inp_l1_pfs_per_fma_step * cache_line_size);
3561 add(reg_inp_pf_l2,
3562 num_inp_l2_pfs_per_fma_step * cache_line_size);
3563 add(reg_out_pf_l2,
3564 num_out_l2_pfs_per_fma_step * cache_line_size);
3565 sub(reg_i, 1);
3566 jnz(w_loop);
3567 }
3568 if (has_w_tail) {
3569 emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
3570 num_inp_l2_pfs_per_fma_step,
3571 num_out_l2_pfs_per_fma_step, true);
3572 add(reg_inp_pf_l2,
3573 num_inp_l2_pfs_per_fma_step * cache_line_size);
3574 add(reg_out_pf_l2,
3575 num_out_l2_pfs_per_fma_step * cache_line_size);
3576 }
3577 // reset reg_inp and reg_out because emit_h_loop expects
3578 // unmodified pointers
3579 int w_offset = num_w_iters * def_step_size;
3580 sub(reg_inp, w_offset * jcp.typesize_in);
3581 sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in);
3582 }
3583 };
3584
3585 auto emit_h_loop = [&](int h_block_size,
3586 bool is_last_block, bool is_last_kh_kw_iter)
3587 {
3588 Label h_loop, skip_h_loop;
3589 mov(reg_j, 1);
3590 cmp(reg_j, reg_h);
3591 je(skip_h_loop, T_NEAR);
3592 L(h_loop); {
3593
3594 lea(reg_inp_pf_l1,
3595 ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]);
3596 emit_block(h_block_size,
3597 is_last_block, is_last_kh_kw_iter, false);
3598
3599 add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
3600 add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in);
3601 add(reg_j, 1);
3602 cmp(reg_j, reg_h);
3603 jb(h_loop);
3604 }
3605
3606 L(skip_h_loop);
3607
3608 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3609 mic_prefetcht0(ker_addr(ic1));
3610
3611 lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]);
3612 emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true);
3613 };
3614
3615 auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block,
3616 int h_block_size)
3617 {
3618 xor_(reg_kh, reg_kh);
3619 Label kh_loop, kh_loop_end;
3620
3621 int last_oh_block_size
3622 = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size);
3623 int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size;
3624 // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size
3625 int ih_block_size = oh_block_size - 1 + jcp.kh
3626 - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad;
3627
3628 L(kh_loop); {
3629 // determine starting indices for this block
3630 if (is_first_block) {
3631 xor_(reg_tmp, reg_tmp);
3632 mov(reg_ohs, jcp.t_pad);
3633 sub(reg_ohs, reg_kh);
3634 cmovb(reg_ohs, reg_tmp);
3635
3636 mov(reg_ihs, reg_ohs);
3637 sub(reg_ihs, jcp.t_pad);
3638 add(reg_ihs, reg_kh);
3639 } else {
3640 xor_(reg_ohs, reg_ohs);
3641 mov(reg_ihs, reg_kh);
3642 }
3643
3644 // determine effective size of block based on padding
3645 mov(reg_tmp, oh_block_size);
3646 sub(reg_tmp, reg_ohs);
3647 mov(reg_h, ih_block_size);
3648 sub(reg_h, reg_ihs);
3649 cmp(reg_tmp, reg_h);
3650 cmovb(reg_h, reg_tmp);
3651
3652 Label kh_loop_work;
3653 cmp(reg_h, 0);
3654 jg(kh_loop_work, T_NEAR);
3655
3656 // empty h loop for this jcp.kh:
3657 // - set the output to 0 if necessary
3658 // - move ker pt
3659 // - jump to the end
3660 sub(reg_h, 1);
3661 Label skip_ker_zeroing;
3662
3663 // The reg_ker ptr has highest bit set if the output needs to be
3664 // zeroed. Those who have byte-aligned their data will suffer the
3665 // consiquences :(
3666 // TODO: move the flag to a mask register? (Roma)
3667 test(reg_ker, 1);
3668 jz(skip_ker_zeroing, T_NEAR);
3669
3670 Label zeroing_loop;
3671 vpxord(zmm0, zmm0, zmm0);
3672 and_(reg_ker, ~1); // temporarily clear the zeroing flag
3673 mov(reg_tmp, jcp.kw);
3674 L(zeroing_loop); {
3675 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3676 vmovups(ker_addr(ic1), zmm0);
3677 add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out);
3678 sub(reg_tmp, 1);
3679 jnz(zeroing_loop, T_NEAR);
3680 }
3681 // restore the zeroing flag (it will be cleared after the end of
3682 // emit_kh_kw_loop, but we may need it until then)
3683 or_(reg_ker, 1);
3684 jmp(kh_loop_end, T_NEAR);
3685
3686 L(skip_ker_zeroing);
3687 add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw
3688 * jcp.typesize_out);
3689 jmp(kh_loop_end, T_NEAR);
3690
3691 L(kh_loop_work);
3692
3693 mul_by_const(reg_ihs, reg_tmp,
3694 jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
3695 mul_by_const(reg_ohs, reg_tmp,
3696 pad_ow * jcp.oc_block * jcp.typesize_in);
3697
3698 add(reg_inp, reg_ihs);
3699 add(reg_out, reg_ohs);
3700
3701 Label kw_loop;
3702 xor_(reg_kw, reg_kw);
3703 L(kw_loop); {
3704 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3705 auto zmm = zmm_ker(ic1);
3706 vpxord(zmm, zmm, zmm);
3707 mic_prefetcht1(ker_addr(ic1));
3708 }
3709
3710 mov(reg_out_save, reg_out);
3711 mov(reg_inp_save, reg_inp);
3712 lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]);
3713
3714#if 0
3715 // XXX: Generate code with special prefetches when switching
3716 // blocks or at the end of the last block. Disabled to reduce
3717 // code size and because there's no performance benefit (Roma)
3718 Label regular_h_loop, end_h_loop;
3719 cmp(reg_kw, jcp.kw - 1);
3720 jne(regular_h_loop, T_NEAR);
3721 cmp(reg_kh, jcp.kh - 1);
3722 jne(regular_h_loop, T_NEAR);
3723
3724 emit_h_loop(oh_block_size, is_last_block, true);
3725 jmp(end_h_loop, T_NEAR);
3726
3727 L(regular_h_loop);
3728 emit_h_loop(oh_block_size, is_last_block, false);
3729
3730 L(end_h_loop);
3731#else
3732 emit_h_loop(oh_block_size, is_last_block, false);
3733#endif
3734
3735 mov(reg_out, reg_out_save);
3736 mov(reg_inp, reg_inp_save);
3737
3738 Label do_store;
3739 // The reg_ker ptr has highest bit set if the output needs to
3740 // be zeroed. Those who have byte-aligned their data will
3741 // suffer the consiquences :(
3742 mov(reg_tmp, reg_ker);
3743 and_(reg_ker, ~1);
3744 test(reg_tmp, 1);
3745 jnz(do_store, T_NEAR);
3746
3747 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3748 auto zmm = zmm_ker(ic1);
3749 if (jcp.ver == ver_4fma) {
3750 vaddps(zmm, ker_addr(ic1));
3751 } else {
3752 assert(!"unknown convolution version");
3753 }
3754 }
3755
3756 L(do_store);
3757 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3758 auto zmm = zmm_ker(ic1);
3759 vmovups(ker_addr(ic1), zmm);
3760 }
3761
3762 mov(reg_ker, reg_tmp);
3763 add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
3764 add(reg_kw, 1);
3765 cmp(reg_kw, jcp.kw);
3766 jl(kw_loop);
3767 }
3768
3769 sub(reg_inp, reg_ihs);
3770 sub(reg_out, reg_ohs);
3771
3772
3773 L(kh_loop_end);
3774 add(reg_kh, 1);
3775 cmp(reg_kh, jcp.kh);
3776 jl(kh_loop);
3777 }
3778 };
3779
3780 mov(reg_inp, ptr[param + GET_OFF(src)]);
3781 mov(reg_out, ptr[param + GET_OFF(dst)]);
3782 mov(reg_ker, ptr[param + GET_OFF(filt)]);
3783 mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]);
3784 mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]);
3785 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3786 or_(reg_ker, reg_tmp);
3787
3788 bool single_kh_kw_loop = (h_block_size == jcp.oh);
3789
3790 size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in;
3791 size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad);
3792 size_t inp_block_step = inp_row_step * h_block_size;
3793 size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in
3794 * h_block_size;
3795
3796 if (!single_kh_kw_loop) {
3797 // Save the original prefetch pointers from the OpenMP driver
3798 vmovq(reg_inp_pf_save, reg_inp_pf_l2);
3799 vmovq(reg_out_pf_save, reg_out_pf_l2);
3800 mov(reg_inp_pf_l2, reg_inp);
3801 add(reg_inp_pf_l2, first_inp_block_step);
3802 mov(reg_out_pf_l2, reg_out);
3803 add(reg_out_pf_l2, out_block_step);
3804 }
3805 emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size);
3806
3807 if (!single_kh_kw_loop) {
3808 size_t ker_reset_offset
3809 = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh;
3810 sub(reg_ker, ker_reset_offset);
3811 and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
3812
3813 add(reg_inp, first_inp_block_step);
3814 add(reg_out, out_block_step);
3815 mov(reg_inp_pf_l2, reg_inp);
3816 add(reg_inp_pf_l2, inp_block_step);
3817 mov(reg_out_pf_l2, reg_out);
3818 add(reg_out_pf_l2, out_block_step);
3819
3820 int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2;
3821 if (num_innermost_iters > 0) {
3822 Label h_block_loop;
3823
3824 mov(reg_tmp_w, num_innermost_iters);
3825 kmovw(reg_h_block, reg_tmp_w);
3826 L(h_block_loop); {
3827 emit_kh_kw_loop(false, false, h_block_size);
3828 sub(reg_ker, ker_reset_offset);
3829 add(reg_inp, inp_row_step * h_block_size);
3830 add(reg_out, out_block_step);
3831 mov(reg_inp_pf_l2, reg_inp);
3832 add(reg_inp_pf_l2, inp_block_step);
3833 mov(reg_out_pf_l2, reg_out);
3834 add(reg_out_pf_l2, out_block_step);
3835 kmovw(reg_tmp_w, reg_h_block);
3836 sub(reg_tmp_w, 1);
3837 kmovw(reg_h_block, reg_tmp_w);
3838 jnz(h_block_loop);
3839 }
3840 }
3841
3842 // Restore the original prefetch pointers that came from the OpenMP
3843 // driver
3844 vmovq(reg_inp_pf_l2, reg_inp_pf_save);
3845 vmovq(reg_out_pf_l2, reg_out_pf_save);
3846 emit_kh_kw_loop(false, true, h_block_size);
3847 }
3848
3849 return true;
3850}
3851
3852bool jit_avx512_common_conv_bwd_weights_kernel_f32
3853 ::flat_4ops_compute() {
3854 const auto &j = jcp;
3855 const bool ok = j.ver == ver_4fma && j.is_1stconv
3856 && everyone_is(0, j.dilate_h, j.dilate_w);
3857 if (!ok) return false;
3858
3859 Reg64 reg_ptr_tr_src = r8;
3860 Reg64 reg_ptr_dst = r9;
3861 Reg64 reg_ptr_wei = r10;
3862 Reg64 reg_ptr_bia = r11;
3863
3864 Reg64 reg_kh_step = rax;
3865 Reg64 reg_oh = abi_not_param1;
3866 Reg64 reg_kh = rdx;
3867
3868 Reg32 reg_flag_save = ebx;
3869 Reg32 reg_flag = esi;
3870
3871 Zmm vbia(31);
3872
3873 auto zmm_wei = [&](int kh, int kw) {
3874 return Zmm(8 + kh * j.kw + kw);
3875 };
3876 auto zmm_dst = [&](int ow) {
3877 return Zmm(ow % 8);
3878 };
3879
3880 auto addr_tr_src = [&](int kh, int iw) {
3881 return ptr[reg_ptr_tr_src
3882 + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in];
3883 };
3884 auto addr_dst = [&](int ow) {
3885 return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in];
3886 };
3887 auto addr_wei = [&](int kh, int kw) {
3888 return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block
3889 * jcp.typesize_out];
3890 };
3891
3892 auto emit_fma_block = [&](int kh_step) {
3893 for (int kh = 0; kh < kh_step; ++kh) {
3894 for (int kw = 0; kw < j.kw; ++kw) {
3895 auto vwei = zmm_wei(kh, kw);
3896 vpxord(vwei, vwei, vwei);
3897 }
3898 }
3899
3900 for (int ow = 0; ow < j.ow; ow += 4) {
3901 for (int _ow = ow; _ow < ow + 4; ++_ow) {
3902 auto vdst = zmm_dst(_ow);
3903 if (_ow < j.ow)
3904 vmovups(vdst, addr_dst(_ow));
3905 else
3906 vpxord(vdst, vdst, vdst);
3907 }
3908
3909 for (int kh = 0; kh < kh_step; ++kh) {
3910 for (int kw = 0; kw < j.kw; ++kw) {
3911 const int iw = ow + (kw % j.stride_w) * j.tr_ld
3912 + (kw / j.stride_w);
3913 v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow),
3914 addr_tr_src(kh, iw));
3915 if (1 && kh == 0 && kw < 4) {
3916 prefetcht1(ptr[reg_ptr_dst
3917 + (j.ow + ow + kw) * jcp.oc_block
3918 * jcp.typesize_in]);
3919 }
3920 if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */
3921 const int off = kw + 4 - j.kw;
3922 if (off >= 0 && ow + off < j.ow)
3923 vaddps(vbia, vbia, zmm_dst(ow + off));
3924 }
3925 }
3926 }
3927 }
3928
3929 Label l_store;
3930 test(reg_flag, FLAG_MB_FIRST);
3931 jnz(l_store, T_NEAR);
3932 for (int kh = 0; kh < kh_step; ++kh) {
3933 for (int kw = 0; kw < j.kw; ++kw)
3934 vaddps(zmm_wei(kh, kw), addr_wei(kh, kw));
3935 }
3936 L(l_store);
3937 for (int kh = 0; kh < kh_step; ++kh) {
3938 for (int kw = 0; kw < j.kw; ++kw)
3939 vmovups(addr_wei(kh, kw), zmm_wei(kh, kw));
3940 }
3941 };
3942
3943 auto emit_kh_loop = [&]() {
3944 const int kh_step_rem = j.kh % j.kh_step;
3945 xor_(reg_kh, reg_kh);
3946 mov(reg_kh_step, j.kh_step);
3947
3948 Label l_kh_loop;
3949 L(l_kh_loop); {
3950 Label l_done;
3951
3952 if (kh_step_rem != 0) {
3953 Label l_keep_kh_step;
3954 cmp(reg_kh, j.kh - j.kh_step);
3955 jle(l_keep_kh_step, T_NEAR);
3956
3957 mov(reg_kh_step, kh_step_rem);
3958 emit_fma_block(kh_step_rem);
3959 jmp(l_done, T_NEAR);
3960
3961 L(l_keep_kh_step);
3962 }
3963
3964 emit_fma_block(j.kh_step);
3965
3966 L(l_done);
3967
3968 add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld
3969 * jcp.typesize_in);
3970 add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out);
3971 add(reg_kh, j.kh_step);
3972
3973 cmp(reg_kh, j.kh);
3974 jl(l_kh_loop, T_NEAR);
3975 }
3976
3977 const int kh_steps = rnd_up(j.kh, j.kh_step);
3978 sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in);
3979 sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out);
3980 };
3981
3982 auto emit_oh_loop = [&]() {
3983 mov(reg_oh, j.oh);
3984
3985 Label l_oh_loop;
3986 L(l_oh_loop); {
3987 Label l_restore_mb_flag, l_jump;
3988
3989 cmp(reg_oh, j.oh);
3990 je(l_restore_mb_flag, T_NEAR);
3991
3992 and_(reg_flag, ~FLAG_MB_FIRST);
3993 jmp(l_jump, T_NEAR);
3994
3995 L(l_restore_mb_flag);
3996 mov(reg_flag, reg_flag_save);
3997
3998 L(l_jump);
3999
4000 emit_kh_loop();
4001
4002 add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld
4003 * jcp.typesize_in);
4004 add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in);
4005
4006 dec(reg_oh);
4007 jnz(l_oh_loop, T_NEAR);
4008 }
4009 };
4010
4011 auto emit_bia_store = [&]() {
4012 if (!j.with_bias) return;
4013
4014 Label l_bia_store, l_bia_skip;
4015 test(reg_flag, FLAG_IC_FIRST);
4016 jz(l_bia_skip);
4017
4018 test(reg_flag, FLAG_MB_FIRST);
4019 jnz(l_bia_store, T_NEAR);
4020 vaddps(vbia, ptr[reg_ptr_bia]);
4021 L(l_bia_store);
4022 vmovups(ptr[reg_ptr_bia], vbia);
4023 L(l_bia_skip);
4024 };
4025
4026 mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]);
4027 mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]);
4028 mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]);
4029 mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]);
4030 mov(reg_flag_save, ptr[param + GET_OFF(flags)]);
4031
4032 vpxord(vbia, vbia, vbia);
4033 emit_oh_loop();
4034 emit_bia_store();
4035
4036 return true;
4037}
4038
4039void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop()
4040{
4041 if (flat_4ops_compute())
4042 return;
4043 if (compute_full_spat_loop())
4044 return;
4045
4046 maybe_zero_kernel();
4047
4048 if (jcp.ndims == 5) compute_d_loop_common();
4049 else compute_oh_loop_common();
4050}
4051
4052void jit_avx512_common_conv_bwd_weights_kernel_f32::generate()
4053{
4054 preamble();
4055
4056 mov(reg_input, ptr[param + GET_OFF(src)]);
4057 mov(reg_output, ptr[param + GET_OFF(dst)]);
4058 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4059
4060 compute_loop();
4061
4062 postamble();
4063}
4064
4065status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
4066 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4067 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4068 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) {
4069 if (!mayiuse(avx512_common))
4070 return status::unimplemented;
4071
4072 const memory_desc_wrapper src_d(&src_md);
4073 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4074 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4075 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4076
4077 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4078 int ndims = src_d.ndims();
4079
4080 jcp = zero<decltype(jcp)>();
4081
4082 jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
4083 jcp.ndims = ndims;
4084 jcp.prop_kind = cd.prop_kind;
4085
4086 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4087 jcp.mb = src_d.dims()[0];
4088
4089 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4090 jcp.oc_without_padding = jcp.oc;
4091 jcp.ic = src_d.dims()[1] / jcp.ngroups;
4092
4093 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4094 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
4095 jcp.iw = src_d.dims()[ndims-1];
4096 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4097 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
4098 jcp.ow = diff_dst_d.dims()[ndims-1];
4099
4100 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4101 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
4102 jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
4103
4104 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4105 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
4106 jcp.l_pad = cd.padding[0][ndims-3];
4107
4108 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4109 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
4110 jcp.stride_w = cd.strides[ndims-3];
4111
4112 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4113 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
4114 jcp.dilate_w = cd.dilates[ndims-3];
4115
4116 const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
4117 bool ok = true
4118 // general condition to simplify dilations
4119 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4120 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4121 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4122 // special condition to simplify dilations in compute_oh_loop_common
4123 && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih);
4124 if (!ok)
4125 return status::unimplemented;
4126
4127 jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
4128 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
4129 jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
4130 + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
4131 jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d
4132 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1));
4133
4134 /* XXX: currently, does not support dilation_d > 0 */
4135 if (ndims == 5)
4136 if (jcp.dilate_d > 0)
4137 return status::unimplemented;
4138
4139 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4140 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4141 jcp.ohp = jcp.oh;
4142 jcp.owp = jcp.ow;
4143 jcp.aligned_threads = 0;
4144
4145 /* check for the 1st convolution */
4146 jcp.is_1stconv = is_1stconv(jcp);
4147
4148 jcp.oc_block = jcp.simd_w;
4149
4150 bool ok_to_pad_channels = true
4151 && jcp.ngroups == 1
4152 && src_d.data_type() == data_type::f32;
4153
4154 if (ok_to_pad_channels)
4155 jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
4156
4157 if (jcp.oc % jcp.oc_block)
4158 return status::unimplemented;
4159
4160 auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
4161 auto wei_tag = with_groups
4162 ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
4163 : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
4164
4165 if (diff_dst_d.format_kind() == format_kind::any) {
4166 CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag));
4167 jcp.dst_tag = dst_tag;
4168 } else {
4169 jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag);
4170 }
4171 if (jcp.dst_tag != dst_tag)
4172 return status::unimplemented;
4173
4174 /* conditions on bias memory */
4175 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
4176 if (jcp.with_bias) {
4177 if (diff_bias_d.format_kind() == format_kind::any)
4178 CHECK(memory_desc_init_by_tag(diff_bias_md, x));
4179 }
4180
4181 jcp.nb_oc = jcp.oc / jcp.oc_block;
4182
4183 /* kernel applicability check wrt boundaries
4184 * the conditions are quite general across the kernels we have,
4185 * but ideally the check should belong to a specific kernel... */
4186 const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
4187 const bool boundaries_ok = true
4188 && jcp.t_pad <= max_pad
4189 && jcp.b_pad <= max_pad
4190 && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad)
4191 && jcp.f_pad < jcp.kd;
4192 if (!boundaries_ok)
4193 return status::unimplemented;
4194
4195 /* yet another common check */
4196 if (jcp.kw > 14)
4197 return status::unimplemented;
4198
4199 /* setting register strategy */
4200 for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) {
4201 if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; }
4202 }
4203
4204 if (jcp.is_1stconv) {
4205 auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw);
4206 if (src_d.format_kind() == format_kind::any) {
4207 CHECK(memory_desc_init_by_tag(src_md, src_tag));
4208 jcp.src_tag = src_tag;
4209 } else {
4210 jcp.src_tag = src_d.matches_one_of_tag(src_tag);
4211 if (jcp.ic == 1 && jcp.src_tag != src_tag)
4212 jcp.src_tag = src_d.matches_one_of_tag(
4213 pick(ndims - 3, nwc, nhwc, ndhwc));
4214 }
4215 if (jcp.src_tag == format_tag::undef)
4216 return status::unimplemented;
4217
4218 const bool src_ok = true
4219 && utils::everyone_is(data_type::f32,
4220 src_d.data_type(), diff_weights_d.data_type(),
4221 diff_dst_d.data_type())
4222 && one_of(jcp.ic, 1, 2, 3)
4223 && jcp.ngroups == 1;
4224 if (!src_ok)
4225 return status::unimplemented;
4226
4227 const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad,
4228 jcp.stride_w), 16);
4229 const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1);
4230 const int kh_step_rem = jcp.kh % kh_step;
4231
4232 const auto wei_4fma_tag = with_groups
4233 ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)
4234 : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o);
4235
4236 auto current_wei_tag = format_tag::undef;
4237 if (diff_weights_d.format_kind() != format_kind::any)
4238 current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag);
4239
4240 const bool use_4fma = true
4241 && one_of(ndims, 3, 4)
4242 && mayiuse(avx512_mic_4ops)
4243 && mkldnn_thr_syncable()
4244 && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
4245 && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
4246 && jcp.kw <= 28 - jcp.with_bias
4247 && jcp.stride_w == 4
4248 && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */
4249 && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */
4250 && IMPLICATION(diff_weights_d.format_kind() != format_kind::any,
4251 current_wei_tag == wei_4fma_tag);
4252
4253 if (use_4fma) {
4254 jcp.ver = ver_4fma;
4255 jcp.kh_step = kh_step;
4256 jcp.tr_ld = tr_ld;
4257 jcp.ic_block = 1;
4258 if (diff_weights_d.format_kind() == format_kind::any)
4259 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag));
4260 jcp.wei_tag = wei_4fma_tag;
4261 } else {
4262 jcp.ver = ver_fma;
4263 jcp.ic_block = jcp.ic;
4264
4265 wei_tag = with_groups
4266 ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
4267 : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
4268
4269 if (diff_weights_d.format_kind() == format_kind::any) {
4270 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
4271 jcp.wei_tag = wei_tag;
4272 } else {
4273 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
4274 }
4275 if (jcp.wei_tag != wei_tag)
4276 return status::unimplemented;
4277 }
4278
4279 jcp.nb_ic = jcp.ic / jcp.ic_block;
4280 } else {
4281 auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
4282 if (src_d.format_kind() == format_kind::any) {
4283 CHECK(memory_desc_init_by_tag(src_md, src_tag));
4284 jcp.src_tag = src_tag;
4285 } else {
4286 jcp.src_tag = src_d.matches_one_of_tag(src_tag);
4287 }
4288 if (jcp.src_tag != src_tag)
4289 return status::unimplemented;
4290
4291 if (diff_weights_d.format_kind() == format_kind::any) {
4292 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
4293 jcp.wei_tag = wei_tag;
4294 } else {
4295 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
4296 }
4297 if (jcp.wei_tag != wei_tag)
4298 return status::unimplemented;
4299
4300 jcp.ic_block = jcp.simd_w;
4301 if (ok_to_pad_channels)
4302 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
4303 jcp.nb_ic = jcp.ic / jcp.ic_block;
4304 if ((mayiuse(avx512_mic) || mayiuse(avx512_core))
4305 && utils::everyone_is(data_type::f32,
4306 src_d.data_type(), diff_weights_d.data_type(),
4307 diff_dst_d.data_type())) {
4308 jcp.ver = ver_fma;
4309 if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 &&
4310 everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) &&
4311 mkldnn_thr_syncable()) {
4312 jcp.ver = ver_4fma;
4313 }
4314 } else {
4315 return status::unimplemented;
4316 }
4317 if (jcp.ver == ver_4fma) {
4318 jcp.ur_w = jcp.ow;
4319 // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to
4320 // cross the right boundary. The only requirement is not to have
4321 // NaNs there because another multiplicand is always guaranteed to
4322 // be zero. This also may require the top-level driver to allocate
4323 // four extra guarding elements at the very end of the buffer.
4324 // I'm not proud of this hack, but it improves performance by
4325 // about 5-10% depending on the dimensions (Roma)
4326
4327 const int tr_round = 4;
4328
4329 jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round);
4330 jcp.tr_src_num_guard_elems = tr_round; // upper bound
4331 }
4332 }
4333
4334 if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) {
4335 jcp.typesize_in = sizeof(float);
4336 jcp.typesize_out = sizeof(float);
4337 } else
4338 return status::unimplemented;
4339
4340 bool args_ok = true
4341 && jcp.ic % jcp.ic_block == 0
4342 && jcp.oc % jcp.oc_block == 0
4343 && jcp.ic <= src_d.padded_dims()[1]
4344 && jcp.oc <= diff_dst_d.padded_dims()[1]
4345 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
4346 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
4347 if (!args_ok) return status::unimplemented;
4348
4349 { // balancing
4350 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
4351 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
4352 jcp.nthr = nthr;
4353 jcp.nthr_mb = nthr_mb;
4354 jcp.nthr_g = nthr_g;
4355 jcp.nthr_oc_b = nthr_oc_b;
4356 jcp.nthr_ic_b = nthr_ic_b;
4357 }
4358
4359 return status::success;
4360}
4361
4362void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
4363 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
4364 if (jcp.ver == ver_4fma) {
4365 if (jcp.is_1stconv) {
4366 const size_t tr_src_size =
4367 jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld;
4368 scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
4369 } else {
4370 // XXX: See the comment about tr_iw and guarding elements in
4371 // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
4372 const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
4373 const size_t min_tr_src_size_per_thr
4374 = jcp.ih * jcp.ic_block * jcp.tr_iw;
4375 const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
4376 + jcp.tr_src_num_guard_elems;
4377 scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
4378 }
4379
4380 /* prepare synchronization contexts */
4381 if (jcp.nthr_oc_b > 1) {
4382 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
4383 scratchpad.book(key_conv_tr_src_bctx,
4384 sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
4385 }
4386 }
4387
4388 if (jcp.nthr_mb > 1) {
4389 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic
4390 * jcp.kh * jcp.kw * jcp.kd;
4391 const int bia_size = jcp.ngroups * jcp.oc;
4392 const size_t wei_bia_reduction_size = wei_size + bia_size;
4393
4394 scratchpad.book(key_conv_wei_bia_reduction,
4395 jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1));
4396 scratchpad.book(key_conv_wei_bia_reduction_bctx,
4397 sizeof(simple_barrier::ctx_t));
4398 }
4399
4400 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
4401 scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
4402}
4403
4404void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
4405 const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
4406 int &nthr_oc_b_, int &nthr_ic_b_)
4407{
4408 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
4409
4410 const int max_threads = mkldnn_get_max_threads();
4411
4412 if (max_threads < j.ngroups) {
4413 /* simplification... fortunately it doesn't hurt much */
4414 return;
4415 }
4416
4417 if (!mkldnn_thr_syncable() && j.ver == ver_4fma) {
4418 // should not happen -- the driver is not ready
4419 // for TBB-like non-synchronous threading yet
4420 return;
4421 }
4422
4423 if (j.ver == ver_4fma && j.is_1stconv) {
4424 nthr_g_ = 1;
4425 nthr_oc_b_ = 1;
4426 nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
4427 nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
4428 nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
4429 return;
4430 }
4431
4432 nthr_g_ = j.ngroups;
4433 const int nthr = max_threads / nthr_g_;
4434
4435 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4436 /* calculate per thread memory cost (read/write). high level optimizer
4437 * tries to minimize memory consumption. few notes:
4438 * (n1) unclear why, but that essentially helps first convolution...
4439 * (n2) assuming the reduction over minibatch is always there:
4440 * - instead of 8 it should be 5 here (write ~= 2 read):
4441 * kernel: temporal workspace 1 write
4442 * reduction: 1 read from workspace and 1 write to the diff_wei
4443 * - but experiments showed 8 works better than 5 or 6... */
4444
4445 const int src_coef = j.ver == ver_4fma ? 4 : 1;
4446 const int dst_coef = 1;
4447 const int wei_coef = 8;
4448
4449 return 0
4450 + src_coef
4451 * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
4452 * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
4453 / j.stride_d / j.stride_h / j.stride_w /* (n1) */
4454 + dst_coef
4455 * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
4456 * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
4457 + wei_coef /* (n2) */
4458 * div_up(j.ngroups, nthr_g_)
4459 * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
4460 * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
4461 };
4462
4463 int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4464
4465 /* step 1: find the best thread distribution with lowest memory cost */
4466 const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
4467 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4468 const int nthr_par = nthr / nthr_mb;
4469 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4470 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4471 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4472
4473 int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4474 if (mem_cost <= best_mem_cost) {
4475 best_mem_cost = mem_cost;
4476 nthr_mb_ = nthr_mb;
4477 nthr_oc_b_ = nthr_oc_b;
4478 nthr_ic_b_ = nthr_ic_b;
4479 }
4480 }
4481
4482 if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
4483 }
4484
4485 if (!mayiuse(avx512_mic)) {
4486 auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4487 return 1
4488 * div_up(j.mb, nthr_mb)
4489 * div_up(j.ngroups, nthr_g_)
4490 * div_up(j.nb_oc, nthr_oc_b)
4491 * div_up(j.nb_ic, nthr_ic_b);
4492 };
4493
4494 /* step 2: search for a thread distribution with lower compute cost.
4495 * the constrains:
4496 * - memory cost cannot exceed 110% of the best found in the step 1
4497 * - unless compute cost is 133% lower than the current best case
4498 * note: both constants were found empirically */
4499 int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4500 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4501 const int nthr_par = nthr / nthr_mb;
4502 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4503 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4504 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4505 int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4506 int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4507
4508 const bool opt1 = comp_cost <= best_comp_cost
4509 && mem_cost < 1.1 * best_mem_cost;
4510 const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
4511
4512 if (opt1 || opt2) {
4513 best_comp_cost = comp_cost;
4514 nthr_mb_ = nthr_mb;
4515 nthr_oc_b_ = nthr_oc_b;
4516 nthr_ic_b_ = nthr_ic_b;
4517 }
4518 }
4519
4520 if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
4521 }
4522 }
4523
4524 if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
4525 nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
4526 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
4527
4528 assert(nthr_ <= max_threads);
4529 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
4530}
4531
4532template struct _jit_avx512_common_conv_fwd_kernel<Zmm>;
4533template struct _jit_avx512_common_conv_fwd_kernel<Xmm>;
4534
4535}
4536}
4537}
4538
4539// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
4540