1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef DECONVOLUTION_PD_HPP
18#define DECONVOLUTION_PD_HPP
19
20#include "mkldnn.h"
21
22#include "c_types_map.hpp"
23#include "convolution_pd.hpp"
24#include "primitive_desc.hpp"
25#include "utils.hpp"
26
27namespace mkldnn {
28namespace impl {
29
30struct deconvolution_fwd_pd_t;
31
32struct deconvolution_pd_t: public primitive_desc_t {
33 static constexpr auto base_pkind = primitive_kind::deconvolution;
34
35 deconvolution_pd_t(engine_t *engine,
36 const deconvolution_desc_t *adesc,
37 const primitive_attr_t *attr,
38 const deconvolution_fwd_pd_t *hint_fwd_pd)
39 : primitive_desc_t(engine, attr, base_pkind)
40 , desc_(*adesc)
41 , hint_fwd_pd_(hint_fwd_pd)
42 {}
43
44 const deconvolution_desc_t *desc() const { return &desc_; }
45 virtual const op_desc_t *op_desc() const override
46 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
47 virtual void init_info() override { impl::init_info(this, this->info_); }
48
49 virtual status_t query(query_t what, int idx, void *result) const override {
50 switch (what) {
51 case pkind_traits<base_pkind>::query_d:
52 *(const deconvolution_desc_t **)result = desc();
53 break;
54 default: return primitive_desc_t::query(what, idx, result);
55 }
56 return status::success;
57 }
58
59 /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
60
61 dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
62
63 dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
64 dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
65 dim_t G() const
66 { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
67
68 dim_t ID() const {
69 return ndims() >= 5
70 ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
71 }
72 dim_t IH() const {
73 return ndims() >= 4
74 ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
75 }
76 dim_t IW() const {
77 return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
78 }
79
80 dim_t OD() const {
81 return ndims() >= 5
82 ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
83 }
84 dim_t OH() const {
85 return ndims() >= 4
86 ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
87 }
88 dim_t OW() const {
89 return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
90 }
91
92 dim_t KD() const {
93 const int w_ndims = ndims() + with_groups();
94 return ndims() >= 5
95 ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
96 }
97 dim_t KH() const {
98 const int w_ndims = ndims() + with_groups();
99 return ndims() >= 4
100 ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
101 }
102 dim_t KW() const {
103 const int w_ndims = ndims() + with_groups();
104 return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
105 }
106
107 dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
108 dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
109 dim_t KSW() const { return desc_.strides[ndims() - 3]; }
110
111 dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
112 dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
113 dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
114
115 dim_t padFront() const
116 { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
117 dim_t padBack() const
118 { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
119 dim_t padT() const
120 { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
121 dim_t padB() const
122 { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
123 dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
124 dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
125
126 bool with_bias() const {
127 return
128 !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
129 }
130
131 bool with_groups() const
132 { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
133
134 int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
135
136 bool is_fwd() const {
137 return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
138 prop_kind::forward_inference);
139 }
140
141 bool has_zero_dim_memory() const {
142 const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
143 const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
144 return s_d.has_zero_dim() || d_d.has_zero_dim();
145 }
146
147protected:
148 deconvolution_desc_t desc_;
149 const deconvolution_fwd_pd_t *hint_fwd_pd_;
150};
151
152struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
153 typedef deconvolution_fwd_pd_t base_class;
154 typedef deconvolution_fwd_pd_t hint_class;
155
156 deconvolution_fwd_pd_t(engine_t *engine,
157 const deconvolution_desc_t *adesc,
158 const primitive_attr_t *attr,
159 const deconvolution_fwd_pd_t *hint_fwd_pd)
160 : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
161 , src_md_(desc_.src_desc)
162 , weights_md_(desc_.weights_desc)
163 , bias_md_(desc_.bias_desc)
164 , dst_md_(desc_.dst_desc)
165 {}
166
167 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
168 if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
169 return arg_usage_t::input;
170
171 if (arg == MKLDNN_ARG_BIAS && with_bias())
172 return arg_usage_t::input;
173
174 if (arg == MKLDNN_ARG_DST)
175 return arg_usage_t::output;
176
177 return primitive_desc_t::arg_usage(arg);
178 }
179
180 virtual const memory_desc_t *src_md(int index = 0) const override
181 { return index == 0 ? &src_md_ : nullptr; }
182 virtual const memory_desc_t *dst_md(int index = 0) const override
183 { return index == 0 ? &dst_md_ : nullptr; }
184 virtual const memory_desc_t *weights_md(int index = 0) const override {
185 if (index == 0) return &weights_md_;
186 if (index == 1 && with_bias()) return &bias_md_;
187 return nullptr;
188 }
189
190 virtual int n_inputs() const override { return 2 + with_bias(); }
191 virtual int n_outputs() const override { return 1; }
192
193protected:
194 memory_desc_t src_md_;
195 memory_desc_t weights_md_;
196 memory_desc_t bias_md_;
197 memory_desc_t dst_md_;
198};
199
200struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
201 typedef deconvolution_bwd_data_pd_t base_class;
202 typedef deconvolution_fwd_pd_t hint_class;
203
204 deconvolution_bwd_data_pd_t(engine_t *engine,
205 const deconvolution_desc_t *adesc,
206 const primitive_attr_t *attr,
207 const deconvolution_fwd_pd_t *hint_fwd_pd)
208 : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
209 , diff_src_md_(desc_.diff_src_desc)
210 , weights_md_(desc_.weights_desc)
211 , diff_dst_md_(desc_.diff_dst_desc)
212 {}
213
214 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
215 if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
216 return arg_usage_t::input;
217
218 if (arg == MKLDNN_ARG_DIFF_SRC)
219 return arg_usage_t::output;
220
221 return primitive_desc_t::arg_usage(arg);
222 }
223
224 virtual const memory_desc_t *diff_src_md(int index = 0) const override
225 { return index == 0 ? &diff_src_md_ : nullptr; }
226 virtual const memory_desc_t *diff_dst_md(int index = 0) const override
227 { return index == 0 ? &diff_dst_md_ : nullptr; }
228 virtual const memory_desc_t *weights_md(int index = 0) const override
229 { return index == 0 ? &weights_md_ : nullptr; }
230
231 virtual int n_inputs() const override { return 2; }
232 virtual int n_outputs() const override { return 1; }
233
234protected:
235 memory_desc_t diff_src_md_;
236 memory_desc_t weights_md_;
237 memory_desc_t diff_dst_md_;
238};
239
240struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
241 typedef deconvolution_bwd_weights_pd_t base_class;
242 typedef deconvolution_fwd_pd_t hint_class;
243
244 deconvolution_bwd_weights_pd_t(engine_t *engine,
245 const deconvolution_desc_t *adesc,
246 const primitive_attr_t *attr,
247 const deconvolution_fwd_pd_t *hint_fwd_pd)
248 : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
249 , src_md_(desc_.src_desc)
250 , diff_weights_md_(desc_.diff_weights_desc)
251 , diff_bias_md_(desc_.diff_bias_desc)
252 , diff_dst_md_(desc_.diff_dst_desc)
253 {}
254
255 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
256 if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
257 return arg_usage_t::input;
258
259 if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
260 return arg_usage_t::output;
261
262 if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
263 return arg_usage_t::output;
264
265 return primitive_desc_t::arg_usage(arg);
266 }
267
268 virtual const memory_desc_t *src_md(int index = 0) const override
269 { return index == 0 ? &src_md_ : nullptr; }
270 virtual const memory_desc_t *diff_dst_md(int index = 0) const override
271 { return index == 0 ? &diff_dst_md_ : nullptr; }
272 virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
273 if (index == 0) return &diff_weights_md_;
274 if (index == 1 && with_bias()) return &diff_bias_md_;
275 return nullptr;
276 }
277
278 virtual int n_inputs() const override { return 2; }
279 virtual int n_outputs() const override { return 1 + with_bias(); }
280
281protected:
282 memory_desc_t src_md_;
283 memory_desc_t diff_weights_md_;
284 memory_desc_t diff_bias_md_;
285 memory_desc_t diff_dst_md_;
286};
287
288}
289}
290
291#endif
292
293// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
294