1 | /******************************************************************************* |
2 | * Copyright 2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "mkldnn.h" |
18 | |
19 | #include "c_types_map.hpp" |
20 | #include "type_helpers.hpp" |
21 | #include "utils.hpp" |
22 | #include "cpu/gemm/os_blas.hpp" |
23 | |
24 | using namespace mkldnn::impl; |
25 | using namespace mkldnn::impl::status; |
26 | using namespace mkldnn::impl::types; |
27 | using namespace mkldnn::impl::utils; |
28 | |
29 | namespace { |
30 | memory_desc_t copy_maybe_null(const memory_desc_t *md) { |
31 | return md ? *md : zero_md(); |
32 | } |
33 | |
34 | rnn_desc_t zero_rnn_desc() { |
35 | auto rd = rnn_desc_t(); |
36 | rd.src_layer_desc = zero_md(); |
37 | rd.src_iter_desc = zero_md(); |
38 | rd.weights_layer_desc = zero_md(); |
39 | rd.weights_iter_desc = zero_md(); |
40 | rd.bias_desc = zero_md(); |
41 | rd.dst_layer_desc = zero_md(); |
42 | rd.dst_iter_desc = zero_md(); |
43 | rd.diff_src_layer_desc = zero_md(); |
44 | rd.diff_src_iter_desc = zero_md(); |
45 | rd.diff_weights_layer_desc = zero_md(); |
46 | rd.diff_weights_iter_desc = zero_md(); |
47 | rd.diff_bias_desc = zero_md(); |
48 | rd.diff_dst_layer_desc = zero_md(); |
49 | rd.diff_dst_iter_desc = zero_md(); |
50 | return rd; |
51 | } |
52 | } |
53 | |
54 | /* Public C Api */ |
55 | |
56 | status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, |
57 | mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, |
58 | unsigned int flags, float alpha, float clipping) { |
59 | using namespace mkldnn::impl::alg_kind; |
60 | |
61 | bool args_ok = true |
62 | && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, |
63 | gru_linear_before_reset) |
64 | && IMPLICATION(cell_kind == vanilla_rnn, |
65 | one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); |
66 | if (!args_ok) |
67 | return invalid_arguments; |
68 | |
69 | auto rcd = mkldnn_rnn_cell_desc_t(); |
70 | |
71 | rcd.cell_kind = cell_kind; |
72 | rcd.activation_kind = act_f; |
73 | rcd.flags = flags; |
74 | rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; |
75 | rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; |
76 | |
77 | *rnn_cell_desc = rcd; |
78 | |
79 | return success; |
80 | } |
81 | |
82 | int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { |
83 | switch (rnn_cell_desc->cell_kind) { |
84 | case mkldnn::impl::alg_kind::vanilla_rnn: return 1; |
85 | case mkldnn::impl::alg_kind::vanilla_gru: return 3; |
86 | case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; |
87 | case mkldnn::impl::alg_kind::vanilla_lstm: return 4; |
88 | default: assert(!"unknown cell kind" ); return 0; |
89 | } |
90 | return 0; |
91 | } |
92 | |
93 | int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { |
94 | switch (rnn_cell_desc->cell_kind) { |
95 | case mkldnn::impl::alg_kind::vanilla_rnn: return 1; |
96 | case mkldnn::impl::alg_kind::vanilla_gru: return 1; |
97 | case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; |
98 | case mkldnn::impl::alg_kind::vanilla_lstm: return 2; |
99 | default: assert(!"unknown cell kind" ); return 0; |
100 | } |
101 | return 0; |
102 | } |
103 | |
104 | status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, |
105 | prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, |
106 | const memory_desc_t *src_iter_desc, |
107 | const memory_desc_t *weights_layer_desc, |
108 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
109 | const memory_desc_t *dst_layer_desc, |
110 | const memory_desc_t *dst_iter_desc) { |
111 | using namespace data_type; |
112 | data_type_t src_layer_dt = src_layer_desc->data_type; |
113 | data_type_t dst_layer_dt = dst_layer_desc->data_type; |
114 | data_type_t weights_iter_dt = weights_iter_desc->data_type; |
115 | data_type_t weights_layer_dt = weights_layer_desc->data_type; |
116 | |
117 | bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, |
118 | weights_layer_dt) |
119 | && IMPLICATION(!is_zero_md(src_iter_desc), |
120 | src_iter_desc->data_type == f32) |
121 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
122 | dst_iter_desc->data_type == f32) |
123 | && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); |
124 | |
125 | #if USE_MKL_PACKED_GEMM |
126 | bool is_u8u8u8 = src_layer_dt == u8 |
127 | && IMPLICATION(!is_zero_md(src_iter_desc), |
128 | src_iter_desc->data_type == u8) |
129 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
130 | dst_iter_desc->data_type == u8) |
131 | && one_of(dst_layer_dt, u8, f32) |
132 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
133 | && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); |
134 | |
135 | bool is_f32u8f32 = src_layer_dt == u8 |
136 | && IMPLICATION(!is_zero_md(src_iter_desc), |
137 | src_iter_desc->data_type == f32) |
138 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
139 | dst_iter_desc->data_type == f32) |
140 | && one_of(dst_layer_dt, u8, f32) |
141 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
142 | && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); |
143 | |
144 | bool is_inference = prop_kind == prop_kind::forward_inference; |
145 | bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; |
146 | |
147 | return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) |
148 | ? success |
149 | : unimplemented; |
150 | #else |
151 | return is_f32 ? success : unimplemented; |
152 | #endif |
153 | } |
154 | |
155 | status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, |
156 | rnn_direction_t direction, int L, int D, int T, int N, int S, int G, |
157 | int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, |
158 | const memory_desc_t *src_iter_desc, |
159 | const memory_desc_t *weights_layer_desc, |
160 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
161 | const memory_desc_t *dst_layer_desc, |
162 | const memory_desc_t *dst_iter_desc) { |
163 | bool args_ok; |
164 | |
165 | // * algorithm specific |
166 | args_ok = true |
167 | && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, |
168 | DIC == SIC); |
169 | if (!args_ok) return invalid_arguments; |
170 | int = |
171 | rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; |
172 | |
173 | // * on num layers |
174 | args_ok = true |
175 | && L == weights_layer_desc->dims[0] |
176 | && L == weights_iter_desc->dims[0] |
177 | && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) |
178 | && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) |
179 | && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); |
180 | if (!args_ok) return invalid_arguments; |
181 | |
182 | // * on num directions |
183 | args_ok = true |
184 | && D == weights_layer_desc->dims[1] |
185 | && D == weights_iter_desc->dims[1] |
186 | && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) |
187 | && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) |
188 | && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); |
189 | if (!args_ok) return invalid_arguments; |
190 | |
191 | // * on num iterations |
192 | args_ok = true |
193 | && T == src_layer_desc->dims[0] |
194 | && T == dst_layer_desc->dims[0]; |
195 | if (!args_ok) return invalid_arguments; |
196 | |
197 | // * on mb |
198 | args_ok = true |
199 | && N == src_layer_desc->dims[1] |
200 | && N == dst_layer_desc->dims[1] |
201 | && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) |
202 | && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); |
203 | if (!args_ok) return invalid_arguments; |
204 | |
205 | // * on num gates |
206 | args_ok = true |
207 | && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) |
208 | && G == weights_layer_desc->dims[3] |
209 | && G == weights_iter_desc->dims[3] |
210 | && IMPLICATION(!is_zero_md(bias_desc), |
211 | G + extra_bias == bias_desc->dims[2]); |
212 | if (!args_ok) return invalid_arguments; |
213 | |
214 | // * on num states |
215 | args_ok = true |
216 | && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) |
217 | && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) |
218 | && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); |
219 | if (!args_ok) return invalid_arguments; |
220 | |
221 | // * on slc |
222 | args_ok = true |
223 | && SLC == weights_layer_desc->dims[2] |
224 | && SLC == src_layer_desc->dims[2]; |
225 | if (!args_ok) return invalid_arguments; |
226 | |
227 | // * on sic |
228 | args_ok = true |
229 | && SIC == weights_iter_desc->dims[2] |
230 | && IMPLICATION(!is_zero_md(src_iter_desc), |
231 | SIC == src_iter_desc->dims[4]); |
232 | if (!args_ok) return invalid_arguments; |
233 | |
234 | // * on dlc |
235 | int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; |
236 | args_ok = true |
237 | && DLC == dlc_multiplier * DIC |
238 | && DLC == dst_layer_desc->dims[2]; |
239 | if (!args_ok) return invalid_arguments; |
240 | |
241 | // * on dic |
242 | args_ok = true |
243 | && DIC == weights_layer_desc->dims[4] |
244 | && DIC == weights_iter_desc->dims[4] |
245 | && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) |
246 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
247 | DIC == dst_iter_desc->dims[4]); |
248 | if (!args_ok) return invalid_arguments; |
249 | |
250 | // * unrolling/fusion conditions |
251 | args_ok = true |
252 | && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) |
253 | && IMPLICATION(T > 1, SIC == DIC); |
254 | if (!args_ok) return invalid_arguments; |
255 | |
256 | return success; |
257 | } |
258 | |
259 | status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, |
260 | prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, |
261 | const rnn_direction_t direction, const memory_desc_t *src_layer_desc, |
262 | const memory_desc_t *src_iter_desc, |
263 | const memory_desc_t *weights_layer_desc, |
264 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
265 | const memory_desc_t *dst_layer_desc, |
266 | const memory_desc_t *dst_iter_desc) { |
267 | bool args_ok = true && rnn_cell_desc != nullptr |
268 | && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, |
269 | dst_layer_desc); |
270 | if (!args_ok) return invalid_arguments; |
271 | |
272 | //check dimensions consistency |
273 | int L = weights_layer_desc->dims[0]; |
274 | int T = src_layer_desc->dims[0]; |
275 | int N = src_layer_desc->dims[1]; |
276 | const int D = one_of(direction, mkldnn_unidirectional_left2right, |
277 | mkldnn_unidirectional_right2left) ? |
278 | 1 : |
279 | 2; |
280 | int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); |
281 | int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); |
282 | int SLC = src_layer_desc->dims[2]; |
283 | int SIC = weights_iter_desc->dims[2]; |
284 | int DLC = dst_layer_desc->dims[2]; |
285 | int DIC = weights_layer_desc->dims[4]; |
286 | |
287 | CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, |
288 | G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, |
289 | weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, |
290 | dst_iter_desc)); |
291 | |
292 | CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, |
293 | src_layer_desc, src_iter_desc, weights_layer_desc, |
294 | weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); |
295 | |
296 | // Create the descriptor |
297 | mkldnn_rnn_desc_t rd = zero_rnn_desc(); |
298 | |
299 | rd.primitive_kind = primitive_kind::rnn; |
300 | rd.prop_kind = prop_kind; |
301 | rd.cell_desc = *rnn_cell_desc; |
302 | rd.direction = direction; |
303 | rd.src_layer_desc = copy_maybe_null(src_layer_desc); |
304 | rd.src_iter_desc = copy_maybe_null(src_iter_desc); |
305 | rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); |
306 | rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); |
307 | rd.bias_desc = copy_maybe_null(bias_desc); |
308 | rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); |
309 | rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); |
310 | |
311 | *rnn_desc = rd; |
312 | |
313 | return success; |
314 | } |
315 | |
316 | status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, |
317 | prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, |
318 | const rnn_direction_t direction, const memory_desc_t *src_layer_desc, |
319 | const memory_desc_t *src_iter_desc, |
320 | const memory_desc_t *weights_layer_desc, |
321 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
322 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
323 | const memory_desc_t *diff_src_layer_desc, |
324 | const memory_desc_t *diff_src_iter_desc, |
325 | const memory_desc_t *diff_weights_layer_desc, |
326 | const memory_desc_t *diff_weights_iter_desc, |
327 | const memory_desc_t *diff_bias_desc, |
328 | const memory_desc_t *diff_dst_layer_desc, |
329 | const memory_desc_t *diff_dst_iter_desc) { |
330 | bool args_ok = true |
331 | && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, |
332 | dst_layer_desc, diff_src_layer_desc, |
333 | diff_weights_layer_desc, diff_weights_iter_desc, |
334 | diff_dst_layer_desc); |
335 | if (!args_ok) |
336 | return invalid_arguments; |
337 | |
338 | auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { |
339 | return is_zero_md(a_md) == is_zero_md(b_md); |
340 | }; |
341 | |
342 | args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) |
343 | && xnor_md(dst_iter_desc, diff_dst_iter_desc) |
344 | && xnor_md(src_iter_desc, diff_src_iter_desc); |
345 | if (!args_ok) |
346 | return invalid_arguments; |
347 | |
348 | //check dimensions consistency |
349 | int L = weights_layer_desc->dims[0]; |
350 | int T = src_layer_desc->dims[0]; |
351 | int N = src_layer_desc->dims[1]; |
352 | const int D = one_of(direction, mkldnn_unidirectional_left2right, |
353 | mkldnn_unidirectional_right2left) ? |
354 | 1 : |
355 | 2; |
356 | int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); |
357 | int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); |
358 | int SLC = src_layer_desc->dims[2]; |
359 | int SIC = weights_iter_desc->dims[2]; |
360 | int DLC = dst_layer_desc->dims[2]; |
361 | int DIC = weights_layer_desc->dims[4]; |
362 | |
363 | status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, |
364 | G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, |
365 | weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, |
366 | dst_iter_desc); |
367 | if (st != success) return st; |
368 | |
369 | st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, |
370 | G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, |
371 | diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, |
372 | diff_dst_layer_desc, diff_dst_iter_desc); |
373 | if (st != success) return st; |
374 | |
375 | mkldnn_rnn_desc_t rd = zero_rnn_desc(); |
376 | |
377 | rd.primitive_kind = primitive_kind::rnn; |
378 | rd.prop_kind = prop_kind; |
379 | rd.cell_desc = *rnn_cell_desc; |
380 | rd.direction = direction; |
381 | |
382 | rd.src_layer_desc = copy_maybe_null(src_layer_desc); |
383 | rd.src_iter_desc = copy_maybe_null(src_iter_desc); |
384 | rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); |
385 | rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); |
386 | rd.bias_desc = copy_maybe_null(bias_desc); |
387 | rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); |
388 | rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); |
389 | rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); |
390 | rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); |
391 | rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); |
392 | rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); |
393 | rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); |
394 | rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); |
395 | rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); |
396 | |
397 | *rnn_desc = rd; |
398 | |
399 | return success; |
400 | } |
401 | |