1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef RNN_PD_HPP
18#define RNN_PD_HPP
19
20#include "mkldnn.h"
21
22#include "c_types_map.hpp"
23#include "primitive_desc.hpp"
24#include "type_helpers.hpp"
25
26namespace mkldnn {
27namespace impl {
28
29struct rnn_fwd_pd_t;
30
31struct rnn_pd_t : public primitive_desc_t {
32 static constexpr auto base_pkind = primitive_kind::rnn;
33
34 rnn_pd_t(engine_t *engine,
35 const rnn_desc_t *adesc,
36 const primitive_attr_t *attr,
37 const rnn_fwd_pd_t *hint_fwd_pd)
38 : primitive_desc_t(engine, attr, base_pkind)
39 , desc_(*adesc)
40 , hint_fwd_pd_(hint_fwd_pd)
41 , src_layer_md_(desc_.src_layer_desc)
42 , src_iter_md_(desc_.src_iter_desc)
43 , weights_layer_md_(desc_.weights_layer_desc)
44 , weights_iter_md_(desc_.weights_iter_desc)
45 , bias_md_(desc_.bias_desc)
46 , dst_layer_md_(desc_.dst_layer_desc)
47 , dst_iter_md_(desc_.dst_iter_desc)
48 , ws_md_()
49 {}
50
51 const rnn_desc_t *desc() const { return &desc_; }
52 virtual const op_desc_t *op_desc() const override
53 { return reinterpret_cast<const op_desc_t *>(this->desc()); }
54 virtual void init_info() override { impl::init_info(this, this->info_); }
55
56 virtual status_t query(query_t what, int idx, void *result) const override {
57 switch (what) {
58 case query::rnn_d: *(const rnn_desc_t **)result = desc(); break;
59 default: return primitive_desc_t::query(what, idx, result);
60 }
61 return status::success;
62 }
63
64 virtual const memory_desc_t *src_md(int index = 0) const override {
65 if (index == 0) return &src_layer_md_;
66 if (index == 1 && with_src_iter()) return &src_iter_md_;
67 return nullptr;
68 }
69 virtual const memory_desc_t *weights_md(int index = 0) const override {
70 if (index == 0) return &weights_layer_md_;
71 if (index == 1) return &weights_iter_md_;
72 if (index == 2 && with_bias()) return &bias_md_;
73 return nullptr;
74 }
75 virtual const memory_desc_t *dst_md(int index = 0) const override {
76 if (index == 0) return &dst_layer_md_;
77 if (index == 1 && with_dst_iter()) return &dst_iter_md_;
78 return nullptr;
79 }
80 virtual const memory_desc_t *workspace_md(int index = 0) const override
81 { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
82
83 /* common pooling aux functions */
84
85 bool is_training() const {
86 return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
87 prop_kind::backward);
88 }
89
90 bool is_fwd() const {
91 return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
92 prop_kind::forward_inference);
93 }
94
95 dim_t T() const { return desc_.src_layer_desc.dims[0]; }
96 dim_t MB() const { return desc_.src_layer_desc.dims[1]; }
97
98 dim_t L() const { return desc_.weights_layer_desc.dims[0]; }
99 dim_t D() const { return desc_.weights_layer_desc.dims[1]; }
100
101 dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; }
102
103 dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; }
104 dim_t G() const { return desc_.weights_layer_desc.dims[3]; }
105 dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; }
106
107 dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; }
108
109 bool with_bias() const
110 { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
111
112 bool with_src_iter() const
113 { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); }
114
115 bool with_dst_iter() const
116 { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); }
117
118 mkldnn::impl::alg_kind_t cell_kind() const
119 { return desc_.cell_desc.cell_kind; }
120 mkldnn::impl::alg_kind_t activation_kind() const
121 { return desc_.cell_desc.activation_kind; }
122
123 bool is_lbr() const
124 { return cell_kind() == mkldnn_gru_linear_before_reset; }
125
126 mkldnn_rnn_direction_t direction() const { return desc_.direction; }
127
128protected:
129 rnn_desc_t desc_;
130 const rnn_fwd_pd_t *hint_fwd_pd_;
131
132 memory_desc_t src_layer_md_;
133 memory_desc_t src_iter_md_;
134 memory_desc_t weights_layer_md_;
135 memory_desc_t weights_iter_md_;
136 memory_desc_t bias_md_;
137 memory_desc_t dst_layer_md_;
138 memory_desc_t dst_iter_md_;
139
140 memory_desc_t ws_md_;
141};
142
143struct rnn_fwd_pd_t: public rnn_pd_t {
144 typedef rnn_fwd_pd_t base_class;
145 typedef rnn_fwd_pd_t hint_class;
146
147 rnn_fwd_pd_t(engine_t *engine,
148 const rnn_desc_t *adesc,
149 const primitive_attr_t *attr,
150 const rnn_fwd_pd_t *hint_fwd_pd)
151 : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
152 {}
153
154 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
155 if (arg == MKLDNN_ARG_SRC_LAYER)
156 return arg_usage_t::input;
157
158 if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter())
159 return arg_usage_t::input;
160
161 if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
162 MKLDNN_ARG_WEIGHTS_ITER))
163 return arg_usage_t::input;
164
165 if (arg == MKLDNN_ARG_BIAS && with_bias())
166 return arg_usage_t::input;
167
168 if (arg == MKLDNN_ARG_DST_LAYER)
169 return arg_usage_t::output;
170
171 if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter())
172 return arg_usage_t::output;
173
174 if (arg == MKLDNN_ARG_WORKSPACE && is_training())
175 return arg_usage_t::output;
176
177 return primitive_desc_t::arg_usage(arg);
178 }
179
180 virtual int n_inputs() const override
181 { return 3 + with_bias() + with_src_iter(); }
182 virtual int n_outputs() const override
183 { return 1 + with_dst_iter() + is_training(); }
184};
185
186struct rnn_bwd_pd_t : public rnn_pd_t {
187 typedef rnn_bwd_pd_t base_class;
188 typedef rnn_fwd_pd_t hint_class;
189
190 rnn_bwd_pd_t(engine_t *engine,
191 const rnn_desc_t *adesc,
192 const primitive_attr_t *attr,
193 const rnn_fwd_pd_t *hint_fwd_pd)
194 : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
195 , diff_src_layer_md_(desc_.diff_src_layer_desc)
196 , diff_src_iter_md_(desc_.diff_src_iter_desc)
197 , diff_weights_layer_md_(desc_.diff_weights_layer_desc)
198 , diff_weights_iter_md_(desc_.diff_weights_iter_desc)
199 , diff_bias_md_(desc_.diff_bias_desc)
200 , diff_dst_layer_md_(desc_.diff_dst_layer_desc)
201 , diff_dst_iter_md_(desc_.diff_dst_iter_desc)
202 {}
203
204 virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
205 if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER,
206 MKLDNN_ARG_DIFF_DST_LAYER))
207 return arg_usage_t::input;
208
209 if (with_src_iter()) {
210 if (arg == MKLDNN_ARG_SRC_ITER)
211 return arg_usage_t::input;
212
213 if (arg == MKLDNN_ARG_DIFF_SRC_ITER)
214 return arg_usage_t::output;
215 }
216
217 if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
218 MKLDNN_ARG_WEIGHTS_ITER))
219 return arg_usage_t::input;
220
221 if (with_bias()) {
222 if (arg == MKLDNN_ARG_BIAS)
223 return arg_usage_t::input;
224
225 if (arg == MKLDNN_ARG_DIFF_BIAS)
226 return arg_usage_t::output;
227 }
228
229 if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER)
230 && with_dst_iter())
231 return arg_usage_t::input;
232
233 if (arg == MKLDNN_ARG_WORKSPACE)
234 return arg_usage_t::input;
235
236 if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER,
237 MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
238 MKLDNN_ARG_DIFF_WEIGHTS_ITER))
239 return arg_usage_t::output;
240
241 return primitive_desc_t::arg_usage(arg);
242 }
243
244 virtual const memory_desc_t *diff_src_md(int index = 0) const override {
245 if (index == 0) return &diff_src_layer_md_;
246 if (index == 1 && with_src_iter()) return &diff_src_iter_md_;
247 return nullptr;
248 }
249 virtual const memory_desc_t *diff_weights_md(
250 int index = 0) const override {
251 if (index == 0) return &diff_weights_layer_md_;
252 if (index == 1) return &diff_weights_iter_md_;
253 if (index == 2 && with_bias()) return &diff_bias_md_;
254 return nullptr;
255 }
256 virtual const memory_desc_t *diff_dst_md(int index = 0) const override {
257 if (index == 0) return &diff_dst_layer_md_;
258 if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_;
259 return nullptr;
260 }
261
262 virtual int n_inputs() const override
263 { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); }
264 virtual int n_outputs() const override
265 { return 3 + with_src_iter() + with_bias(); }
266
267protected:
268 memory_desc_t diff_src_layer_md_;
269 memory_desc_t diff_src_iter_md_;
270 memory_desc_t diff_weights_layer_md_;
271 memory_desc_t diff_weights_iter_md_;
272 memory_desc_t diff_bias_md_;
273 memory_desc_t diff_dst_layer_md_;
274 memory_desc_t diff_dst_iter_md_;
275};
276
277}
278}
279
280#endif
281