1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "c_types_map.hpp"
18#include "mkldnn_thread.hpp"
19#include "nstl.hpp"
20#include "type_helpers.hpp"
21#include "utils.hpp"
22#include "cpu_memory.hpp"
23
24#include <math.h>
25
26#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
27
28#ifndef KERNEL_SIZE_THRESHOLD
29#define KERNEL_SIZE_THRESHOLD 16
30#endif
31
32#define MIN_REQUIRED_DIMN_REG_BLOCK 14
33
34namespace mkldnn {
35namespace impl {
36namespace cpu {
37
38namespace {
39
40using namespace mkldnn::impl::utils;
41
42unsigned int L1_cache_size = get_cache_size(1, true);
43unsigned int L2_cache_size = get_cache_size(2, true);
44unsigned int LLC_data_size = get_cache_size(3, false);
45
46// the test funtion takes jcp, the candidate and the current best.
47// it returns true if the new candidate is better
48int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
49 int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
50{
51 int best_divisor = default_best;
52 auto test_num
53 = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
54 if (test(jcp, num, best_divisor)) {
55 best_divisor = num;
56 }
57 };
58
59 for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
60 if (number % divisor == 0) {
61 test_num(jcp, divisor);
62 test_num(jcp, number / divisor);
63 }
64 }
65
66 return best_divisor;
67}
68
69namespace {
70bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
71 if (jcp.ver == ver_4fma)
72 return jcp.mb >= 32;
73 else
74 return jcp.mb >= 16;
75}
76}
77
78/* assumes 512 bits registers */
79/* TODO: add support for strides */
80/* TODO: handle the prefetch distance automatically */
81typedef enum cache_t_ { L1, L2, L3 } cache_t;
82
83template <typename data_t>
84struct prefetcher_t {
85 prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
86 cache_t cache_type, size_t block_size, /* in number of elements*/
87 int nb_instructions_in_block, int fma_ipc)
88 : cg_(generator)
89 , reg_base_addr_(reg_base_addr)
90 , cache_type_(cache_type)
91 , cache_block_size_(block_size)
92 {
93 nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
94 prefetch_spread_
95 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
96 prefetch_blk_
97 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
98
99 /* assumption: when fetch in Li, data is already in L(i+1) */
100 int cache_latency;
101 switch (cache_type_) {
102 case L1: cache_latency = 14; break;
103 case L2:
104 case L3:
105 default: cache_latency = 250; break;
106 }
107
108 prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
109 }
110
111 void prefetch(int instruction_number)
112 {
113 if (instruction_number % prefetch_spread_ == 0) {
114 for (int i = 0; (i < prefetch_blk_)
115 && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
116 i++, prefetches_issued_++) {
117 prefetch_inst_(cg_->EVEX_compress_addr(
118 reg_base_addr_, (cache_block_size_ * prefetch_distance_)
119 * sizeof(data_t)
120 + (prefetches_issued_ * 64)));
121 }
122 }
123 }
124
125private:
126 void prefetch_inst_(const Xbyak::Address &addr)
127 {
128 switch (cache_type_) {
129 case L1: cg_->prefetcht0(addr); break;
130 case L2: cg_->prefetcht1(addr); break;
131 case L3: cg_->prefetcht2(addr); break;
132 default:
133 break; // TODO: raise an exception or put an assert
134 }
135 }
136
137 jit_generator *cg_;
138 Xbyak::Reg64 reg_base_addr_;
139 cache_t cache_type_;
140 int cache_block_size_ = 0;
141 int nb_cache_lines_to_prefetch_ = 0;
142 int prefetches_issued_ = 0;
143 int prefetch_spread_ = 0;
144 int prefetch_blk_ = 0;
145 int prefetch_distance_ = 0;
146};
147
148// utilities to support kernel parameter selection
149bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
150 int dimM_block, int dimM_simd_block, float C)
151{
152 float lhs = (dimM_block * dimN_reg_block * dimM_simd_block
153 + dimM_block * dimK_block * dimK_reg_block
154 * dimM_simd_block
155 + dimK_block * dimN_reg_block * dimK_reg_block)
156 * (float)sizeof(float);
157 float rhs = C * L1_cache_size;
158 return (lhs < rhs);
159}
160
161bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
162 int dimM_block, int dimM_simd_block, float C)
163{
164 float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block
165 + dimK_block * dimN_reg_block * dimK_reg_block)
166 * (float)sizeof(float);
167 float rhs = C * L1_cache_size;
168 return (lhs < rhs);
169}
170
171bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
172 int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block,
173 float C)
174{
175 float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block
176 + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
177 * dimM_simd_block
178 + nb_dimN_reg_block * dimK_nb_block * dimK_block
179 * dimN_reg_block * dimK_reg_block)
180 * (float)sizeof(float);
181 float rhs = C * L2_cache_size;
182 return (lhs < rhs);
183}
184}
185
186using namespace mkldnn::impl::format_tag;
187using namespace mkldnn::impl::utils;
188using namespace Xbyak;
189
190void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate(
191 bool is_beta_zero)
192{
193 // const int dimK_simd_block = jcp.dimK_reg_block;
194
195 // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
196 // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
197 // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
198 // dimK_reg_block++)
199 // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
200 // C[dimM_block][tile] +=
201 // A[dimM_block][dimK_block][dimK_reg_block] *
202 // broadcast(B[dimK_block][tile][dimK_reg_block]);
203 // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block],
204 // so we load it before the loop on tile
205 // 2) the loop on tile must be fully unrolled. Don't know about the one on
206 // dimK_reg_block. I think it should be
207
208 auto inner_loops = [=]() {
209 Label dimM_block_loop, dimK_block_loop;
210 const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1;
211 const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
212
213 prefetcher_t<float> L1_pf(this, reg_srcB, L1,
214 jcp.dimN_reg_block * jcp.dimK_reg_block,
215 jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
216 fma_ipc);
217 prefetcher_t<float> L2_pf(this, reg_srcB, L2,
218 jcp.dimN_reg_block * jcp.dimK_reg_block,
219 jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
220 fma_ipc);
221
222 if (jcp.dimM_block > 1) {
223 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
224 L(dimM_block_loop);
225 }
226 {
227 // First, we zero the accumulators if first nb_ic iteration,
228 // otherwise we load them
229 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
230 Zmm zmm(jcp.zmm_start + tile);
231 if (is_beta_zero)
232 vpxord(zmm, zmm, zmm);
233 else
234 vmovups(zmm, zword[reg_dstC + 64 * tile]);
235 }
236
237 if (jcp.dimK_block > 1) {
238 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
239 L(dimK_block_loop);
240 }
241 {
242 auto load_A = [=](int reg_idx, int offset) {
243 for (int i = 0; i < inc_dimK_reg_block; i++)
244 vmovups(Zmm(reg_idx + i),
245 zword[reg_srcA + 64 * (offset + i)]);
246 };
247
248 // Used when doing double buffering
249 int next = 0;
250 if (jcp.double_buffering) {
251 load_A(next, 0);
252 }
253 for (int dimK_reg_block = 0;
254 dimK_reg_block < jcp.dimK_reg_block;
255 dimK_reg_block += inc_dimK_reg_block) {
256 int current;
257 /* Loading the next vector from A */
258 current = next;
259 if (jcp.double_buffering) {
260 next = (dimK_reg_block + inc_dimK_reg_block)
261 % (2 * inc_dimK_reg_block);
262 load_A(next, dimK_reg_block + inc_dimK_reg_block);
263 } else {
264 next = 0;
265 load_A(next, dimK_reg_block);
266 }
267 /* Performing the fmas */
268 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
269 Zmm zmm(jcp.zmm_start + tile);
270 if (jcp.ver != ver_avx512_core)
271 L1_pf.prefetch(
272 dimK_reg_block * jcp.dimN_reg_block + tile);
273 if (jcp.ver == ver_4fma)
274 v4fmaddps(zmm, Zmm(current),
275 EVEX_compress_addr(reg_srcB,
276 64 * tile + dimK_reg_block * 4));
277 else
278 vfmadd231ps(zmm, Zmm(current),
279 EVEX_compress_addr(reg_srcB,
280 64 * tile + dimK_reg_block * 4,
281 true));
282 if (jcp.ver != ver_avx512_core)
283 L2_pf.prefetch(
284 dimK_reg_block * jcp.dimN_reg_block + tile);
285 }
286 }
287
288 add(reg_srcA, jcp.dimK_reg_block * 64);
289 add(reg_srcB, jcp.dimN_reg_block * 64);
290 if (jcp.dimK_block > 1) {
291 sub(reg_dimK_block_loop_cnt, 1);
292 jnz(dimK_block_loop);
293 }
294 }
295
296
297 auto store_output = [=](bool output_is_aligned) {
298 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
299 Zmm zmm(jcp.zmm_start + tile);
300 if (output_is_aligned
301 && jcp.dimK_nb_block == 1
302 && (jcp.dimN * jcp.dimM * alpha * alpha
303 * sizeof(float) > 2 * LLC_data_size))
304 vmovntps(zword[reg_dstC + 64 * tile], zmm);
305 else
306 vmovups(zword[reg_dstC + 64 * tile], zmm);
307 }
308 };
309
310 Label unaligned_store, end_store;
311 test(reg_dstC, cpu_isa_traits<avx512_common>::vlen - 1);
312 jnz(unaligned_store, T_NEAR);
313 store_output(true);
314 jmp(end_store, T_NEAR);
315 L(unaligned_store); {
316 store_output(false);
317 }
318 L(end_store);
319
320 if (jcp.dimM_block > 1) {
321 sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
322 add(reg_dstC, jcp.dimN_reg_block * 64);
323 sub(reg_dimM_block_loop_cnt, 1);
324 jnz(dimM_block_loop);
325 }
326 }
327 };
328
329 /* Preamble */
330 preamble();
331
332 /* kernel */
333 inner_loops();
334
335 /* Postamble */
336 postamble();
337 ret();
338}
339
340status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
341 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
342 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
343 const memory_desc_wrapper &dst_d)
344{
345
346 if (mayiuse(avx512_core))
347 return status::unimplemented;
348 else if (!mayiuse(avx512_common))
349 return status::unimplemented;
350 else if (mayiuse(avx512_mic_4ops))
351 jcp.ver = ver_4fma;
352 else
353 jcp.ver = ver_fma;
354
355 jcp.nthr = mkldnn_get_max_threads();
356
357 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
358
359 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
360 jcp.mb = src_d.dims()[0];
361 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
362 jcp.oc_without_padding = jcp.oc;
363 jcp.ic = src_d.dims()[1] / jcp.ngroups;
364 jcp.ih = src_d.dims()[2];
365 jcp.iw = src_d.dims()[3];
366 jcp.oh = dst_d.dims()[2];
367 jcp.ow = dst_d.dims()[3];
368 jcp.kh = weights_d.dims()[with_groups + 2];
369 jcp.kw = weights_d.dims()[with_groups + 3];
370 jcp.t_pad = cd.padding[0][0];
371 jcp.l_pad = cd.padding[0][1];
372 jcp.stride_h = cd.strides[0];
373 jcp.stride_w = cd.strides[1];
374 jcp.dilate_h = cd.dilates[0];
375 jcp.dilate_w = cd.dilates[1];
376 jcp.r_pad = nstl::max(
377 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
378 jcp.b_pad = nstl::max(
379 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
380 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
381 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
382 jcp.ohp = jcp.oh;
383 jcp.owp = jcp.ow;
384
385 bool ok_to_pad_channels = jcp.ngroups == 1;
386 if (ok_to_pad_channels) {
387 jcp.oc = rnd_up(jcp.oc, simd_w);
388 jcp.ic = rnd_up(jcp.ic, simd_w);
389 }
390
391 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
392 is_winograd_faster_than_direct(jcp)))
393 return status::unimplemented;
394
395 // Checking conditions not supported by these kernels
396 if (jcp.ngroups != 1)
397 return status::unimplemented;
398 if ((jcp.kh != 3) || (jcp.kw != 3))
399 return status::unimplemented;
400 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
401 return status::unimplemented;
402 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
403 return status::unimplemented;
404 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
405 return status::unimplemented;
406
407 format_tag_t dat_tag = nChw16c;
408 format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
409 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
410 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
411 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
412
413 if (jcp.src_tag != dat_tag) return status::unimplemented;
414 if (jcp.wei_tag != wei_tag) return status::unimplemented;
415 if (jcp.dst_tag != dat_tag) return status::unimplemented;
416
417 bool layout_consistency = true
418 && jcp.ic <= src_d.padded_dims()[1]
419 && jcp.oc <= dst_d.padded_dims()[1]
420 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
421 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
422 if (!layout_consistency) return status::unimplemented;
423
424 return status::success;
425}
426
427
428status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
429
430 auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
431 int dimN_reg_block, int current_best) {
432 return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK)
433 && (dimN_reg_block < jcp.nb_reg)
434 && (dimN_reg_block < current_best);
435 };
436 jcp.dimN_reg_block = get_divisor_satisfying_cond(
437 jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block);
438
439 if (jcp.dimN_reg_block >= jcp.nb_reg) {
440 auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
441 int dimN_reg_block, int current_best) {
442 return (dimN_reg_block < jcp.nb_reg)
443 && (dimN_reg_block > current_best);
444 };
445
446 jcp.dimN_reg_block = get_divisor_satisfying_cond(
447 jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
448 }
449
450 //********************* Choosing dimK_block **********************//
451 auto test_cond1_dimK_block = [](
452 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
453 return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
454 1, jcp.dimM_simd_block, .75f)
455 && (dimK_block > current_best);
456 };
457
458 auto test_cond1_bis_dimK_block = [](
459 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
460 return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
461 jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f)
462 && (dimK_block > current_best);
463 };
464
465 jcp.dimK_block = get_divisor_satisfying_cond(
466 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
467 // If we are not able to use streams, we fall back to condition [1]
468 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
469 jcp.dimK_block = get_divisor_satisfying_cond(
470 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
471 jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
472
473 //********************* Choosing dimM_block **********************//
474 jcp.dimM_simd_block = 16;
475 /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/
476 auto test_cond1_dimM_block = [](
477 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
478 return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
479 jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f)
480 && (dimM_block > current_best);
481 };
482
483 auto test_cond1_bis_dimM_block = [](
484 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
485 return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
486 jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f)
487 && (dimM_block > current_best);
488 };
489
490 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
491 jcp.dimM_block = get_divisor_satisfying_cond(
492 jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
493 else
494 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
495 jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block);
496 jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
497
498 //******************* Choosing dimN_block *******************//
499 auto test_cond2_dimN_block = [](
500 jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
501 return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
502 jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
503 jcp.dimM_simd_block, .5f)
504 && (dimN_block > current_best);
505 };
506
507 jcp.dimN_block = get_divisor_satisfying_cond(
508 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
509 jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
510 jcp.sched_policy = WSCHED_DATA_W_S_G_D;
511 return status::success;
512}
513
514status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
515 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
516{
517 jcp.dimK_reg_block = 16;
518 jcp.dimM_simd_block = 16;
519
520 // TODO: replace double buffering with nuple buffering to maximize register
521 // usage.
522 // the choice of the number of buffers will then come after choosing
523 // dimN_reg_block
524 jcp.double_buffering = true;
525 if (jcp.double_buffering)
526 jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2);
527 else
528 jcp.zmm_start = 1;
529 jcp.nb_reg = 32 - jcp.zmm_start;
530
531 jcp.dimN = dimN;
532 jcp.dimK = dimK;
533 jcp.dimM = dimM;
534
535 jcp.sched_policy = WSCHED_INVALID;
536 set_wsched_DATA_W_S_G_D_avx512_common(jcp);
537
538 assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D);
539 return status::success;
540}
541
542bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok(
543 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
544 const auto &p = attr.post_ops_;
545
546 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
547 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
548
549 switch (p.len_) {
550 case 0: return true; // no post_ops
551 case 1: return is_relu(0) || is_sum(0); // relu or sum
552 case 2: return (is_sum(0) && is_relu(1)) ||
553 (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
554 case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
555 default: return false;
556 }
557
558 return false;
559}
560
561status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
562 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
563 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
564 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
565 status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
566
567 if (st != status::success)
568 return st;
569
570 // Winograd specific initialization
571 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
572 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
573 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
574
575 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
576
577 if (!post_ops_ok(jcp, attr))
578 return status::unimplemented;
579
580 const auto &p = attr.post_ops_;
581 const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
582 jcp.with_eltwise = eltwise_ind != -1;
583 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
584 jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
585
586 status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
587 jcp.ic_simd_block = jcp.dimK_reg_block;
588 jcp.ic_block = jcp.dimK_block;
589 jcp.nb_ic = jcp.dimK_nb_block;
590 jcp.oc_simd_block = jcp.dimM_simd_block;
591 jcp.oc_block = jcp.dimM_block;
592 jcp.nb_oc = jcp.dimM_nb_block;
593 jcp.tile_block_ur = jcp.dimN_reg_block;
594 jcp.nb_tile_block_ur = jcp.dimN_block;
595 jcp.tile_block = jcp.dimN_nb_block;
596 jcp.tile_4fma_padding = 0; // only relevant for backward weights
597
598 return res;
599}
600
601status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
602 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
603 const memory_desc_wrapper &diff_src_d,
604 const memory_desc_wrapper &weights_d,
605 const memory_desc_wrapper &diff_dst_d)
606{
607 status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
608
609 if (st != status::success)
610 return st;
611
612 jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
613 jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
614 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
615
616 status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
617 jcp.oc_simd_block = jcp.dimK_reg_block;
618 jcp.oc_block = jcp.dimK_block;
619 jcp.nb_oc = jcp.dimK_nb_block;
620 jcp.ic_simd_block = jcp.dimM_simd_block;
621 jcp.ic_block = jcp.dimM_block;
622 jcp.nb_ic = jcp.dimM_nb_block;
623 jcp.tile_block_ur = jcp.dimN_reg_block;
624 jcp.nb_tile_block_ur = jcp.dimN_block;
625 jcp.tile_block = jcp.dimN_nb_block;
626 jcp.tile_4fma_padding = 0; // only relevant for backward weights
627
628 return res;
629}
630
631void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate()
632{
633 auto load_B = [=](int reg_idx, int offset) {
634 for (int i = 0; i < 4; i++) {
635 vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]);
636 }
637 };
638
639 preamble();
640 int curr = 0;
641 for (int j = 0; j < alpha; j++) {
642 for (int i = 0; i < alpha; i++) {
643 int origB_offset = (j * alpha + i) * jcp.dimK_4fma;
644 size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block *
645 jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block *
646 jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float);
647 mov(reg_transB_idx, transB_offset);
648 for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) {
649 /*double buffering to hide load latencies*/
650 int next = (curr + 4) % 8;
651 if (i == 0 && tb == 0) {
652 load_B(0, origB_offset);
653 }
654 if (tb + 4 < (jcp.dimK_4fma -1)) {
655 load_B(next, origB_offset + 4);
656 } else if (i < alpha - 1) {
657 load_B(next, origB_offset + jcp.dimK_4fma);
658 }
659
660 vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1));
661 vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3));
662 vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1));
663 vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3));
664
665 vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9));
666 vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9));
667
668 vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1));
669 vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1));
670
671 vmovntps(zword[reg_transB + reg_transB_idx
672 + sizeof(float) * tb * jcp.dimN_reg_block],
673 Zmm(curr+2));
674 vmovntps(zword[reg_transB + reg_transB_idx
675 + sizeof(float) * (tb + 1) * jcp.dimN_reg_block],
676 Zmm(curr+3));
677 vmovntps(zword[reg_transB + reg_transB_idx
678 + sizeof(float) * (tb + 2) * jcp.dimN_reg_block],
679 Zmm(8));
680 vmovntps(zword[reg_transB + reg_transB_idx
681 + sizeof(float) * (tb + 3) * jcp.dimN_reg_block],
682 Zmm(9));
683 curr = next;
684
685 }
686 }
687 }
688 postamble();
689 ret();
690}
691void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate(
692 bool is_first_tile)
693{
694 // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++)
695 // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++)
696 // for (int nb_tile_block_ur = 0; nb_tile_block_ur <
697 // jcp.nb_tile_block_ur; nb_tile_block_ur++)
698 // for (int tile_block_ur = 0; tile_block_ur <
699 // jcp.tile_block_ur; tile_block_ur++)
700 // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3)
701 // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] +=
702 // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block]
703 // *
704 // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur])
705 auto inner_loops = [=]() {
706 int inc_fma = jcp.ver == ver_4fma ? 4 : 1;
707 const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
708 prefetcher_t<float> L1_pf(this, reg_srcB, L1,
709 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
710 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
711 / inc_fma,
712 fma_ipc);
713 prefetcher_t<float> L2_pf(this, reg_srcB, L2,
714 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
715 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
716 / inc_fma,
717 fma_ipc);
718
719 auto load_A = [=](int reg_idx, int offset) {
720 for (int i = 0; i < inc_fma; i++) {
721 vmovups(Zmm(reg_idx + i),
722 zword[reg_srcA +
723 sizeof(float) * jcp.dimM_simd_block * (offset + i)]);
724 }
725 };
726
727 Label dimM_block_loop, dimK_block_loop, dimN_block_loop;
728 if (jcp.dimM_block > 1) {
729 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
730 L(dimM_block_loop);
731 }
732 { /************* OC_block (M) loop ***********/
733 if (jcp.dimN_block > 1) {
734 mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
735 L(dimN_block_loop);
736 }
737 { /*************** IC_block (N) loop *********/
738 for (int dimN_reg_block = 0;
739 dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
740 Zmm zmm(jcp.zmm_start + dimN_reg_block);
741 if (is_first_tile)
742 vpxord(zmm, zmm, zmm);
743 else
744 vmovups(zmm, zword[reg_dstC +
745 dimN_reg_block * jcp.dimM_simd_block *
746 sizeof(float)]);
747 }
748
749 if (jcp.dimK_block > 1) {
750 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
751 L(dimK_block_loop);
752 }
753 { /************* nb_tile_ur(K) loop ********/
754 int next = 0;
755 if (jcp.double_buffering) {
756 load_A(next, 0);
757 }
758 for (int dimK_reg_block = 0;
759 dimK_reg_block < jcp.dimK_reg_block;
760 dimK_reg_block++) {
761 int srcB_offset = dimK_reg_block * jcp.dimK_4fma
762 * jcp.dimN_reg_block;
763 for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma;
764 dimK_4fma += inc_fma) {
765 int current = next;
766 if (jcp.double_buffering) {
767 next = (dimK_reg_block * jcp.dimK_4fma
768 + dimK_4fma + inc_fma)
769 % (2 * inc_fma);
770 load_A(next, dimK_reg_block * jcp.dimK_4fma
771 + dimK_4fma + inc_fma);
772 } else {
773 next = 0;
774 load_A(next, dimK_reg_block * jcp.dimK_4fma
775 + dimK_4fma);
776 }
777 for (int dimN_reg_block = 0;
778 dimN_reg_block < jcp.dimN_reg_block;
779 ++dimN_reg_block) {
780 L1_pf.prefetch(srcB_offset / inc_fma
781 + dimK_4fma / inc_fma
782 * jcp.dimN_reg_block
783 + dimN_reg_block);
784 L2_pf.prefetch(srcB_offset / inc_fma
785 + dimK_4fma / inc_fma
786 * jcp.dimN_reg_block
787 + dimN_reg_block);
788 if (jcp.ver == ver_4fma) {
789 int srcB_trans_offset = (dimK_4fma / 4) * 64
790 + dimK_4fma % 4;
791 v4fmaddps(
792 Zmm(jcp.zmm_start + dimN_reg_block),
793 Zmm(current),
794 EVEX_compress_addr(reg_srcB,
795 sizeof(float) * (
796 srcB_offset +
797 srcB_trans_offset +
798 (dimN_reg_block % 4) * 16 +
799 (dimN_reg_block / 4) * 4)));
800 } else {
801 vfmadd231ps(
802 Zmm(jcp.zmm_start + dimN_reg_block),
803 Zmm(current),
804 EVEX_compress_addr(reg_srcB,
805 sizeof(float) * (srcB_offset + dimN_reg_block),
806 true));
807 }
808 }
809 }
810 }
811 }
812
813 add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma
814 * jcp.dimM_simd_block * sizeof(float));
815 add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block
816 * jcp.dimK_4fma * sizeof(float));
817 if (jcp.dimK_block > 1) {
818 sub(reg_dimK_block_loop_cnt, 1);
819 jnz(dimK_block_loop);
820 }
821
822 /******** Write C back to memory *******/
823 for (int dimN_reg_block = 0;
824 dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
825 Zmm zmm(jcp.zmm_start + dimN_reg_block);
826 vmovups(zword[reg_dstC +
827 dimN_reg_block * jcp.dimM_simd_block * sizeof(float)],
828 zmm);
829 }
830
831 sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block *
832 jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
833 add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block
834 * sizeof(float));
835 if (jcp.dimN_block > 1) {
836 sub(reg_dimN_block_loop_cnt, 1);
837 jnz(dimN_block_loop);
838 }
839 }
840
841 if (jcp.dimM_block > 1) {
842 sub(reg_srcB, jcp.dimN_block * jcp.dimK_block
843 * jcp.dimK_reg_block * jcp.dimN_reg_block
844 * jcp.dimK_4fma * sizeof(float));
845 add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
846 * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
847 sub(reg_dimM_block_loop_cnt, 1);
848 jnz(dimM_block_loop);
849 }
850 }
851 };
852
853 /* Preamble */
854 // register used to handle long fma encoding
855 preamble();
856 mov(reg_srcA, reg_srcA_const);
857 inner_loops();
858
859 /* Postamble */
860 postamble();
861 ret();
862}
863
864namespace {
865bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block,
866 int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
867{
868 float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw;
869 lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw;
870 lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
871 lhs *= sizeof(float);
872 float rhs = C * L1_cache_size;
873 return (lhs <= rhs);
874}
875
876bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
877 int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
878{
879 float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma
880 * dimM_simdw;
881 lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
882 lhs *= sizeof(float);
883 float rhs = C * L1_cache_size;
884 return (lhs <= rhs);
885}
886
887bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
888 int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
889 float C)
890{
891 float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block
892 * dimK_4fma;
893 lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
894 * dimN_reg_block;
895 lhs *= sizeof(float);
896 float rhs = C * L2_cache_size;
897 return (lhs <= rhs);
898}
899
900bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block,
901 int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
902 float C)
903{
904 float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block;
905 lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma;
906 lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
907 * dimN_reg_block;
908 lhs *= sizeof(float);
909 float rhs = C * L2_cache_size;
910 return (lhs <= rhs);
911}
912} // namespace
913
914status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
915{
916 /*************** Choose dimN_reg_block (ic_simd_block)
917 * *******************************/
918 jcp.dimN = jcp.ic;
919 /*Hardcoded to 16 because N = ic for bwd weights and
920 innermost dimension for ic is assumed 16 in src transforms. This
921 choice covers load latencies while maintaining simplicity of kernel
922 for POR topologies. FIXME in future??: Will not work for future topologies
923 when ic%16 != 0*/
924 jcp.dimN_reg_block = jcp.ic_simd_block;
925
926 /****************************** Choose dimK_block
927 * **************************/
928 // No freedom for choosing dimM_simd_block because ic_simd_block
929 // is determined by input data format
930 jcp.dimM_simd_block = jcp.oc_simd_block;
931
932 auto test_cond1bis_dimK_block = [](
933 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
934 return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
935 jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
936 && (dimK_block > current_best);
937 };
938
939 auto test_cond1_dimK_block = [](
940 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
941 return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1,
942 jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
943 && (dimK_block > current_best);
944 };
945
946 auto test_cond2bis_dimK_block = [](
947 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
948 return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
949 jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f)
950 && (dimK_block > current_best);
951 };
952
953 auto test_cond2_dimK_block = [](
954 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
955 return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1,
956 jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f)
957 && (dimK_block > current_best);
958 };
959
960 jcp.dimK_block = get_divisor_satisfying_cond(
961 jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block);
962 if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma)
963 jcp.dimK_block = get_divisor_satisfying_cond(
964 jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block);
965
966 jcp.dimK_reg_block = get_divisor_satisfying_cond(
967 jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block);
968 if (jcp.dimK_reg_block < jcp.dimK_block) {
969 jcp.dimK_reg_block = get_divisor_satisfying_cond(
970 jcp, jcp.dimK_block, 1, test_cond1_dimK_block);
971 }
972 jcp.dimK_block /= jcp.dimK_reg_block;
973 jcp.dimK_nb_block
974 = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block;
975 jcp.tile_block_ur = jcp.dimK_reg_block;
976 jcp.nb_tile_block_ur = jcp.dimK_block;
977 jcp.tile_block = jcp.dimK_nb_block;
978
979 /***************************** Chose dimN block
980 * ****************************/
981 auto test_cond2_dimN_block = [](
982 jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
983 return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block,
984 jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block,
985 jcp.dimN_reg_block, 0.5f)
986 && (dimN_block > current_best);
987 };
988
989 jcp.dimN_block = get_divisor_satisfying_cond(
990 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
991 jcp.ic_block = jcp.dimN_block;
992 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block;
993 jcp.nb_ic = jcp.dimN_nb_block;
994
995 /********************************* Choose dimM block
996 * ************************/
997 jcp.dimM = jcp.oc;
998
999 auto test_cond1_dimM_block = [](
1000 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1001 return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1,
1002 jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block,
1003 1.0f)
1004 && (dimM_block > current_best)
1005 && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2;
1006 };
1007
1008 jcp.dimM_block = get_divisor_satisfying_cond(
1009 jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
1010 jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
1011
1012 jcp.sched_policy = WSCHED_WEI_S_D_G_W;
1013 return status::success;
1014}
1015
1016status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
1017 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1018 const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
1019 const memory_desc_wrapper &diff_weights_d)
1020{
1021 jcp.nthr = mkldnn_get_max_threads();
1022
1023 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1024
1025 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1026 jcp.mb = src_d.dims()[0];
1027 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1028 jcp.oc_without_padding = jcp.oc;
1029 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1030 jcp.ih = src_d.dims()[2];
1031 jcp.iw = src_d.dims()[3];
1032 jcp.oh = diff_dst_d.dims()[2];
1033 jcp.ow = diff_dst_d.dims()[3];
1034 jcp.kh = diff_weights_d.dims()[with_groups + 2];
1035 jcp.kw = diff_weights_d.dims()[with_groups + 3];
1036 jcp.t_pad = cd.padding[0][0];
1037 jcp.l_pad = cd.padding[0][1];
1038 jcp.stride_h = cd.strides[0];
1039 jcp.stride_w = cd.strides[1];
1040 jcp.r_pad = nstl::max(
1041 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1042 jcp.b_pad = nstl::max(
1043 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1044 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1045 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1046 jcp.ohp = jcp.oh;
1047 jcp.owp = jcp.ow;
1048 jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef);
1049 jcp.dilate_h = cd.dilates[0];
1050 jcp.dilate_w = cd.dilates[1];
1051
1052 bool ok_to_pad_channels = jcp.ngroups == 1;
1053 if (ok_to_pad_channels) {
1054 jcp.oc = rnd_up(jcp.oc, simd_w);
1055 jcp.ic = rnd_up(jcp.ic, simd_w);
1056 }
1057
1058 if (mayiuse(avx512_core))
1059 return status::unimplemented;
1060 if (!mayiuse(avx512_common))
1061 return status::unimplemented;
1062 else if (mayiuse(avx512_mic_4ops))
1063 jcp.ver = ver_4fma;
1064 else
1065 jcp.ver = ver_fma;
1066
1067 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
1068 is_winograd_faster_than_direct(jcp)))
1069 return status::unimplemented;
1070 // Winograd specific initialization
1071 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1072 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1073 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1074
1075 // Winograd kernel works only for 3x3 convolution with stride 1
1076 if (jcp.ngroups != 1)
1077 return status::unimplemented;
1078 if ((jcp.kh != 3) || (jcp.kw != 3))
1079 return status::unimplemented;
1080 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
1081 return status::unimplemented;
1082 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
1083 return status::unimplemented;
1084 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
1085 return status::unimplemented;
1086
1087 format_tag_t dat_tag = nChw16c;
1088 format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
1089 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
1090 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
1091 jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
1092
1093 if (jcp.src_tag != dat_tag) return status::unimplemented;
1094 if (jcp.wei_tag != wei_tag) return status::unimplemented;
1095 if (jcp.dst_tag != dat_tag) return status::unimplemented;
1096
1097 bool layout_consistency = true
1098 && jcp.ic <= src_d.padded_dims()[1]
1099 && jcp.oc <= diff_dst_d.padded_dims()[1]
1100 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
1101 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
1102 if (!layout_consistency) return status::unimplemented;
1103
1104 /*************************** New Kernel Parameters
1105 * *****************************/
1106 jcp.ic_simd_block = simd_w;
1107 jcp.oc_simd_block = simd_w;
1108 jcp.dimK_4fma = 1;
1109 jcp.tile_4fma_padding = 0;
1110
1111#define MAX_4FMA_UR 8
1112 if (jcp.ver == ver_4fma) {
1113 auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma,
1114 int current_best) {
1115 return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR)
1116 && (dimK_4fma > current_best);
1117 };
1118 jcp.dimK_4fma = get_divisor_satisfying_cond(
1119 jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma);
1120 if (jcp.dimK_4fma == 1)
1121 jcp.dimK_4fma = 4;
1122 if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0)
1123 jcp.tile_4fma_padding = jcp.dimK_4fma
1124 - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma);
1125 }
1126
1127 jcp.tile_4fma = jcp.dimK_4fma;
1128 /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src
1129 * transform
1130 * will not work correctly, this is solved by applying padding.*/
1131 jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
1132 jcp.dimN = jcp.ic;
1133 jcp.dimM = jcp.oc;
1134
1135 jcp.double_buffering = true;
1136 if (jcp.double_buffering)
1137 jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2;
1138 else
1139 jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1;
1140 jcp.nb_reg = 32 - jcp.zmm_start;
1141
1142 jcp.sched_policy = WSCHED_INVALID;
1143 status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp);
1144 assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W);
1145
1146 jcp.tile_block_ur = jcp.dimK_reg_block;
1147 jcp.nb_tile_block_ur = jcp.dimK_block;
1148 jcp.tile_block = jcp.dimK_nb_block;
1149
1150 jcp.ic_block = jcp.dimN_block;
1151 jcp.nb_ic = jcp.dimN_nb_block;
1152
1153 jcp.oc_block = jcp.dimM_block;
1154 jcp.nb_oc = jcp.dimM_nb_block;
1155
1156 return res;
1157
1158}
1159}
1160}
1161}
1162
1163// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
1164