1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "c_types_map.hpp"
18#include "mkldnn_thread.hpp"
19#include "type_helpers.hpp"
20#include "utils.hpp"
21
22#include "jit_avx512_common_convolution.hpp"
23
24namespace mkldnn {
25namespace impl {
26namespace cpu {
27
28using namespace mkldnn::impl::status;
29using namespace mkldnn::impl::memory_tracking::names;
30using namespace mkldnn::impl::utils;
31
32using namespace nstl;
33
34using jit_conv_ker_t = void (*)(jit_conv_call_s *);
35
36#define PIPELINE(field) \
37 do { \
38 p.field = p.field ## _prf; \
39 p.field ## _prf = field; \
40 } while (0)
41
42inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
43 const void *src, const void *dst, const void *filt, const void *bias,
44 int channel, int kh_padding)
45{
46 PIPELINE(src);
47 PIPELINE(dst);
48 PIPELINE(filt);
49 PIPELINE(bias);
50 PIPELINE(channel);
51 PIPELINE(kh_padding);
52
53 if (p.src)
54 ker(&p);
55}
56// The special case for the driver with ow-parallelization (FWD)
57// TODO: implement it for BWD_D and BWD_W too
58inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p,
59 const void *src, const void *dst, const void *filt, const void *bias,
60 int channel, int kh_padding, int owb)
61{
62 PIPELINE(src);
63 PIPELINE(dst);
64 PIPELINE(filt);
65 PIPELINE(bias);
66 PIPELINE(channel);
67 PIPELINE(kh_padding);
68 PIPELINE(owb);
69
70 if (p.src)
71 ker(&p);
72}
73
74inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
75 const void *src, const void *dst, const void *filt, const void *bias,
76 int channel, int kh_padding, int kd_padding)
77{
78 PIPELINE(src);
79 PIPELINE(dst);
80 PIPELINE(filt);
81 PIPELINE(bias);
82 PIPELINE(channel);
83 PIPELINE(kh_padding);
84 PIPELINE(kd_padding);
85
86 if (p.src)
87 ker(&p);
88}
89// The special case for the driver with ow-parallelization (FWD)
90// TODO: implement it for BWD_D and BWD_W too
91inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker,
92 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
93 const void *bias, int channel, int kh_padding, int kd_padding, int owb)
94{
95 PIPELINE(src);
96 PIPELINE(dst);
97 PIPELINE(filt);
98 PIPELINE(bias);
99 PIPELINE(channel);
100 PIPELINE(kh_padding);
101 PIPELINE(kd_padding);
102 PIPELINE(owb);
103
104 if (p.src)
105 ker(&p);
106}
107
108void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
109 const void *src, const void *dst, const void *filt, const void *bias,
110 int channel, int d_index, int d_worksize,
111 int kd_padding /* kd_work_size */, size_t kd_offset) {
112 PIPELINE(src);
113 PIPELINE(dst);
114 PIPELINE(filt);
115 PIPELINE(bias);
116 PIPELINE(channel);
117 PIPELINE(kd_padding);
118 PIPELINE(d_worksize);
119 PIPELINE(d_index);
120 PIPELINE(kd_offset);
121
122 if (p.src)
123 ker(&p);
124}
125#define wht_blk_off(d, g, ...) \
126 (pd()->with_groups() \
127 ? (d).blk_off((g), __VA_ARGS__) \
128 : (d).blk_off(__VA_ARGS__))
129
130template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
131void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
132 dst_type>::prepare_padded_bias(const dst_data_t *&bias,
133 const memory_tracking::grantor_t &scratchpad) const {
134 if (!pd()->wants_padded_bias()) return;
135
136 auto padded_bias = scratchpad.template get<dst_data_t>(
137 key_conv_padded_bias);
138 utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
139 utils::array_set(padded_bias + pd()->jcp_.oc_without_padding,
140 (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
141 bias = padded_bias;
142}
143
144template <data_type_t src_type, data_type_t wei_type,
145 data_type_t dst_type>
146void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
147execute_forward_1d(const exec_ctx_t &ctx) const {
148 auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
149 auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
150 auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
151 auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
152
153 prepare_padded_bias(bias, this->scratchpad(ctx));
154
155 const memory_desc_wrapper src_d(pd()->src_md());
156 const memory_desc_wrapper dst_d(pd()->dst_md());
157 const memory_desc_wrapper weights_d(pd()->weights_md(0));
158
159 const auto &jcp = pd()->jcp_;
160 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
161
162 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
163 int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow;
164
165 int nthr;
166 if (jcp.aligned_threads)
167 nthr = jcp.aligned_threads;
168 else
169 nthr = mkldnn_get_max_threads();
170
171 parallel(nthr, [&](const int ithr, const int nthr) {
172 int start{0}, end{0}, start_copy;
173 balance211(work_amount, nthr, ithr, start, end);
174 start_copy = start;
175
176 auto par_conv = jit_conv_call_s();
177 size_t src_c_stride = src_d.blk_off(0, 1);
178 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
179
180 for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
181 start = start_copy;
182 int n{0}, g{0}, occ{0}, owb{0};
183
184 if (jcp.loop_order == loop_cwgn) {
185 int dummy{0};
186 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
187 g, jcp.ngroups, n, jcp.mb, dummy, 1);
188 } else if (jcp.loop_order == loop_gncw) {
189 int dummy{0};
190 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ,
191 oc_chunks, owb, jcp.nb_ow, dummy, 1);
192 } else {
193 assert(!"unsupported loop order");
194 }
195
196 while (start < end) {
197 int ocb = occ * jcp.nb_oc_blocking;
198 int g_ocb = g * jcp.nb_oc + ocb;
199 int g_oc = g_ocb * jcp.oc_block;
200 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
201
202 int ow_s = owb * jcp.ow_block;
203 int iw_s = ow_s * jcp.stride_w;
204 auto bias_w = bias ? bias + g_oc : nullptr;
205 auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s);
206 auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s);
207 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
208
209 for (int icb = icb_l2;
210 icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
211 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
212 src_w, dst_w, wht_w, bias_w, icb, 1, owb);
213
214 src_w += src_c_stride;
215 wht_w += wht_ic_stride;
216 }
217 if (jcp.loop_order == loop_cwgn) {
218 int dummy{0};
219 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
220 g, jcp.ngroups, n, jcp.mb, dummy, 1);
221 } else if (jcp.loop_order == loop_gncw) {
222 int dummy{0};
223 nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb,
224 occ, oc_chunks, owb, jcp.nb_ow, dummy, 1);
225 } else {
226 assert(!"unsupported loop order");
227 }
228 }
229 }
230 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
231 src, dst, weights, bias, 0, 0, 0);
232 });
233}
234
235template <data_type_t src_type, data_type_t wei_type,
236 data_type_t dst_type>
237void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
238execute_forward_2d(const exec_ctx_t &ctx) const {
239 auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
240 auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
241 auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
242 auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
243
244 prepare_padded_bias(bias, this->scratchpad(ctx));
245
246 const memory_desc_wrapper src_d(pd()->src_md());
247 const memory_desc_wrapper dst_d(pd()->dst_md());
248 const memory_desc_wrapper weights_d(pd()->weights_md(0));
249
250 const auto &jcp = pd()->jcp_;
251 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
252
253 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
254 int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow;
255
256 int nthr;
257 if (jcp.aligned_threads)
258 nthr = jcp.aligned_threads;
259 else
260 nthr = mkldnn_get_max_threads();
261
262 parallel(nthr, [&](const int ithr, const int nthr) {
263 int start{0}, end{0}, start_copy;
264 balance211(work_amount, nthr, ithr, start, end);
265 start_copy = start;
266
267 auto par_conv = jit_conv_call_s();
268 size_t src_h_stride = src_d.blk_off(0, 0, 1);
269 size_t src_c_stride = src_d.blk_off(0, 1);
270 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
271 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
272 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
273
274 for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
275 start = start_copy;
276 int n{0}, g{0}, occ{0}, oh_s{0}, owb{0};
277
278 if (jcp.loop_order == loop_cwgn)
279 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
280 g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh);
281 else if (jcp.loop_order == loop_gncw)
282 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb,
283 occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
284 else
285 assert(!"unsupported loop order");
286
287 while (start < end) {
288 int ocb = occ * jcp.nb_oc_blocking;
289 int g_ocb = g * jcp.nb_oc + ocb;
290 int g_oc = g_ocb * jcp.oc_block;
291 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
292
293 int work_rem = end - start;
294
295 int ow_s = owb * jcp.ow_block;
296 int iw_s = ow_s * jcp.stride_w;
297 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
298 auto bias_w = bias ? bias + g_oc : nullptr;
299
300 for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) {
301 int ih_b = -jcp.t_pad + oh_b * jcp.stride_h;
302
303 auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s);
304 auto src_w
305 = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s);
306 auto wht_w
307 = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
308
309 for (int icb = icb_l2;
310 icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
311 ++icb) {
312 auto src_c = src_w;
313 auto dst_c = dst_w;
314 for (int oj = oh_b, ij = ih_b;
315 oj < min(oh_e, oh_b + jcp.h_blocking);
316 ++oj, ij += jcp.stride_h) {
317 int dilate_h = jcp.dilate_h + 1;
318 int i_t_overflow = div_up(max(0, -ij), dilate_h);
319 int i_b_overflow = div_up(max(0, ij - jcp.ih
320 + (jcp.kh - 1) * dilate_h + 1), dilate_h);
321 int kh_padding = nstl::max(
322 0, jcp.kh - i_t_overflow - i_b_overflow);
323
324 auto aux_src = src_c
325 + i_t_overflow * dilate_h * src_h_stride;
326 auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
327
328 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker,
329 par_conv, aux_src, dst_c, aux_wht, bias_w, icb,
330 kh_padding, owb);
331
332 src_c += src_h_stride * jcp.stride_h;
333 dst_c += dst_h_stride;
334 }
335 src_w += src_c_stride;
336 wht_w += wht_ic_stride;
337 }
338 }
339
340 if (jcp.loop_order == loop_cwgn)
341 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
342 g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh);
343 else if (jcp.loop_order == loop_gncw)
344 nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ,
345 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
346 else
347 assert(!"unsupported loop order");
348 }
349 }
350
351 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
352 src, dst, weights, bias, 0, 0, 0);
353 });
354}
355
356template <data_type_t src_type, data_type_t wei_type,
357 data_type_t dst_type>
358void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
359execute_forward_3d(const exec_ctx_t &ctx) const {
360 auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
361 auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
362 auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
363 auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
364
365 prepare_padded_bias(bias, this->scratchpad(ctx));
366
367 const memory_desc_wrapper src_d(pd()->src_md());
368 const memory_desc_wrapper dst_d(pd()->dst_md());
369 const memory_desc_wrapper weights_d(pd()->weights_md(0));
370 const memory_desc_wrapper bias_d(pd()->weights_md(1));
371
372 const auto &jcp = pd()->jcp_;
373 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
374
375 parallel(0, [&](const int ithr, const int nthr) {
376 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
377 int start{0}, end{0}, start_copy;
378 int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh
379 * jcp.nb_ow;
380 balance211(work_amount, nthr, ithr, start, end);
381 start_copy = start;
382
383 auto par_conv = jit_conv_call_s();
384 size_t src_d_stride = src_d.blk_off(0, 0, 1);
385 size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
386 size_t src_c_stride = src_d.blk_off(0, 1);
387 size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
388 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
389 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
390 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
391
392 for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
393 start = start_copy;
394 int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0};
395
396 if (jcp.loop_order == loop_cwgn)
397 nd_iterator_init(start,
398 occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb,
399 od_s, jcp.od, oh_s, jcp.oh);
400 else if (jcp.loop_order == loop_gncw)
401 nd_iterator_init(start,
402 g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow,
403 od_s, jcp.od, oh_s, jcp.oh);
404 else
405 assert(!"unsupported loop order");
406
407 while (start < end) {
408 int ocb = occ * jcp.nb_oc_blocking;
409 int g_ocb = g * jcp.nb_oc + ocb;
410 int g_oc = g_ocb * jcp.oc_block;
411 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
412
413 int work_rem = end - start;
414 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
415 int ow_s = owb * jcp.ow_block;
416 int iw_s = ow_s * jcp.stride_w;
417 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
418
419 int id_s = -jcp.f_pad + od_s * jcp.stride_d;
420
421 int dilate_d = jcp.dilate_d + 1;
422 int d_t_overflow = div_up(max(0, -id_s), dilate_d);
423 int d_b_overflow = div_up(
424 max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
425 dilate_d);
426 int kd_padding = nstl::max(0,
427 jcp.kd - d_t_overflow - d_b_overflow);
428
429 auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0;
430 auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s);
431 auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s,
432 iw_s) + d_t_overflow * dilate_d * src_d_stride;
433 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2)
434 + d_t_overflow * wht_d_stride;
435
436 for (int icb = icb_l2;
437 icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
438 auto src_c = src_w;
439 auto dst_c = dst_w;
440 for (int oj = oh_s, ij = ih_s;
441 oj < oh_e; ++oj, ij += jcp.stride_h)
442 {
443 int dilate_h = jcp.dilate_h + 1;
444 int i_t_overflow = div_up(max(0, -ij), dilate_h);
445 int i_b_overflow = div_up(
446 max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h
447 + 1),
448 dilate_h);
449 int kh_padding = nstl::max(0,
450 jcp.kh - i_t_overflow - i_b_overflow);
451 jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker,
452 par_conv,
453 src_c + i_t_overflow * dilate_h * src_h_stride,
454 dst_c, wht_w + i_t_overflow * wht_h_stride,
455 bias_w, icb, kh_padding, kd_padding, owb);
456
457 src_c += src_h_stride * jcp.stride_h;
458 dst_c += dst_h_stride;
459 }
460 src_w += src_c_stride;
461 wht_w += wht_ic_stride;
462 }
463
464 if (jcp.loop_order == loop_cwgn)
465 nd_iterator_jump(start, end,
466 occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb,
467 od_s, jcp.od, oh_s, jcp.oh);
468 else if (jcp.loop_order == loop_gncw)
469 nd_iterator_jump(start, end,
470 g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow,
471 od_s, jcp.od, oh_s, jcp.oh);
472 else
473 assert(!"unsupported loop order");
474 }
475 }
476 jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
477 src, dst, weights, bias, 0, 0, 0);
478 });
479}
480
481template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
482
483template <data_type_t diff_dst_type, data_type_t wei_type,
484 data_type_t diff_src_type>
485void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
486 diff_src_type>::execute_backward_data_1d(const exec_ctx_t &ctx) const
487{
488 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
489 auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
490 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
491
492 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
493 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
494 const memory_desc_wrapper weights_d(pd()->weights_md(0));
495
496 const auto &jcp = kernel_->jcp;
497
498 parallel(0, [&](const int ithr, const int nthr) {
499 int start{0}, end{0}, start_copy;
500 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
501 int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
502 balance211(work_amount, nthr, ithr, start, end);
503 start_copy = start;
504
505 auto par_conv = jit_conv_call_s();
506 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
507 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
508
509 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
510 start = start_copy;
511 int n{0}, g{0}, icc{0};
512 if (jcp.loop_order == loop_cgn) {
513 int dummy{0};
514 nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n,
515 jcp.mb, dummy, 1);
516 } else if (jcp.loop_order == loop_gnc) {
517 int dummy{0};
518 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc,
519 ic_chunks, dummy, 1);
520 } else {
521 assert(!"unsupported loop order");
522 }
523
524 while (start < end) {
525 int icb = icc * jcp.nb_ic_blocking;
526 int g_icb = g * jcp.nb_ic + icb;
527 int g_ocb = g * jcp.nb_oc;
528
529 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
530 auto diff_dst_w = diff_dst
531 + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
532 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
533
534 for (int ocb = ocb_l2;
535 ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
536 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
537 diff_src_w, diff_dst_w, wht_w, 0, ocb, 1);
538 diff_dst_w += diff_dst_c_stride;
539 wht_w += wht_oc_stride;
540 }
541
542 if (jcp.loop_order == loop_cgn) {
543 int dummy{0};
544 nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups,
545 n, jcp.mb, dummy, 1);
546 } else if (jcp.loop_order == loop_gnc) {
547 int dummy{0};
548 nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc,
549 ic_chunks, dummy, 1);
550 } else {
551 assert(!"unsupported loop order");
552 }
553 }
554 }
555
556 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
557 diff_src, diff_dst, weights, 0, 0, 1);
558 });
559}
560
561template <data_type_t diff_dst_type, data_type_t wei_type,
562 data_type_t diff_src_type>
563void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
564 diff_src_type>::execute_backward_data_2d(const exec_ctx_t &ctx) const
565{
566 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
567 auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
568 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
569
570 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
571 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
572 const memory_desc_wrapper weights_d(pd()->weights_md(0));
573
574 const auto &jcp = kernel_->jcp;
575
576 parallel(0, [&](const int ithr, const int nthr) {
577 int start{0}, end{0}, start_copy;
578 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
579 int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
580 balance211(work_amount, nthr, ithr, start, end);
581 start_copy = start;
582
583 auto par_conv = jit_conv_call_s();
584 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
585 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
586 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
587 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
588 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
589
590 bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1;
591
592 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
593 start = start_copy;
594 int n{0}, g{0}, icc{0}, ih_s{0};
595 if (jcp.loop_order == loop_cgn)
596 nd_iterator_init(start,
597 icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih);
598 else if (jcp.loop_order == loop_gnc)
599 nd_iterator_init(start,
600 g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih);
601 else
602 assert(!"unsupported loop order");
603
604 while (start < end) {
605 int icb = icc * jcp.nb_ic_blocking;
606 int g_icb = g * jcp.nb_ic + icb;
607 int g_ocb = g * jcp.nb_oc;
608
609 int work_rem = end - start;
610 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
611
612 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
613 auto diff_dst_w = diff_dst
614 + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
615 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
616
617 for (int ocb = ocb_l2;
618 ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
619 for (int ij = ih_s; ij < ih_e; ++ij) {
620 int oj, k_len, k_lo;
621 if (is_fast_path) { // dilate == 0 && stride == 1
622 int i_t_overflow = max(0, jcp.kh - 1 - ij
623 - jcp.t_pad);
624 int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
625 - jcp.b_pad);
626 k_len = jcp.kh - i_t_overflow - i_b_overflow;
627 k_lo = i_b_overflow;
628 oj = ij + jcp.t_pad - i_b_overflow;
629 } else if (jcp.dilate_h != 0) { // stride == 1
630 int dilate_h = jcp.dilate_h + 1;
631 // Note: use div_up to account for "holes" in filter
632 int i_t_overflow
633 = div_up(max(0, (jcp.kh - 1) * dilate_h
634 - ij - jcp.t_pad), dilate_h);
635 int i_b_overflow
636 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
637 - jcp.ih + ij - jcp.b_pad), dilate_h);
638 k_len = jcp.kh - i_t_overflow - i_b_overflow;
639 k_lo = i_b_overflow;
640 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
641 } else { // dilate == 0
642 int i_t_overflow = max(0, (jcp.kh - 1 - ij
643 - jcp.t_pad) / jcp.stride_h);
644 int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
645 - jcp.b_pad) / jcp.stride_h);
646 int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
647 + jcp.b_pad - ij) % jcp.stride_h);
648 int overflow_kh_lo = (ij + jcp.t_pad)
649 % jcp.stride_h;
650
651 k_len = (overflow_kh_hi - overflow_kh_lo)
652 / jcp.stride_h + 1 - i_t_overflow
653 - i_b_overflow;
654 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
655 oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
656 }
657 assert(k_len >= 0);
658
659 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
660 diff_src_w + ij * diff_src_h_stride,
661 diff_dst_w + oj * diff_dst_h_stride,
662 wht_w + k_lo * wht_h_stride,
663 0, ocb, k_len);
664 }
665 diff_dst_w += diff_dst_c_stride;
666 wht_w += wht_oc_stride;
667 }
668
669 if (jcp.loop_order == loop_cgn)
670 nd_iterator_jump(start, end,
671 icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih);
672 else if (jcp.loop_order == loop_gnc)
673 nd_iterator_jump(start, end,
674 g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih);
675 else
676 assert(!"unsupported loop order");
677 }
678 }
679
680 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
681 diff_src, diff_dst, weights, 0, 0, 1);
682 });
683}
684
685template <data_type_t diff_dst_type, data_type_t wei_type,
686 data_type_t diff_src_type>
687void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
688 diff_src_type>::execute_backward_data_3d(const exec_ctx_t &ctx) const
689{
690 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
691 auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
692 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
693
694 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
695 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
696 const memory_desc_wrapper weights_d(pd()->weights_md(0));
697
698 const auto &jcp = kernel_->jcp;
699
700 parallel(0, [&](const int ithr, const int nthr) {
701 int start{0}, end{0}, start_copy;
702 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
703 int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih;
704 balance211(work_amount, nthr, ithr, start, end);
705 start_copy = start;
706
707 auto par_conv = jit_conv_call_s();
708 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
709 size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1);
710 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
711 size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1);
712 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
713 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
714 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
715 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
716
717 bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1;
718 bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1;
719
720 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
721 start = start_copy;
722 int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0};
723 if (jcp.loop_order == loop_cgn)
724 nd_iterator_init(start,
725 icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id,
726 ih_s, jcp.ih);
727 else if (jcp.loop_order == loop_gnc)
728 nd_iterator_init(start,
729 g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id,
730 ih_s, jcp.ih);
731 else
732 assert(!"unsupported loop order");
733
734 while (start < end) {
735 int icb = icc * jcp.nb_ic_blocking;
736 int g_icb = g * jcp.nb_ic + icb;
737 int g_ocb = g * jcp.nb_oc;
738
739 int work_rem = end - start;
740 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
741 int d_len = 0, d_lo = 0, d_oj = 0;
742 if (is_fast_path_d) { // dilate == 0 && stride == 1
743 int d_t_overflow = max(0, jcp.kd - 1 - id_s
744 - jcp.f_pad);
745 int d_b_overflow = max(0, jcp.kd - jcp.id + id_s
746 - jcp.back_pad);
747 d_len = jcp.kd - d_t_overflow - d_b_overflow;
748 d_lo = d_b_overflow;
749 d_oj = id_s + jcp.f_pad - d_b_overflow;
750 } else if (jcp.dilate_d != 0) { // stride == 1
751 int dilate_d = jcp.dilate_d + 1;
752 // Note: use div_up to account for "holes" in filter
753 int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d
754 - id_s - jcp.f_pad), dilate_d);
755 int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1
756 - jcp.id + id_s - jcp.back_pad), dilate_d);
757 d_len = jcp.kd - d_t_overflow - d_b_overflow;
758 d_lo = d_b_overflow;
759 d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d;
760 } else { // dilate == 0
761 int d_t_overflow = max(0, (jcp.kd - 1 - id_s
762 - jcp.f_pad) / jcp.stride_d);
763 int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s
764 - jcp.back_pad) / jcp.stride_d);
765 int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1
766 + jcp.back_pad - id_s) % jcp.stride_d);
767 int overflow_kd_lo = (id_s + jcp.f_pad)
768 % jcp.stride_d;
769
770 d_len = (overflow_kd_hi - overflow_kd_lo)
771 / jcp.stride_d + 1 - d_t_overflow
772 - d_b_overflow;
773 d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d;
774 d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d;
775 }
776 assert(d_len >= 0);
777
778 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb)
779 + id_s * diff_src_d_stride;
780 auto diff_dst_w = diff_dst
781 + diff_dst_d.blk_off(n, g_ocb + ocb_l2)
782 + d_oj * diff_dst_d_stride;
783 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb)
784 + d_lo * wht_d_stride;
785
786 for (int ocb = ocb_l2;
787 ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
788 for (int ij = ih_s; ij < ih_e; ++ij) {
789 int oj, k_len, k_lo;
790 if (is_fast_path_h) { // dilate == 0 && stride == 1
791 int i_t_overflow = max(0, jcp.kh - 1 - ij
792 - jcp.t_pad);
793 int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
794 - jcp.b_pad);
795 k_len = jcp.kh - i_t_overflow - i_b_overflow;
796 k_lo = i_b_overflow;
797 oj = ij + jcp.t_pad - i_b_overflow;
798 } else if (jcp.dilate_h != 0) { // stride == 1
799 int dilate_h = jcp.dilate_h + 1;
800 // Note: use div_up to account for "holes" in filter
801 int i_t_overflow
802 = div_up(max(0, (jcp.kh - 1) * dilate_h
803 - ij - jcp.t_pad), dilate_h);
804 int i_b_overflow
805 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
806 - jcp.ih + ij - jcp.b_pad), dilate_h);
807 k_len = jcp.kh - i_t_overflow - i_b_overflow;
808 k_lo = i_b_overflow;
809 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
810 } else { // dilate == 0
811 int i_t_overflow = max(0, (jcp.kh - 1 - ij
812 - jcp.t_pad) / jcp.stride_h);
813 int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
814 - jcp.b_pad) / jcp.stride_h);
815 int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
816 + jcp.b_pad - ij) % jcp.stride_h);
817 int overflow_kh_lo = (ij + jcp.t_pad)
818 % jcp.stride_h;
819
820 k_len = (overflow_kh_hi - overflow_kh_lo)
821 / jcp.stride_h + 1 - i_t_overflow
822 - i_b_overflow;
823 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
824 oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
825 }
826 assert(k_len >= 0);
827
828 jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
829 diff_src_w + ij * diff_src_h_stride,
830 diff_dst_w + oj * diff_dst_h_stride,
831 wht_w + k_lo * wht_h_stride,
832 0, ocb, k_len, d_len);
833 }
834 diff_dst_w += diff_dst_c_stride;
835 wht_w += wht_oc_stride;
836 }
837
838 if (jcp.loop_order == loop_cgn)
839 nd_iterator_jump(start, end,
840 icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id,
841 ih_s, jcp.ih);
842 else if (jcp.loop_order == loop_gnc)
843 nd_iterator_jump(start, end,
844 g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id,
845 ih_s, jcp.ih);
846 else
847 assert(!"unsupported loop order");
848 }
849 }
850
851 jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
852 diff_src, diff_dst, weights, 0, 0, 1, 1);
853 });
854}
855
856template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>;
857
858template <data_type_t src_type, data_type_t diff_dst_type,
859 data_type_t diff_weights_type>
860jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
861 diff_weights_type>::
862jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd)
863 : cpu_primitive_t(apd), kernel_(nullptr)
864 , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
865{
866 const auto &j = pd()->jcp_;
867
868 nthr_ = j.nthr;
869 nthr_mb_ = j.nthr_mb;
870 nthr_g_ = j.nthr_g;
871 nthr_oc_b_ = j.nthr_oc_b;
872 nthr_ic_b_ = j.nthr_ic_b;
873
874 kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
875
876 if (j.ver == ver_4fma)
877 trans_kernel_ = create_trans_src(&j);
878
879 if (nthr_mb_ > 1)
880 acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>();
881
882 reducer_bias_ =
883 new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_);
884}
885
886template <data_type_t src_type, data_type_t diff_dst_type,
887 data_type_t diff_weights_type>
888struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
889 diff_weights_type>::thread_info_t {
890 const src_data_t *src;
891 const diff_dst_data_t *diff_dst;
892 const diff_weights_data_t *diff_weights;
893 diff_weights_data_t *diff_bias;
894
895 const memory_tracking::grantor_t scratchpad;
896
897 src_data_t *tr_src;
898 simple_barrier::ctx_t *tr_src_bctx;
899
900 diff_dst_data_t *tr_diff_dst;
901 simple_barrier::ctx_t *tr_diff_dst_bctx;
902
903 diff_weights_data_t *wei_bia_reduction;
904 simple_barrier::ctx_t *wei_bia_reduction_bctx;
905
906 int ithr;
907 int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
908 int ithr_but_oc;
909 int ithr_but_ic;
910
911 int img_start = 0, img_end = 0, img_work;
912 int g_start = 0, g_end = 0, g_work;
913 int oc_b_start = 0, oc_b_end = 0, oc_b_work;
914 int ic_b_start = 0, ic_b_end = 0, ic_b_work;
915
916 thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
917 const exec_ctx_t &ctx, int ithr)
918 : scratchpad(self->scratchpad(ctx)), ithr(ithr)
919 {
920 diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
921 src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
922 diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
923 diff_bias = self->pd()->wants_padded_bias()
924 ? scratchpad.template get<diff_weights_data_t>(
925 key_conv_padded_bias)
926 : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS);
927
928 tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
929 tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
930 key_conv_tr_src_bctx);
931
932 tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
933 key_conv_tr_diff_dst);
934 tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
935 key_conv_tr_diff_dst_bctx);
936
937 wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
938 key_conv_wei_bia_reduction);
939 wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
940 key_conv_wei_bia_reduction_bctx);
941
942 ithr_ic_b = ithr % self->nthr_ic_b_;
943 ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
944 ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
945 ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
946
947 ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
948 + ithr_ic_b;
949
950 ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
951 + ithr_oc_b;
952
953 const auto &jcp = self->kernel_->jcp;
954
955 /* reduction dimension */
956 balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end);
957 img_work = img_end - img_start;
958
959 /* independent dimensions */
960 balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
961 g_work = g_end - g_start;
962
963 balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start,
964 oc_b_end);
965 oc_b_work = oc_b_end - oc_b_start;
966
967 balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start,
968 ic_b_end);
969 ic_b_work = ic_b_end - ic_b_start;
970 }
971};
972
973template <data_type_t src_type, data_type_t diff_dst_type,
974 data_type_t diff_weights_type>
975void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
976 diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const {
977 const memory_desc_wrapper src_d(pd()->src_md());
978 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
979 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
980
981 const auto &jcp = kernel_->jcp;
982 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd;
983
984 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
985 ? (diff_weights_data_t*)ti->diff_weights
986 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
987 diff_weights_data_t *diff_bia = ti->ithr_mb == 0
988 ? (diff_weights_data_t*)ti->diff_bias
989 : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
990 + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
991
992 // TODO: use memory descriptor with the same fmt as src (or use a macro :))
993 auto tr_src_off = [&](int ithr_mb, int ic, int ij) {
994 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
995 const size_t tr_chn_size = tr_row_size * jcp.ih;
996 const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups;
997
998 return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size;
999 };
1000
1001 auto uker_trans = [&](int img) {
1002 const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih;
1003
1004 int start{0}, end{0};
1005 balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end);
1006 const int my_work = end - start;
1007
1008 int g{0}, ic_b{0}, j{0};
1009 nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih);
1010 g += ti->g_start;
1011 ic_b += ti->ic_b_start;
1012
1013 const int _ic = g * jcp.nb_ic + ic_b;
1014 src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)];
1015 src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)];
1016
1017 assert(jcp.ic_block == 16);
1018 const int src_stride = jcp.iw * jcp.ic_block;
1019 const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
1020
1021 const int pf_depth = 2;
1022 struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth];
1023
1024 for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) {
1025 pf_circ_buf[iwork % pf_depth] = {src1, tr_src1};
1026
1027 if (iwork >= pf_depth - 1) {
1028 int old_idx = (iwork - pf_depth + 1) % pf_depth;
1029 auto ctx = jit_trans_src_t::ctx_t();
1030 ctx.src = pf_circ_buf[old_idx].src;
1031 ctx.tr_src = pf_circ_buf[old_idx].tr_src;
1032 ctx.src_prf = src1;
1033 ctx.tr_src_prf = tr_src1;
1034 (*trans_kernel_)(&ctx);
1035 }
1036 src1 += src_stride;
1037 tr_src1 += tr_src_stride;
1038 }
1039#if 0
1040 // reference transposition
1041 const int l_pad = jcp.l_pad;
1042 const int iwlp = l_pad + jcp.iw;
1043 const int tr_iw = jcp.tr_iw;
1044
1045 for (size_t iwork = start; iwork < end; iwork++) {
1046 PRAGMA_OMP_SIMD()
1047# pragma unroll
1048 for (int i = 0; i < l_pad; i++)
1049 for (int j = 0; j < jcp.ic_block; j++)
1050 tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
1051
1052 PRAGMA_OMP_SIMD()
1053# pragma unroll
1054 for (int i = l_pad; i < iwlp; i++)
1055 for (int j = 0; j < jcp.ic_block; j++)
1056 tr_src1[j * jcp.tr_iw + i]
1057 = (src_data_t)src1[(i - l_pad) * 16 + j];
1058
1059 PRAGMA_OMP_SIMD()
1060# pragma unroll
1061 for (int i = iwlp; i < tr_iw; i++)
1062 for (int j = 0; j < jcp.ic_block; j++)
1063 tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
1064
1065 src1 += src_stride;
1066 tr_src1 += tr_src_stride;
1067 }
1068#endif
1069 };
1070
1071 if (jcp.is_1stconv && jcp.ver == ver_4fma) {
1072 /* prepare contexts */
1073 auto tr_ctx = jit_trans_src_t::ctx_t();
1074 tr_ctx.tr_src = ti->tr_src
1075 + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
1076
1077 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
1078 tr_ctx.nthr_oc_b = nthr_oc_b_;
1079 int ih_start{0}, ih_end{0};
1080 balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
1081 tr_ctx.tr_src_ih_start = ih_start;
1082 tr_ctx.tr_src_ih_end = ih_end;
1083 tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc;
1084
1085 auto p = jit_conv_call_s();
1086 p.src = tr_ctx.tr_src;
1087
1088 /* zero diff_bias if applicable */
1089 if (jcp.with_bias && ti->ithr_ic_b == 0) {
1090 assert(jcp.oc_block == 16);
1091 for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1092 diff_weights_data_t *db = &diff_bia[oc_b * 16];
1093 for (int o = 0; o < 16; ++o)
1094 db[o] = 0;
1095 }
1096 }
1097
1098 for (int img = ti->img_start; img < ti->img_end; ++img) {
1099 p.flags = (img == ti->img_start) * FLAG_MB_FIRST;
1100
1101 for (int g = ti->g_start; g < ti->g_end; ++g) {
1102 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1103 const int _ic = g * jcp.nb_ic + ic_b;
1104 tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)];
1105
1106 (*trans_kernel_)(&tr_ctx);
1107
1108 if (ic_b == 0)
1109 p.flags |= FLAG_IC_FIRST;
1110 else
1111 p.flags &= ~FLAG_IC_FIRST;
1112
1113 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1114 const int _oc = g * jcp.nb_oc + oc_b;
1115 p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
1116
1117 const size_t off =
1118 wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1119 p.filt = diff_wei + off;
1120 p.bias = diff_bia + _oc * jcp.oc_block;
1121
1122 kernel_->jit_ker(&p);
1123 }
1124 }
1125 }
1126 }
1127 } else {
1128 for (int img = ti->img_start; img < ti->img_end; ++img) {
1129 auto p = jit_conv_call_s();
1130
1131 if (jcp.ver == ver_4fma) {
1132 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
1133 using simple_barrier::barrier;
1134 if (nthr_oc_b_ > 1)
1135 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1136 uker_trans(img);
1137 if (nthr_oc_b_ > 1)
1138 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1139 }
1140
1141 for (int g = ti->g_start; g < ti->g_end; ++g) {
1142 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1143 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1144 const int _oc = g * jcp.nb_oc + oc_b;
1145 const int _ic = g * jcp.nb_ic + ic_b;
1146
1147 jit_conv_ker_pipeline(kernel_->jit_ker, p,
1148 jcp.ver == ver_4fma
1149 ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
1150 : &ti->src[src_d.blk_off(img, _ic)],
1151 &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
1152 diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1153 0, (img == ti->img_start), 0);
1154
1155 }
1156 }
1157 }
1158
1159 const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
1160 const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
1161 jit_conv_ker_pipeline(kernel_->jit_ker, p,
1162 jcp.ver == ver_4fma
1163 ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
1164 : &ti->src[src_d.blk_off(img + 1, _ic)],
1165 &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
1166 diff_wei + wht_blk_off(
1167 diff_weights_d, ti->g_start,
1168 ti->oc_b_start, ti->ic_b_start),
1169 0, 0, 0);
1170 }
1171 }
1172}
1173
1174template <data_type_t src_type, data_type_t diff_dst_type,
1175 data_type_t diff_weights_type>
1176void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1177 diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const
1178{
1179 const memory_desc_wrapper src_d(pd()->src_md());
1180 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1181 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1182
1183 const auto &jcp = kernel_->jcp;
1184 const int wei_size
1185 = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd;
1186
1187 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1188 ? (diff_weights_data_t*)ti->diff_weights
1189 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1190 diff_weights_data_t *diff_bia = ti->ithr_mb == 0
1191 ? (diff_weights_data_t*)ti->diff_bias
1192 : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
1193 + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
1194
1195 const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1196 const int input_step = jcp.ih * jcp.iw * inp_mult;
1197 const int output_step = jcp.ow * jcp.oh * jcp.oc_block;
1198 int img{0}, od_s{0};
1199 int img_start = ti->img_start, img_end = ti->img_end;
1200 nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
1201 const int img_first = img;
1202
1203 while (img_start < img_end) {
1204 auto p = jit_conv_call_s();
1205
1206 int work_rem = img_end - img_start;
1207 const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
1208 const int id_s = od_s * jcp.stride_d;
1209 const int ik_overlap = nstl::max(0, id_s - jcp.f_pad);
1210 const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s);
1211 const int kd_back_pad
1212 = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd);
1213 int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw
1214 * jcp.ic_block * jcp.oc_block * jcp.typesize_out;
1215
1216 for (int g = ti->g_start; g < ti->g_end; ++g) {
1217 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1218 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1219 const int _oc = g * jcp.nb_oc + oc_b;
1220 const int _ic = g * jcp.nb_ic + ic_b;
1221
1222 auto src = &ti->src[src_d.blk_off(img, _ic)
1223 + ik_overlap * input_step];
1224 auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)
1225 + od_s * output_step];
1226
1227 jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst,
1228 diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1229 diff_bia + _oc * 16, (img == img_first), od_s, od_e,
1230 jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off);
1231
1232 if (ic_b == 0) p.flags = 0;
1233 else p.flags = 1;
1234 }
1235 }
1236 }
1237
1238 const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
1239 const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
1240 jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p,
1241 &ti->src[src_d.blk_off(img + 1, _ic)],
1242 &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
1243 diff_wei + wht_blk_off(diff_weights_d, ti->g_start,
1244 ti->oc_b_start, ti->ic_b_start),
1245 diff_bia, 0, 0, 0, 0, 0);
1246 nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
1247 }
1248}
1249
1250template <data_type_t src_type, data_type_t diff_dst_type,
1251 data_type_t diff_weights_type>
1252void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1253 diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
1254 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1255
1256 const auto &jcp = kernel_->jcp;
1257 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
1258 const int bia_size = jcp.ngroups * jcp.oc;
1259 const diff_weights_data_t *diff_bias_ws
1260 = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size;
1261
1262 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1263 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1264
1265 const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
1266 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1267
1268 int start{0}, end{0};
1269 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1270 if (start == end) return;
1271
1272 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1273 int w = start;
1274 int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
1275 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1276 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1277 while (w < end) {
1278 const int g = ti->g_start + sub_g_start;
1279 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1280 const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh;
1281 const int kh = sub_ic_b_kh_start % jcp.kh;
1282
1283 const int acc_size
1284 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1285 * jcp.kw * jcp.ic_block * jcp.oc_block;
1286
1287 const size_t off
1288 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh);
1289
1290 diff_weights_data_t *d
1291 = (diff_weights_data_t *)ti->diff_weights + off;
1292 diff_weights_data_t *s
1293 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1294
1295 acc_ker_->accumulate(d, s, acc_size);
1296
1297 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1298 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1299 }
1300
1301 if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) {
1302 if (ti->ithr == 0)
1303 acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias,
1304 diff_bias_ws, bia_size);
1305 diff_bias_ws += bia_size;
1306 }
1307 }
1308}
1309
1310template <data_type_t src_type, data_type_t diff_dst_type,
1311 data_type_t diff_weights_type>
1312void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1313 diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const {
1314 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1315
1316 const auto &jcp = kernel_->jcp;
1317 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw
1318 * jcp.kd;
1319
1320 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1321 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1322
1323 const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
1324 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1325
1326 int start{0}, end{0};
1327 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1328 if (start == end) return;
1329
1330 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1331 int w = start;
1332 int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
1333 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1334 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1335 while (w < end) {
1336 const int g = ti->g_start + sub_g_start;
1337 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1338 const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd;
1339 const int kd = sub_ic_b_kh_start % jcp.kd;
1340
1341 const int acc_size
1342 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1343 * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh;
1344
1345 const size_t off
1346 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd);
1347 diff_weights_data_t *d
1348 = (diff_weights_data_t *)ti->diff_weights + off;
1349 diff_weights_data_t *s
1350 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1351 acc_ker_->accumulate(d, s, acc_size);
1352
1353 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1354 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1355 }
1356 }
1357}
1358
1359template <data_type_t src_type, data_type_t diff_dst_type,
1360 data_type_t diff_weights_type>
1361void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1362 diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
1363 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1364
1365 auto rb = this->reducer_bias_;
1366 assert(nthr_ == rb->balancer().nthr_);
1367
1368 const auto reducer_bia_scratchpad = memory_tracking::grantor_t(
1369 ti->scratchpad, prefix_reducer_bia);
1370
1371 const auto &jcp = kernel_->jcp;
1372
1373 if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return;
1374
1375 const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
1376 const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
1377
1378 if (b_njobs == 0) return;
1379
1380 /* reduction dimension */
1381 int img_start{0}, img_end{0};
1382 balance211(jcp.mb, rb->balancer().nthr_per_group_,
1383 rb->balancer().id_in_group(ti->ithr), img_start, img_end);
1384
1385 /* jobs */
1386 int g_start{0}, ocb_start{0};
1387 nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
1388 for (int img = img_start; img < img_end; ++img) {
1389 int g = g_start, ocb = ocb_start;
1390 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
1391 const size_t _oc = g * jcp.nb_oc + ocb;
1392
1393 const diff_dst_data_t *d_dst
1394 = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
1395 diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr,
1396 ti->diff_bias, reducer_bia_scratchpad)
1397 + b_job_loc * rb->balancer().job_size_;
1398
1399 if (img == img_start)
1400 for (int o = 0; o < 16; ++o)
1401 d_bias[o] = 0;
1402 for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) {
1403 PRAGMA_OMP_SIMD()
1404 for (int o = 0; o < 16; ++o)
1405 d_bias[o] += d_dst[o];
1406 d_dst += 16;
1407 }
1408
1409 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
1410 }
1411 }
1412
1413 rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
1414}
1415
1416template <data_type_t src_type, data_type_t diff_dst_type,
1417 data_type_t diff_weights_type>
1418void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1419 diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const {
1420
1421 const auto &jcp = kernel_->jcp;
1422
1423 const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh
1424 * jcp.kw * jcp.kd;
1425 const int bia_size = jcp.ngroups * jcp.oc;
1426 const diff_weights_data_t *diff_bias_ws
1427 = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
1428
1429 if (nthr_mb_ > 1) mkldnn_thr_barrier();
1430
1431 if (ti->ithr == 0)
1432 {
1433 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1434 acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size);
1435 diff_bias_ws += bia_size;
1436 }
1437 }
1438}
1439
1440template <data_type_t src_type, data_type_t diff_dst_type,
1441 data_type_t diff_weights_type>
1442void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1443 diff_weights_type>::prepare_scratchpad_data(const exec_ctx_t &ctx) const
1444{
1445 const auto &j = pd()->jcp_;
1446 auto scratchpad = this->scratchpad(ctx);
1447
1448 if (j.ver == ver_4fma) {
1449 if (!j.is_1stconv) {
1450 // XXX: See the comment about tr_iw and guarding elements in
1451 // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
1452 const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic;
1453 const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
1454
1455 auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
1456 /* to avoid NaNs in computations we zero tail num_guard_elems for
1457 * each possible thread group */
1458
1459 for (int ithr = 1; ithr <= max_nthr; ++ithr) {
1460 src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr];
1461 for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
1462 ts[i] = 0;
1463 }
1464 }
1465
1466 if (j.nthr_oc_b > 1) {
1467 const int tr_src_bctx_size = j.nthr / j.nthr_oc_b;
1468 auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1469 key_conv_tr_src_bctx);
1470 for (int i = 0; i < tr_src_bctx_size; ++i)
1471 simple_barrier::ctx_init(&tr_src_bctx[i]);
1472 }
1473 }
1474
1475 if (nthr_mb_ > 1) {
1476 simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1477 key_conv_wei_bia_reduction_bctx));
1478 }
1479
1480 const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
1481 prefix_reducer_bia);
1482 auto rb = this->reducer_bias_;
1483 rb->init(reducer_bia_scratchpad);
1484}
1485
1486template <data_type_t src_type, data_type_t diff_dst_type,
1487 data_type_t diff_weights_type>
1488void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1489 diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx) const {
1490 prepare_scratchpad_data(ctx);
1491
1492 parallel(nthr_, [&](const int ithr, const int nthr) {
1493 assert(nthr_ == nthr);
1494
1495 thread_info_t thread_info(this, ctx, ithr);
1496
1497 if (utils::one_of(pd()->ndims(), 3, 4)) {
1498 compute_diff_weights(&thread_info);
1499 if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
1500 if (pd()->with_bias()) compute_diff_bias(&thread_info);
1501 } else if (pd()->ndims() == 5) {
1502 compute_diff_weights_3d(&thread_info);
1503 if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
1504 if (pd()->with_bias()) compute_diff_bias_3d(&thread_info);
1505 } else {
1506 assert(false);
1507 }
1508 });
1509
1510 /* TODO: put that into compute_diff_bias() */
1511 if (pd()->wants_padded_bias()) {
1512 auto diff_bias = scratchpad(ctx).template get<const diff_weights_data_t>(
1513 key_conv_padded_bias);
1514 auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS);
1515 for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
1516 diff_bias_in[oc] = diff_bias[oc];
1517 }
1518}
1519
1520template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;
1521
1522}
1523}
1524}
1525
1526// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
1527