| 1 | /******************************************************************************* |
| 2 | * Copyright 2016-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #ifndef CONVOLUTION_PD_HPP |
| 18 | #define CONVOLUTION_PD_HPP |
| 19 | |
| 20 | #include "mkldnn.h" |
| 21 | |
| 22 | #include "c_types_map.hpp" |
| 23 | #include "primitive_desc.hpp" |
| 24 | #include "utils.hpp" |
| 25 | |
| 26 | namespace mkldnn { |
| 27 | namespace impl { |
| 28 | |
| 29 | status_t conv_desc_init(convolution_desc_t *conv_desc, |
| 30 | prop_kind_t prop_kind, alg_kind_t alg_kind, |
| 31 | const memory_desc_t *src_desc, const memory_desc_t *weights_desc, |
| 32 | const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, |
| 33 | const dims_t strides, const dims_t dilates, |
| 34 | const dims_t padding_l, const dims_t padding_r, |
| 35 | padding_kind_t padding_kind); |
| 36 | |
| 37 | memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); |
| 38 | memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); |
| 39 | memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); |
| 40 | memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); |
| 41 | const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); |
| 42 | const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); |
| 43 | const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); |
| 44 | const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); |
| 45 | |
| 46 | struct convolution_fwd_pd_t; |
| 47 | |
| 48 | struct convolution_pd_t: public primitive_desc_t { |
| 49 | static constexpr auto base_pkind = primitive_kind::convolution; |
| 50 | |
| 51 | convolution_pd_t(engine_t *engine, |
| 52 | const convolution_desc_t *adesc, |
| 53 | const primitive_attr_t *attr, |
| 54 | const convolution_fwd_pd_t *hint_fwd_pd) |
| 55 | : primitive_desc_t(engine, attr, base_pkind) |
| 56 | , desc_(*adesc) |
| 57 | , hint_fwd_pd_(hint_fwd_pd) |
| 58 | {} |
| 59 | |
| 60 | const convolution_desc_t *desc() const { return &desc_; } |
| 61 | virtual const op_desc_t *op_desc() const override |
| 62 | { return reinterpret_cast<const op_desc_t *>(this->desc()); } |
| 63 | virtual void init_info() override { impl::init_info(this, this->info_); } |
| 64 | |
| 65 | virtual status_t query(query_t what, int idx, void *result) const override { |
| 66 | switch (what) { |
| 67 | case pkind_traits<base_pkind>::query_d: |
| 68 | *(const convolution_desc_t**)result = desc(); break; |
| 69 | default: return primitive_desc_t::query(what, idx, result); |
| 70 | } |
| 71 | return status::success; |
| 72 | } |
| 73 | |
| 74 | /* common conv aux functions */ |
| 75 | |
| 76 | dim_t MB() const { return _src_md()->dims[0]; } |
| 77 | |
| 78 | dim_t IC() const { return _src_md()->dims[1]; } |
| 79 | dim_t OC() const { return _dst_md()->dims[1]; } |
| 80 | dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; } |
| 81 | |
| 82 | dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; } |
| 83 | dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; } |
| 84 | dim_t IW() const { return _src_md()->dims[ndims() - 1]; } |
| 85 | |
| 86 | dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; } |
| 87 | dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; } |
| 88 | dim_t OW() const { return _dst_md()->dims[ndims() - 1]; } |
| 89 | |
| 90 | dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; } |
| 91 | dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; } |
| 92 | dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; } |
| 93 | |
| 94 | dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } |
| 95 | dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } |
| 96 | dim_t KSW() const { return desc_.strides[ndims() - 3]; } |
| 97 | |
| 98 | dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } |
| 99 | dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } |
| 100 | dim_t KDW() const { return desc_.dilates[ndims() - 3]; } |
| 101 | |
| 102 | dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } |
| 103 | dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } |
| 104 | dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } |
| 105 | dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } |
| 106 | dim_t padL() const { return desc_.padding[0][ndims() - 3]; } |
| 107 | dim_t padR() const { return desc_.padding[1][ndims() - 3]; } |
| 108 | |
| 109 | int ndims() const { return _src_md()->ndims; } |
| 110 | |
| 111 | bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); } |
| 112 | bool with_groups() const { return _wei_md()->ndims == ndims() + 1; } |
| 113 | |
| 114 | bool is_fwd() const { |
| 115 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
| 116 | prop_kind::forward_inference); |
| 117 | } |
| 118 | |
| 119 | bool has_zero_dim_memory() const { |
| 120 | const auto s_d = memory_desc_wrapper(*_src_md()); |
| 121 | const auto d_d = memory_desc_wrapper(*_dst_md()); |
| 122 | return s_d.has_zero_dim() || d_d.has_zero_dim(); |
| 123 | } |
| 124 | |
| 125 | protected: |
| 126 | convolution_desc_t desc_; |
| 127 | const convolution_fwd_pd_t *hint_fwd_pd_; |
| 128 | |
| 129 | bool set_default_formats_common_template( |
| 130 | memory_desc_t &src_md, format_tag_t src_tag, |
| 131 | memory_desc_t &wei_md, format_tag_t wei_tag, |
| 132 | memory_desc_t &dst_md, format_tag_t dst_tag, |
| 133 | memory_desc_t &bia_md) { |
| 134 | using namespace format_tag; |
| 135 | |
| 136 | # define IS_OK(f) \ |
| 137 | do { if ((f) != status::success) return false; } while(0) |
| 138 | if (src_md.format_kind == format_kind::any |
| 139 | && !utils::one_of(src_tag, any, undef)) |
| 140 | IS_OK(memory_desc_init_by_tag(src_md, src_tag)); |
| 141 | if (dst_md.format_kind == format_kind::any |
| 142 | && !utils::one_of(dst_tag, any, undef)) |
| 143 | IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); |
| 144 | if (wei_md.format_kind == format_kind::any |
| 145 | && !utils::one_of(wei_tag, any, undef)) |
| 146 | IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); |
| 147 | if (with_bias() && bia_md.format_kind == format_kind::any) |
| 148 | IS_OK(memory_desc_init_by_tag(bia_md, x)); |
| 149 | # undef IS_OK |
| 150 | |
| 151 | return true; |
| 152 | } |
| 153 | |
| 154 | bool set_default_alg_kind(alg_kind_t alg_kind) { |
| 155 | assert(utils::one_of(alg_kind, alg_kind::convolution_direct, |
| 156 | alg_kind::convolution_winograd)); |
| 157 | if (desc_.alg_kind == alg_kind::convolution_auto) |
| 158 | desc_.alg_kind = alg_kind; |
| 159 | return desc_.alg_kind == alg_kind; |
| 160 | } |
| 161 | |
| 162 | bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, |
| 163 | data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { |
| 164 | bool ok = true |
| 165 | && (src_dt == data_type::undef || _src_md()->data_type == src_dt) |
| 166 | && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt) |
| 167 | && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt) |
| 168 | && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); |
| 169 | if (with_bias() && bia_dt != data_type::undef) |
| 170 | ok = ok && _bia_md()->data_type == bia_dt; |
| 171 | return ok; |
| 172 | } |
| 173 | |
| 174 | private: |
| 175 | const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); } |
| 176 | const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); } |
| 177 | const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); } |
| 178 | const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); } |
| 179 | }; |
| 180 | |
| 181 | struct convolution_fwd_pd_t: public convolution_pd_t { |
| 182 | typedef convolution_fwd_pd_t base_class; |
| 183 | typedef convolution_fwd_pd_t hint_class; |
| 184 | |
| 185 | convolution_fwd_pd_t(engine_t *engine, |
| 186 | const convolution_desc_t *adesc, |
| 187 | const primitive_attr_t *attr, |
| 188 | const convolution_fwd_pd_t *hint_fwd_pd) |
| 189 | : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) |
| 190 | , src_md_(desc_.src_desc) |
| 191 | , weights_md_(desc_.weights_desc) |
| 192 | , bias_md_(desc_.bias_desc) |
| 193 | , dst_md_(desc_.dst_desc) |
| 194 | {} |
| 195 | |
| 196 | virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { |
| 197 | if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) |
| 198 | return arg_usage_t::input; |
| 199 | |
| 200 | if (arg == MKLDNN_ARG_BIAS && with_bias()) |
| 201 | return arg_usage_t::input; |
| 202 | |
| 203 | if (arg == MKLDNN_ARG_DST) |
| 204 | return arg_usage_t::output; |
| 205 | |
| 206 | return primitive_desc_t::arg_usage(arg); |
| 207 | } |
| 208 | |
| 209 | virtual const memory_desc_t *src_md(int index = 0) const override |
| 210 | { return index == 0 ? &src_md_ : nullptr; } |
| 211 | virtual const memory_desc_t *dst_md(int index = 0) const override |
| 212 | { return index == 0 ? &dst_md_ : nullptr; } |
| 213 | virtual const memory_desc_t *weights_md(int index = 0) const override { |
| 214 | if (index == 0) return &weights_md_; |
| 215 | if (index == 1 && with_bias()) return &bias_md_; |
| 216 | return nullptr; |
| 217 | } |
| 218 | |
| 219 | virtual int n_inputs() const override { return 2 + with_bias(); } |
| 220 | virtual int n_outputs() const override { return 1; } |
| 221 | |
| 222 | protected: |
| 223 | memory_desc_t src_md_; |
| 224 | memory_desc_t weights_md_; |
| 225 | memory_desc_t bias_md_; |
| 226 | memory_desc_t dst_md_; |
| 227 | |
| 228 | bool set_default_formats_common(format_tag_t src_tag, |
| 229 | format_tag_t wei_tag, format_tag_t dst_tag) { |
| 230 | return set_default_formats_common_template(src_md_, src_tag, |
| 231 | weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); |
| 232 | } |
| 233 | }; |
| 234 | |
| 235 | struct convolution_bwd_data_pd_t: public convolution_pd_t { |
| 236 | typedef convolution_bwd_data_pd_t base_class; |
| 237 | typedef convolution_fwd_pd_t hint_class; |
| 238 | |
| 239 | convolution_bwd_data_pd_t(engine_t *engine, |
| 240 | const convolution_desc_t *adesc, |
| 241 | const primitive_attr_t *attr, |
| 242 | const convolution_fwd_pd_t *hint_fwd_pd) |
| 243 | : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) |
| 244 | , diff_src_md_(desc_.diff_src_desc) |
| 245 | , weights_md_(desc_.weights_desc) |
| 246 | , bias_md_(desc_.bias_desc) |
| 247 | , diff_dst_md_(desc_.diff_dst_desc) |
| 248 | {} |
| 249 | |
| 250 | virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { |
| 251 | if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) |
| 252 | return arg_usage_t::input; |
| 253 | |
| 254 | if (arg == MKLDNN_ARG_DIFF_SRC) |
| 255 | return arg_usage_t::output; |
| 256 | |
| 257 | return primitive_desc_t::arg_usage(arg); |
| 258 | } |
| 259 | |
| 260 | virtual const memory_desc_t *diff_src_md(int index = 0) const override |
| 261 | { return index == 0 ? &diff_src_md_ : nullptr; } |
| 262 | virtual const memory_desc_t *diff_dst_md(int index = 0) const override |
| 263 | { return index == 0 ? &diff_dst_md_ : nullptr; } |
| 264 | virtual const memory_desc_t *weights_md(int index = 0) const override { |
| 265 | if (index == 0) return &weights_md_; |
| 266 | if (index == 1 && with_bias()) return &bias_md_; |
| 267 | return nullptr; |
| 268 | } |
| 269 | |
| 270 | virtual int n_inputs() const override { return 2 + with_bias(); } |
| 271 | virtual int n_outputs() const override { return 1; } |
| 272 | |
| 273 | virtual bool support_bias() const { return false; } |
| 274 | |
| 275 | protected: |
| 276 | memory_desc_t diff_src_md_; |
| 277 | memory_desc_t weights_md_; |
| 278 | memory_desc_t bias_md_; |
| 279 | memory_desc_t diff_dst_md_; |
| 280 | |
| 281 | bool set_default_formats_common(format_tag_t diff_src_tag, |
| 282 | format_tag_t wei_tag, format_tag_t diff_dst_tag) { |
| 283 | return set_default_formats_common_template(diff_src_md_, diff_src_tag, |
| 284 | weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); |
| 285 | } |
| 286 | }; |
| 287 | |
| 288 | struct convolution_bwd_weights_pd_t: public convolution_pd_t { |
| 289 | typedef convolution_bwd_weights_pd_t base_class; |
| 290 | typedef convolution_fwd_pd_t hint_class; |
| 291 | |
| 292 | convolution_bwd_weights_pd_t(engine_t *engine, |
| 293 | const convolution_desc_t *adesc, |
| 294 | const primitive_attr_t *attr, |
| 295 | const convolution_fwd_pd_t *hint_fwd_pd) |
| 296 | : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) |
| 297 | , src_md_(desc_.src_desc) |
| 298 | , diff_weights_md_(desc_.diff_weights_desc) |
| 299 | , diff_bias_md_(desc_.diff_bias_desc) |
| 300 | , diff_dst_md_(desc_.diff_dst_desc) |
| 301 | {} |
| 302 | |
| 303 | virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { |
| 304 | if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) |
| 305 | return arg_usage_t::input; |
| 306 | |
| 307 | if (arg == MKLDNN_ARG_DIFF_WEIGHTS) |
| 308 | return arg_usage_t::output; |
| 309 | |
| 310 | if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) |
| 311 | return arg_usage_t::output; |
| 312 | |
| 313 | return primitive_desc_t::arg_usage(arg); |
| 314 | } |
| 315 | |
| 316 | virtual const memory_desc_t *src_md(int index = 0) const override |
| 317 | { return index == 0 ? &src_md_ : nullptr; } |
| 318 | virtual const memory_desc_t *diff_dst_md(int index = 0) const override |
| 319 | { return index == 0 ? &diff_dst_md_ : nullptr; } |
| 320 | virtual const memory_desc_t *diff_weights_md(int index = 0) const override { |
| 321 | if (index == 0) return &diff_weights_md_; |
| 322 | if (index == 1 && with_bias()) return &diff_bias_md_; |
| 323 | return nullptr; |
| 324 | } |
| 325 | |
| 326 | virtual int n_inputs() const override { return 2; } |
| 327 | virtual int n_outputs() const override { return 1 + with_bias(); } |
| 328 | |
| 329 | protected: |
| 330 | memory_desc_t src_md_; |
| 331 | memory_desc_t diff_weights_md_; |
| 332 | memory_desc_t diff_bias_md_; |
| 333 | memory_desc_t diff_dst_md_; |
| 334 | |
| 335 | bool set_default_formats_common(format_tag_t src_tag, |
| 336 | format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { |
| 337 | return set_default_formats_common_template(src_md_, src_tag, |
| 338 | diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, |
| 339 | diff_bias_md_); |
| 340 | } |
| 341 | }; |
| 342 | |
| 343 | } |
| 344 | } |
| 345 | |
| 346 | #endif |
| 347 | |
| 348 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
| 349 | |