| 1 | /******************************************************************************* |
| 2 | * Copyright 2016-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #include "c_types_map.hpp" |
| 18 | #include "mkldnn_thread.hpp" |
| 19 | #include "type_helpers.hpp" |
| 20 | #include "utils.hpp" |
| 21 | |
| 22 | #include "jit_avx2_convolution.hpp" |
| 23 | |
| 24 | namespace mkldnn { |
| 25 | namespace impl { |
| 26 | namespace cpu { |
| 27 | |
| 28 | using namespace mkldnn::impl::status; |
| 29 | using namespace mkldnn::impl::memory_tracking::names; |
| 30 | using namespace mkldnn::impl::utils; |
| 31 | |
| 32 | #define src_blk_off(f, n, c, d, h, w) \ |
| 33 | (pd()->ndims() == 3) \ |
| 34 | ? (f).blk_off(n, c, w) \ |
| 35 | : (pd()->ndims() == 4) \ |
| 36 | ? (f).blk_off(n, c, h, w) \ |
| 37 | : (f).blk_off(n, c, d, h, w) |
| 38 | |
| 39 | #define wht_blk_off_(f, g, ...) \ |
| 40 | pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__) |
| 41 | #define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ |
| 42 | (pd()->ndims() == 3) \ |
| 43 | ? wht_blk_off_(f, g, oc, ic, kw) \ |
| 44 | : (pd()->ndims() == 4) \ |
| 45 | ? wht_blk_off_(f, g, oc, ic, kh, kw) \ |
| 46 | : wht_blk_off_(f, g, oc, ic, kd, kh, kw) |
| 47 | |
| 48 | void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
| 49 | auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); |
| 50 | auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); |
| 51 | auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); |
| 52 | auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); |
| 53 | |
| 54 | const memory_desc_wrapper src_d(pd()->src_md()); |
| 55 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
| 56 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
| 57 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
| 58 | |
| 59 | const auto &jcp = kernel_->jcp; |
| 60 | |
| 61 | int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); |
| 62 | const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od |
| 63 | * jcp.oh; |
| 64 | |
| 65 | auto ker = [&](const int ithr, const int nthr) { |
| 66 | size_t start{0}, end{0}; |
| 67 | balance211(work_amount, nthr, ithr, start, end); |
| 68 | |
| 69 | int icbb = 0; |
| 70 | while (icbb < jcp.nb_ic) { |
| 71 | int icb_step = jcp.nb_ic_blocking; |
| 72 | int icb_step_rem = jcp.nb_ic - icbb; |
| 73 | if (icb_step_rem < jcp.nb_ic_blocking_max) |
| 74 | icb_step = icb_step_rem; |
| 75 | |
| 76 | size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0}; |
| 77 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, |
| 78 | od, jcp.od, oh, jcp.oh); |
| 79 | for (size_t iwork = start; iwork < end; ++iwork) { |
| 80 | int ocb = ocbb * jcp.nb_oc_blocking; |
| 81 | int ocb_num = jcp.nb_oc_blocking; |
| 82 | |
| 83 | for (int icb = icbb; icb < icbb + icb_step; ++icb) { |
| 84 | auto par_conv = jit_conv_call_s(); |
| 85 | |
| 86 | const int ij = oh * jcp.stride_h; |
| 87 | const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); |
| 88 | const int i_b_overflow = nstl::max(jcp.ih, ij |
| 89 | + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; |
| 90 | |
| 91 | const int dj = od * jcp.stride_d; |
| 92 | const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); |
| 93 | const int d_b_overflow = nstl::max(jcp.id, dj |
| 94 | + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id; |
| 95 | |
| 96 | const size_t _oc = g * jcp.nb_oc + ocb; |
| 97 | const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb; |
| 98 | |
| 99 | const int ih = nstl::max(ij - jcp.t_pad |
| 100 | + div_up(i_t_overflow, |
| 101 | (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); |
| 102 | |
| 103 | const int id = nstl::max(dj - jcp.f_pad |
| 104 | + div_up(d_t_overflow, |
| 105 | (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0); |
| 106 | |
| 107 | par_conv.src = &src[src_blk_off(src_d, n, |
| 108 | jcp.ic == 3 ? 0 : _ic, id, ih, 0)]; |
| 109 | |
| 110 | par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; |
| 111 | |
| 112 | const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); |
| 113 | const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); |
| 114 | par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, |
| 115 | jcp.ic == 3 ? 0 : icb, wd, wh, 0)]; |
| 116 | |
| 117 | if (icb == 0) { |
| 118 | if (bias) |
| 119 | par_conv.bias = |
| 120 | &bias[bias_d.blk_off(_oc * jcp.oc_block)]; |
| 121 | par_conv.flags |= FLAG_IC_FIRST; |
| 122 | } |
| 123 | |
| 124 | if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { |
| 125 | par_conv.flags |= FLAG_IC_LAST; |
| 126 | } |
| 127 | |
| 128 | par_conv.oc_blocks = |
| 129 | nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; |
| 130 | |
| 131 | par_conv.kw_padding = 0; |
| 132 | const int kh_padding = jcp.kh |
| 133 | - div_up(i_t_overflow, (jcp.dilate_h + 1)) |
| 134 | - div_up(i_b_overflow, (jcp.dilate_h + 1)); |
| 135 | par_conv.kh_padding = nstl::max(0, kh_padding); |
| 136 | |
| 137 | const int kd_padding = jcp.kd |
| 138 | - div_up(d_t_overflow, (jcp.dilate_d + 1)) |
| 139 | - div_up(d_b_overflow, (jcp.dilate_d + 1)); |
| 140 | par_conv.kd_padding = nstl::max(0, kd_padding); |
| 141 | |
| 142 | kernel_->jit_ker(&par_conv); |
| 143 | } |
| 144 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, |
| 145 | od, jcp.od, oh, jcp.oh); |
| 146 | } |
| 147 | icbb += icb_step; |
| 148 | } |
| 149 | }; |
| 150 | |
| 151 | if (pd()->wants_padded_bias()) { |
| 152 | auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias); |
| 153 | utils::array_copy(padded_bias, bias, jcp.oc_without_padding); |
| 154 | utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, |
| 155 | jcp.oc - jcp.oc_without_padding); |
| 156 | bias = padded_bias; |
| 157 | } |
| 158 | |
| 159 | parallel(0, ker); |
| 160 | |
| 161 | if (pd()->wants_zero_pad_dst()) |
| 162 | ctx.memory(MKLDNN_ARG_DST)->zero_pad(); |
| 163 | } |
| 164 | |
| 165 | void jit_avx2_convolution_bwd_data_t::execute_backward_data( |
| 166 | const exec_ctx_t &ctx) const { |
| 167 | auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); |
| 168 | auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); |
| 169 | auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); |
| 170 | |
| 171 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
| 172 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
| 173 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
| 174 | |
| 175 | const auto &jcp = kernel_->jcp; |
| 176 | |
| 177 | int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; |
| 178 | int ih_block_size = jcp.ih; |
| 179 | int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); |
| 180 | size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; |
| 181 | if (work_amount < (size_t)2 * mkldnn_get_max_threads()) { |
| 182 | ih_block_size = 1; |
| 183 | num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); |
| 184 | work_amount *= num_ih_blocks; |
| 185 | } |
| 186 | |
| 187 | auto ker = [&](const int ithr, const int nthr) { |
| 188 | size_t start{0}, end{0}; |
| 189 | balance211(work_amount, nthr, ithr, start, end); |
| 190 | |
| 191 | size_t n{0}, g{0}, icbb{0}, ihb{0}; |
| 192 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, |
| 193 | ihb, num_ih_blocks); |
| 194 | for (size_t iwork = start; iwork < end; ++iwork) { |
| 195 | for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) |
| 196 | for (int id = 0; id < jcp.id; ++id) { |
| 197 | auto par_conv = jit_conv_call_s(); |
| 198 | |
| 199 | const int idp = jcp.id + 2 * jcp.f_pad; |
| 200 | const int d_t_overflow = nstl::max(0, |
| 201 | jcp.kd - 1 - id - jcp.f_pad); |
| 202 | const int back_pad = idp - jcp.id - jcp.f_pad; |
| 203 | const int d_b_overflow = nstl::max(0, |
| 204 | jcp.kd - 1 - (jcp.id - 1 - id) - back_pad); |
| 205 | const int od = id + jcp.f_pad - d_b_overflow; |
| 206 | |
| 207 | int ih_start = ihb * ih_block_size; |
| 208 | int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size); |
| 209 | for (int ih = ih_start; ih < ih_end; ++ih) { |
| 210 | |
| 211 | const int i_t_overflow = nstl::max(0, (jcp.kh - 1 |
| 212 | - ih - jcp.t_pad) / jcp.stride_h); |
| 213 | const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih |
| 214 | + ih - jcp.b_pad) / jcp.stride_h); |
| 215 | int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 |
| 216 | + jcp.b_pad - ih) % jcp.stride_h); |
| 217 | int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h; |
| 218 | |
| 219 | par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow; |
| 220 | par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo) |
| 221 | / jcp.stride_h + 1 - i_t_overflow - i_b_overflow; |
| 222 | par_conv.kw_padding = 0; |
| 223 | |
| 224 | const int k_lo = overflow_kh_lo |
| 225 | + i_b_overflow * jcp.stride_h; |
| 226 | const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h; |
| 227 | |
| 228 | par_conv.src = &diff_src[src_blk_off(diff_src_d, n, |
| 229 | /*jcp.ic == 3 ? 0 :*/ |
| 230 | g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)]; |
| 231 | par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, |
| 232 | n, g * jcp.nb_oc + oc, od, oh, 0)]; |
| 233 | par_conv.filt = &weights[wht_blk_off(weights_d, g, oc, |
| 234 | jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb, |
| 235 | d_b_overflow, k_lo, 0)]; |
| 236 | |
| 237 | par_conv.src_prf = nullptr; |
| 238 | par_conv.dst_prf = nullptr; |
| 239 | par_conv.filt_prf = nullptr; |
| 240 | par_conv.channel = oc; |
| 241 | par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc, |
| 242 | jcp.nb_oc_blocking); |
| 243 | |
| 244 | kernel_->jit_ker(&par_conv); |
| 245 | } |
| 246 | } |
| 247 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, |
| 248 | num_ih_blocks); |
| 249 | } |
| 250 | }; |
| 251 | |
| 252 | parallel(0, ker); |
| 253 | } |
| 254 | |
| 255 | void jit_avx2_convolution_bwd_weights_t::execute_backward_weights( |
| 256 | const exec_ctx_t &ctx) const { |
| 257 | auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); |
| 258 | auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); |
| 259 | auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); |
| 260 | auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); |
| 261 | |
| 262 | auto scratchpad = this->scratchpad(ctx); |
| 263 | |
| 264 | data_t *diff_bias = pd()->wants_padded_bias() |
| 265 | ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in; |
| 266 | |
| 267 | const memory_desc_wrapper src_d(pd()->src_md()); |
| 268 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
| 269 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
| 270 | |
| 271 | const auto &jcp = kernel_->jcp; |
| 272 | |
| 273 | auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, |
| 274 | prefix_reducer_bia); |
| 275 | auto rb = this->reducer_bias_; |
| 276 | rb->init(reducer_bia_scratchpad); |
| 277 | |
| 278 | auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, |
| 279 | prefix_reducer_wei); |
| 280 | auto rw = this->reducer_weights_; |
| 281 | rw->init(reducer_wei_scratchpad); |
| 282 | |
| 283 | auto ker = [&](int ithr, int nthr) { |
| 284 | assert(nthr == rw->balancer().nthr_); |
| 285 | |
| 286 | const int w_job_start = rw->balancer().ithr_job_off(ithr); |
| 287 | const int w_njobs = rw->balancer().ithr_njobs(ithr); |
| 288 | |
| 289 | if (w_njobs == 0) return; |
| 290 | |
| 291 | /* reduction dimension */ |
| 292 | int img_od_start{0}, img_od_end{0}, img{0}, od_s{0}; |
| 293 | balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_, |
| 294 | rw->balancer().id_in_group(ithr), img_od_start, img_od_end); |
| 295 | |
| 296 | int img_start = img_od_start, img_end = img_od_end; |
| 297 | nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); |
| 298 | const int img_first = img; |
| 299 | |
| 300 | /* jobs */ |
| 301 | int g_start{0}, ocb_start{0}, icb_start{0}; |
| 302 | nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start, |
| 303 | jcp.nb_oc, icb_start, jcp.nb_ic); |
| 304 | |
| 305 | while (img_start < img_end) { |
| 306 | int g = g_start, ocb = ocb_start, icb = icb_start; |
| 307 | |
| 308 | const int work_rem = img_end - img_start; |
| 309 | const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; |
| 310 | const int id_s = od_s * jcp.stride_d; |
| 311 | const int idp = jcp.id + jcp.f_pad + jcp.back_pad; |
| 312 | |
| 313 | if (id_s < idp - jcp.back_pad - jcp.kd + 1) |
| 314 | for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) { |
| 315 | const size_t _oc = g * jcp.nb_oc + ocb; |
| 316 | const size_t _ic = g * jcp.nb_ic + icb; |
| 317 | |
| 318 | /* TODO: put dw <-- 0 in kernel */ |
| 319 | if (img == img_first) |
| 320 | array_set(rw->get_local_ptr(ithr, diff_weights, |
| 321 | reducer_wei_scratchpad) + |
| 322 | w_job_loc * rw->balancer().job_size_, 0, |
| 323 | rw->balancer().job_size_); |
| 324 | |
| 325 | for (int od = od_s; od < od_e; ++od) { |
| 326 | const int id = od * jcp.stride_d; |
| 327 | if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break; |
| 328 | |
| 329 | auto par_conv = jit_conv_call_s(); |
| 330 | par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)]; |
| 331 | par_conv.dst = |
| 332 | &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)]; |
| 333 | par_conv.filt = rw->get_local_ptr(ithr, diff_weights, |
| 334 | reducer_wei_scratchpad) + |
| 335 | w_job_loc * rw->balancer().job_size_; |
| 336 | |
| 337 | kernel_->jit_ker(&par_conv); |
| 338 | } |
| 339 | nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb, |
| 340 | jcp.nb_ic); |
| 341 | } |
| 342 | nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); |
| 343 | } |
| 344 | rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); |
| 345 | }; |
| 346 | |
| 347 | auto ker_bias = [&](int ithr, int nthr) { |
| 348 | assert(nthr == rb->balancer().nthr_); |
| 349 | |
| 350 | const int b_job_start = rb->balancer().ithr_job_off(ithr); |
| 351 | const int b_njobs = rb->balancer().ithr_njobs(ithr); |
| 352 | |
| 353 | if (b_njobs == 0) return; |
| 354 | |
| 355 | /* reduction dimension */ |
| 356 | int img_start{0}, img_end{0}; |
| 357 | balance211(jcp.mb, rb->balancer().nthr_per_group_, |
| 358 | rb->balancer().id_in_group(ithr), img_start, img_end); |
| 359 | |
| 360 | /* jobs */ |
| 361 | int g_start{0}, ocb_start{0}; |
| 362 | nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, |
| 363 | jcp.nb_oc); |
| 364 | |
| 365 | for (int img = img_start; img < img_end; ++img) { |
| 366 | int g = g_start, ocb = ocb_start; |
| 367 | for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { |
| 368 | const size_t _oc = g * jcp.nb_oc + ocb; |
| 369 | |
| 370 | const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; |
| 371 | data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, |
| 372 | reducer_bia_scratchpad) + |
| 373 | b_job_loc * rb->balancer().job_size_; |
| 374 | |
| 375 | if (img == img_start) |
| 376 | for (int o = 0; o < 8; ++o) |
| 377 | d_bias[o] = 0.; |
| 378 | |
| 379 | for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) { |
| 380 | PRAGMA_OMP_SIMD() |
| 381 | for (int o = 0; o < 8; ++o) |
| 382 | d_bias[o] += d_dst[o]; |
| 383 | d_dst += 8; |
| 384 | } |
| 385 | |
| 386 | nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); |
| 387 | } |
| 388 | } |
| 389 | rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); |
| 390 | }; |
| 391 | |
| 392 | parallel(0, [&](const int ithr, const int nthr) { |
| 393 | ker(ithr, nthr); |
| 394 | if (pd()->with_bias()) |
| 395 | ker_bias(ithr, nthr); |
| 396 | }); |
| 397 | |
| 398 | /* TODO: put this in ker_bias */ |
| 399 | if (pd()->wants_padded_bias()) { |
| 400 | assert(jcp.ngroups == 1); |
| 401 | for (int oc = 0; oc < jcp.oc_without_padding; ++oc) |
| 402 | diff_bias_in[oc] = diff_bias[oc]; |
| 403 | } |
| 404 | } |
| 405 | |
| 406 | } |
| 407 | } |
| 408 | } |
| 409 | |
| 410 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
| 411 | |