1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef MKLDNN_TYPES_H
18#define MKLDNN_TYPES_H
19
20#ifdef __cplusplus
21extern "C" {
22#endif
23
24#ifndef DOXYGEN_SHOULD_SKIP_THIS
25#include <stddef.h>
26#include <stdint.h>
27#endif
28
29/** @addtogroup c_api C API
30 * @{
31 *
32 * @addtogroup c_api_types Types
33 * @{
34 *
35 * @addtogroup c_api_types_generic Generic
36 * @{ */
37
38/** Intel(R) MKL-DNN Version type */
39typedef struct {
40 int major;
41 int minor;
42 int patch;
43 const char *hash;
44} mkldnn_version_t;
45
46/** Status values returned by Intel(R) MKL-DNN functions. */
47typedef enum {
48 /** The operation was successful */
49 mkldnn_success = 0,
50 /** The operation failed due to an out-of-memory condition */
51 mkldnn_out_of_memory = 1,
52 /** The operation failed and should be retried */
53 mkldnn_try_again = 2,
54 /** The operation failed because of incorrect function arguments */
55 mkldnn_invalid_arguments = 3,
56 /** The operation failed because a primitive was not ready for execution */
57 mkldnn_not_ready = 4,
58 /** The operation failed because requested functionality is not implemented
59 */
60 mkldnn_unimplemented = 5,
61 /** Primitive iterator passed over last primitive descriptor */
62 mkldnn_iterator_ends = 6,
63 /** Primitive or engine failed on execution */
64 mkldnn_runtime_error = 7,
65 /** Queried element is not required for given primitive */
66 mkldnn_not_required = 8,
67} mkldnn_status_t;
68
69/** Data type specification */
70typedef enum {
71 /** Undefined data type, used for empty memory descriptors. */
72 mkldnn_data_type_undef = 0,
73 /** 32-bit/single-precision floating point. */
74 mkldnn_f32 = 1,
75 /** 32-bit signed integer. */
76 mkldnn_s32 = 2,
77 /** 8-bit signed integer. */
78 mkldnn_s8 = 3,
79 /** 8-bit unsigned integer. */
80 mkldnn_u8 = 4,
81} mkldnn_data_type_t;
82
83/** Memory format kind */
84typedef enum {
85 /** Undefined memory format, used for empty memory descriptors. */
86 mkldnn_format_kind_undef = 0,
87 /** Unspecified format. The primitive selects a format automatically. */
88 mkldnn_format_kind_any,
89 /** A tensor in a generic format described by the stride and blocking
90 * values in each dimension. See #mkldnn_blocking_desc_t for more
91 * information. */
92 mkldnn_blocked,
93 /** Weights format used in 8bit Winograd convolution */
94 mkldnn_format_kind_wino,
95 /** Packed weights format used in RNN */
96 mkldnn_format_kind_rnn_packed,
97} mkldnn_format_kind_t;
98
99/** Memory format tag specification.
100 *
101 * Intel MKL-DNN formats describe physical data layout. The physical layout
102 * is described as a sequence of the dimensions as they are laid out in the
103 * memory (from the outer-most to the inner-most). Note that this order
104 * doesn't affect the logical order of the dimensions that is kept in the
105 * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the
106 * dimensions is specified by the type of tensor.
107 *
108 * For example, CNN 5D tensor always has its logical dimensions in the order
109 * `(batch, channels, depth, height, width)`, while the physical layout might be
110 * #mkldnn_ncdhw or #mkldnn_ndhwc:
111 *
112 * ~~~cpp
113 * int batch = 2, channels = 16, depth = 13, height = 13, width = 13;
114 *
115 * int ndims = 5; // 5D tensor
116 * mkldnn_dims_t dims = {batch, channels, depth, height, width};
117 * mkldnn_memory_desc_t data_in_ncdhw;
118 * mkldnn_memory_desc_init_by_tag(
119 * &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw);
120 *
121 * // note that in both cases dims passed are the same
122 * mkldnn_memory_desc_t data_in_ndhwc;
123 * mkldnn_memory_desc_init_by_tag(
124 * &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc);
125 * ~~~
126 *
127 * The following notation applies to memory format names:
128 * - @c 'n' denotes the mini-batch dimension
129 * - @c 'c' denotes a channels dimension
130 * - When there are multiple channel dimensions (for example, in convolution
131 * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output
132 * channels
133 * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
134 * respectively
135 * - Upper-case letters indicate that the data is laid out in blocks
136 * for a particular dimension. In such cases, the format name contains both
137 * upper- and lower-case letters for that dimension with a lower-case letter
138 * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a
139 * format where the outermost dimension is mini-batch, followed by the
140 * channel block number, followed by the spatial height and width, and
141 * finally followed by 8-element channel blocks.
142 *
143 * @note
144 * Channel designations can be different. For example, both the @c
145 * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D
146 * tensor.
147 *
148 * @sa @ref understanding_memory_formats
149 */
150typedef enum {
151 /** Undefined memory format tag */
152 mkldnn_format_tag_undef = 0,
153 /** Undefined memory format tag.
154 * The primitive selects a format automatically. */
155 mkldnn_format_tag_any,
156
157 /* Semantic agnostic section */
158 /* The physical order of dimensions is defined by the permutation of the
159 * characters, assuming that ab..z defines the natural order.
160 */
161
162 /* Plain formats */
163
164 mkldnn_a,
165 mkldnn_ab,
166 mkldnn_abc,
167 mkldnn_abcd,
168 mkldnn_abcde,
169 mkldnn_abcdef,
170 mkldnn_abdec,
171 mkldnn_acb,
172 mkldnn_acbde,
173 mkldnn_acdb,
174 mkldnn_acdeb,
175 mkldnn_ba,
176 mkldnn_bac,
177 mkldnn_bacd,
178 mkldnn_bcda,
179 mkldnn_cba,
180 mkldnn_cdba,
181 mkldnn_cdeba,
182 mkldnn_decab,
183
184 /* Opaque blocked formats */
185
186 mkldnn_Abc16a,
187 mkldnn_ABc16a16b,
188 mkldnn_aBc16b,
189 mkldnn_ABc16b16a,
190 mkldnn_Abc4a,
191 mkldnn_aBc4b,
192 mkldnn_ABc4b16a4b,
193 mkldnn_ABc4b4a,
194 mkldnn_ABc8a16b2a,
195 mkldnn_ABc8a8b,
196 mkldnn_aBc8b,
197 mkldnn_ABc8b16a2b,
198 mkldnn_ABc8b8a,
199 mkldnn_Abcd16a,
200 mkldnn_ABcd16a16b,
201 mkldnn_aBcd16b,
202 mkldnn_ABcd16b16a,
203 mkldnn_aBCd16b16c,
204 mkldnn_aBCd16c16b,
205 mkldnn_Abcd4a,
206 mkldnn_aBcd4b,
207 mkldnn_ABcd4b16a4b,
208 mkldnn_ABcd4b4a,
209 mkldnn_aBCd4c16b4c,
210 mkldnn_aBCd4c4b,
211 mkldnn_ABcd8a16b2a,
212 mkldnn_ABcd8a8b,
213 mkldnn_aBcd8b,
214 mkldnn_ABcd8b16a2b,
215 mkldnn_aBCd8b16c2b,
216 mkldnn_ABcd8b8a,
217 mkldnn_aBCd8b8c,
218 mkldnn_aBCd8c16b2c,
219 mkldnn_aBCd8c8b,
220 mkldnn_Abcde16a,
221 mkldnn_ABcde16a16b,
222 mkldnn_aBcde16b,
223 mkldnn_ABcde16b16a,
224 mkldnn_aBCde16b16c,
225 mkldnn_aBCde16c16b,
226 mkldnn_aBCde2c8b4c,
227 mkldnn_Abcde4a,
228 mkldnn_aBcde4b,
229 mkldnn_ABcde4b4a,
230 mkldnn_aBCde4b4c,
231 mkldnn_aBCde4c16b4c,
232 mkldnn_aBCde4c4b,
233 mkldnn_Abcde8a,
234 mkldnn_ABcde8a8b,
235 mkldnn_aBcde8b,
236 mkldnn_ABcde8b16a2b,
237 mkldnn_aBCde8b16c2b,
238 mkldnn_ABcde8b8a,
239 mkldnn_aBCde8b8c,
240 mkldnn_aBCde8c16b2c,
241 mkldnn_aBCde8c8b,
242 mkldnn_aBcdef16b,
243 mkldnn_aBCdef16b16c,
244 mkldnn_aBCdef16c16b,
245 mkldnn_aBcdef4b,
246 mkldnn_aBCdef4c4b,
247 mkldnn_aBCdef8b8c,
248 mkldnn_aBCdef8c16b2c,
249 mkldnn_aBCdef8c8b,
250 mkldnn_aBdc16b,
251 mkldnn_aBdc4b,
252 mkldnn_aBdc8b,
253 mkldnn_aBdec16b,
254 mkldnn_aBdec4b,
255 mkldnn_aBdec8b,
256 mkldnn_aBdefc16b,
257 mkldnn_aBdefc4b,
258 mkldnn_aBdefc8b,
259 mkldnn_Acb16a,
260 mkldnn_Acb4a,
261 mkldnn_Acb8a,
262 mkldnn_aCBd16b16c,
263 mkldnn_aCBde16b16c,
264 mkldnn_Acdb16a,
265 mkldnn_Acdb4a,
266 mkldnn_Acdb8a,
267 mkldnn_Acdeb16a,
268 mkldnn_Acdeb4a,
269 mkldnn_Acdeb8a,
270 mkldnn_BAc16a16b,
271 mkldnn_BAcd16a16b,
272
273 /** Just a sentinel, not real memory format tag. Must be changed after new
274 * format tag is added. */
275 mkldnn_format_tag_last,
276
277 /* Aliases */
278
279 mkldnn_x = mkldnn_a,
280 mkldnn_nc = mkldnn_ab,
281 mkldnn_cn = mkldnn_ba,
282 mkldnn_ncw = mkldnn_abc,
283 mkldnn_nwc = mkldnn_acb,
284 mkldnn_nchw = mkldnn_abcd,
285 mkldnn_nhwc = mkldnn_acdb,
286 mkldnn_chwn = mkldnn_bcda,
287 mkldnn_ncdhw = mkldnn_abcde,
288 mkldnn_ndhwc = mkldnn_acdeb,
289
290 mkldnn_oi = mkldnn_ab,
291 mkldnn_io = mkldnn_ba,
292 mkldnn_oiw = mkldnn_abc,
293 mkldnn_wio = mkldnn_cba,
294 mkldnn_oihw = mkldnn_abcd,
295 mkldnn_hwio = mkldnn_cdba,
296 mkldnn_ihwo = mkldnn_bcda,
297 mkldnn_iohw = mkldnn_bacd,
298 mkldnn_oidhw = mkldnn_abcde,
299 mkldnn_dhwio = mkldnn_cdeba,
300 mkldnn_goiw = mkldnn_abcd,
301 mkldnn_goihw = mkldnn_abcde,
302 mkldnn_hwigo = mkldnn_decab,
303 mkldnn_giohw = mkldnn_acbde,
304 mkldnn_goidhw = mkldnn_abcdef,
305
306 /** 3D RNN data tensor in the format (seq_length, batch, input channels). */
307 mkldnn_tnc = mkldnn_abc,
308 /** 3D RNN data tensor in the format (batch, seq_length, input channels). */
309 mkldnn_ntc = mkldnn_bac,
310 /** 5D RNN states tensor in the format (num_layers, num_directions,
311 * num_states, batch, state channels). */
312 mkldnn_ldsnc = mkldnn_abcde,
313 /** 5D RNN weights tensor in the format (num_layers, num_directions,
314 * input_channels, num_gates, output_channels).
315 *
316 * - For LSTM cells, the gates order is input, forget, candidate
317 * and output gate.
318 * - For GRU cells, the gates order is update, reset and output gate. */
319 mkldnn_ldigo = mkldnn_abcde,
320 /** 5D RNN weights tensor in the format (num_layers, num_directions,
321 * num_gates, output_channels, input_channels).
322 *
323 * - For LSTM cells, the gates order is input, forget, candidate
324 * and output gate.
325 * - For GRU cells, the gates order is update, reset and output gate. */
326 mkldnn_ldgoi = mkldnn_abdec,
327 /** 4D RNN bias tensor in the format (num_layers, num_directions,
328 * num_gates, output_channels).
329 *
330 * - For LSTM cells, the gates order is input, forget, candidate
331 * and output gate.
332 * - For GRU cells, the gates order is update, reset and output gate. */
333 mkldnn_ldgo = mkldnn_abcd,
334
335 /* Opaque data types, are not to be used explicitly */
336
337 /* data */
338 mkldnn_nCdhw16c = mkldnn_aBcde16b,
339 mkldnn_nCdhw4c = mkldnn_aBcde4b,
340 mkldnn_nCdhw8c = mkldnn_aBcde8b,
341 mkldnn_nChw16c = mkldnn_aBcd16b,
342 mkldnn_nChw4c = mkldnn_aBcd4b,
343 mkldnn_nChw8c = mkldnn_aBcd8b,
344 mkldnn_nCw16c = mkldnn_aBc16b,
345 mkldnn_nCw4c = mkldnn_aBc4b,
346 mkldnn_nCw8c = mkldnn_aBc8b,
347
348 /* weights, 3D */
349 mkldnn_IOw16o16i = mkldnn_BAc16a16b,
350 mkldnn_OIw16i16o = mkldnn_ABc16b16a,
351 mkldnn_OIw16o16i = mkldnn_ABc16a16b,
352 mkldnn_Oiw16o = mkldnn_Abc16a,
353 mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b,
354 mkldnn_OIw4i4o = mkldnn_ABc4b4a,
355 mkldnn_Oiw4o = mkldnn_Abc4a,
356 mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b,
357 mkldnn_OIw8i8o = mkldnn_ABc8b8a,
358 mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a,
359 mkldnn_OIw8o8i = mkldnn_ABc8a8b,
360 mkldnn_Owi16o = mkldnn_Acb16a,
361 mkldnn_Owi4o = mkldnn_Acb4a,
362 mkldnn_Owi8o = mkldnn_Acb8a,
363
364 /* weights, 4D */
365 mkldnn_IOhw16o16i = mkldnn_BAcd16a16b,
366 mkldnn_Ohwi16o = mkldnn_Acdb16a,
367 mkldnn_Ohwi4o = mkldnn_Acdb4a,
368 mkldnn_Ohwi8o = mkldnn_Acdb8a,
369 mkldnn_OIhw16i16o = mkldnn_ABcd16b16a,
370 mkldnn_OIhw16o16i = mkldnn_ABcd16a16b,
371 mkldnn_Oihw16o = mkldnn_Abcd16a,
372 mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b,
373 mkldnn_OIhw4i4o = mkldnn_ABcd4b4a,
374 mkldnn_Oihw4o = mkldnn_Abcd4a,
375 mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b,
376 mkldnn_OIhw8i8o = mkldnn_ABcd8b8a,
377 mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a,
378 mkldnn_OIhw8o8i = mkldnn_ABcd8a8b,
379
380 /* weights, 5D */
381 mkldnn_Odhwi16o = mkldnn_Acdeb16a,
382 mkldnn_Odhwi4o = mkldnn_Acdeb4a,
383 mkldnn_Odhwi8o = mkldnn_Acdeb8a,
384 mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a,
385 mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b,
386 mkldnn_Oidhw16o = mkldnn_Abcde16a,
387 mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a,
388 mkldnn_Oidhw4o = mkldnn_Abcde4a,
389 mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b,
390 mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a,
391 mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b,
392
393 /* weights w/ groups, 3D */
394 mkldnn_Goiw16g = mkldnn_Abcd16a,
395 mkldnn_gIOw16o16i = mkldnn_aCBd16b16c,
396 mkldnn_gOIw16i16o = mkldnn_aBCd16c16b,
397 mkldnn_gOIw16o16i = mkldnn_aBCd16b16c,
398 mkldnn_gOiw16o = mkldnn_aBcd16b,
399 mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c,
400 mkldnn_gOIw4i4o = mkldnn_aBCd4c4b,
401 mkldnn_gOiw4o = mkldnn_aBcd4b,
402 mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c,
403 mkldnn_gOIw8i8o = mkldnn_aBCd8c8b,
404 mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b,
405 mkldnn_gOIw8o8i = mkldnn_aBCd8b8c,
406 mkldnn_gOwi16o = mkldnn_aBdc16b,
407 mkldnn_gOwi4o = mkldnn_aBdc4b,
408 mkldnn_gOwi8o = mkldnn_aBdc8b,
409
410 /* weights w/ groups, 4D */
411 mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c,
412 mkldnn_gOhwi16o = mkldnn_aBdec16b,
413 mkldnn_gOhwi4o = mkldnn_aBdec4b,
414 mkldnn_gOhwi8o = mkldnn_aBdec8b,
415 mkldnn_Goihw16g = mkldnn_Abcde16a,
416 mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b,
417 mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c,
418 mkldnn_gOihw16o = mkldnn_aBcde16b,
419 mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c,
420 mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c,
421 mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b,
422 mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c,
423 mkldnn_gOihw4o = mkldnn_aBcde4b,
424 mkldnn_Goihw8g = mkldnn_Abcde8a,
425 mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c,
426 mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b,
427 mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b,
428 mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c,
429
430 /* weights w/ groups, 6D */
431 mkldnn_gOdhwi16o = mkldnn_aBdefc16b,
432 mkldnn_gOdhwi4o = mkldnn_aBdefc4b,
433 mkldnn_gOdhwi8o = mkldnn_aBdefc8b,
434 mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b,
435 mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c,
436 mkldnn_gOidhw16o = mkldnn_aBcdef16b,
437 mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b,
438 mkldnn_gOidhw4o = mkldnn_aBcdef4b,
439 mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c,
440 mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b,
441 mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c,
442} mkldnn_format_tag_t;
443
444/** Kinds of padding. Define how to interpret the data in padding regions. */
445typedef enum {
446 /** The data in padding regions is zero. */
447 mkldnn_padding_zero,
448} mkldnn_padding_kind_t;
449
450/** Kinds of propagation. */
451typedef enum {
452 /* TODO: suggest renames */
453 /** Undefined propagation type. */
454 mkldnn_prop_kind_undef = 0,
455 /** Forward data propagation (training mode). In this mode primitives
456 * perform computations necessary for subsequent backward propagation. */
457 mkldnn_forward_training = 64,
458 /** Forward data propagation (inference mode). In this mode primitives
459 * perform only computations that are necessary for inference and omit
460 * computations that are necessary only for backward propagation. */
461 mkldnn_forward_inference = 96,
462 /** Forward data propagation (alias for @c mkldnn_forward_inference) */
463 mkldnn_forward_scoring = mkldnn_forward_inference,
464 /** Forward data propagation (alias for @c mkldnn_forward_training) */
465 mkldnn_forward = mkldnn_forward_training,
466 /** Backward propagation (with respect to all parameters */
467 mkldnn_backward = 128,
468 /** Backward data propagation */
469 mkldnn_backward_data = 160,
470 /** Backward weights propagation */
471 mkldnn_backward_weights = 192,
472 /** Backward bias propagation */
473 mkldnn_backward_bias = 193,
474} mkldnn_prop_kind_t;
475
476/** Kinds of primitives. Used to implement a way to extend the library with new
477 * primitives without changing the ABI. */
478typedef enum {
479 /** Undefined primitive (XXX: why do we have it?). */
480 mkldnn_undefined_primitive,
481 /** A reorder primitive.*/
482 mkldnn_reorder,
483 /** A shuffle primitive.*/
484 mkldnn_shuffle,
485 /** A (out-of-place) concat primitive. */
486 mkldnn_concat,
487 /** A sum primitive. */
488 mkldnn_sum,
489 /** A convolution primitive. */
490 mkldnn_convolution,
491 /** A deconvolution primitive. */
492 mkldnn_deconvolution,
493 /** An element-wise primitive. */
494 mkldnn_eltwise,
495 /** A Softmax primitive. */
496 mkldnn_softmax,
497 /** A pooling primitive. */
498 mkldnn_pooling,
499 /** An LRN primitive. */
500 mkldnn_lrn,
501 /** An batch normalization primitive. */
502 mkldnn_batch_normalization,
503 /** An inner product primitive. */
504 mkldnn_inner_product,
505 /** A rnn primitive. */
506 mkldnn_rnn,
507} mkldnn_primitive_kind_t;
508
509/** Kinds of algorithms. */
510typedef enum {
511 mkldnn_alg_kind_undef,
512 /** Direct convolution */
513 mkldnn_convolution_direct = 0x1,
514 /** Winograd convolution */
515 mkldnn_convolution_winograd = 0x2,
516 /** Convolution algorithm(either direct or Winograd) is chosen just in time **/
517 mkldnn_convolution_auto = 0x3,
518 /** Direct deconvolution */
519 mkldnn_deconvolution_direct = 0xa,
520 /** Winograd deconvolution */
521 mkldnn_deconvolution_winograd = 0xb,
522 /** Eltwise: ReLU */
523 mkldnn_eltwise_relu = 0x1f,
524 /** Eltwise: hyperbolic tangent non-linearity (tanh) */
525 mkldnn_eltwise_tanh = 0x2f,
526 /** Eltwise: parametric exponential linear unit (elu) */
527 mkldnn_eltwise_elu = 0x3f,
528 /** Eltwise: square */
529 mkldnn_eltwise_square = 0x4f,
530 /** Eltwise: abs */
531 mkldnn_eltwise_abs = 0x5f,
532 /** Eltwise: square root */
533 mkldnn_eltwise_sqrt = 0x6f,
534 /** Eltwise: linear */
535 mkldnn_eltwise_linear = 0x7f,
536 /** Eltwise: bounded_relu */
537 mkldnn_eltwise_bounded_relu = 0x8f,
538 /** Eltwise: soft_relu */
539 mkldnn_eltwise_soft_relu = 0x9f,
540 /** Eltwise: logistic */
541 mkldnn_eltwise_logistic = 0xaf,
542 /** Max pooling */
543 mkldnn_pooling_max = 0x1ff,
544 /** Average pooling include padding */
545 mkldnn_pooling_avg_include_padding = 0x2ff,
546 /** Average pooling exclude padding */
547 mkldnn_pooling_avg_exclude_padding = 0x3ff,
548 mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding,
549 /** Local response normalization (LRN) across multiple channels */
550 mkldnn_lrn_across_channels = 0xaff,
551 /** LRN within a single channel */
552 mkldnn_lrn_within_channel = 0xbff,
553 /** RNN cell */
554 mkldnn_vanilla_rnn = 0x1fff,
555 /** LSTM cell */
556 mkldnn_vanilla_lstm = 0x2fff,
557 /** GRU cell */
558 mkldnn_vanilla_gru = 0x3fff,
559 /** GRU cell with linear before reset
560 *
561 * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru
562 * in how the new memory gate is calculated:
563 * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f]
564 * Primitive expects 4 biases on input:
565 * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
566 * */
567 mkldnn_gru_linear_before_reset = 0x4fff,
568} mkldnn_alg_kind_t;
569
570/** Flags for batch-normalization primititve. */
571typedef enum {
572 /** Use global statistics
573 *
574 * If specified
575 * - on forward propagation use mean and variance provided by user (input)
576 * - on backward propagation reduces the amount of computations, since
577 * mean and variance are considered as constants
578 *
579 * If not specified:
580 * - on forward propagation mean and variance are computed and stored in
581 * output
582 * - on backward propagation compute full derivative wrt to data
583 */
584 mkldnn_use_global_stats = 0x1U,
585 /** Use scale and shift parameters
586 *
587 * If specified:
588 * - on forward propagation use scale and shift (aka scale and bias) for
589 * the batch normalization results
590 * - on backward propagation (for prop_kind == #mkldnn_backward) compute
591 * diff wrt to scale and shift (hence one extra output used)
592 *
593 * If no specified:
594 * - on backward propagation prop_kind == #mkldnn_backward_data has the
595 * same behavior as prop_kind == #mkldnn_backward
596 */
597 mkldnn_use_scaleshift = 0x2U,
598 /** Fuse with ReLU
599 *
600 * If specified:
601 * - on inference this option behaves the same as if the primitive were
602 * fused with ReLU via post ops API
603 * - on training primitive requires workspace (required to be able to
604 * perform backward pass)
605 */
606 mkldnn_fuse_bn_relu = 0x4U,
607} mkldnn_batch_normalization_flag_t;
608
609/** @} */
610
611/** @addtogroup c_api_types_memory Memory
612 * @{ */
613
614/** Maximum number of dimensions a tensor can have. Only restricts the amount
615 * of space used for the tensor description. Individual computational
616 * primitives may support only tensors of certain dimensions. */
617#define MKLDNN_MAX_NDIMS 12
618
619/** A type to describe tensor dimension. */
620typedef int64_t mkldnn_dim_t;
621
622/** A type to describe tensor dimensions. */
623typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS];
624
625/** A type to describe strides within a tensor. */
626typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS];
627
628/** Generic description of blocked data layout for most memory formats.
629 *
630 * @sa @ref understanding_memory_formats */
631typedef struct {
632 /** The strides between the outermost blocks.
633 * In case of plain (non-blocked) formats the strides between dimensions. */
634 mkldnn_dims_t strides;
635 /* Innermost section
636 * ASSUMPTION: the innermost blocks are always dense */
637 /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */
638 int inner_nblks;
639 /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */
640 mkldnn_dims_t inner_blks;
641 /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of
642 * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */
643 mkldnn_dims_t inner_idxs;
644} mkldnn_blocking_desc_t;
645
646typedef enum {
647 /** Undefined memory format, used for empty memory descriptors. */
648 mkldnn_wino_undef = 0,
649 /** Tensors of weights for 2x3 winograd convolutions. */
650 mkldnn_wino_wei_aaOIoi,
651 mkldnn_wino_wei_aaOio,
652 mkldnn_wino_wei_aaOBiOo,
653 /** Tensor of weights for 4x3 convolution. */
654 mkldnn_wino_wei_OBaaIBOIio
655} mkldnn_wino_memory_format_t;
656
657/** Description of tensor of weights for winograd 2x3 convolution. */
658typedef struct {
659 mkldnn_wino_memory_format_t wino_format;
660 int r;
661 int alpha;
662 int ic;
663 int oc;
664 int ic_block;
665 int oc_block;
666 int ic2_block;
667 int oc2_block;
668 float adj_scale;
669 size_t size;
670} mkldnn_wino_desc_t;
671
672typedef enum {
673 mkldnn_packed_format_undef = 0,
674 mkldnn_ldigo_p,
675 mkldnn_ldgoi_p
676} mkldnn_rnn_packed_memory_format_t;
677
678/* Maximum number of parts of RNN weights tensor that require separate
679 * computation. */
680#define MKLDNN_RNN_MAX_N_PARTS 4
681
682/** Description of tensor of packed weights for rnn. */
683typedef struct {
684 mkldnn_rnn_packed_memory_format_t format;
685 int n_parts;
686 int n;
687 int parts[MKLDNN_RNN_MAX_N_PARTS];
688 size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS];
689 size_t offset_compensation;
690 size_t size;
691} mkldnn_rnn_packed_desc_t;
692
693typedef enum {
694 mkldnn_memory_extra_flag_none = 0x0U,
695 /** Indicates the weights have an additional buffer, that depends on the
696 * @p compensation_mask.
697 *
698 * For instance, in 4D case with the compensation mask equals (1 << 0)
699 * the additional buffer would consist of OC values:
700 * O[oc : 0,OC] =
701 * -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) }
702 */
703 mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U,
704 mkldnn_memory_extra_flag_scale_adjust = 0x2U,
705} mkldnn_memory_extra_flags_t;
706
707/** Description of extra information stored in memory */
708typedef struct {
709 /** The flags contain arbitrary extra information, such as compensation.
710 * @sa mkldnn_memory_extra_flags_t */
711 uint64_t flags;
712 /** Compensation mask */
713 int compensation_mask;
714 /** Scale applied to the data */
715 float scale_adjust;
716 /** For future backwards compatibility */
717 char reserved[64];
718} mkldnn_memory_extra_desc_t;
719
720/** Memory descriptor. The description is based on a number of dimensions,
721 * dimensions themselves, plus information about elements type and memory
722 * format. Additionally, contains format-specific descriptions of the data
723 * layout. */
724typedef struct {
725 /** Number of dimensions */
726 int ndims;
727 /** Dimensions in the following order:
728 * - CNN data tensors: mini-batch, channel, spatial
729 * (<code>{N, C, [[D,] H,] W}</code>)
730 * - CNN weight tensors: group (optional), output channel, input channel,
731 * spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
732 * - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
733 * or layers, directions, states, mini-batch, channels (<code>{L, D, S, N, C}</code>)
734 * - RNN weight tensor: layers, directions, input channel, gates, output channels
735 * (<code>{L, D, I, G, O}</code>).
736 *
737 * @note
738 * The order of dimensions does not depend on the memory format, so
739 * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc
740 * the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
741 */
742 mkldnn_dims_t dims;
743 /** Data type of the tensor elements. */
744 mkldnn_data_type_t data_type;
745
746 /** Size of the data including padding in each dimension. */
747 mkldnn_dims_t padded_dims;
748 /** Per-dimension offset from the padding to actual data, the top-level
749 * tensor with offsets applied must lie within the padding area. */
750 mkldnn_dims_t padded_offsets;
751
752 /** Offset from memory origin to the current block, non-zero only in
753 * a description of a memory sub-block. */
754 mkldnn_dim_t offset0;
755
756 /** Memory format kind. */
757 mkldnn_format_kind_t format_kind;
758 union {
759 /** Description of the data layout for memory formats that use
760 * blocking. */
761 mkldnn_blocking_desc_t blocking;
762 /** Tensor of weights for integer 8bit winograd convolution. */
763 mkldnn_wino_desc_t wino_desc;
764 /** Tensor of packed weights for RNN. */
765 mkldnn_rnn_packed_desc_t rnn_packed_desc;
766 /* ... other descriptions possible */
767 } format_desc;
768
769 mkldnn_memory_extra_desc_t extra;
770} mkldnn_memory_desc_t;
771
772/** @struct mkldnn_memory
773 * An opaque structure to describe a memory. */
774struct mkldnn_memory;
775
776/** A memory handle. */
777typedef struct mkldnn_memory *mkldnn_memory_t;
778
779/** A constant memory handle. */
780typedef const struct mkldnn_memory *const_mkldnn_memory_t;
781
782#define MKLDNN_NATIVE_HANDLE_NONE (NULL)
783#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1)
784
785/** @} */
786
787/** @addtogroup c_api_types_op_descs Operation descriptors
788 * @{*/
789
790/** A pointer to any of the operation descriptors. */
791typedef void *mkldnn_op_desc_t;
792/** A pointer to any of the operation descriptors (constant variant). */
793typedef const void *const_mkldnn_op_desc_t;
794
795/** A descriptor of a convolution operation. */
796typedef struct {
797 /** The kind of primitive. Used for self-identifying the primitive
798 * descriptor. Must be #mkldnn_convolution. */
799 mkldnn_primitive_kind_t primitive_kind;
800 /** The kind of propagation. Possible values: #mkldnn_forward_training,
801 * #mkldnn_forward_inference, #mkldnn_backward_data,
802 * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
803 mkldnn_prop_kind_t prop_kind;
804 /** The kind of the convolution algorithm. Possible values:
805 * #mkldnn_convolution_direct. */
806 mkldnn_alg_kind_t alg_kind;
807 /** Source memory descriptor. */
808 mkldnn_memory_desc_t src_desc;
809 /** Source gradient memory descriptor. */
810 mkldnn_memory_desc_t diff_src_desc;
811 /** Weights memory descriptor. */
812 mkldnn_memory_desc_t weights_desc;
813 /** Weights gradient memory descriptor. */
814 mkldnn_memory_desc_t diff_weights_desc;
815 /** Bias memory descriptor. */
816 mkldnn_memory_desc_t bias_desc;
817 /** Bias gradient memory descriptor. */
818 mkldnn_memory_desc_t diff_bias_desc;
819 /** Destination memory descriptor. */
820 mkldnn_memory_desc_t dst_desc;
821 /** Destination gradient memory descriptor. */
822 mkldnn_memory_desc_t diff_dst_desc;
823 /** Convolution strides in each spatial dimension. */
824 mkldnn_dims_t strides;
825 /** Convolution dilates in each spatial dimension. */
826 mkldnn_dims_t dilates;
827 /** Padding in each spatial dimension. padding[0] is a padding in the
828 * beginning (@p padding_l), padding[1] is a padding in the end (@p
829 * padding_r). */
830 mkldnn_dims_t padding[2];
831 /** The kind of padding to use. */
832 mkldnn_padding_kind_t padding_kind;
833 /** The accumulator data type. Initialized automatically. */
834 mkldnn_data_type_t accum_data_type;
835} mkldnn_convolution_desc_t;
836
837/** A descriptor of a deconvolution operation. */
838typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t;
839
840/** A descriptor of a shuffle operation. */
841typedef struct {
842 /** The kind of primitive. Used for self-identifying the primitive
843 * descriptor. Must be #mkldnn_convolution. */
844 mkldnn_primitive_kind_t primitive_kind;
845 /** The kind of propagation. Possible values: #mkldnn_forward_training,
846 * #mkldnn_forward_inference, and #mkldnn_backward_data. */
847 mkldnn_prop_kind_t prop_kind;
848 /** Source and destination memory descriptor,
849 * and source and destination gradient memory descriptor. */
850 mkldnn_memory_desc_t data_desc;
851 /** axis for shuffling. */
852 int axis;
853 /** number of groups in group convolution */
854 mkldnn_dim_t group_size;
855} mkldnn_shuffle_desc_t;
856
857/** A descriptor of a element-wise operation. */
858typedef struct {
859 /** The kind of primitive. Used for self-identifying the primitive
860 * descriptor. Must be #mkldnn_eltwise. */
861 mkldnn_primitive_kind_t primitive_kind;
862 /** The kind of propagation. Possible values: #mkldnn_forward_training,
863 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
864 */
865 mkldnn_prop_kind_t prop_kind;
866 /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu,
867 * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square,
868 * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear,
869 * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and
870 * #mkldnn_eltwise_logistic. */
871 mkldnn_alg_kind_t alg_kind;
872 /** Source and destination memory descriptor. */
873 mkldnn_memory_desc_t data_desc;
874 /** Source and destination gradient memory descriptor. */
875 mkldnn_memory_desc_t diff_data_desc;
876 /** Algorithm specific parameter.
877 * Accordance table:
878 * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored
879 * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored
880 * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored
881 * - #mkldnn_eltwise_square: @p alpha and @p beta ignored
882 * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored
883 * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored
884 * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift
885 * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored
886 * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored
887 * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored
888 */
889 float alpha, beta;
890} mkldnn_eltwise_desc_t;
891
892/** A descriptor of a Softmax operation. */
893typedef struct {
894 /** The kind of primitive. Used for self-identifying the primitive
895 * descriptor. Must be #mkldnn_softmax. */
896 mkldnn_primitive_kind_t primitive_kind;
897 /** The kind of propagation. Possible values: #mkldnn_forward_training and
898 * #mkldnn_forward_inference. */
899 mkldnn_prop_kind_t prop_kind;
900 /** Source and destination memory descriptor. */
901 mkldnn_memory_desc_t data_desc;
902 /** Source and Destination of gradient memory descriptor. */
903 mkldnn_memory_desc_t diff_desc;
904 /** The axis along which to perform the softmax. */
905 int softmax_axis;
906} mkldnn_softmax_desc_t;
907
908/** A descriptor of a pooling operation. */
909typedef struct {
910 /** The kind of primitive. Used for self-identifying the primitive
911 * descriptor. Must be #mkldnn_pooling. */
912 mkldnn_primitive_kind_t primitive_kind;
913 /** The kind of propagation. Possible values: #mkldnn_forward_training,
914 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
915 */
916 mkldnn_prop_kind_t prop_kind;
917 /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and
918 * #mkldnn_pooling_avg. */
919 mkldnn_alg_kind_t alg_kind;
920 /** Source memory descriptor. */
921 mkldnn_memory_desc_t src_desc;
922 /** Source gradient memory descriptor. */
923 mkldnn_memory_desc_t diff_src_desc;
924 /** Destination memory descriptor. */
925 mkldnn_memory_desc_t dst_desc;
926 /** Destination gradient memory descriptor. */
927 mkldnn_memory_desc_t diff_dst_desc;
928 /** Pooling kernel strides for spatial dimensions. */
929 mkldnn_dims_t strides;
930 /** Pooling kernel spatial dimensions. */
931 mkldnn_dims_t kernel;
932 /** Padding in each spatial dimension. padding[0] is a padding in the
933 * beginning (@p padding_l), padding[1] is a padding in the end (@p
934 * padding_r). */
935 mkldnn_dims_t padding[2];
936 /** The kind of padding to use. */
937 mkldnn_padding_kind_t padding_kind;
938 /** The accumulator data type. Initialized automatically. */
939 mkldnn_data_type_t accum_data_type;
940} mkldnn_pooling_desc_t;
941
942/** A descriptor of a Local Response Normalization (LRN) operation. */
943typedef struct {
944 /** The kind of primitive. Used for self-identifying the primitive
945 * descriptor. Must be #mkldnn_lrn. */
946 mkldnn_primitive_kind_t primitive_kind;
947 /** The kind of propagation. Possible values: #mkldnn_forward_training,
948 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
949 */
950 mkldnn_prop_kind_t prop_kind;
951 /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and
952 * #mkldnn_lrn_across_channels. */
953 mkldnn_alg_kind_t alg_kind;
954 /** Source and destination memory descriptor. */
955 mkldnn_memory_desc_t data_desc;
956 /** Source and destination gradient memory descriptor. */
957 mkldnn_memory_desc_t diff_data_desc;
958 /** The number of channels to sum over (for cross-channel LRN) or the side
959 * length of the square region to sum over (for within-channel LRN). */
960 mkldnn_dim_t local_size;
961 /** LRN alpha parameter. */
962 float lrn_alpha;
963 /** LRN beta parameter. */
964 float lrn_beta;
965 /** LRN k parameter. */
966 float lrn_k;
967} mkldnn_lrn_desc_t;
968
969/** A descriptor of a Batch Normalization operation. */
970typedef struct {
971 /** The kind of primitive. Used for self-identifying the primitive
972 * descriptor. Must be #mkldnn_batch_normalization. */
973 mkldnn_primitive_kind_t primitive_kind;
974 /** The kind of propagation. Possible values: #mkldnn_forward_training,
975 * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
976 */
977 mkldnn_prop_kind_t prop_kind;
978 /** Source and destination memory descriptor. */
979 mkldnn_memory_desc_t data_desc;
980 /** Source and destination gradient memory descriptor. */
981 mkldnn_memory_desc_t diff_data_desc;
982 /** Scale and shift data and gradient memory descriptors.
983 *
984 * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st
985 * dimension contains gamma parameter, 2-nd dimension contains beta
986 * parameter. */
987 mkldnn_memory_desc_t data_scaleshift_desc;
988 mkldnn_memory_desc_t diff_data_scaleshift_desc;
989 /** Mean and variance data memory descriptors.
990 *
991 * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels].
992 */
993 mkldnn_memory_desc_t mean_desc;
994 mkldnn_memory_desc_t variance_desc;
995 /** Batch normalization epsilon parameter. */
996 float batch_norm_epsilon;
997 unsigned flags;
998} mkldnn_batch_normalization_desc_t;
999
1000/** A descriptor of an inner product operation. */
1001typedef struct {
1002 /** The kind of primitive. Used for self-identifying the primitive
1003 * descriptor. Must be #mkldnn_inner_product. */
1004 mkldnn_primitive_kind_t primitive_kind;
1005 /** The kind of propagation. Possible values: #mkldnn_forward_training,
1006 * #mkldnn_forward_inference, #mkldnn_backward_data,
1007 * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
1008 mkldnn_prop_kind_t prop_kind;
1009 /** Source memory descriptor. */
1010 mkldnn_memory_desc_t src_desc;
1011 /** Source gradient memory descriptor. */
1012 mkldnn_memory_desc_t diff_src_desc;
1013 /** Weights memory descriptor. */
1014 mkldnn_memory_desc_t weights_desc;
1015 /** Weights gradient memory descriptor. */
1016 mkldnn_memory_desc_t diff_weights_desc;
1017 /** Bias memory descriptor. */
1018 mkldnn_memory_desc_t bias_desc;
1019 /** Bias gradient memory descriptor. */
1020 mkldnn_memory_desc_t diff_bias_desc;
1021 /** Destination memory descriptor. */
1022 mkldnn_memory_desc_t dst_desc;
1023 /** Destination gradient memory descriptor. */
1024 mkldnn_memory_desc_t diff_dst_desc;
1025 /** The accumulator data type. Initialized automatically. */
1026 mkldnn_data_type_t accum_data_type;
1027} mkldnn_inner_product_desc_t;
1028
1029/** Flags for RNN cell. */
1030typedef enum {
1031 mkldnn_rnn_cell_with_relu = 0x1U,
1032 mkldnn_rnn_cell_with_clipping = 0x2U,
1033} mkldnn_rnn_cell_flags_t;
1034
1035typedef struct {
1036 /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn,
1037 * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru,
1038 * or #mkldnn_gru_linear_before_reset. */
1039 mkldnn_alg_kind_t cell_kind;
1040 /** Activation function used. Must be either #mkldnn_eltwise_relu or
1041 * #mkldnn_eltwise_tanh. */
1042 mkldnn_alg_kind_t activation_kind;
1043 /** RNN cell flags */
1044 unsigned int flags;
1045 /** @c alpha is a negative slope parameter (used only if
1046 * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */
1047 float alpha;
1048 /** clipping parameter (used only if
1049 * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */
1050 float clipping;
1051} mkldnn_rnn_cell_desc_t;
1052
1053/** A direction of RNN primitive execution. */
1054typedef enum {
1055 /* Unidirectional execution of RNN primitive from left to right. */
1056 mkldnn_unidirectional_left2right,
1057 /* Unidirectional execution of RNN primitive from right to left. */
1058 mkldnn_unidirectional_right2left,
1059 /* Bidirectional execution of RNN primitive with concatenation of the
1060 * results. */
1061 mkldnn_bidirectional_concat,
1062 /* Bidirectional execution of RNN primitive with summation of the
1063 * results. */
1064 mkldnn_bidirectional_sum,
1065 mkldnn_unidirectional = mkldnn_unidirectional_left2right,
1066} mkldnn_rnn_direction_t;
1067
1068/** A descriptor for an RNN operation. */
1069typedef struct {
1070 /** The kind of primitive. Used for self-identifying the primitive
1071 * descriptor. Must be #mkldnn_rnn. */
1072 mkldnn_primitive_kind_t primitive_kind;
1073 /** The kind of propagation. Possible values: #mkldnn_forward_training,
1074 * #mkldnn_forward_inference, and #mkldnn_backward. */
1075 mkldnn_prop_kind_t prop_kind;
1076 /** The RNN cell desc. */
1077 mkldnn_rnn_cell_desc_t cell_desc;
1078 /** The direction of RNN primitive execution. */
1079 mkldnn_rnn_direction_t direction;
1080 /** Source layer memory descriptor. */
1081 mkldnn_memory_desc_t src_layer_desc;
1082 /** Source iteration memory descriptor. */
1083 mkldnn_memory_desc_t src_iter_desc;
1084 /** Weights layer memory descriptor. */
1085 mkldnn_memory_desc_t weights_layer_desc;
1086 /** Weights iteration memory descriptor. */
1087 mkldnn_memory_desc_t weights_iter_desc;
1088 /** Bias memory descriptor. */
1089 mkldnn_memory_desc_t bias_desc;
1090 /** Destination layer memory descriptor. */
1091 mkldnn_memory_desc_t dst_layer_desc;
1092 /** Destination iter memory descriptor. */
1093 mkldnn_memory_desc_t dst_iter_desc;
1094 /** Source gradient layer memory descriptor. */
1095 mkldnn_memory_desc_t diff_src_layer_desc;
1096 /** Source gradient iter memory descriptor. */
1097 mkldnn_memory_desc_t diff_src_iter_desc;
1098 /** Weights gradient layer memory descriptor. */
1099 mkldnn_memory_desc_t diff_weights_layer_desc;
1100 /** Weights gradient iter memory descriptor. */
1101 mkldnn_memory_desc_t diff_weights_iter_desc;
1102 /** Bias gradient memory descriptor. */
1103 mkldnn_memory_desc_t diff_bias_desc;
1104 /** Destination gradient layer memory descriptor. */
1105 mkldnn_memory_desc_t diff_dst_layer_desc;
1106 /** Destination gradient iteration memory descriptor. */
1107 mkldnn_memory_desc_t diff_dst_iter_desc;
1108} mkldnn_rnn_desc_t;
1109
1110/** @} */
1111
1112/** @addtogroup c_api_engine_types Engine
1113 * @{ */
1114
1115/** @brief Kinds of engines. */
1116typedef enum {
1117 /** An unspecified engine. */
1118 mkldnn_any_engine,
1119 /** CPU engine. */
1120 mkldnn_cpu,
1121} mkldnn_engine_kind_t;
1122
1123/** @struct mkldnn_engine
1124 * @brief An opaque structure to describe an engine. */
1125struct mkldnn_engine;
1126/** @brief An engine handle. */
1127typedef struct mkldnn_engine *mkldnn_engine_t;
1128#if 0
1129/* FIXME: looks like this never happens */
1130/** @brief A constant engine handle. */
1131typedef const struct mkldnn_engine *const_mkldnn_engine_t;
1132#endif
1133
1134/** @} */
1135
1136/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators
1137 * @{ */
1138
1139/** @struct mkldnn_primitive_desc_iterator
1140 * @brief An opaque structure to describe a primitive descriptor iterator. */
1141struct mkldnn_primitive_desc_iterator;
1142
1143/** @brief A primitive descriptor iterator handle. */
1144typedef struct mkldnn_primitive_desc_iterator
1145 *mkldnn_primitive_desc_iterator_t;
1146
1147/** @brief A constant primitive descriptor iterator handle. */
1148typedef const struct mkldnn_primitive_desc_iterator
1149 *const_mkldnn_primitive_desc_iterator_t;
1150
1151/** @} */
1152
1153/** @addtogroup c_api_primitive_descs Primitive descriptors
1154 * @{ */
1155
1156/** @struct mkldnn_primitive_desc
1157 * @brief An opaque structure to describe a primitive descriptor. */
1158struct mkldnn_primitive_desc;
1159
1160/** @brief A primitive descriptor handle. */
1161typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t;
1162
1163/** @brief A constant primitive descriptor handle. */
1164typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t;
1165
1166/** @} */
1167
1168/** @addtogroup c_api_primitive_attr Primitive descriptor attributes
1169 * @{ */
1170
1171/** Scratchpad mode */
1172typedef enum {
1173 /** The library manages scratchpad (default) */
1174 mkldnn_scratchpad_mode_library,
1175 /** A user shall query and provide the scratchpad memory to primitives */
1176 mkldnn_scratchpad_mode_user,
1177} mkldnn_scratchpad_mode_t;
1178
1179/** @struct mkldnn_primitive_attr
1180 * @brief An opaque structure for primitive descriptor attributes.
1181 *
1182 * Attributes may contain:
1183 * - output scales (to scale the result prior to storing it to the memory)
1184 */
1185struct mkldnn_primitive_attr;
1186
1187/** @brief A primitive descriptor attributes handle that controls primitive
1188 * behavior. */
1189typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t;
1190
1191/** @brief A constant primitive descriptor attributes handle. */
1192typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t;
1193
1194/** @struct mkldnn_post_ops
1195 * @brief An opaque structure for a chain of post operations.
1196 *
1197 * mkldnn_post_ops can be used to perform some (trivial) operations like
1198 * accumulation or eltwise after certain primitives like convolution.
1199 *
1200 * Post operations might be combined together, making a chain of post
1201 * operations. For instance one can configure convolution followed by
1202 * accumulation followed by eltwise. This might be especially beneficial
1203 * for residual learning blocks.
1204 *
1205 * @warning
1206 * Of course not all combinations are supported, so the user should handle
1207 * errors accordingly.
1208 *
1209 * Supported post operations:
1210 * - accumulation (base primitive: convolution)
1211 * - eltwise (base primitive: convolution)
1212 */
1213struct mkldnn_post_ops;
1214
1215/** @brief A post operation chain handle. */
1216typedef struct mkldnn_post_ops *mkldnn_post_ops_t;
1217
1218/** @brief A constant post operation chain handle. */
1219typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t;
1220
1221/** @} */
1222
1223/** @addtogroup c_api_types_primitive Primitive
1224 * @{ */
1225
1226/** @struct mkldnn_primitive
1227 * An opaque structure to describe a primitive. */
1228struct mkldnn_primitive;
1229/** A primitive handle. */
1230typedef struct mkldnn_primitive *mkldnn_primitive_t;
1231/** A constant primitive handle. */
1232typedef const struct mkldnn_primitive *const_mkldnn_primitive_t;
1233
1234/** @addtogroup c_api_types_arguments Argument indices
1235 * @{ */
1236
1237#define MKLDNN_ARG_SRC_0 1
1238#define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0
1239#define MKLDNN_ARG_SRC_LAYER MKLDNN_ARG_SRC_0
1240#define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0
1241
1242#define MKLDNN_ARG_SRC_1 2
1243#define MKLDNN_ARG_SRC_ITER MKLDNN_ARG_SRC_1
1244
1245#define MKLDNN_ARG_DST_0 17
1246#define MKLDNN_ARG_DST MKLDNN_ARG_DST_0
1247#define MKLDNN_ARG_TO MKLDNN_ARG_DST_0
1248#define MKLDNN_ARG_DST_LAYER MKLDNN_ARG_DST_0
1249
1250#define MKLDNN_ARG_DST_1 18
1251#define MKLDNN_ARG_DST_ITER MKLDNN_ARG_DST_1
1252
1253#define MKLDNN_ARG_WEIGHTS_0 33
1254#define MKLDNN_ARG_WEIGHTS MKLDNN_ARG_WEIGHTS_0
1255#define MKLDNN_ARG_SCALE_SHIFT MKLDNN_ARG_WEIGHTS_0
1256#define MKLDNN_ARG_WEIGHTS_LAYER MKLDNN_ARG_WEIGHTS_0
1257
1258#define MKLDNN_ARG_WEIGHTS_1 34
1259#define MKLDNN_ARG_WEIGHTS_ITER MKLDNN_ARG_WEIGHTS_1
1260
1261#define MKLDNN_ARG_BIAS 41
1262
1263#define MKLDNN_ARG_MEAN 49
1264#define MKLDNN_ARG_VARIANCE 50
1265
1266#define MKLDNN_ARG_WORKSPACE 64
1267#define MKLDNN_ARG_SCRATCHPAD 80
1268
1269#define MKLDNN_ARG_DIFF_SRC_0 129
1270#define MKLDNN_ARG_DIFF_SRC MKLDNN_ARG_DIFF_SRC_0
1271#define MKLDNN_ARG_DIFF_SRC_LAYER MKLDNN_ARG_DIFF_SRC_0
1272
1273#define MKLDNN_ARG_DIFF_SRC_1 130
1274#define MKLDNN_ARG_DIFF_SRC_ITER MKLDNN_ARG_DIFF_SRC_1
1275
1276#define MKLDNN_ARG_DIFF_DST_0 145
1277#define MKLDNN_ARG_DIFF_DST MKLDNN_ARG_DIFF_DST_0
1278#define MKLDNN_ARG_DIFF_DST_LAYER MKLDNN_ARG_DIFF_DST_0
1279
1280#define MKLDNN_ARG_DIFF_DST_1 146
1281#define MKLDNN_ARG_DIFF_DST_ITER MKLDNN_ARG_DIFF_DST_1
1282
1283#define MKLDNN_ARG_DIFF_WEIGHTS_0 161
1284#define MKLDNN_ARG_DIFF_WEIGHTS MKLDNN_ARG_DIFF_WEIGHTS_0
1285#define MKLDNN_ARG_DIFF_SCALE_SHIFT MKLDNN_ARG_DIFF_WEIGHTS_0
1286#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER MKLDNN_ARG_DIFF_WEIGHTS_0
1287
1288#define MKLDNN_ARG_DIFF_WEIGHTS_1 162
1289#define MKLDNN_ARG_DIFF_WEIGHTS_ITER MKLDNN_ARG_DIFF_WEIGHTS_1
1290
1291#define MKLDNN_ARG_DIFF_BIAS 169
1292
1293#define MKLDNN_ARG_MULTIPLE_SRC 1024
1294#define MKLDNN_ARG_MULTIPLE_DST 2048
1295
1296/** @} */
1297
1298/** An auxiliary structure to specify primitive's inputs/outputs at execution
1299 *
1300 * @warning
1301 * With this API it's impossible to preserve constness of memory, so all
1302 * memories are passed w/o const qualifier. However only memories with
1303 * output semantics might be changed during the execution */
1304typedef struct {
1305 int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */
1306 mkldnn_memory_t memory; /**< Input/output memory */
1307} mkldnn_exec_arg_t;
1308
1309/** @} */
1310
1311/** @addtogroup c_api_types_query Queries
1312 * @{ */
1313
1314/** Primitive descriptor query specification
1315 *
1316 * For generic function mkldnn_primitive_desc_query(), the type of result must
1317 * agree with the queried argument. The correspondence table:
1318 * Query | type of result
1319 * --------------------------------------------------------------
1320 * #mkldnn_query_engine | mkldnn_engine_t *
1321 * #mkldnn_query_scratchpad_engine | mkldnn_engine_t *
1322 * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t *
1323 * *_s32 | int *
1324 * *_s64 | mkldnn_dim_t * (same as int64_t *)
1325 * *_f64 | double *
1326 * *_str | const char **
1327 * #mkldnn_query_op_d | const_mkldnn_op_desc_t *
1328 * *_md | const mkldnn_memory_desc_t **
1329 * *_${op}_d | const mkldnn_${op}_desc_t **
1330 * *_pd | const_mkldnn_primitive_desc_t *
1331 *
1332 * @note
1333 * Rule of thumb: all opaque types and structures are returned by
1334 * reference. All numbers are returned by value.
1335 *
1336 * @warning
1337 * All returned references point to constant objects and are valid only
1338 * during the lifetime of the queried primitive descriptor. Returned objects
1339 * must not be destroyed by the user. If you need to keep the object longer
1340 * than the lifetime of the queried primitive descriptor, use
1341 * mkldnn_primitive_desc_clone() to make a copy. */
1342typedef enum {
1343 mkldnn_query_undef = 0, /**< no query */
1344
1345 mkldnn_query_engine, /**< execution engine */
1346 mkldnn_query_primitive_kind, /**< primitive kind */
1347
1348 mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */
1349 mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */
1350
1351 mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */
1352 mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra
1353 (scratch) memory, additional to all
1354 inputs and outputs memory (bytes) */
1355
1356 mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used
1357 for creating scratchpad memory */
1358
1359 mkldnn_query_impl_info_str, /**< implementation name */
1360
1361 /* memory and op descriptor section */
1362 mkldnn_query_some_d = 64, /**< stub */
1363 mkldnn_query_op_d, /**< op descriptor */
1364 mkldnn_query_convolution_d, /**< convolution descriptor */
1365 mkldnn_query_deconvolution_d, /**< deconvolution descriptor */
1366 mkldnn_query_shuffle_d, /**< shuffle descriptor */
1367 mkldnn_query_eltwise_d, /**< eltwise descriptor */
1368 mkldnn_query_softmax_d, /**< softmax descriptor */
1369 mkldnn_query_pooling_d, /**< pooling descriptor */
1370 mkldnn_query_lrn_d, /**< lrn descriptor */
1371 mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */
1372 mkldnn_query_inner_product_d, /**< inner product descriptor */
1373 mkldnn_query_rnn_d, /**< rnn descriptor */
1374
1375 /* memory descriptor section */
1376 mkldnn_query_some_md = 128, /**< stub */
1377 mkldnn_query_src_md, /**< source memory desc */
1378 mkldnn_query_diff_src_md, /**< source gradient memory desc */
1379 mkldnn_query_weights_md, /**< weights memory descriptor desc */
1380 mkldnn_query_diff_weights_md, /**< weights grad. memory desc */
1381 mkldnn_query_dst_md, /**< destination memory desc */
1382 mkldnn_query_diff_dst_md, /**< destination grad. memory desc */
1383 mkldnn_query_workspace_md, /**< workspace memory desc */
1384 mkldnn_query_scratchpad_md, /**< scratchpad memory desc */
1385} mkldnn_query_t;
1386
1387/** @} */
1388
1389/** @addtogroup c_api_types_stream Execution stream
1390 * @{ */
1391
1392/** @brief Stream flags. */
1393typedef enum {
1394 /** A default stream configuration. */
1395 mkldnn_stream_default_flags = 0x0U,
1396} mkldnn_stream_flags_t;
1397
1398/** @struct mkldnn_stream
1399 * An opaque structure to describe an execution stream. */
1400struct mkldnn_stream;
1401/** An execution stream handle. */
1402typedef struct mkldnn_stream *mkldnn_stream_t;
1403/** A constant execution stream handle. */
1404typedef const struct mkldnn_stream *const_mkldnn_stream_t;
1405
1406/** @} */
1407/** @} */
1408/** @} */
1409
1410#ifdef __cplusplus
1411}
1412#endif
1413
1414
1415#endif
1416