1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef INNER_PRODUCT_PD_HPP
18#define INNER_PRODUCT_PD_HPP
19
20#include "mkldnn.h"
21
22#include "c_types_map.hpp"
23#include "primitive_desc.hpp"
24#include "utils.hpp"
25
26namespace mkldnn {
27namespace impl {
28
29memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
30memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
31memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
32memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
33const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
34const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
35const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
36const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
37
38struct inner_product_fwd_pd_t;
39
40struct inner_product_pd_t: public primitive_desc_t {
41 static constexpr auto base_pkind = primitive_kind::inner_product;
42
43 inner_product_pd_t(engine_t *engine,
44 const inner_product_desc_t *adesc,
45 const primitive_attr_t *attr,
46 const inner_product_fwd_pd_t *hint_fwd_pd)
47 : primitive_desc_t(engine, attr, base_pkind)
48 , desc_(*adesc)
49 , hint_fwd_pd_(hint_fwd_pd)
50 {}
51
52 const inner_product_desc_t *desc() const { return &desc_; }
53 virtual const op_desc_t *op_desc() const override
54 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
55 virtual void init_info() override { impl::init_info(this, this->info_); }
56
57 virtual status_t query(query_t what, int idx, void *result) const override {
58 switch (what) {
59 case query::inner_product_d:
60 *(const inner_product_desc_t**)result = desc(); break;
61 default: return primitive_desc_t::query(what, idx, result);
62 }
63 return status::success;
64 }
65
66 /* common inner_product aux functions */
67
68 dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
69 dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
70 dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
71
72 dim_t ID() const {
73 return ndims() >= 5
74 ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
75 }
76 dim_t IH() const {
77 return ndims() >= 4
78 ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
79 }
80 dim_t IW() const {
81 return ndims() >= 3
82 ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
83 }
84
85 dim_t OD() const {
86 return ndims() >= 5
87 ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
88 }
89 dim_t OH() const {
90 return ndims() >= 4
91 ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
92 }
93 dim_t OW() const {
94 return ndims() >= 3
95 ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
96 }
97
98 dim_t KD() const {
99 return ndims() >= 5
100 ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
101 }
102 dim_t KH() const {
103 return ndims() >= 4
104 ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
105 }
106 dim_t KW() const {
107 return ndims() >= 3
108 ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
109 }
110
111 dim_t IC_total() const {
112 return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
113 ndims() - 1);
114 }
115
116 dim_t IC_total_padded() const {
117 auto src_d = desc()->prop_kind == prop_kind::backward_data
118 ? memory_desc_wrapper(diff_src_md())
119 : memory_desc_wrapper(src_md());
120 assert(src_d.is_blocking_desc());
121 if (!src_d.is_blocking_desc()) return -1;
122 return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
123 }
124
125 int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
126
127 bool with_bias() const
128 { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
129
130 bool has_zero_dim_memory() const {
131 const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
132 const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
133 return s_d.has_zero_dim() || d_d.has_zero_dim();
134 }
135
136 bool is_fwd() const {
137 return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
138 prop_kind::forward_inference);
139 }
140
141protected:
142 inner_product_desc_t desc_;
143 const inner_product_fwd_pd_t *hint_fwd_pd_;
144
145 status_t template_set_default_params(memory_desc_t &src_md,
146 memory_desc_t &weights_md, memory_desc_t &dst_md,
147 memory_desc_t *bias_md) {
148 using namespace format_tag;
149 if (src_md.format_kind == format_kind::any) {
150 CHECK(memory_desc_init_by_tag(src_md,
151 utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
152 }
153 if (dst_md.format_kind == format_kind::any)
154 CHECK(memory_desc_init_by_tag(dst_md, nc));
155 if (weights_md.format_kind == format_kind::any) {
156 CHECK(memory_desc_init_by_tag(weights_md,
157 utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
158 }
159 if (bias_md && bias_md->format_kind == format_kind::any)
160 CHECK(memory_desc_init_by_tag(*bias_md, x));
161 return status::success;
162 }
163};
164
165struct inner_product_fwd_pd_t: public inner_product_pd_t {
166 typedef inner_product_fwd_pd_t base_class;
167 typedef inner_product_fwd_pd_t hint_class;
168
169 inner_product_fwd_pd_t(engine_t *engine,
170 const inner_product_desc_t *adesc,
171 const primitive_attr_t *attr,
172 const inner_product_fwd_pd_t *hint_fwd_pd)
173 : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
174 , src_md_(desc_.src_desc)
175 , weights_md_(desc_.weights_desc)
176 , bias_md_(desc_.bias_desc)
177 , dst_md_(desc_.dst_desc)
178 {}
179
180 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
181 if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
182 return arg_usage_t::input;
183
184 if (arg == MKLDNN_ARG_BIAS && with_bias())
185 return arg_usage_t::input;
186
187 if (arg == MKLDNN_ARG_DST)
188 return arg_usage_t::output;
189
190 return primitive_desc_t::arg_usage(arg);
191 }
192
193 virtual const memory_desc_t *src_md(int index = 0) const override
194 { return index == 0 ? &src_md_ : nullptr; }
195 virtual const memory_desc_t *dst_md(int index = 0) const override
196 { return index == 0 ? &dst_md_ : nullptr; }
197 virtual const memory_desc_t *weights_md(int index = 0) const override {
198 if (index == 0) return &weights_md_;
199 if (index == 1 && with_bias()) return &bias_md_;
200 return nullptr;
201 }
202
203 virtual int n_inputs() const override { return 2 + with_bias(); }
204 virtual int n_outputs() const override { return 1; }
205
206protected:
207 memory_desc_t src_md_;
208 memory_desc_t weights_md_;
209 memory_desc_t bias_md_;
210 memory_desc_t dst_md_;
211
212 status_t set_default_params() {
213 return template_set_default_params(src_md_, weights_md_, dst_md_,
214 &bias_md_);
215 }
216};
217
218struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
219 typedef inner_product_bwd_data_pd_t base_class;
220 typedef inner_product_fwd_pd_t hint_class;
221
222 inner_product_bwd_data_pd_t(engine_t *engine,
223 const inner_product_desc_t *adesc,
224 const primitive_attr_t *attr,
225 const inner_product_fwd_pd_t *hint_fwd_pd)
226 : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
227 , diff_src_md_(desc_.diff_src_desc)
228 , weights_md_(desc_.weights_desc)
229 , diff_dst_md_(desc_.diff_dst_desc)
230 {}
231
232 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
233 if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
234 return arg_usage_t::input;
235
236 if (arg == MKLDNN_ARG_DIFF_SRC)
237 return arg_usage_t::output;
238
239 return primitive_desc_t::arg_usage(arg);
240 }
241
242 virtual const memory_desc_t *diff_src_md(int index = 0) const override
243 { return index == 0 ? &diff_src_md_ : nullptr; }
244 virtual const memory_desc_t *diff_dst_md(int index = 0) const override
245 { return index == 0 ? &diff_dst_md_ : nullptr; }
246 virtual const memory_desc_t *weights_md(int index = 0) const override
247 { return index == 0 ? &weights_md_ : nullptr; }
248
249 virtual int n_inputs() const override { return 2; }
250 virtual int n_outputs() const override { return 1; }
251
252protected:
253 memory_desc_t diff_src_md_;
254 memory_desc_t weights_md_;
255 memory_desc_t diff_dst_md_;
256
257 status_t set_default_params() {
258 return template_set_default_params(diff_src_md_, weights_md_,
259 diff_dst_md_, nullptr);
260 }
261};
262
263struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
264 typedef inner_product_bwd_weights_pd_t base_class;
265 typedef inner_product_fwd_pd_t hint_class;
266
267 inner_product_bwd_weights_pd_t(engine_t *engine,
268 const inner_product_desc_t *adesc,
269 const primitive_attr_t *attr,
270 const inner_product_fwd_pd_t *hint_fwd_pd)
271 : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
272 , src_md_(desc_.src_desc)
273 , diff_weights_md_(desc_.diff_weights_desc)
274 , diff_bias_md_(desc_.diff_bias_desc)
275 , diff_dst_md_(desc_.diff_dst_desc)
276 {}
277
278 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
279 if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
280 return arg_usage_t::input;
281
282 if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
283 return arg_usage_t::output;
284
285 if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
286 return arg_usage_t::output;
287
288 return primitive_desc_t::arg_usage(arg);
289 }
290
291 virtual const memory_desc_t *src_md(int index = 0) const override
292 { return index == 0 ? &src_md_ : nullptr; }
293 virtual const memory_desc_t *diff_dst_md(int index = 0) const override
294 { return index == 0 ? &diff_dst_md_ : nullptr; }
295 virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
296 if (index == 0) return &diff_weights_md_;
297 if (index == 1 && with_bias()) return &diff_bias_md_;
298 return nullptr;
299 }
300
301 virtual int n_inputs() const override { return 2; }
302 virtual int n_outputs() const override { return 1 + with_bias(); }
303
304protected:
305 memory_desc_t src_md_;
306 memory_desc_t diff_weights_md_;
307 memory_desc_t diff_bias_md_;
308 memory_desc_t diff_dst_md_;
309
310 status_t set_default_params() {
311 return template_set_default_params(src_md_, diff_weights_md_,
312 diff_dst_md_, &diff_bias_md_);
313 }
314};
315
316}
317}
318
319#endif
320
321// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
322