| 1 | /******************************************************************************* | 
| 2 | * Copyright 2016-2018 Intel Corporation | 
| 3 | * | 
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); | 
| 5 | * you may not use this file except in compliance with the License. | 
| 6 | * You may obtain a copy of the License at | 
| 7 | * | 
| 8 | *     http://www.apache.org/licenses/LICENSE-2.0 | 
| 9 | * | 
| 10 | * Unless required by applicable law or agreed to in writing, software | 
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, | 
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
| 13 | * See the License for the specific language governing permissions and | 
| 14 | * limitations under the License. | 
| 15 | *******************************************************************************/ | 
| 16 |  | 
| 17 | #ifndef INNER_PRODUCT_PD_HPP | 
| 18 | #define INNER_PRODUCT_PD_HPP | 
| 19 |  | 
| 20 | #include "mkldnn.h" | 
| 21 |  | 
| 22 | #include "c_types_map.hpp" | 
| 23 | #include "primitive_desc.hpp" | 
| 24 | #include "utils.hpp" | 
| 25 |  | 
| 26 | namespace mkldnn { | 
| 27 | namespace impl { | 
| 28 |  | 
| 29 | memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc); | 
| 30 | memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc); | 
| 31 | memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc); | 
| 32 | memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc); | 
| 33 | const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc); | 
| 34 | const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc); | 
| 35 | const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc); | 
| 36 | const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc); | 
| 37 |  | 
| 38 | struct inner_product_fwd_pd_t; | 
| 39 |  | 
| 40 | struct inner_product_pd_t: public primitive_desc_t { | 
| 41 |     static constexpr auto base_pkind = primitive_kind::inner_product; | 
| 42 |  | 
| 43 |     inner_product_pd_t(engine_t *engine, | 
| 44 |             const inner_product_desc_t *adesc, | 
| 45 |             const primitive_attr_t *attr, | 
| 46 |             const inner_product_fwd_pd_t *hint_fwd_pd) | 
| 47 |         : primitive_desc_t(engine, attr, base_pkind) | 
| 48 |         , desc_(*adesc) | 
| 49 |         , hint_fwd_pd_(hint_fwd_pd) | 
| 50 |     {} | 
| 51 |  | 
| 52 |     const inner_product_desc_t *desc() const { return &desc_; } | 
| 53 |     virtual const op_desc_t *op_desc() const override | 
| 54 |     { return reinterpret_cast<const op_desc_t *>(this->desc()); } | 
| 55 |     virtual void init_info() override { impl::init_info(this, this->info_); } | 
| 56 |  | 
| 57 |     virtual status_t query(query_t what, int idx, void *result) const override { | 
| 58 |         switch (what) { | 
| 59 |         case query::inner_product_d: | 
| 60 |             *(const inner_product_desc_t**)result = desc(); break; | 
| 61 |         default: return primitive_desc_t::query(what, idx, result); | 
| 62 |         } | 
| 63 |         return status::success; | 
| 64 |     } | 
| 65 |  | 
| 66 |     /* common inner_product aux functions */ | 
| 67 |  | 
| 68 |     dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; } | 
| 69 |     dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; } | 
| 70 |     dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; } | 
| 71 |  | 
| 72 |     dim_t ID() const { | 
| 73 |         return ndims() >= 5 | 
| 74 |             ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; | 
| 75 |     } | 
| 76 |     dim_t IH() const { | 
| 77 |         return ndims() >= 4 | 
| 78 |             ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; | 
| 79 |     } | 
| 80 |     dim_t IW() const { | 
| 81 |         return ndims() >= 3 | 
| 82 |             ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1; | 
| 83 |     } | 
| 84 |  | 
| 85 |     dim_t OD() const { | 
| 86 |         return ndims() >= 5 | 
| 87 |             ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; | 
| 88 |     } | 
| 89 |     dim_t OH() const { | 
| 90 |         return ndims() >= 4 | 
| 91 |             ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; | 
| 92 |     } | 
| 93 |     dim_t OW() const { | 
| 94 |         return ndims() >= 3 | 
| 95 |             ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1; | 
| 96 |     } | 
| 97 |  | 
| 98 |     dim_t KD() const { | 
| 99 |         return ndims() >= 5 | 
| 100 |             ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1; | 
| 101 |     } | 
| 102 |     dim_t KH() const { | 
| 103 |         return ndims() >= 4 | 
| 104 |             ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1; | 
| 105 |     } | 
| 106 |     dim_t KW() const { | 
| 107 |         return ndims() >= 3 | 
| 108 |             ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1; | 
| 109 |     } | 
| 110 |  | 
| 111 |     dim_t IC_total() const { | 
| 112 |         return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1], | 
| 113 |                 ndims() - 1); | 
| 114 |     } | 
| 115 |  | 
| 116 |     dim_t IC_total_padded() const { | 
| 117 |         auto src_d = desc()->prop_kind == prop_kind::backward_data | 
| 118 |             ? memory_desc_wrapper(diff_src_md()) | 
| 119 |             : memory_desc_wrapper(src_md()); | 
| 120 |         assert(src_d.is_blocking_desc()); | 
| 121 |         if (!src_d.is_blocking_desc()) return -1; | 
| 122 |         return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); | 
| 123 |     } | 
| 124 |  | 
| 125 |     int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; } | 
| 126 |  | 
| 127 |     bool with_bias() const | 
| 128 |     { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); } | 
| 129 |  | 
| 130 |     bool has_zero_dim_memory() const { | 
| 131 |         const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_)); | 
| 132 |         const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_)); | 
| 133 |         return s_d.has_zero_dim() || d_d.has_zero_dim(); | 
| 134 |     } | 
| 135 |  | 
| 136 |     bool is_fwd() const { | 
| 137 |         return utils::one_of(desc_.prop_kind, prop_kind::forward_training, | 
| 138 |                 prop_kind::forward_inference); | 
| 139 |     } | 
| 140 |  | 
| 141 | protected: | 
| 142 |     inner_product_desc_t desc_; | 
| 143 |     const inner_product_fwd_pd_t *hint_fwd_pd_; | 
| 144 |  | 
| 145 |     status_t template_set_default_params(memory_desc_t &src_md, | 
| 146 |             memory_desc_t &weights_md, memory_desc_t &dst_md, | 
| 147 |             memory_desc_t *bias_md) { | 
| 148 |         using namespace format_tag; | 
| 149 |         if (src_md.format_kind == format_kind::any) { | 
| 150 |             CHECK(memory_desc_init_by_tag(src_md, | 
| 151 |                         utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw))); | 
| 152 |         } | 
| 153 |         if (dst_md.format_kind == format_kind::any) | 
| 154 |             CHECK(memory_desc_init_by_tag(dst_md, nc)); | 
| 155 |         if (weights_md.format_kind == format_kind::any) { | 
| 156 |             CHECK(memory_desc_init_by_tag(weights_md, | 
| 157 |                         utils::pick(ndims() - 2, oi, oiw, oihw, oidhw))); | 
| 158 |         } | 
| 159 |         if (bias_md && bias_md->format_kind == format_kind::any) | 
| 160 |             CHECK(memory_desc_init_by_tag(*bias_md, x)); | 
| 161 |         return status::success; | 
| 162 |     } | 
| 163 | }; | 
| 164 |  | 
| 165 | struct inner_product_fwd_pd_t: public inner_product_pd_t { | 
| 166 |     typedef inner_product_fwd_pd_t base_class; | 
| 167 |     typedef inner_product_fwd_pd_t hint_class; | 
| 168 |  | 
| 169 |     inner_product_fwd_pd_t(engine_t *engine, | 
| 170 |             const inner_product_desc_t *adesc, | 
| 171 |             const primitive_attr_t *attr, | 
| 172 |             const inner_product_fwd_pd_t *hint_fwd_pd) | 
| 173 |         : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) | 
| 174 |         , src_md_(desc_.src_desc) | 
| 175 |         , weights_md_(desc_.weights_desc) | 
| 176 |         , bias_md_(desc_.bias_desc) | 
| 177 |         , dst_md_(desc_.dst_desc) | 
| 178 |     {} | 
| 179 |  | 
| 180 |     virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { | 
| 181 |         if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) | 
| 182 |             return arg_usage_t::input; | 
| 183 |  | 
| 184 |         if (arg == MKLDNN_ARG_BIAS && with_bias()) | 
| 185 |             return arg_usage_t::input; | 
| 186 |  | 
| 187 |         if (arg == MKLDNN_ARG_DST) | 
| 188 |             return arg_usage_t::output; | 
| 189 |  | 
| 190 |         return primitive_desc_t::arg_usage(arg); | 
| 191 |     } | 
| 192 |  | 
| 193 |     virtual const memory_desc_t *src_md(int index = 0) const override | 
| 194 |     { return index == 0 ? &src_md_ : nullptr; } | 
| 195 |     virtual const memory_desc_t *dst_md(int index = 0) const override | 
| 196 |     { return index == 0 ? &dst_md_ : nullptr; } | 
| 197 |     virtual const memory_desc_t *weights_md(int index = 0) const override { | 
| 198 |         if (index == 0) return &weights_md_; | 
| 199 |         if (index == 1 && with_bias()) return &bias_md_; | 
| 200 |         return nullptr; | 
| 201 |     } | 
| 202 |  | 
| 203 |     virtual int n_inputs() const override { return 2 + with_bias(); } | 
| 204 |     virtual int n_outputs() const override { return 1; } | 
| 205 |  | 
| 206 | protected: | 
| 207 |     memory_desc_t src_md_; | 
| 208 |     memory_desc_t weights_md_; | 
| 209 |     memory_desc_t bias_md_; | 
| 210 |     memory_desc_t dst_md_; | 
| 211 |  | 
| 212 |     status_t set_default_params() { | 
| 213 |         return template_set_default_params(src_md_, weights_md_, dst_md_, | 
| 214 |                 &bias_md_); | 
| 215 |     } | 
| 216 | }; | 
| 217 |  | 
| 218 | struct inner_product_bwd_data_pd_t: public inner_product_pd_t { | 
| 219 |     typedef inner_product_bwd_data_pd_t base_class; | 
| 220 |     typedef inner_product_fwd_pd_t hint_class; | 
| 221 |  | 
| 222 |     inner_product_bwd_data_pd_t(engine_t *engine, | 
| 223 |             const inner_product_desc_t *adesc, | 
| 224 |             const primitive_attr_t *attr, | 
| 225 |             const inner_product_fwd_pd_t *hint_fwd_pd) | 
| 226 |         : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) | 
| 227 |         , diff_src_md_(desc_.diff_src_desc) | 
| 228 |         , weights_md_(desc_.weights_desc) | 
| 229 |         , diff_dst_md_(desc_.diff_dst_desc) | 
| 230 |     {} | 
| 231 |  | 
| 232 |     virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { | 
| 233 |         if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) | 
| 234 |             return arg_usage_t::input; | 
| 235 |  | 
| 236 |         if (arg == MKLDNN_ARG_DIFF_SRC) | 
| 237 |             return arg_usage_t::output; | 
| 238 |  | 
| 239 |         return primitive_desc_t::arg_usage(arg); | 
| 240 |     } | 
| 241 |  | 
| 242 |     virtual const memory_desc_t *diff_src_md(int index = 0) const override | 
| 243 |     { return index == 0 ? &diff_src_md_ : nullptr; } | 
| 244 |     virtual const memory_desc_t *diff_dst_md(int index = 0) const override | 
| 245 |     { return index == 0 ? &diff_dst_md_ : nullptr; } | 
| 246 |     virtual const memory_desc_t *weights_md(int index = 0) const override | 
| 247 |     { return index == 0 ? &weights_md_ : nullptr; } | 
| 248 |  | 
| 249 |     virtual int n_inputs() const override { return 2; } | 
| 250 |     virtual int n_outputs() const override { return 1; } | 
| 251 |  | 
| 252 | protected: | 
| 253 |     memory_desc_t diff_src_md_; | 
| 254 |     memory_desc_t weights_md_; | 
| 255 |     memory_desc_t diff_dst_md_; | 
| 256 |  | 
| 257 |     status_t set_default_params() { | 
| 258 |         return template_set_default_params(diff_src_md_, weights_md_, | 
| 259 |                 diff_dst_md_, nullptr); | 
| 260 |     } | 
| 261 | }; | 
| 262 |  | 
| 263 | struct inner_product_bwd_weights_pd_t: public inner_product_pd_t { | 
| 264 |     typedef inner_product_bwd_weights_pd_t base_class; | 
| 265 |     typedef inner_product_fwd_pd_t hint_class; | 
| 266 |  | 
| 267 |     inner_product_bwd_weights_pd_t(engine_t *engine, | 
| 268 |             const inner_product_desc_t *adesc, | 
| 269 |             const primitive_attr_t *attr, | 
| 270 |             const inner_product_fwd_pd_t *hint_fwd_pd) | 
| 271 |         : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) | 
| 272 |         , src_md_(desc_.src_desc) | 
| 273 |         , diff_weights_md_(desc_.diff_weights_desc) | 
| 274 |         , diff_bias_md_(desc_.diff_bias_desc) | 
| 275 |         , diff_dst_md_(desc_.diff_dst_desc) | 
| 276 |     {} | 
| 277 |  | 
| 278 |     virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { | 
| 279 |         if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) | 
| 280 |             return arg_usage_t::input; | 
| 281 |  | 
| 282 |         if (arg == MKLDNN_ARG_DIFF_WEIGHTS) | 
| 283 |             return arg_usage_t::output; | 
| 284 |  | 
| 285 |         if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) | 
| 286 |             return arg_usage_t::output; | 
| 287 |  | 
| 288 |         return primitive_desc_t::arg_usage(arg); | 
| 289 |     } | 
| 290 |  | 
| 291 |     virtual const memory_desc_t *src_md(int index = 0) const override | 
| 292 |     { return index == 0 ? &src_md_ : nullptr; } | 
| 293 |     virtual const memory_desc_t *diff_dst_md(int index = 0) const override | 
| 294 |     { return index == 0 ? &diff_dst_md_ : nullptr; } | 
| 295 |     virtual const memory_desc_t *diff_weights_md(int index = 0) const override { | 
| 296 |         if (index == 0) return &diff_weights_md_; | 
| 297 |         if (index == 1 && with_bias()) return &diff_bias_md_; | 
| 298 |         return nullptr; | 
| 299 |     } | 
| 300 |  | 
| 301 |     virtual int n_inputs() const override { return 2; } | 
| 302 |     virtual int n_outputs() const override { return 1 + with_bias(); } | 
| 303 |  | 
| 304 | protected: | 
| 305 |     memory_desc_t src_md_; | 
| 306 |     memory_desc_t diff_weights_md_; | 
| 307 |     memory_desc_t diff_bias_md_; | 
| 308 |     memory_desc_t diff_dst_md_; | 
| 309 |  | 
| 310 |     status_t set_default_params() { | 
| 311 |         return template_set_default_params(src_md_, diff_weights_md_, | 
| 312 |                 diff_dst_md_, &diff_bias_md_); | 
| 313 |     } | 
| 314 | }; | 
| 315 |  | 
| 316 | } | 
| 317 | } | 
| 318 |  | 
| 319 | #endif | 
| 320 |  | 
| 321 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s | 
| 322 |  |