1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "mkldnn.h"
18
19#include "c_types_map.hpp"
20#include "type_helpers.hpp"
21#include "utils.hpp"
22#include "cpu/gemm/os_blas.hpp"
23
24using namespace mkldnn::impl;
25using namespace mkldnn::impl::status;
26using namespace mkldnn::impl::types;
27using namespace mkldnn::impl::utils;
28
29namespace {
30memory_desc_t copy_maybe_null(const memory_desc_t *md) {
31 return md ? *md : zero_md();
32}
33
34rnn_desc_t zero_rnn_desc() {
35 auto rd = rnn_desc_t();
36 rd.src_layer_desc = zero_md();
37 rd.src_iter_desc = zero_md();
38 rd.weights_layer_desc = zero_md();
39 rd.weights_iter_desc = zero_md();
40 rd.bias_desc = zero_md();
41 rd.dst_layer_desc = zero_md();
42 rd.dst_iter_desc = zero_md();
43 rd.diff_src_layer_desc = zero_md();
44 rd.diff_src_iter_desc = zero_md();
45 rd.diff_weights_layer_desc = zero_md();
46 rd.diff_weights_iter_desc = zero_md();
47 rd.diff_bias_desc = zero_md();
48 rd.diff_dst_layer_desc = zero_md();
49 rd.diff_dst_iter_desc = zero_md();
50 return rd;
51}
52}
53
54/* Public C Api */
55
56status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
57 mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
58 unsigned int flags, float alpha, float clipping) {
59 using namespace mkldnn::impl::alg_kind;
60
61 bool args_ok = true
62 && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
63 gru_linear_before_reset)
64 && IMPLICATION(cell_kind == vanilla_rnn,
65 one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
66 if (!args_ok)
67 return invalid_arguments;
68
69 auto rcd = mkldnn_rnn_cell_desc_t();
70
71 rcd.cell_kind = cell_kind;
72 rcd.activation_kind = act_f;
73 rcd.flags = flags;
74 rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
75 rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
76
77 *rnn_cell_desc = rcd;
78
79 return success;
80}
81
82int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
83 switch (rnn_cell_desc->cell_kind) {
84 case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
85 case mkldnn::impl::alg_kind::vanilla_gru: return 3;
86 case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
87 case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
88 default: assert(!"unknown cell kind"); return 0;
89 }
90 return 0;
91}
92
93int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
94 switch (rnn_cell_desc->cell_kind) {
95 case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
96 case mkldnn::impl::alg_kind::vanilla_gru: return 1;
97 case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
98 case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
99 default: assert(!"unknown cell kind"); return 0;
100 }
101 return 0;
102}
103
104status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
105 prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
106 const memory_desc_t *src_iter_desc,
107 const memory_desc_t *weights_layer_desc,
108 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
109 const memory_desc_t *dst_layer_desc,
110 const memory_desc_t *dst_iter_desc) {
111 using namespace data_type;
112 data_type_t src_layer_dt = src_layer_desc->data_type;
113 data_type_t dst_layer_dt = dst_layer_desc->data_type;
114 data_type_t weights_iter_dt = weights_iter_desc->data_type;
115 data_type_t weights_layer_dt = weights_layer_desc->data_type;
116
117 bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
118 weights_layer_dt)
119 && IMPLICATION(!is_zero_md(src_iter_desc),
120 src_iter_desc->data_type == f32)
121 && IMPLICATION(!is_zero_md(dst_iter_desc),
122 dst_iter_desc->data_type == f32)
123 && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
124
125#if USE_MKL_PACKED_GEMM
126 bool is_u8u8u8 = src_layer_dt == u8
127 && IMPLICATION(!is_zero_md(src_iter_desc),
128 src_iter_desc->data_type == u8)
129 && IMPLICATION(!is_zero_md(dst_iter_desc),
130 dst_iter_desc->data_type == u8)
131 && one_of(dst_layer_dt, u8, f32)
132 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
133 && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
134
135 bool is_f32u8f32 = src_layer_dt == u8
136 && IMPLICATION(!is_zero_md(src_iter_desc),
137 src_iter_desc->data_type == f32)
138 && IMPLICATION(!is_zero_md(dst_iter_desc),
139 dst_iter_desc->data_type == f32)
140 && one_of(dst_layer_dt, u8, f32)
141 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
142 && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
143
144 bool is_inference = prop_kind == prop_kind::forward_inference;
145 bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
146
147 return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
148 ? success
149 : unimplemented;
150#else
151 return is_f32 ? success : unimplemented;
152#endif
153}
154
155status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
156 rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
157 int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
158 const memory_desc_t *src_iter_desc,
159 const memory_desc_t *weights_layer_desc,
160 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
161 const memory_desc_t *dst_layer_desc,
162 const memory_desc_t *dst_iter_desc) {
163 bool args_ok;
164
165 // * algorithm specific
166 args_ok = true
167 && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
168 DIC == SIC);
169 if (!args_ok) return invalid_arguments;
170 int extra_bias =
171 rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
172
173 // * on num layers
174 args_ok = true
175 && L == weights_layer_desc->dims[0]
176 && L == weights_iter_desc->dims[0]
177 && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
178 && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
179 && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
180 if (!args_ok) return invalid_arguments;
181
182 // * on num directions
183 args_ok = true
184 && D == weights_layer_desc->dims[1]
185 && D == weights_iter_desc->dims[1]
186 && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
187 && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
188 && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
189 if (!args_ok) return invalid_arguments;
190
191 // * on num iterations
192 args_ok = true
193 && T == src_layer_desc->dims[0]
194 && T == dst_layer_desc->dims[0];
195 if (!args_ok) return invalid_arguments;
196
197 // * on mb
198 args_ok = true
199 && N == src_layer_desc->dims[1]
200 && N == dst_layer_desc->dims[1]
201 && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
202 && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
203 if (!args_ok) return invalid_arguments;
204
205 // * on num gates
206 args_ok = true
207 && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
208 && G == weights_layer_desc->dims[3]
209 && G == weights_iter_desc->dims[3]
210 && IMPLICATION(!is_zero_md(bias_desc),
211 G + extra_bias == bias_desc->dims[2]);
212 if (!args_ok) return invalid_arguments;
213
214 // * on num states
215 args_ok = true
216 && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
217 && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
218 && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
219 if (!args_ok) return invalid_arguments;
220
221 // * on slc
222 args_ok = true
223 && SLC == weights_layer_desc->dims[2]
224 && SLC == src_layer_desc->dims[2];
225 if (!args_ok) return invalid_arguments;
226
227 // * on sic
228 args_ok = true
229 && SIC == weights_iter_desc->dims[2]
230 && IMPLICATION(!is_zero_md(src_iter_desc),
231 SIC == src_iter_desc->dims[4]);
232 if (!args_ok) return invalid_arguments;
233
234 // * on dlc
235 int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
236 args_ok = true
237 && DLC == dlc_multiplier * DIC
238 && DLC == dst_layer_desc->dims[2];
239 if (!args_ok) return invalid_arguments;
240
241 // * on dic
242 args_ok = true
243 && DIC == weights_layer_desc->dims[4]
244 && DIC == weights_iter_desc->dims[4]
245 && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
246 && IMPLICATION(!is_zero_md(dst_iter_desc),
247 DIC == dst_iter_desc->dims[4]);
248 if (!args_ok) return invalid_arguments;
249
250 // * unrolling/fusion conditions
251 args_ok = true
252 && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
253 && IMPLICATION(T > 1, SIC == DIC);
254 if (!args_ok) return invalid_arguments;
255
256 return success;
257}
258
259status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
260 prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
261 const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
262 const memory_desc_t *src_iter_desc,
263 const memory_desc_t *weights_layer_desc,
264 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
265 const memory_desc_t *dst_layer_desc,
266 const memory_desc_t *dst_iter_desc) {
267 bool args_ok = true && rnn_cell_desc != nullptr
268 && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
269 dst_layer_desc);
270 if (!args_ok) return invalid_arguments;
271
272 //check dimensions consistency
273 int L = weights_layer_desc->dims[0];
274 int T = src_layer_desc->dims[0];
275 int N = src_layer_desc->dims[1];
276 const int D = one_of(direction, mkldnn_unidirectional_left2right,
277 mkldnn_unidirectional_right2left) ?
278 1 :
279 2;
280 int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
281 int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
282 int SLC = src_layer_desc->dims[2];
283 int SIC = weights_iter_desc->dims[2];
284 int DLC = dst_layer_desc->dims[2];
285 int DIC = weights_layer_desc->dims[4];
286
287 CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
288 G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
289 weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
290 dst_iter_desc));
291
292 CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
293 src_layer_desc, src_iter_desc, weights_layer_desc,
294 weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
295
296 // Create the descriptor
297 mkldnn_rnn_desc_t rd = zero_rnn_desc();
298
299 rd.primitive_kind = primitive_kind::rnn;
300 rd.prop_kind = prop_kind;
301 rd.cell_desc = *rnn_cell_desc;
302 rd.direction = direction;
303 rd.src_layer_desc = copy_maybe_null(src_layer_desc);
304 rd.src_iter_desc = copy_maybe_null(src_iter_desc);
305 rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
306 rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
307 rd.bias_desc = copy_maybe_null(bias_desc);
308 rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
309 rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
310
311 *rnn_desc = rd;
312
313 return success;
314}
315
316status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
317 prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
318 const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
319 const memory_desc_t *src_iter_desc,
320 const memory_desc_t *weights_layer_desc,
321 const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
322 const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
323 const memory_desc_t *diff_src_layer_desc,
324 const memory_desc_t *diff_src_iter_desc,
325 const memory_desc_t *diff_weights_layer_desc,
326 const memory_desc_t *diff_weights_iter_desc,
327 const memory_desc_t *diff_bias_desc,
328 const memory_desc_t *diff_dst_layer_desc,
329 const memory_desc_t *diff_dst_iter_desc) {
330 bool args_ok = true
331 && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
332 dst_layer_desc, diff_src_layer_desc,
333 diff_weights_layer_desc, diff_weights_iter_desc,
334 diff_dst_layer_desc);
335 if (!args_ok)
336 return invalid_arguments;
337
338 auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
339 return is_zero_md(a_md) == is_zero_md(b_md);
340 };
341
342 args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
343 && xnor_md(dst_iter_desc, diff_dst_iter_desc)
344 && xnor_md(src_iter_desc, diff_src_iter_desc);
345 if (!args_ok)
346 return invalid_arguments;
347
348 //check dimensions consistency
349 int L = weights_layer_desc->dims[0];
350 int T = src_layer_desc->dims[0];
351 int N = src_layer_desc->dims[1];
352 const int D = one_of(direction, mkldnn_unidirectional_left2right,
353 mkldnn_unidirectional_right2left) ?
354 1 :
355 2;
356 int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
357 int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
358 int SLC = src_layer_desc->dims[2];
359 int SIC = weights_iter_desc->dims[2];
360 int DLC = dst_layer_desc->dims[2];
361 int DIC = weights_layer_desc->dims[4];
362
363 status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
364 G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
365 weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
366 dst_iter_desc);
367 if (st != success) return st;
368
369 st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
370 G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
371 diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
372 diff_dst_layer_desc, diff_dst_iter_desc);
373 if (st != success) return st;
374
375 mkldnn_rnn_desc_t rd = zero_rnn_desc();
376
377 rd.primitive_kind = primitive_kind::rnn;
378 rd.prop_kind = prop_kind;
379 rd.cell_desc = *rnn_cell_desc;
380 rd.direction = direction;
381
382 rd.src_layer_desc = copy_maybe_null(src_layer_desc);
383 rd.src_iter_desc = copy_maybe_null(src_iter_desc);
384 rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
385 rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
386 rd.bias_desc = copy_maybe_null(bias_desc);
387 rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
388 rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
389 rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
390 rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
391 rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
392 rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
393 rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
394 rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
395 rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
396
397 *rnn_desc = rd;
398
399 return success;
400}
401