| 1 | /******************************************************************************* |
| 2 | * Copyright 2017-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #include "mkldnn_types.h" |
| 18 | |
| 19 | #include "c_types_map.hpp" |
| 20 | #include "jit_sse42_convolution.hpp" |
| 21 | #include "mkldnn_thread.hpp" |
| 22 | |
| 23 | namespace mkldnn { |
| 24 | namespace impl { |
| 25 | namespace cpu { |
| 26 | |
| 27 | using namespace mkldnn::impl::status; |
| 28 | using namespace mkldnn::impl::utils; |
| 29 | |
| 30 | #define src_blk_off(f, n, c, h, w) \ |
| 31 | (pd()->ndims() == 3) \ |
| 32 | ? (f).blk_off(n, c, w) \ |
| 33 | : (f).blk_off(n, c, h, w) |
| 34 | |
| 35 | #define wht_blk_off_(f, g, ...) \ |
| 36 | pd()->with_groups() \ |
| 37 | ? (f).blk_off(g, __VA_ARGS__) \ |
| 38 | : (f).blk_off(__VA_ARGS__) |
| 39 | #define wht_blk_off(f, g, oc, ic, kh, kw) \ |
| 40 | pd()->ndims() == 3 \ |
| 41 | ? wht_blk_off_(f, g, oc, ic, kw) \ |
| 42 | : wht_blk_off_(f, g, oc, ic, kh, kw) |
| 43 | |
| 44 | void jit_sse42_convolution_fwd_t::execute_forward( |
| 45 | const exec_ctx_t &ctx) const { |
| 46 | auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); |
| 47 | auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); |
| 48 | auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); |
| 49 | auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); |
| 50 | |
| 51 | const memory_desc_wrapper src_d(pd()->src_md()); |
| 52 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
| 53 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
| 54 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
| 55 | |
| 56 | const auto &jcp = kernel_->jcp; |
| 57 | |
| 58 | int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); |
| 59 | const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh; |
| 60 | |
| 61 | parallel(0, [&](const int ithr, const int nthr) { |
| 62 | size_t start{ 0 }, end{ 0 }; |
| 63 | balance211(work_amount, nthr, ithr, start, end); |
| 64 | |
| 65 | int icbb = 0; |
| 66 | while (icbb < jcp.nb_ic) { |
| 67 | int icb_step = jcp.nb_ic_blocking; |
| 68 | int icb_step_rem = jcp.nb_ic - icbb; |
| 69 | if (icb_step_rem < jcp.nb_ic_blocking_max) |
| 70 | icb_step = icb_step_rem; |
| 71 | |
| 72 | size_t n{0}, g{0}, ocbb{0}, oh{0}; |
| 73 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, |
| 74 | oh, jcp.oh); |
| 75 | for (size_t iwork = start; iwork < end; ++iwork) { |
| 76 | int ocb = ocbb * jcp.nb_oc_blocking; |
| 77 | int ocb_num = jcp.nb_oc_blocking; |
| 78 | |
| 79 | for (int icb = icbb; icb < icbb + icb_step; ++icb) { |
| 80 | auto par_conv = jit_conv_call_s(); |
| 81 | |
| 82 | const int ij = oh * jcp.stride_h; |
| 83 | const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); |
| 84 | const int i_b_overflow = nstl::max(jcp.ih, ij |
| 85 | + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; |
| 86 | |
| 87 | const size_t _oc = g * jcp.nb_oc + ocb; |
| 88 | const size_t _ic = g * jcp.nb_ic + icb; |
| 89 | |
| 90 | const int ih = nstl::max(ij - jcp.t_pad |
| 91 | + div_up(i_t_overflow, |
| 92 | (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); |
| 93 | par_conv.src = &src[src_blk_off(src_d, n, |
| 94 | jcp.ic == 3 ? 0 : _ic, ih, 0)]; |
| 95 | |
| 96 | par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)]; |
| 97 | |
| 98 | const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); |
| 99 | par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, |
| 100 | jcp.ic == 3 ? 0 : icb, wh, 0)]; |
| 101 | |
| 102 | if (icb == 0) { |
| 103 | if (bias) |
| 104 | par_conv.bias = |
| 105 | &bias[bias_d.blk_off(_oc * jcp.oc_block)]; |
| 106 | par_conv.flags |= FLAG_IC_FIRST; |
| 107 | } |
| 108 | |
| 109 | if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { |
| 110 | par_conv.flags |= FLAG_IC_LAST; |
| 111 | } |
| 112 | |
| 113 | par_conv.oc_blocks = |
| 114 | nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; |
| 115 | |
| 116 | par_conv.kw_padding = 0; |
| 117 | const int kh_padding = jcp.kh |
| 118 | - div_up(i_t_overflow, (jcp.dilate_h + 1)) |
| 119 | - div_up(i_b_overflow, (jcp.dilate_h + 1)); |
| 120 | par_conv.kh_padding = nstl::max(0, kh_padding); |
| 121 | kernel_->jit_ker(&par_conv); |
| 122 | } |
| 123 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, |
| 124 | oh, jcp.oh); |
| 125 | } |
| 126 | icbb += icb_step; |
| 127 | } |
| 128 | }); |
| 129 | |
| 130 | if (pd()->wants_zero_pad_dst()) |
| 131 | ctx.memory(MKLDNN_ARG_DST)->zero_pad(); |
| 132 | } |
| 133 | |
| 134 | } |
| 135 | } |
| 136 | } |
| 137 | |