1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3* Copyright 2018 YANDEX LLC
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include "c_types_map.hpp"
19#include "nstl.hpp"
20#include "type_helpers.hpp"
21#include "utils.hpp"
22#include "cpu_memory.hpp"
23
24#include "jit_avx2_conv_kernel_f32.hpp"
25
26#define GET_OFF(field) offsetof(jit_conv_call_s, field)
27
28namespace mkldnn {
29namespace impl {
30namespace cpu {
31
32using namespace mkldnn::impl::prop_kind;
33using namespace mkldnn::impl::format_tag;
34using namespace mkldnn::impl::memory_tracking::names;
35using namespace mkldnn::impl::utils;
36
37using namespace Xbyak;
38
39void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
40 int pad_l, int pad_r, int oc_blocks)
41{
42 int iw = jcp.iw;
43 int ih = jcp.ih;
44 int id = jcp.id;
45 int kw = jcp.kw;
46 int kh = jcp.kh;
47 int kd = jcp.kd;
48 int nb_ic = jcp.nb_ic;
49 int stride_w = jcp.stride_w;
50 int dilate_w = jcp.dilate_w + 1;
51 int ic_blk = jcp.ic_block;
52 int oc_blk = jcp.oc_block;
53
54 for (int ki = 0; ki < kw; ki++) {
55 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
56 int jj_end = ur_w
57 - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
58 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
59 for (int jj = jj_start; jj < jj_end; jj++) {
60 size_t inp_off;
61 if (one_of(jcp.src_tag, ncw, nchw, ncdhw))
62 inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw
63 + (ki*dilate_w + jj*stride_w - pad_l));
64 else
65 inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w
66 - pad_l)*ic_blk + ifm2);
67 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
68 make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
69 }
70
71 for (int ii = 0; ii < oc_blocks; ii++) {
72 int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk
73 + ki * ic_blk * oc_blk + ifm2 * oc_blk;
74 vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]);
75 for (int jj = jj_start; jj < jj_end; jj++)
76 if (mayiuse(avx2))
77 vfmadd231ps(Ymm(ur_w * ii + jj),
78 Ymm(oc_blocks * ur_w + jj), ymm15);
79 else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
80 vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
81 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
82 }
83 }
84 }
85 }
86}
87
88void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
89 int pad_l, int pad_r, char pad_tag,
90 int oc_blocks, char oc_blocks_tag)
91{
92 Label kw_loop;
93
94 int iw = jcp.iw;
95 int ih = jcp.ih;
96 int id = jcp.id;
97 int kw = jcp.kw;
98 int kh = jcp.kh;
99 int kd = jcp.kd;
100 int nb_ic = jcp.nb_ic;
101 int stride_w = jcp.stride_w;
102 int dilate_w = jcp.dilate_w + 1;
103 int ic_blk = jcp.ic_block;
104 int oc_blk = jcp.oc_block;
105
106 xor_(ki_iter, ki_iter);
107 L(kw_loop);
108 {
109 int jj_start = 0;
110 int jj_end = ur_w;
111 for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
112 for (int jj = jj_start; jj < jj_end; jj++) {
113 size_t inp_off;
114 if (one_of(jcp.src_tag, ncw, nchw, ncdhw))
115 inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw
116 + (jj * stride_w - pad_l));
117 else
118 inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk
119 + ifm2);
120 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
121 make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
122 }
123 for (int ii = 0; ii < oc_blocks; ii++) {
124 int aux_kernel_offset =
125 ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk;
126 vmovups(ymm15, ptr[aux_reg_kernel
127 + sizeof(float) * aux_kernel_offset]);
128 for (int jj = jj_start; jj < jj_end; jj++)
129 if (mayiuse(avx2))
130 vfmadd231ps(Ymm(ur_w * ii + jj),
131 Ymm(oc_blocks * ur_w + jj), ymm15);
132 else { // Intel AVX support
133 vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
134 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
135 }
136 }
137 }
138 add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
139 add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
140 ? dilate_w : ic_blk * dilate_w));
141
142 inc(ki_iter);
143 cmp(ki_iter, kw);
144 jl(kw_loop, T_NEAR);
145 }
146}
147
148void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
149 int pad_l, int pad_r, char pad_tag,
150 int oc_blocks, char oc_blocks_tag)
151{
152 int iw = jcp.iw;
153 int kw = jcp.kw;
154 int ow = jcp.ow;
155 int oh = jcp.oh;
156 int od = jcp.od;
157 int dilate_h = jcp.dilate_h + 1;
158 int dilate_w = jcp.dilate_w + 1;
159 int ic_blk = jcp.ic_block;
160 int oc_blk = jcp.oc_block;
161 const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
162 ? 1 : ic_blk;
163 const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw)
164 ? dilate_w : ic_blk * dilate_w;
165
166 Label init_done, init_first;
167
168 if (!jcp.with_sum) {
169 test(reg_ci_flag, FLAG_IC_FIRST);
170 jne(init_first, T_NEAR);
171 }
172
173 for (int ii = 0; ii < oc_blocks; ii++) {
174 for (int jj = 0; jj < ur_w; jj++) {
175 size_t offt =
176 sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
177 vmovups(Ymm(ur_w * ii + jj),
178 make_safe_addr(reg_output, offt, reg_long_offt));
179 }
180 }
181
182 if (jcp.with_sum && jcp.with_bias) {
183 test(reg_ci_flag, FLAG_IC_FIRST);
184 je(init_done, T_NEAR);
185
186 for (int ii = 0; ii < oc_blocks; ii++)
187 for (int jj = 0; jj < ur_w; jj++)
188 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
189 yword[reg_bias + sizeof(float) * ii * oc_blk]);
190 }
191
192 jmp(init_done);
193
194 L(init_first);
195 if (this->jcp.with_bias) {
196 for (int ii = 0; ii < oc_blocks; ii++)
197 for (int jj = 0; jj < ur_w; jj++)
198 vmovups(Ymm(ur_w * ii + jj),
199 yword[reg_bias + sizeof(float) * ii * oc_blk]);
200 } else {
201 for (int ii = 0; ii < oc_blocks; ii++)
202 for (int jj = 0; jj < ur_w; jj++)
203 uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj));
204 }
205
206 L(init_done);
207
208 if (one_of(jcp.ndims, 3, 4)) {
209 mov(aux_reg_input, reg_input);
210 mov(aux_reg_kernel, reg_kernel);
211 }
212
213 Label skip_kh_loop, skip_kd_loop, kd_loop;
214 if (jcp.ndims == 5) {
215 push(reg_output);
216 push(oi_iter);
217
218 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
219 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
220 mov(aux_reg_inp_d, reg_input);
221
222 if ((jcp.dilate_d >= jcp.id)
223 || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
224 cmp(reg_ki, 0);
225 je(skip_kd_loop, T_NEAR);
226 }
227 L(kd_loop);
228 mov(kj, ptr[param1 + GET_OFF(kh_padding)]);
229 } else {
230 mov(kj, reg_kh);
231 }
232
233 if (jcp.ndims == 5) {
234 mov(aux_reg_input, aux_reg_inp_d);
235 mov(aux_reg_kernel, aux_reg_ker_d);
236 }
237
238 if ((jcp.dilate_h >= jcp.ih)
239 || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
240 cmp(kj, 0);
241 je(skip_kh_loop, T_NEAR);
242 }
243 Label kh_loop;
244 L(kh_loop);
245 {
246 if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
247 oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks,
248 oc_blocks_tag);
249 sub(aux_reg_input, sizeof(float) * kw * inp_off);
250 add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
251 } else {
252 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
253 add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
254 add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
255 }
256
257 dec(kj);
258 cmp(kj, 0);
259 jg(kh_loop, T_NEAR);
260 }
261
262 L(skip_kh_loop);
263
264 if (jcp.ndims == 5) {
265 add(aux_reg_inp_d,
266 sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult);
267 add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block
268 * jcp.ic_block);
269
270 dec(reg_ki);
271 cmp(reg_ki, 0);
272 jg(kd_loop, T_NEAR);
273 L(skip_kd_loop);
274
275 pop(oi_iter);
276 pop(reg_output);
277 }
278
279 Label regular_store;
280
281 if (jcp.with_eltwise) {
282 test(reg_ci_flag, FLAG_IC_LAST);
283 je(regular_store, T_NEAR);
284
285 eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w);
286
287 L(regular_store);
288 }
289
290 for (int ii = 0; ii < oc_blocks; ii++) {
291 for (int jj = 0; jj < ur_w; jj++) {
292 const size_t o_off
293 = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
294 Ymm reg_out = Ymm(ur_w * ii + jj);
295 vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out);
296 }
297 }
298}
299
300inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
301 int oc_blocks, char oc_blocks_tag)
302{
303 int ur_w = jcp.ur_w;
304 int ur_w_tail = jcp.ur_w_tail;
305 int n_oi = jcp.ow / ur_w;
306 int iw = jcp.iw;
307 int kw = jcp.kw;
308 int ic_blk = jcp.ic_block;
309 int oc_blk = jcp.oc_block;
310 int dilate_w = jcp.dilate_w + 1;
311 int str_w = jcp.stride_w;
312 const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk;
313
314 int l_pad = jcp.l_pad;
315 int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
316 - (iw + l_pad - 1));
317 int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
318 - (iw + l_pad - 1);
319 if (r_pad1 > 0) n_oi--;
320
321 if (l_pad > 0) {
322 n_oi--;
323 if (n_oi < 0 && r_pad1 > 0)
324 width_blk_step(ur_w, l_pad, r_pad1,
325 'l', oc_blocks, oc_blocks_tag); // "lrpad"
326 else
327 width_blk_step(ur_w, l_pad, 0,
328 'l', oc_blocks, oc_blocks_tag); // "lpad"
329 add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
330 add(reg_output, sizeof(float) * ur_w * oc_blk);
331 }
332
333 Label ow_loop;
334 xor_(oi_iter, oi_iter);
335
336 if (n_oi > 0) {
337 L(ow_loop);
338
339 width_blk_step(ur_w, 0, 0,
340 'm', oc_blocks, oc_blocks_tag); // "middle"
341 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
342 add(reg_output, sizeof(float) * ur_w * oc_blk);
343
344 inc(oi_iter);
345 cmp(oi_iter, n_oi);
346 jl(ow_loop, T_NEAR);
347 }
348
349 if (r_pad1 > 0 && n_oi >=0) {
350 width_blk_step(ur_w, 0, r_pad1,
351 'r', oc_blocks, oc_blocks_tag); // "rpad"
352 add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
353 add(reg_output, sizeof(float) * ur_w * oc_blk);
354 }
355
356 if (ur_w_tail != 0)
357 width_blk_step(ur_w_tail, 0, r_pad,
358 't', oc_blocks, oc_blocks_tag); // "tail"
359}
360
361void jit_avx2_conv_fwd_kernel_f32::generate()
362{
363 this->preamble();
364
365 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
366 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
367 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
368 if (jcp.with_bias)
369 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
370 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
371 mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
372 mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
373
374 int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
375 Label tail, exit;
376
377 if (jcp.nb_oc > jcp.nb_oc_blocking) {
378 cmp(reg_oc_blocks, jcp.nb_oc_blocking);
379 jne(nb_oc_tail ? tail : exit, T_NEAR);
380
381 solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
382 jmp(exit, T_NEAR);
383
384 if (nb_oc_tail) {
385 L(tail);
386 cmp(reg_oc_blocks, nb_oc_tail);
387 jne(exit, T_NEAR);
388 solve_common(nb_oc_tail, '0' + nb_oc_tail);
389 }
390
391 L(exit);
392 } else if (jcp.nb_oc == jcp.nb_oc_blocking) {
393 solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
394 } else {
395 solve_common(nb_oc_tail, '0' + nb_oc_tail);
396 }
397
398 this->postamble();
399
400 if (jcp.with_eltwise)
401 eltwise_injector_->prepare_table();
402}
403
404bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok(
405 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
406 const auto &p = attr.post_ops_;
407
408 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
409 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
410
411 switch (p.len_) {
412 case 0: return true; // no post_ops
413 case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
414 case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
415 default: return false;
416 }
417
418 return false;
419}
420
421status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
422 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
423 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
424 const primitive_attr_t &attr)
425{
426 if (!mayiuse(avx)) return status::unimplemented;
427
428 jcp.prop_kind = cd.prop_kind;
429
430 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
431 int ndims = src_d.ndims();
432 jcp.ndims = ndims;
433
434 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
435 jcp.mb = src_d.dims()[0];
436
437 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
438 jcp.oc_without_padding = jcp.oc;
439 jcp.ic = src_d.dims()[1] / jcp.ngroups;
440
441 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
442 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
443 jcp.iw = src_d.dims()[ndims-1];
444 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
445 jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2];
446 jcp.ow = dst_d.dims()[ndims-1];
447 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
448 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
449 jcp.kw = weights_d.dims()[with_groups + ndims-1];
450
451 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
452 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
453 jcp.l_pad = cd.padding[0][ndims-3];
454 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
455 jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4];
456 jcp.stride_w = cd.strides[ndims-3];
457
458 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
459 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
460 jcp.dilate_w = cd.dilates[ndims-3];
461
462 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
463 - (jcp.ih + jcp.t_pad - 1);
464
465 if (ndims == 3) {
466 jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
467 jcp.wei_tag = weights_d.matches_one_of_tag(
468 Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
469 jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c);
470 } else if (ndims == 4) {
471 jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
472 jcp.wei_tag = weights_d.matches_one_of_tag(
473 Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
474 jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c);
475 } else if (ndims == 5) {
476 jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c);
477 jcp.wei_tag = weights_d.matches_one_of_tag(
478 Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o);
479 jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c);
480 }
481 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
482
483 if (!post_ops_ok(jcp, attr))
484 return status::unimplemented;
485
486 const auto &p = attr.post_ops_;
487 jcp.with_sum = p.find(primitive_kind::sum) != -1;
488 const int eltwise_ind = p.find(primitive_kind::eltwise);
489 jcp.with_eltwise = eltwise_ind != -1;
490 if (jcp.with_eltwise) {
491 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
492 if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu)
493 return status::unimplemented;
494 }
495
496 const int simd_w = 8;
497 const bool flat = jcp.ic < simd_w;
498 const bool mimo = !flat;
499
500
501 /* Grouped channel offset to support 'non-blocked data' format for
502 * convolution sizes with '(input_channel / ngroups) < simd' */
503 jcp.nonblk_group_off =
504 one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1;
505
506 bool ok_to_pad_channels = true
507 && jcp.ngroups == 1;
508
509 if (ok_to_pad_channels) {
510 jcp.oc = rnd_up(jcp.oc, simd_w);
511 if (mimo)
512 jcp.ic = rnd_up(jcp.ic, simd_w);
513 }
514
515 bool args_ok = true
516 && IMPLICATION(flat, true
517 && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc)
518 && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o,
519 gOdhwi8o))
520 && IMPLICATION(mimo, true
521 && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
522 && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o,
523 OIdhw8i8o, gOIdhw8i8o))
524 && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c);
525 if (!args_ok) return status::unimplemented;
526
527 jcp.ur_h = 1; /* no code-unrolling by h so far */
528 jcp.ur_w = 3;
529
530 jcp.oc_block = simd_w;
531 jcp.nb_oc = jcp.oc / jcp.oc_block;
532
533 jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
534
535 // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively
536 // Thus, we can only assign 14 or 15 YMMs for data storage
537 const int num_avail_regs = mayiuse(avx2) ? 15 : 14;
538 if (!mayiuse(avx2)) {
539 if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) {
540 // current register assignment requires more YMMs than available
541 // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad
542 if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1)
543 jcp.ur_w -= 1;
544 else
545 for (int b = 3; b > 1; b--)
546 if (jcp.nb_oc % b == 0) {
547 jcp.nb_oc_blocking = b;
548 break;
549 }
550 }
551 }
552
553 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
554 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
555
556 args_ok = true
557 && jcp.oc % simd_w == 0
558 && jcp.l_pad <= jcp.ur_w
559 && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
560 || (jcp.stride_w == 1 && jcp.stride_h == 1))
561 && IMPLICATION(mimo, jcp.ic % simd_w == 0);
562 if (!args_ok) return status::unimplemented;
563
564 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
565 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
566
567 if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
568 /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
569 jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
570 nstl::min(jcp.ow, num_avail_regs / 2));
571 jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
572 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
573 /* check again ... */
574 r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
575 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
576 if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
577 return status::unimplemented;
578 }
579 assert(jcp.nb_oc_blocking > 0);
580 assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
581
582 jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
583 jcp.nb_ic = jcp.ic / jcp.ic_block;
584
585 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
586 jcp.nb_ic_blocking = 12;
587 jcp.nb_ic_blocking_max = 16;
588 } else {
589 jcp.nb_ic_blocking = 1;
590 jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
591 }
592
593 return status::success;
594}
595
596void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
597 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
598 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
599 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
600}
601
602void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow,
603 int r_overflow)
604{
605 int kw = jcp.kw;
606 int kh = jcp.kh;
607 int kd = jcp.kd;
608 int iw = jcp.iw;
609 int ih = jcp.ih;
610 int id = jcp.id;
611 int ow = jcp.ow;
612
613 int ic_block = jcp.ic_block;
614 int oc_block = jcp.oc_block;
615 int nb_ic_block = jcp.nb_ic_blocking;
616 int stride_w = jcp.stride_w;
617 int stride_h = jcp.stride_h;
618
619 Label kd_loop, skip_kd_loop;
620 Label oc_loop, skip_oc_loop;
621
622 for (int ii = 0; ii < nb_ic_block; ii++)
623 for (int jj = 0; jj < ur_w; jj++) {
624 uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
625 Ymm(ur_w * ii + jj));
626 }
627
628 if (one_of(jcp.ndims, 3, 4)) {
629 cmp(reg_channel_work, 0);
630 jle(skip_oc_loop, T_NEAR);
631 xor_(reg_channel, reg_channel);
632
633 mov(aux_reg_ddst_oc_loop, reg_ddst);
634 mov(aux_reg_kernel_oc_loop, reg_kernel);
635
636 L(oc_loop);
637 mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
638 mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
639 }
640
641 if (jcp.ndims == 5) {
642 assert(jcp.nb_oc_blocking == 1);
643 push(oi_iter);
644
645 mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
646 mov(aux_reg_dst_d, reg_ddst);
647 mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]);
648
649 L(kd_loop);
650 mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]);
651 } else {
652 mov(kj, reg_kh);
653 }
654
655 if (jcp.ndims == 5) {
656 mov(aux_reg_ddst, aux_reg_dst_d);
657 mov(aux_reg_kernel, aux_reg_ker_d);
658 }
659
660 Label kh_loop, skip_kh_loop;
661 cmp(kj, 0);
662 jle(skip_kh_loop, T_NEAR);
663 L(kh_loop); {
664 for (int ki = 0; ki < kw; ki++) {
665 int jj_start = get_iw_start(ki, l_overflow); // 0;
666 int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
667 for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
668
669 for (int jj = jj_start ; jj < jj_end; jj += stride_w) {
670 int aux_output_offset
671 = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2;
672 vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
673 ptr[aux_reg_ddst
674 + sizeof(float) * aux_output_offset]);
675 }
676
677 for (int ii = 0; ii < nb_ic_block; ii++) {
678 int aux_kernel_offset
679 = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block
680 + ki * jcp.ic_block * jcp.oc_block
681 + ofm2 * jcp.ic_block;
682 vmovups(ymm15,
683 ptr[aux_reg_kernel
684 + sizeof(float) * aux_kernel_offset]);
685 for (int jj = jj_start; jj < jj_end; jj += stride_w)
686 vfmadd231ps(Ymm(ur_w * ii + jj),
687 Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15);
688 }
689 }
690 }
691 add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block
692 * ic_block);
693 sub(aux_reg_ddst, sizeof(float) * ow * oc_block);
694
695 dec(kj);
696 cmp(kj, 0);
697 jg(kh_loop, T_NEAR);
698 }
699 L(skip_kh_loop);
700
701 if (jcp.ndims == 5) {
702 sub(aux_reg_dst_d,
703 sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
704 add(aux_reg_ker_d,
705 sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block);
706
707 dec(reg_ki);
708 cmp(reg_ki, 0);
709 jg(kd_loop, T_NEAR);
710 L(skip_kd_loop);
711
712 pop(oi_iter);
713 }
714
715 if (one_of(jcp.ndims, 3, 4)) {
716 int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow
717 * jcp.oc_block;
718 int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw
719 * jcp.ic * jcp.oc_block;
720
721 add(aux_reg_ddst_oc_loop, ddst_oc_shift);
722 add(aux_reg_kernel_oc_loop, kernel_oc_shift);
723
724 inc(reg_channel);
725 cmp(reg_channel, reg_channel_work);
726 jl(oc_loop, T_NEAR);
727
728 L(skip_oc_loop);
729 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
730 }
731
732 Label no_update_label;
733 cmp(reg_channel, 0);
734 je(no_update_label, T_NEAR);
735 for (int ii = 0; ii < nb_ic_block; ii++) {
736 for (int jj = 0; jj < ur_w; jj++) {
737 size_t offt =
738 sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
739 vmovups(Ymm(15),
740 make_safe_addr(reg_dsrc, offt, reg_long_offt));
741 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
742 Ymm(15));
743
744 }
745 }
746 L(no_update_label);
747
748 for (int ii = 0; ii < nb_ic_block; ii++)
749 for (int jj = 0; jj < ur_w; jj++) {
750 size_t offt =
751 sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
752 vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt),
753 Ymm(ur_w * ii + jj));
754 }
755}
756
757void jit_avx2_conv_bwd_data_kernel_f32::generate() {
758 preamble();
759
760 mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
761 mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
762 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
763 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
764 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
765 mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
766
767 int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block;
768 int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block;
769
770 int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
771 int r_overflow = nstl::max(0, (jcp.kw - 1
772 - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
773 int r_overflow1 = nstl::max(0, (jcp.kw - 1
774 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
775
776 int n_oi = jcp.iw / jcp.ur_w;
777 if (r_overflow1 > 0)
778 n_oi--;
779
780 if (jcp.ur_w == jcp.iw) {
781 compute_loop(jcp.ur_w, l_overflow, r_overflow);
782 } else if (n_oi == 0) {
783 compute_loop(jcp.ur_w, l_overflow, r_overflow1);
784 add(reg_dsrc, dsrc_shift);
785 add(reg_ddst, ddst_shift);
786 if (jcp.ur_w_tail != 0)
787 compute_loop(jcp.ur_w_tail, 0, r_overflow);
788 } else {
789 xor_(oi_iter, oi_iter);
790 if (l_overflow > 0) {
791 compute_loop(jcp.ur_w, l_overflow, 0);
792 add(reg_dsrc, dsrc_shift);
793 add(reg_ddst, ddst_shift);
794 inc(oi_iter);
795 }
796
797 if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) {
798 Label ow_loop;
799 L(ow_loop); {
800 compute_loop(jcp.ur_w, 0, 0);
801 add(reg_dsrc, dsrc_shift);
802 add(reg_ddst, ddst_shift);
803 inc(oi_iter);
804 cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR);
805 }
806 }
807
808 if (r_overflow1 > 0 ) {
809 compute_loop(jcp.ur_w, 0, r_overflow1);
810 add(reg_dsrc, dsrc_shift);
811 add(reg_ddst, ddst_shift);
812 }
813
814 if (jcp.ur_w_tail != 0)
815 compute_loop(jcp.ur_w_tail, 0, r_overflow);
816 }
817
818 this->postamble();
819}
820
821status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
822 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
823 const memory_desc_wrapper &weights_d,
824 const memory_desc_wrapper &diff_dst_d)
825{
826 if (!mayiuse(avx2)) return status::unimplemented;
827
828 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
829
830 int ndims = diff_src_d.ndims();
831 jcp.ndims = ndims;
832
833 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
834 jcp.mb = diff_src_d.dims()[0];
835
836 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
837 jcp.oc_without_padding = jcp.oc;
838 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
839
840 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
841 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
842 jcp.iw = diff_src_d.dims()[ndims-1];
843 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
844 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
845 jcp.ow = diff_dst_d.dims()[ndims-1];
846
847 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
848 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
849 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
850
851 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
852 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
853 jcp.l_pad = cd.padding[0][ndims-3];
854
855 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
856 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
857 jcp.stride_w = cd.strides[ndims-3];
858
859 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
860 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
861 jcp.dilate_w = cd.dilates[ndims-3];
862
863 const int simd_w = 8;
864
865 /* derivatives */
866 jcp.idp = jcp.id + 2 * jcp.f_pad;
867 jcp.ihp = jcp.ih + 2 * jcp.t_pad;
868 jcp.iwp = jcp.iw + 2 * jcp.l_pad;
869 jcp.ohp = jcp.oh; /* do we really need */
870 jcp.owp = jcp.ow; /* padded output ??? */
871
872 bool ok_to_pad_channels = true
873 && jcp.ngroups == 1;
874
875 /* gemm-based convolution performs better in these cases */
876 if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
877 return status::unimplemented;
878
879 if (ok_to_pad_channels) {
880 jcp.oc = rnd_up(jcp.oc, simd_w);
881 jcp.ic = rnd_up(jcp.ic, simd_w);
882 }
883
884 jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w;
885 jcp.nb_ic = jcp.ic / jcp.ic_block;
886
887 jcp.oc_block = simd_w;
888 if (jcp.oc % jcp.oc_block) return status::unimplemented;
889 jcp.nb_oc = jcp.oc / jcp.oc_block;
890
891 jcp.ur_h = 1; /* no code-unrolling by h so far */
892 jcp.nb_ic_blocking = 1;
893 jcp.nb_oc_blocking = 1;
894 jcp.ur_w = 1;
895
896 if(one_of(ndims, 3, 4) && jcp.ow < 40)
897 jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
898
899 if (ndims == 3) {
900 jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c);
901 jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i);
902 jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c);
903 } else if (ndims == 4) {
904 jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c);
905 jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i);
906 jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c);
907 } else if (ndims == 5) {
908 jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c);
909 jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i);
910 jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c);
911 }
912
913 bool args_ok = true
914 && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
915 && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i,
916 gOIdhw8o8i, OIdhw8o8i)
917 && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c)
918 && jcp.stride_w == jcp.stride_h
919 && jcp.stride_d == 1
920 && jcp.dilate_d == 0
921 && jcp.dilate_h == 0
922 && jcp.dilate_w == 0
923 && jcp.ic % simd_w == 0
924 && jcp.oc % simd_w == 0
925 && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1
926 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
927 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
928 if (!args_ok) return status::unimplemented;
929 jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad;
930 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad;
931 int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
932
933 const int max_regs = 15; /* Maximun number of registers available for
934 result accumulation and delta dst data.
935 One additional register is reserved for weights
936 data. */
937
938 /* Find the best blocking with maximum number of fma instructions
939 per ur_w * nb_ic_blocking compute loops. Number of required registers
940 is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
941 ur_w must be divisible by stride_w */
942 if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
943 distribution exceeds max_regs */
944 return status::unimplemented;
945
946 int best_nfmas = 0;
947 for (int b = 1; b <= 4; b++)
948 {
949 if (jcp.nb_ic % b != 0)
950 continue;
951
952 for (int u = jcp.stride_w;
953 u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w;
954 u += jcp.stride_w)
955 {
956 int ur_w = nstl::min(u, jcp.iw);
957 /* maximum 1 step with l_overflow so far */
958 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
959 continue;
960 int nfmas = utils::div_up(ur_w, jcp.stride_w) * b;
961 if (nfmas > best_nfmas
962 || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
963 jcp.ur_w = ur_w;
964 jcp.nb_ic_blocking = b;
965 best_nfmas = nfmas;
966 }
967 }
968 }
969 if (best_nfmas == 0) /* can't find appropriate blocking */
970 return status::unimplemented;
971
972 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
973
974 int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
975 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
976 /* maximum 1 ur_w block with r_overflow so far */
977 if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
978 return status::unimplemented;
979
980 if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
981 return status::unimplemented;
982
983 return status::success;
984}
985
986void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
987 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
988 UNUSED(scratchpad);
989 UNUSED(jcp);
990}
991
992void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
993 this->preamble();
994
995 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
996 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
997 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
998 compute_oh_loop_common();
999 this->postamble();
1000}
1001
1002status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
1003 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
1004 const memory_desc_wrapper &diff_weights_d,
1005 const memory_desc_wrapper &diff_dst_d) {
1006 if (!mayiuse(avx2)) return status::unimplemented;
1007
1008 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1009 int ndims = src_d.ndims();
1010 jcp.ndims = ndims;
1011
1012 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1013 jcp.mb = src_d.dims()[0];
1014
1015 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1016 jcp.oc_without_padding = jcp.oc;
1017 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1018
1019 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1020 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1021 jcp.iw = src_d.dims()[ndims-1];
1022 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1023 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
1024 jcp.ow = diff_dst_d.dims()[ndims-1];
1025
1026 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
1027 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
1028 jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
1029
1030 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1031 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1032 jcp.l_pad = cd.padding[0][ndims-3];
1033
1034 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1035 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1036 jcp.stride_w = cd.strides[ndims-3];
1037
1038 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1039 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1040 jcp.dilate_w = cd.dilates[ndims-3];
1041
1042 if (ndims == 3) {
1043 jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
1044 jcp.wei_tag = diff_weights_d.matches_one_of_tag(
1045 Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
1046 jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c);
1047 } else if (ndims == 4) {
1048 jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
1049 jcp.wei_tag = diff_weights_d.matches_one_of_tag(
1050 Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
1051 jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c);
1052 } else if (ndims == 5) {
1053 jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c);
1054 jcp.wei_tag = diff_weights_d.matches_one_of_tag(
1055 Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o);
1056 jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c);
1057 }
1058 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
1059
1060 const bool flat = jcp.ic == 3;
1061 const bool mimo = !flat;
1062
1063 const int simd_w = 8;
1064
1065 jcp.b_pad = nstl::max(
1066 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1067 jcp.r_pad = nstl::max(
1068 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1069
1070 int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id
1071 - jcp.f_pad);
1072 if (ndims == 5)
1073 if (jcp.f_pad != 0 || back_pad != 0)
1074 return status::unimplemented;
1075
1076 const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1);
1077 const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1);
1078 const bool boundaries_ok = true
1079 && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad
1080 && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad;
1081 if (!boundaries_ok)
1082 return status::unimplemented;
1083
1084 bool ok_to_pad_channels = true
1085 && jcp.ngroups == 1;
1086
1087 if (ok_to_pad_channels) {
1088 jcp.oc = rnd_up(jcp.oc, simd_w);
1089 if (mimo)
1090 jcp.ic = rnd_up(jcp.ic, simd_w);
1091 }
1092
1093 bool args_ok = true
1094 && IMPLICATION(flat, true
1095 && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc)
1096 && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o,
1097 gOdhwi8o))
1098 && IMPLICATION(mimo, true
1099 && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
1100 && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o,
1101 OIdhw8i8o, gOIdhw8i8o))
1102 && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c)
1103 && IMPLICATION(mimo, jcp.ic % simd_w == 0)
1104 && jcp.oc % simd_w == 0
1105 && jcp.kw < 14
1106 && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */
1107 && jcp.kh <= jcp.ih /* [bwd_w:r2] */
1108 && jcp.kd <= jcp.f_pad + jcp.id
1109 && jcp.kd <= jcp.id
1110 && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */
1111 && jcp.dilate_d == 0
1112 && jcp.dilate_h == 0
1113 && jcp.dilate_w == 0;
1114 if (!args_ok) return status::unimplemented;
1115
1116 jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
1117 jcp.nb_ic = jcp.ic / jcp.ic_block;
1118
1119 jcp.oc_block = simd_w;
1120 jcp.nb_oc = jcp.oc / jcp.oc_block;
1121 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1122
1123 return status::success;
1124}
1125
1126void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
1127 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1128 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
1129 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
1130}
1131
1132inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
1133{
1134 Label kd_comeback_loop;
1135 mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0
1136 L(kd_comeback_loop); {
1137 const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
1138 ? 1 : jcp.ic_block;
1139 sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult);
1140 sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block
1141 * jcp.oc_block);
1142 dec(kj);
1143 cmp(kj, 0);
1144 jg(kd_comeback_loop, T_NEAR);
1145 }
1146}
1147
1148inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
1149{
1150 mov(kj, reg_kh);
1151 Label kh_comeback_loop;
1152 L(kh_comeback_loop); {
1153 const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
1154 ? 1 : jcp.ic_block;
1155 sub(reg_input, sizeof(float) * jcp.iw * inp_mult);
1156 sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block);
1157 dec(kj);
1158 cmp(kj, 0);
1159 jg(kh_comeback_loop, T_NEAR);
1160 }
1161}
1162
1163inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1164 int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
1165 int kernel_offset, int output_offset)
1166{
1167 const int kw = jcp.kw;
1168 const int ic_block = jcp.ic_block;
1169 const int oc_block = jcp.oc_block;
1170 for (int i_kw = 0; i_kw < kw; i_kw++)
1171 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1172 size_t off
1173 = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
1174 + kernel_offset;
1175 vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]);
1176 }
1177
1178 for (int i_ur = 0; i_ur < ur_w; i_ur++) {
1179 vmovups(Ymm(kw * ic_block_step + 0),
1180 yword[reg_output
1181 + sizeof(float) * i_ur * oc_block + output_offset]);
1182
1183 for (int i_kw = 0; i_kw < kw; i_kw++) {
1184 int i_iw = i_ur * jcp.stride_w + i_kw;
1185 if (i_iw - pad_l < 0
1186 || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r)
1187 continue;
1188 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1189 size_t i_off = (size_t)input_offset + sizeof(float)*(
1190 one_of(jcp.src_tag, ncw, nchw, ncdhw)
1191 ? (i_iw - pad_l) + i_ic
1192 * ((size_t)jcp.id * jcp.ih * jcp.iw)
1193 : (i_iw - pad_l) * ic_block + i_ic);
1194 vbroadcastss(Ymm(kw * ic_block_step + 1),
1195 make_safe_addr(reg_input, i_off, reg_long_offt));
1196 vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic),
1197 Ymm(kw * ic_block_step + 0),
1198 Ymm(kw * ic_block_step + 1));
1199 }
1200 }
1201 }
1202
1203 for (int i_kw = 0; i_kw < kw; i_kw++)
1204 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1205 size_t off
1206 = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
1207 + kernel_offset;
1208 vmovups(yword[reg_kernel + off],
1209 Ymm(i_kw * ic_block_step + i_ic));
1210 }
1211}
1212
1213inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp()
1214{
1215 int ic_block_step;
1216 if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
1217 ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block;
1218 } else {
1219 ic_block_step = jcp.kw > 7 ? 1
1220 : jcp.kw > 3 ? 2
1221 : jcp.kw > 1 ? 4 : 8;
1222 }
1223
1224 const int max_ur_w = jcp.ow > 56 ? 14 : 28;
1225
1226 if (jcp.ow <= max_ur_w)
1227 compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
1228 else
1229 compute_oh_step_common(ic_block_step, max_ur_w);
1230
1231 if (jcp.ndims == 5) {
1232 od_step_comeback_pointers();
1233 mov(reg_input, aux_reg_input);
1234 mov(reg_kernel, aux_reg_kernel);
1235 } else {
1236 oh_step_comeback_pointers();
1237 }
1238}
1239
1240inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow(
1241 int ic_block_step, int max_ur_w)
1242{
1243 UNUSED(max_ur_w);
1244
1245 const int ic_block = jcp.ic_block;
1246 const int oc_block = jcp.oc_block;
1247 int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
1248 Label kd_loop;
1249
1250 const int r_pad
1251 = nstl::max(0,
1252 (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1253
1254 if (jcp.ndims == 5) {
1255 mov(aux_reg_input, reg_input);
1256 mov(aux_reg_kernel, reg_kernel);
1257 mov(ki, jcp.kd);
1258 L(kd_loop);
1259 mov(reg_input, aux_reg_input);
1260 mov(reg_kernel, aux_reg_kernel);
1261 }
1262
1263 mov(kj, reg_kh);
1264 Label kh_loop;
1265 L(kh_loop); {
1266 xor_(b_ic, b_ic);
1267 Label ic_block_loop;
1268 L(ic_block_loop); {
1269 compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0,
1270 0, 0);
1271 size_t inp_icblk_stride = sizeof(float) * ic_block_step
1272 * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
1273 ? jcp.id*jcp.ih*jcp.iw : 1);
1274 safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1275 add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
1276 add(b_ic, ic_block_step);
1277 cmp(b_ic, ic_block);
1278 jl(ic_block_loop, T_NEAR);
1279 }
1280 if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
1281 size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
1282 safe_sub(reg_input, offt, reg_long_offt);
1283 add(reg_input, sizeof(float) * jcp.iw);
1284 } else {
1285 add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
1286 }
1287 add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
1288 dec(kj);
1289 cmp(kj, 0);
1290 jg(kh_loop, T_NEAR);
1291 }
1292
1293 if (jcp.ndims == 5) {
1294 add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
1295 add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
1296 * oc_block);
1297 dec(ki);
1298 cmp(ki, 0);
1299 jg(kd_loop, T_NEAR);
1300 }
1301
1302}
1303
1304inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common(
1305 int ic_block_step, int max_ur_w)
1306{
1307 const int ic_block = jcp.ic_block;
1308 const int oc_block = jcp.oc_block;
1309 const int stride_w = jcp.stride_w;
1310 int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
1311 Label kd_loop;
1312
1313 const int r_pad = jcp.r_pad;
1314
1315 int ur_w = nstl::min(jcp.ow, max_ur_w);
1316 int ur_w_trips = jcp.ow / ur_w;
1317 int ur_w_tail = jcp.ow % ur_w;
1318 if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) {
1319 if (ur_w_trips > 1) {
1320 ur_w_tail += ur_w;
1321 ur_w_trips--;
1322 } else {
1323 ur_w_tail += (ur_w - ur_w / 2);
1324 ur_w = ur_w / 2;
1325 }
1326 }
1327 const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block;
1328
1329 int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult;
1330 int output_comeback = ur_w_trips * ur_w * oc_block;
1331
1332 if (jcp.ndims == 5) {
1333 mov(aux_reg_input, reg_input);
1334 mov(aux_reg_kernel, reg_kernel);
1335 mov(ki, jcp.kd);
1336 L(kd_loop);
1337 mov(reg_input, aux_reg_input);
1338 mov(reg_kernel, aux_reg_kernel);
1339 }
1340
1341 mov(kj, reg_kh);
1342 Label kh_loop;
1343 L(kh_loop); {
1344 xor_(b_ic, b_ic);
1345 Label ic_block_loop;
1346 L(ic_block_loop); {
1347 if (jcp.l_pad != 0) {
1348 ur_w_trips--;
1349 compute_ic_block_step(ur_w,
1350 jcp.l_pad, 0, ic_block_step, 0, 0, 0);
1351 add(reg_input, sizeof(float)
1352 * (ur_w * stride_w - jcp.l_pad) * inp_mult);
1353 add(reg_output, sizeof(float) * ur_w * oc_block);
1354 }
1355
1356 if (ur_w_trips > 0) {
1357 xor_(reg_ur_w_trips, reg_ur_w_trips);
1358 Label ow_block_loop;
1359 L(ow_block_loop); {
1360 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
1361 add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult);
1362 add(reg_output, sizeof(float) * ur_w * oc_block);
1363
1364 inc(reg_ur_w_trips);
1365 cmp(reg_ur_w_trips, ur_w_trips);
1366 jl(ow_block_loop, T_NEAR);
1367 }
1368 }
1369
1370 if (ur_w_tail > 0)
1371 compute_ic_block_step(ur_w_tail,
1372 0, r_pad, ic_block_step, 0, 0, 0);
1373
1374 sub(reg_input, sizeof(float) * input_comeback);
1375 sub(reg_output, sizeof(float) * output_comeback);
1376
1377 size_t inp_icblk_stride = sizeof(float) * ic_block_step
1378 * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
1379 ? jcp.id*jcp.ih*jcp.iw : 1);
1380 safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1381 add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
1382
1383 add(b_ic, ic_block_step);
1384 cmp(b_ic, jcp.ic_block);
1385 jl(ic_block_loop, T_NEAR);
1386 }
1387 if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
1388 size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
1389 safe_sub(reg_input, offt, reg_long_offt);
1390 add(reg_input, sizeof(float) * jcp.iw);
1391 } else {
1392 add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
1393 }
1394 add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
1395 dec(kj);
1396 cmp(kj, 0);
1397 jg(kh_loop, T_NEAR);
1398 }
1399
1400 if (jcp.ndims == 5) {
1401 add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
1402 add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
1403 * oc_block);
1404 dec(ki);
1405 cmp(ki, 0);
1406 jg(kd_loop, T_NEAR);
1407 }
1408
1409}
1410
1411inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common()
1412{
1413 const int icoc_block = jcp.ic_block * jcp.oc_block;
1414 const int t_pad = jcp.t_pad;
1415 const int stride_h = jcp.stride_h;
1416 const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
1417 ? 1 : jcp.ic_block;
1418 int b_pad = jcp.b_pad;
1419
1420 Label oh_tpad_loop, oh_loop, oh_loop_end;
1421
1422 mov(reg_kh, jcp.kh);
1423 xor_(reg_ih_count, reg_ih_count);
1424 xor_(reg_oj, reg_oj);
1425 if (t_pad > 0) {
1426 assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */
1427 mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih);
1428 add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block);
1429
1430 L(oh_tpad_loop); {
1431 compute_oh_step_disp();
1432 add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1433 sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block);
1434
1435 inc(reg_oj);
1436 add(reg_ih_count, stride_h);
1437 add(reg_kh, stride_h);
1438
1439 /* the overlap between input and kernel may not reach kernel size.
1440 * so far we do not support that (until we put constant here) */
1441 const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */
1442 cmp(reg_kh, final_inp_ker_overlap);
1443 jl(oh_tpad_loop, T_NEAR);
1444 }
1445
1446 if (t_pad % stride_h != 0) {
1447 int inp_corr = stride_h - t_pad % stride_h;
1448 add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block);
1449 add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult);
1450 }
1451 }
1452 cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1453 jge(oh_loop_end, T_NEAR);
1454 cmp(reg_oj, jcp.oh);
1455 jge(oh_loop, T_NEAR);
1456
1457 mov(reg_kh, jcp.kh);
1458 L(oh_loop); {
1459 compute_oh_step_disp();
1460 add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
1461 add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1462
1463 inc(reg_oj);
1464 add(reg_ih_count, stride_h);
1465
1466 cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1467 jge(oh_loop_end, T_NEAR);
1468
1469 cmp(reg_oj, jcp.oh);
1470 jl(oh_loop, T_NEAR);
1471 }
1472 L(oh_loop_end);
1473 if (b_pad > 0) {
1474 Label oh_bpad_loop, oh_bpad_loop_end;
1475 cmp(reg_oj, jcp.oh);
1476 jge(oh_bpad_loop_end, T_NEAR);
1477
1478 mov(reg_kh, jcp.ih + t_pad);
1479 sub(reg_kh, reg_ih_count);
1480 L(oh_bpad_loop); {
1481 compute_oh_step_disp();
1482 add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
1483 add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1484
1485 sub(reg_kh, stride_h);
1486 cmp(reg_kh, 0);
1487 jle(oh_bpad_loop_end, T_NEAR);
1488
1489 inc(reg_oj);
1490 cmp(reg_oj, jcp.oh);
1491 jl(oh_bpad_loop, T_NEAR);
1492 }
1493 L(oh_bpad_loop_end);
1494 }
1495}
1496
1497}
1498}
1499}
1500
1501// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
1502