1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "c_types_map.hpp"
18#include "mkldnn_thread.hpp"
19#include "type_helpers.hpp"
20#include "utils.hpp"
21
22#include "jit_avx2_convolution.hpp"
23
24namespace mkldnn {
25namespace impl {
26namespace cpu {
27
28using namespace mkldnn::impl::status;
29using namespace mkldnn::impl::memory_tracking::names;
30using namespace mkldnn::impl::utils;
31
32#define src_blk_off(f, n, c, d, h, w) \
33 (pd()->ndims() == 3) \
34 ? (f).blk_off(n, c, w) \
35 : (pd()->ndims() == 4) \
36 ? (f).blk_off(n, c, h, w) \
37 : (f).blk_off(n, c, d, h, w)
38
39#define wht_blk_off_(f, g, ...) \
40 pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
41#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
42 (pd()->ndims() == 3) \
43 ? wht_blk_off_(f, g, oc, ic, kw) \
44 : (pd()->ndims() == 4) \
45 ? wht_blk_off_(f, g, oc, ic, kh, kw) \
46 : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
47
48void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
49 auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
50 auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
51 auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
52 auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
53
54 const memory_desc_wrapper src_d(pd()->src_md());
55 const memory_desc_wrapper dst_d(pd()->dst_md());
56 const memory_desc_wrapper weights_d(pd()->weights_md(0));
57 const memory_desc_wrapper bias_d(pd()->weights_md(1));
58
59 const auto &jcp = kernel_->jcp;
60
61 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
62 const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od
63 * jcp.oh;
64
65 auto ker = [&](const int ithr, const int nthr) {
66 size_t start{0}, end{0};
67 balance211(work_amount, nthr, ithr, start, end);
68
69 int icbb = 0;
70 while (icbb < jcp.nb_ic) {
71 int icb_step = jcp.nb_ic_blocking;
72 int icb_step_rem = jcp.nb_ic - icbb;
73 if (icb_step_rem < jcp.nb_ic_blocking_max)
74 icb_step = icb_step_rem;
75
76 size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0};
77 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
78 od, jcp.od, oh, jcp.oh);
79 for (size_t iwork = start; iwork < end; ++iwork) {
80 int ocb = ocbb * jcp.nb_oc_blocking;
81 int ocb_num = jcp.nb_oc_blocking;
82
83 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
84 auto par_conv = jit_conv_call_s();
85
86 const int ij = oh * jcp.stride_h;
87 const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
88 const int i_b_overflow = nstl::max(jcp.ih, ij
89 + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
90
91 const int dj = od * jcp.stride_d;
92 const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
93 const int d_b_overflow = nstl::max(jcp.id, dj
94 + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
95
96 const size_t _oc = g * jcp.nb_oc + ocb;
97 const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb;
98
99 const int ih = nstl::max(ij - jcp.t_pad
100 + div_up(i_t_overflow,
101 (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
102
103 const int id = nstl::max(dj - jcp.f_pad
104 + div_up(d_t_overflow,
105 (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0);
106
107 par_conv.src = &src[src_blk_off(src_d, n,
108 jcp.ic == 3 ? 0 : _ic, id, ih, 0)];
109
110 par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
111
112 const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
113 const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
114 par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
115 jcp.ic == 3 ? 0 : icb, wd, wh, 0)];
116
117 if (icb == 0) {
118 if (bias)
119 par_conv.bias =
120 &bias[bias_d.blk_off(_oc * jcp.oc_block)];
121 par_conv.flags |= FLAG_IC_FIRST;
122 }
123
124 if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) {
125 par_conv.flags |= FLAG_IC_LAST;
126 }
127
128 par_conv.oc_blocks =
129 nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
130
131 par_conv.kw_padding = 0;
132 const int kh_padding = jcp.kh
133 - div_up(i_t_overflow, (jcp.dilate_h + 1))
134 - div_up(i_b_overflow, (jcp.dilate_h + 1));
135 par_conv.kh_padding = nstl::max(0, kh_padding);
136
137 const int kd_padding = jcp.kd
138 - div_up(d_t_overflow, (jcp.dilate_d + 1))
139 - div_up(d_b_overflow, (jcp.dilate_d + 1));
140 par_conv.kd_padding = nstl::max(0, kd_padding);
141
142 kernel_->jit_ker(&par_conv);
143 }
144 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
145 od, jcp.od, oh, jcp.oh);
146 }
147 icbb += icb_step;
148 }
149 };
150
151 if (pd()->wants_padded_bias()) {
152 auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias);
153 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
154 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
155 jcp.oc - jcp.oc_without_padding);
156 bias = padded_bias;
157 }
158
159 parallel(0, ker);
160
161 if (pd()->wants_zero_pad_dst())
162 ctx.memory(MKLDNN_ARG_DST)->zero_pad();
163}
164
165void jit_avx2_convolution_bwd_data_t::execute_backward_data(
166 const exec_ctx_t &ctx) const {
167 auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
168 auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
169 auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
170
171 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
172 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
173 const memory_desc_wrapper weights_d(pd()->weights_md(0));
174
175 const auto &jcp = kernel_->jcp;
176
177 int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
178 int ih_block_size = jcp.ih;
179 int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
180 size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks;
181 if (work_amount < (size_t)2 * mkldnn_get_max_threads()) {
182 ih_block_size = 1;
183 num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
184 work_amount *= num_ih_blocks;
185 }
186
187 auto ker = [&](const int ithr, const int nthr) {
188 size_t start{0}, end{0};
189 balance211(work_amount, nthr, ithr, start, end);
190
191 size_t n{0}, g{0}, icbb{0}, ihb{0};
192 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work,
193 ihb, num_ih_blocks);
194 for (size_t iwork = start; iwork < end; ++iwork) {
195 for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
196 for (int id = 0; id < jcp.id; ++id) {
197 auto par_conv = jit_conv_call_s();
198
199 const int idp = jcp.id + 2 * jcp.f_pad;
200 const int d_t_overflow = nstl::max(0,
201 jcp.kd - 1 - id - jcp.f_pad);
202 const int back_pad = idp - jcp.id - jcp.f_pad;
203 const int d_b_overflow = nstl::max(0,
204 jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
205 const int od = id + jcp.f_pad - d_b_overflow;
206
207 int ih_start = ihb * ih_block_size;
208 int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
209 for (int ih = ih_start; ih < ih_end; ++ih) {
210
211 const int i_t_overflow = nstl::max(0, (jcp.kh - 1
212 - ih - jcp.t_pad) / jcp.stride_h);
213 const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih
214 + ih - jcp.b_pad) / jcp.stride_h);
215 int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
216 + jcp.b_pad - ih) % jcp.stride_h);
217 int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
218
219 par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
220 par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
221 / jcp.stride_h + 1 - i_t_overflow - i_b_overflow;
222 par_conv.kw_padding = 0;
223
224 const int k_lo = overflow_kh_lo
225 + i_b_overflow * jcp.stride_h;
226 const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
227
228 par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
229 /*jcp.ic == 3 ? 0 :*/
230 g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
231 par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
232 n, g * jcp.nb_oc + oc, od, oh, 0)];
233 par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
234 jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
235 d_b_overflow, k_lo, 0)];
236
237 par_conv.src_prf = nullptr;
238 par_conv.dst_prf = nullptr;
239 par_conv.filt_prf = nullptr;
240 par_conv.channel = oc;
241 par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc,
242 jcp.nb_oc_blocking);
243
244 kernel_->jit_ker(&par_conv);
245 }
246 }
247 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb,
248 num_ih_blocks);
249 }
250 };
251
252 parallel(0, ker);
253}
254
255void jit_avx2_convolution_bwd_weights_t::execute_backward_weights(
256 const exec_ctx_t &ctx) const {
257 auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
258 auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
259 auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
260 auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
261
262 auto scratchpad = this->scratchpad(ctx);
263
264 data_t *diff_bias = pd()->wants_padded_bias()
265 ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
266
267 const memory_desc_wrapper src_d(pd()->src_md());
268 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
269 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
270
271 const auto &jcp = kernel_->jcp;
272
273 auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
274 prefix_reducer_bia);
275 auto rb = this->reducer_bias_;
276 rb->init(reducer_bia_scratchpad);
277
278 auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
279 prefix_reducer_wei);
280 auto rw = this->reducer_weights_;
281 rw->init(reducer_wei_scratchpad);
282
283 auto ker = [&](int ithr, int nthr) {
284 assert(nthr == rw->balancer().nthr_);
285
286 const int w_job_start = rw->balancer().ithr_job_off(ithr);
287 const int w_njobs = rw->balancer().ithr_njobs(ithr);
288
289 if (w_njobs == 0) return;
290
291 /* reduction dimension */
292 int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
293 balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
294 rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
295
296 int img_start = img_od_start, img_end = img_od_end;
297 nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
298 const int img_first = img;
299
300 /* jobs */
301 int g_start{0}, ocb_start{0}, icb_start{0};
302 nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
303 jcp.nb_oc, icb_start, jcp.nb_ic);
304
305 while (img_start < img_end) {
306 int g = g_start, ocb = ocb_start, icb = icb_start;
307
308 const int work_rem = img_end - img_start;
309 const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
310 const int id_s = od_s * jcp.stride_d;
311 const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
312
313 if (id_s < idp - jcp.back_pad - jcp.kd + 1)
314 for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
315 const size_t _oc = g * jcp.nb_oc + ocb;
316 const size_t _ic = g * jcp.nb_ic + icb;
317
318 /* TODO: put dw <-- 0 in kernel */
319 if (img == img_first)
320 array_set(rw->get_local_ptr(ithr, diff_weights,
321 reducer_wei_scratchpad) +
322 w_job_loc * rw->balancer().job_size_, 0,
323 rw->balancer().job_size_);
324
325 for (int od = od_s; od < od_e; ++od) {
326 const int id = od * jcp.stride_d;
327 if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
328
329 auto par_conv = jit_conv_call_s();
330 par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
331 par_conv.dst =
332 &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
333 par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
334 reducer_wei_scratchpad) +
335 w_job_loc * rw->balancer().job_size_;
336
337 kernel_->jit_ker(&par_conv);
338 }
339 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb,
340 jcp.nb_ic);
341 }
342 nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
343 }
344 rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
345 };
346
347 auto ker_bias = [&](int ithr, int nthr) {
348 assert(nthr == rb->balancer().nthr_);
349
350 const int b_job_start = rb->balancer().ithr_job_off(ithr);
351 const int b_njobs = rb->balancer().ithr_njobs(ithr);
352
353 if (b_njobs == 0) return;
354
355 /* reduction dimension */
356 int img_start{0}, img_end{0};
357 balance211(jcp.mb, rb->balancer().nthr_per_group_,
358 rb->balancer().id_in_group(ithr), img_start, img_end);
359
360 /* jobs */
361 int g_start{0}, ocb_start{0};
362 nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start,
363 jcp.nb_oc);
364
365 for (int img = img_start; img < img_end; ++img) {
366 int g = g_start, ocb = ocb_start;
367 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
368 const size_t _oc = g * jcp.nb_oc + ocb;
369
370 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
371 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
372 reducer_bia_scratchpad) +
373 b_job_loc * rb->balancer().job_size_;
374
375 if (img == img_start)
376 for (int o = 0; o < 8; ++o)
377 d_bias[o] = 0.;
378
379 for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
380 PRAGMA_OMP_SIMD()
381 for (int o = 0; o < 8; ++o)
382 d_bias[o] += d_dst[o];
383 d_dst += 8;
384 }
385
386 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
387 }
388 }
389 rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
390 };
391
392 parallel(0, [&](const int ithr, const int nthr) {
393 ker(ithr, nthr);
394 if (pd()->with_bias())
395 ker_bias(ithr, nthr);
396 });
397
398 /* TODO: put this in ker_bias */
399 if (pd()->wants_padded_bias()) {
400 assert(jcp.ngroups == 1);
401 for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
402 diff_bias_in[oc] = diff_bias[oc];
403 }
404}
405
406}
407}
408}
409
410// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
411