1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef MKLDNN_HPP
18#define MKLDNN_HPP
19
20#ifndef DOXYGEN_SHOULD_SKIP_THIS
21#include <stdlib.h>
22#include <memory>
23#include <vector>
24#include <unordered_map>
25#include <algorithm>
26#include <iterator>
27
28#include "mkldnn.h"
29#endif
30
31namespace mkldnn {
32
33/// @addtogroup cpp_api C++ API
34/// @{
35
36/// @addtogroup cpp_api_utils Utils
37/// @{
38
39/// A class that provides the destructor for an Intel(R) MKL-DNN C handle
40template <typename T> class handle_traits {};
41
42/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
43/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
44/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
45/// can be passed by value. This class enables wrapping:
46/// - Newly constructed handles.
47/// @n In this case, the constructed handle uses reference counting provided
48/// by @p std::shared_ptr with a proper deleter function specified through
49/// the @p handle_traits class.
50/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
51/// example, through mkldnn_primitive_get_primitive_desc()).
52/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
53/// deleter because it is assumed that the handle wrapper for the original
54/// object deletes the handle (this model is similar to @p std::weak_ptr).
55template <typename T, typename traits=handle_traits<T>> class handle {
56private:
57 std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58 handle(const handle &&) = delete;
59 handle &operator=(const handle &&other) = delete;
60protected:
61 bool operator==(const T other) const { return other == _data.get(); }
62 bool operator!=(const T other) const { return !(*this == other); }
63public:
64 /// Constructs a C handle wrapper.
65 /// @param t The C handle to wrap.
66 /// @param weak A flag to specify whether to construct a weak wrapper.
67 handle(T t = 0, bool weak = false): _data(0) {
68 reset(t, weak);
69 }
70
71 handle(const handle &other): _data(other._data) {}
72 handle &operator=(const handle &other) {
73 _data = other._data;
74 return *this;
75 }
76 /// Resets the value of a C handle.
77 /// @param t The new value of the C handle.
78 /// @param weak A flag to specify whether the wrapper should be weak.
79 void reset(T t, bool weak = false) {
80 auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81 _data.reset(t, weak ? dummy_destructor : traits::destructor);
82 }
83
84 /// Returns the value of the underlying C handle.
85 T get() const { return _data.get(); }
86
87 bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88 bool operator!=(const handle &other) const { return !(*this == other); }
89};
90
91#ifndef DOXYGEN_SHOULD_SKIP_THIS
92template <> struct handle_traits<mkldnn_memory_t> {
93 static constexpr auto destructor = &mkldnn_memory_destroy;
94};
95
96template <> struct handle_traits<mkldnn_primitive_desc_t> {
97 static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
98};
99
100template <> struct handle_traits<mkldnn_primitive_t> {
101 static constexpr auto destructor = &mkldnn_primitive_destroy;
102};
103
104template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
105 static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
106};
107#endif
108
109struct memory;
110struct primitive_desc;
111
112/// Base class for all computational primitives.
113class primitive: public handle<mkldnn_primitive_t> {
114 friend struct error;
115 friend struct stream;
116 using handle::handle;
117public:
118 /// A proxy to C primitive kind enum
119 enum class kind {
120 undefined_primitive = mkldnn_undefined_primitive,
121 reorder = mkldnn_reorder,
122 concat = mkldnn_concat,
123 sum = mkldnn_sum,
124 convolution = mkldnn_convolution,
125 deconvolution = mkldnn_deconvolution,
126 shuffle = mkldnn_shuffle,
127 eltwise = mkldnn_eltwise,
128 softmax = mkldnn_softmax,
129 pooling = mkldnn_pooling,
130 lrn = mkldnn_lrn,
131 batch_normalization = mkldnn_batch_normalization,
132 inner_product = mkldnn_inner_product,
133 rnn = mkldnn_rnn,
134 };
135
136 primitive(const_mkldnn_primitive_desc_t c_pd);
137 primitive(const primitive_desc &pd);
138
139 /// Returns the descriptor of the underlying C API primitive.
140 inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
141 // TODO: use the C++ API wrapper structure.
142
143 void execute(struct stream &astream,
144 const std::unordered_map<int, memory> &args) const;
145};
146
147inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
148 return static_cast<mkldnn_primitive_kind_t>(akind);
149}
150/// Intel(R) MKL-DNN exception class.
151///
152/// This class captures the status returned by the failed C API function, error
153/// message, and, optionally, handle of the primitive that caused the error.
154struct error: public std::exception {
155 mkldnn_status_t status;
156 const char *message;
157
158 /// Constructs an error instance.
159 ///
160 /// @param astatus The error status returned by the C API.
161 /// @param amessage The error message.
162 error(mkldnn_status_t astatus, const char *amessage)
163 : status(astatus), message(amessage) {}
164
165 /// A convenience function for wrapping calls to the C API. Checks the
166 /// return status and throws an #error in case of failure.
167 ///
168 /// @param status The error status returned by the C API.
169 /// @param message The error message.
170 static void wrap_c_api(mkldnn_status_t status, const char *message) {
171 if (status != mkldnn_success)
172 throw error(status, message);
173 }
174};
175
176const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
177 const_mkldnn_primitive_desc_t pd;
178 error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
179 "could not get primitive descriptor by primitive");
180 return pd;
181}
182/// @}
183
184/// @addtogroup cpp_api_enums Common data types and enumerations
185/// A proxy to @ref c_api_types in @ref c_api.
186///
187/// @{
188
189enum scratchpad_mode {
190 scratchpad_mode_library = mkldnn_scratchpad_mode_library,
191 scratchpad_mode_user = mkldnn_scratchpad_mode_user,
192};
193
194inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
195 return static_cast<mkldnn_scratchpad_mode_t>(mode);
196}
197
198enum padding_kind {
199 zero = mkldnn_padding_zero
200};
201
202inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
203 return static_cast<mkldnn_padding_kind_t>(kind);
204}
205
206enum prop_kind {
207 forward_training = mkldnn_forward_training,
208 forward_scoring = mkldnn_forward_scoring,
209 forward_inference = mkldnn_forward_inference,
210 forward = mkldnn_forward,
211 backward = mkldnn_backward,
212 backward_data = mkldnn_backward_data,
213 backward_weights = mkldnn_backward_weights,
214 backward_bias = mkldnn_backward_bias
215};
216
217inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
218 return static_cast<mkldnn_prop_kind_t>(kind);
219}
220
221enum algorithm {
222 algorithm_undef = mkldnn_alg_kind_undef,
223 convolution_auto = mkldnn_convolution_auto,
224 convolution_direct = mkldnn_convolution_direct,
225 convolution_winograd = mkldnn_convolution_winograd,
226 deconvolution_direct = mkldnn_deconvolution_direct,
227 deconvolution_winograd = mkldnn_deconvolution_winograd,
228 eltwise_relu = mkldnn_eltwise_relu,
229 eltwise_tanh = mkldnn_eltwise_tanh,
230 eltwise_elu = mkldnn_eltwise_elu,
231 eltwise_square = mkldnn_eltwise_square,
232 eltwise_abs = mkldnn_eltwise_abs,
233 eltwise_sqrt = mkldnn_eltwise_sqrt,
234 eltwise_linear = mkldnn_eltwise_linear,
235 eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
236 eltwise_soft_relu = mkldnn_eltwise_soft_relu,
237 eltwise_logistic = mkldnn_eltwise_logistic,
238 lrn_across_channels = mkldnn_lrn_across_channels,
239 lrn_within_channel = mkldnn_lrn_within_channel,
240 pooling_max = mkldnn_pooling_max,
241 pooling_avg = mkldnn_pooling_avg,
242 pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
243 pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
244 vanilla_rnn = mkldnn_vanilla_rnn,
245 vanilla_lstm = mkldnn_vanilla_lstm,
246 vanilla_gru = mkldnn_vanilla_gru,
247 gru_linear_before_reset = mkldnn_gru_linear_before_reset
248};
249
250inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
251 return static_cast<mkldnn_alg_kind_t>(aalgorithm);
252}
253
254enum batch_normalization_flag {
255 use_global_stats = mkldnn_use_global_stats,
256 use_scale_shift = mkldnn_use_scaleshift,
257 fuse_bn_relu = mkldnn_fuse_bn_relu
258};
259
260inline mkldnn_batch_normalization_flag_t convert_to_c(
261 batch_normalization_flag aflag) {
262 return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
263}
264
265enum rnn_direction {
266 unidirectional_left2right = mkldnn_unidirectional_left2right,
267 unidirectional_right2left = mkldnn_unidirectional_right2left,
268 unidirectional = mkldnn_unidirectional,
269 bidirectional_concat = mkldnn_bidirectional_concat,
270 bidirectional_sum = mkldnn_bidirectional_sum,
271};
272
273inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
274 return static_cast<mkldnn_rnn_direction_t>(adir);
275}
276
277enum query {
278 undef = mkldnn_query_undef,
279
280 query_engine = mkldnn_query_engine,
281 primitive_kind = mkldnn_query_primitive_kind,
282
283 num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
284 num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
285
286 time_estimate_f64 = mkldnn_query_time_estimate_f64,
287 memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
288
289 query_scratchpad_engine = mkldnn_query_scratchpad_engine,
290
291 impl_info_str = mkldnn_query_impl_info_str,
292
293 op_d = mkldnn_query_op_d,
294 convolution_d = mkldnn_query_convolution_d,
295 deconvolution_d = mkldnn_query_deconvolution_d,
296 shuffle_d = mkldnn_query_shuffle_d,
297 eltwise_d = mkldnn_query_eltwise_d,
298 softmax_d = mkldnn_query_softmax_d,
299 pooling_d = mkldnn_query_pooling_d,
300 lrn_d = mkldnn_query_lrn_d,
301 batch_normalization_d = mkldnn_query_batch_normalization_d,
302 inner_product_d = mkldnn_query_inner_product_d,
303 rnn_d = mkldnn_query_rnn_d,
304
305 src_md = mkldnn_query_src_md,
306 diff_src_md = mkldnn_query_diff_src_md,
307 weights_md = mkldnn_query_weights_md,
308 diff_weights_md = mkldnn_query_diff_weights_md,
309 dst_md = mkldnn_query_dst_md,
310 diff_dst_md = mkldnn_query_diff_dst_md,
311 workspace_md = mkldnn_query_workspace_md,
312 scratchpad_md = mkldnn_query_scratchpad_md,
313};
314
315inline mkldnn_query_t convert_to_c(query aquery) {
316 return static_cast<mkldnn_query_t>(aquery);
317}
318
319/// @}
320
321/// @addtogroup cpp_api_attr Attributes
322/// An extension for controlling primitive behavior.
323///
324/// @sa @ref c_api_attributes in @ref c_api
325/// @{
326
327#ifndef DOXYGEN_SHOULD_SKIP_THIS
328template <> struct handle_traits<mkldnn_post_ops_t> {
329 static constexpr auto destructor = &mkldnn_post_ops_destroy;
330};
331#endif
332
333struct post_ops: public handle<mkldnn_post_ops_t> {
334 post_ops() {
335 mkldnn_post_ops_t result;
336 error::wrap_c_api(mkldnn_post_ops_create(&result),
337 "could not create post operation sequence");
338 reset(result);
339 }
340
341 int len() const { return mkldnn_post_ops_len(get()); }
342
343 primitive::kind kind(int index) const {
344 error::wrap_c_api(
345 index < len() ? mkldnn_success : mkldnn_invalid_arguments,
346 "post_ops index is out of range");
347 return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
348 index));
349 }
350
351 void append_sum(float scale = 1.) {
352 error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
353 "could not append sum");
354 }
355
356 void get_params_sum(int index, float &scale) const {
357 error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
358 "could not get sum params");
359 }
360
361 void append_eltwise(float scale, algorithm alg, float alpha,
362 float beta) {
363 error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale,
364 convert_to_c(alg), alpha, beta),
365 "could not append eltwise");
366 }
367
368 void get_params_eltwise(int index, float &scale, algorithm &alg,
369 float &alpha, float &beta) const {
370 mkldnn_alg_kind_t c_alg;
371 error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index,
372 &scale, &c_alg, &alpha, &beta),
373 "could not get eltwise params");
374 alg = static_cast<algorithm>(c_alg);
375 }
376};
377
378#ifndef DOXYGEN_SHOULD_SKIP_THIS
379template <> struct handle_traits<mkldnn_primitive_attr_t> {
380 static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
381};
382#endif
383
384struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
385 primitive_attr() {
386 mkldnn_primitive_attr_t result;
387 error::wrap_c_api(mkldnn_primitive_attr_create(&result),
388 "could not create a primitive attr");
389 reset(result);
390 }
391
392 scratchpad_mode get_scratchpad_mode() const {
393 mkldnn_scratchpad_mode_t result;
394 error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode(
395 get(), &result), "could not get scratchpad mode");
396 return scratchpad_mode(result);
397 }
398
399 void set_scratchpad_mode(scratchpad_mode mode) {
400 error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode(
401 get(), mkldnn::convert_to_c(mode)),
402 "could not set scratchpad mode");
403 }
404
405 void get_output_scales(int &mask, std::vector<float> &scales) const
406 {
407 mkldnn_dim_t count;
408 int c_mask;
409 const float *c_scales;
410 error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
411 &count, &c_mask, &c_scales),
412 "could not get int output scales");
413 scales.resize(count);
414
415 mask = c_mask;
416 for (mkldnn_dim_t c = 0; c < count; ++c)
417 scales[c] = c_scales[c];
418 }
419
420 void set_output_scales(int mask, const std::vector<float> &scales)
421 {
422 error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
423 (mkldnn_dim_t)scales.size(), mask, &scales[0]),
424 "could not set int output scales");
425 }
426
427 const post_ops get_post_ops() const {
428 post_ops result;
429 const_mkldnn_post_ops_t c_result;
430 error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
431 "could not get post operation sequence");
432 result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
433 return result;
434 }
435
436 void set_post_ops(post_ops ops) {
437 error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
438 "could not set post operation sequence");
439 }
440
441 void set_rnn_data_qparams(const float scale, const float shift)
442 {
443 error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
444 scale, shift), "could not set rnn data int scale/shift");
445 }
446
447 void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
448 {
449 error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
450 (int)scales.size(), mask, &scales[0]),
451 "could not set rnn weights int scales");
452 }
453};
454
455/// @}
456
457/// @addtogroup cpp_api_engine Engine
458/// Engine operations.
459///
460/// @sa @ref c_api_engine in @ref c_api
461/// @{
462
463#ifndef DOXYGEN_SHOULD_SKIP_THIS
464template <> struct handle_traits<mkldnn_engine_t> {
465 static constexpr auto destructor = &mkldnn_engine_destroy;
466};
467#endif
468
469/// An execution engine.
470struct engine: public handle<mkldnn_engine_t> {
471 friend class primitive;
472 // gcc bug??? using handle::handle;
473
474 /// Kinds of engines.
475 enum kind {
476 /// An unspecified engine
477 any = mkldnn_any_engine,
478 /// CPU engine
479 cpu = mkldnn_cpu,
480 };
481
482 /// Returns the number of engines of a certain kind.
483 ///
484 /// @param akind The kind of engines to count.
485
486 static size_t get_count(kind akind) {
487 return mkldnn_engine_get_count(convert_to_c(akind));
488 }
489
490 /// Constructs an engine.
491 ///
492 /// @param akind The kind of engine to construct.
493 /// @param index The index of the engine. Must be less than the value
494 /// returned by #get_count() for this particular kind of engine.
495
496 engine(kind akind, size_t index) {
497 mkldnn_engine_t aengine;
498 error::wrap_c_api(
499 mkldnn_engine_create(&aengine,
500 convert_to_c(akind), index),
501 "could not create an engine");
502 reset(aengine);
503 }
504
505 explicit engine(const mkldnn_engine_t& aengine)
506 : handle(aengine, true) {}
507
508 engine(const handle<mkldnn_primitive_desc_t> &pd) {
509 mkldnn_engine_t engine_q;
510 error::wrap_c_api(
511 mkldnn_primitive_desc_query(pd.get(),
512 mkldnn::convert_to_c(query_engine), 0, &engine_q),
513 "could not get engine from primitive_desc");
514 reset(engine_q, true);
515 }
516
517 template <class primitive_desc>
518 static engine query(const primitive_desc &pd) {
519 mkldnn_engine_t engine_q;
520 error::wrap_c_api(
521 mkldnn_primitive_desc_query(pd.get(),
522 mkldnn::convert_to_c(query_engine), 0, &engine_q),
523 "could not get engine from primitive_desc");
524
525 return engine(engine_q);
526 }
527
528private:
529 static mkldnn_engine_kind_t convert_to_c(kind akind) {
530 return static_cast<mkldnn_engine_kind_t>(akind);
531 }
532};
533
534/// @}
535
536/// @addtogroup cpp_api_stream Stream
537/// Execution stream operations
538///
539/// @sa @ref c_api_stream in @ref c_api
540/// @{
541
542#ifndef DOXYGEN_SHOULD_SKIP_THIS
543template <> struct handle_traits<mkldnn_stream_t> {
544 static constexpr auto destructor = &mkldnn_stream_destroy;
545};
546#endif
547
548struct stream: public handle<mkldnn_stream_t> {
549 using handle::handle;
550
551 enum: unsigned {
552 default_flags = mkldnn_stream_default_flags,
553 };
554
555 /// Constructs a stream.
556 stream(const engine &aengine,
557 unsigned flags = static_cast<unsigned>(default_flags)) {
558 mkldnn_stream_t astream;
559 error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags),
560 "could not create a stream");
561 reset(astream);
562 }
563};
564
565/// @}
566
567/// @addtogroup cpp_api_memory_related Memory and memory related operations
568/// @{
569
570/// @addtogroup cpp_api_memory Memory
571/// A primitive to describe and store data.
572///
573/// For more information, refer to @ref c_api_memory in @ref c_api.
574/// @{
575
576/// Memory that describes the data.
577struct memory: public handle<mkldnn_memory_t> {
578 public:
579 typedef mkldnn_dim_t dim;
580 typedef std::vector<dim> dims;
581
582 template <typename T> static void validate_dims(const std::vector<T> &v) {
583 if (v.size() > MKLDNN_MAX_NDIMS)
584 throw error(mkldnn_invalid_arguments, "invalid dimensions");
585 }
586
587 /// Data type specification. See #mkldnn_data_type_t for a detailed
588 /// description.
589 enum data_type {
590 data_undef = mkldnn_data_type_undef,
591 f32 = mkldnn_f32,
592 s32 = mkldnn_s32,
593 s8 = mkldnn_s8,
594 u8 = mkldnn_u8,
595 };
596
597 /// Memory format tag specification. See #mkldnn_format_tag_t
598 /// for a detailed description.
599 enum format_tag {
600 format_tag_undef = mkldnn_format_tag_undef,
601 any = mkldnn_format_tag_any,
602 a = mkldnn_a,
603 ab = mkldnn_ab,
604 abc = mkldnn_abc,
605 abcd = mkldnn_abcd,
606 abcde = mkldnn_abcde,
607 abcdef = mkldnn_abcdef,
608 abdec = mkldnn_abdec,
609 acb = mkldnn_acb,
610 acbde = mkldnn_acbde,
611 acdb = mkldnn_acdb,
612 acdeb = mkldnn_acdeb,
613 ba = mkldnn_ba,
614 bac = mkldnn_bac,
615 bacd = mkldnn_bacd,
616 bcda = mkldnn_bcda,
617 cba = mkldnn_cba,
618 cdba = mkldnn_cdba,
619 cdeba = mkldnn_cdeba,
620 decab = mkldnn_decab,
621 Abc16a = mkldnn_Abc16a,
622 ABc16a16b = mkldnn_ABc16a16b,
623 aBc16b = mkldnn_aBc16b,
624 ABc16b16a = mkldnn_ABc16b16a,
625 Abc4a = mkldnn_Abc4a,
626 aBc4b = mkldnn_aBc4b,
627 ABc4b16a4b = mkldnn_ABc4b16a4b,
628 ABc4b4a = mkldnn_ABc4b4a,
629 ABc8a16b2a = mkldnn_ABc8a16b2a,
630 ABc8a8b = mkldnn_ABc8a8b,
631 aBc8b = mkldnn_aBc8b,
632 ABc8b16a2b = mkldnn_ABc8b16a2b,
633 ABc8b8a = mkldnn_ABc8b8a,
634 Abcd16a = mkldnn_Abcd16a,
635 ABcd16a16b = mkldnn_ABcd16a16b,
636 aBcd16b = mkldnn_aBcd16b,
637 ABcd16b16a = mkldnn_ABcd16b16a,
638 aBCd16b16c = mkldnn_aBCd16b16c,
639 aBCd16c16b = mkldnn_aBCd16c16b,
640 Abcd4a = mkldnn_Abcd4a,
641 aBcd4b = mkldnn_aBcd4b,
642 ABcd4b16a4b = mkldnn_ABcd4b16a4b,
643 ABcd4b4a = mkldnn_ABcd4b4a,
644 aBCd4c16b4c = mkldnn_aBCd4c16b4c,
645 aBCd4c4b = mkldnn_aBCd4c4b,
646 ABcd8a16b2a = mkldnn_ABcd8a16b2a,
647 ABcd8a8b = mkldnn_ABcd8a8b,
648 aBcd8b = mkldnn_aBcd8b,
649 ABcd8b16a2b = mkldnn_ABcd8b16a2b,
650 aBCd8b16c2b = mkldnn_aBCd8b16c2b,
651 ABcd8b8a = mkldnn_ABcd8b8a,
652 aBCd8b8c = mkldnn_aBCd8b8c,
653 aBCd8c16b2c = mkldnn_aBCd8c16b2c,
654 aBCd8c8b = mkldnn_aBCd8c8b,
655 Abcde16a = mkldnn_Abcde16a,
656 ABcde16a16b = mkldnn_ABcde16a16b,
657 aBcde16b = mkldnn_aBcde16b,
658 ABcde16b16a = mkldnn_ABcde16b16a,
659 aBCde16b16c = mkldnn_aBCde16b16c,
660 aBCde16c16b = mkldnn_aBCde16c16b,
661 aBCde2c8b4c = mkldnn_aBCde2c8b4c,
662 Abcde4a = mkldnn_Abcde4a,
663 aBcde4b = mkldnn_aBcde4b,
664 ABcde4b4a = mkldnn_ABcde4b4a,
665 aBCde4b4c = mkldnn_aBCde4b4c,
666 aBCde4c16b4c = mkldnn_aBCde4c16b4c,
667 aBCde4c4b = mkldnn_aBCde4c4b,
668 Abcde8a = mkldnn_Abcde8a,
669 ABcde8a8b = mkldnn_ABcde8a8b,
670 aBcde8b = mkldnn_aBcde8b,
671 ABcde8b16a2b = mkldnn_ABcde8b16a2b,
672 aBCde8b16c2b = mkldnn_aBCde8b16c2b,
673 ABcde8b8a = mkldnn_ABcde8b8a,
674 aBCde8b8c = mkldnn_aBCde8b8c,
675 aBCde8c16b2c = mkldnn_aBCde8c16b2c,
676 aBCde8c8b = mkldnn_aBCde8c8b,
677 aBcdef16b = mkldnn_aBcdef16b,
678 aBCdef16b16c = mkldnn_aBCdef16b16c,
679 aBCdef16c16b = mkldnn_aBCdef16c16b,
680 aBcdef4b = mkldnn_aBcdef4b,
681 aBCdef4c4b = mkldnn_aBCdef4c4b,
682 aBCdef8b8c = mkldnn_aBCdef8b8c,
683 aBCdef8c16b2c = mkldnn_aBCdef8c16b2c,
684 aBCdef8c8b = mkldnn_aBCdef8c8b,
685 aBdc16b = mkldnn_aBdc16b,
686 aBdc4b = mkldnn_aBdc4b,
687 aBdc8b = mkldnn_aBdc8b,
688 aBdec16b = mkldnn_aBdec16b,
689 aBdec4b = mkldnn_aBdec4b,
690 aBdec8b = mkldnn_aBdec8b,
691 aBdefc16b = mkldnn_aBdefc16b,
692 aBdefc4b = mkldnn_aBdefc4b,
693 aBdefc8b = mkldnn_aBdefc8b,
694 Acb16a = mkldnn_Acb16a,
695 Acb4a = mkldnn_Acb4a,
696 Acb8a = mkldnn_Acb8a,
697 aCBd16b16c = mkldnn_aCBd16b16c,
698 aCBde16b16c = mkldnn_aCBde16b16c,
699 Acdb16a = mkldnn_Acdb16a,
700 Acdb4a = mkldnn_Acdb4a,
701 Acdb8a = mkldnn_Acdb8a,
702 Acdeb16a = mkldnn_Acdeb16a,
703 Acdeb4a = mkldnn_Acdeb4a,
704 Acdeb8a = mkldnn_Acdeb8a,
705 BAc16a16b = mkldnn_BAc16a16b,
706 BAcd16a16b = mkldnn_BAcd16a16b,
707 format_tag_last = mkldnn_format_tag_last,
708
709 x = mkldnn_x,
710 nc = mkldnn_nc,
711 cn = mkldnn_cn,
712 ncw = mkldnn_ncw,
713 nwc = mkldnn_nwc,
714 nchw = mkldnn_nchw,
715 nhwc = mkldnn_nhwc,
716 chwn = mkldnn_chwn,
717 ncdhw = mkldnn_ncdhw,
718 ndhwc = mkldnn_ndhwc,
719 oi = mkldnn_oi,
720 io = mkldnn_io,
721 oiw = mkldnn_oiw,
722 wio = mkldnn_wio,
723 oihw = mkldnn_oihw,
724 hwio = mkldnn_hwio,
725 ihwo = mkldnn_ihwo,
726 iohw = mkldnn_iohw,
727 oidhw = mkldnn_oidhw,
728 dhwio = mkldnn_dhwio,
729 goiw = mkldnn_goiw,
730 goihw = mkldnn_goihw,
731 hwigo = mkldnn_hwigo,
732 giohw = mkldnn_giohw,
733 goidhw = mkldnn_goidhw,
734 tnc = mkldnn_tnc,
735 ntc = mkldnn_ntc,
736 ldsnc = mkldnn_ldsnc,
737 ldigo = mkldnn_ldigo,
738 ldgoi = mkldnn_ldgoi,
739 ldgo = mkldnn_ldgo,
740 nCdhw16c = mkldnn_nCdhw16c,
741 nCdhw4c = mkldnn_nCdhw4c,
742 nCdhw8c = mkldnn_nCdhw8c,
743 nChw16c = mkldnn_nChw16c,
744 nChw4c = mkldnn_nChw4c,
745 nChw8c = mkldnn_nChw8c,
746 nCw16c = mkldnn_nCw16c,
747 nCw4c = mkldnn_nCw4c,
748 nCw8c = mkldnn_nCw8c,
749 IOw16o16i = mkldnn_IOw16o16i,
750 OIw16i16o = mkldnn_OIw16i16o,
751 OIw16o16i = mkldnn_OIw16o16i,
752 Oiw16o = mkldnn_Oiw16o,
753 OIw4i16o4i = mkldnn_OIw4i16o4i,
754 OIw4i4o = mkldnn_OIw4i4o,
755 Oiw4o = mkldnn_Oiw4o,
756 OIw8i16o2i = mkldnn_OIw8i16o2i,
757 OIw8i8o = mkldnn_OIw8i8o,
758 OIw8o16i2o = mkldnn_OIw8o16i2o,
759 OIw8o8i = mkldnn_OIw8o8i,
760 Owi16o = mkldnn_Owi16o,
761 Owi4o = mkldnn_Owi4o,
762 Owi8o = mkldnn_Owi8o,
763 IOhw16o16i = mkldnn_IOhw16o16i,
764 Ohwi16o = mkldnn_Ohwi16o,
765 Ohwi4o = mkldnn_Ohwi4o,
766 Ohwi8o = mkldnn_Ohwi8o,
767 OIhw16i16o = mkldnn_OIhw16i16o,
768 OIhw16o16i = mkldnn_OIhw16o16i,
769 Oihw16o = mkldnn_Oihw16o,
770 OIhw4i16o4i = mkldnn_OIhw4i16o4i,
771 OIhw4i4o = mkldnn_OIhw4i4o,
772 Oihw4o = mkldnn_Oihw4o,
773 OIhw8i16o2i = mkldnn_OIhw8i16o2i,
774 OIhw8i8o = mkldnn_OIhw8i8o,
775 OIhw8o16i2o = mkldnn_OIhw8o16i2o,
776 OIhw8o8i = mkldnn_OIhw8o8i,
777 Odhwi16o = mkldnn_Odhwi16o,
778 Odhwi4o = mkldnn_Odhwi4o,
779 Odhwi8o = mkldnn_Odhwi8o,
780 OIdhw16i16o = mkldnn_OIdhw16i16o,
781 OIdhw16o16i = mkldnn_OIdhw16o16i,
782 Oidhw16o = mkldnn_Oidhw16o,
783 OIdhw4i4o = mkldnn_OIdhw4i4o,
784 Oidhw4o = mkldnn_Oidhw4o,
785 OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
786 OIdhw8i8o = mkldnn_OIdhw8i8o,
787 OIdhw8o8i = mkldnn_OIdhw8o8i,
788 gIOw16o16i = mkldnn_gIOw16o16i,
789 gOIw16i16o = mkldnn_gOIw16i16o,
790 gOIw16o16i = mkldnn_gOIw16o16i,
791 gOiw16o = mkldnn_gOiw16o,
792 gOIw4i16o4i = mkldnn_gOIw4i16o4i,
793 gOIw4i4o = mkldnn_gOIw4i4o,
794 gOiw4o = mkldnn_gOiw4o,
795 gOIw8i16o2i = mkldnn_gOIw8i16o2i,
796 gOIw8i8o = mkldnn_gOIw8i8o,
797 gOIw8o16i2o = mkldnn_gOIw8o16i2o,
798 gOIw8o8i = mkldnn_gOIw8o8i,
799 gOwi16o = mkldnn_gOwi16o,
800 gOwi4o = mkldnn_gOwi4o,
801 gOwi8o = mkldnn_gOwi8o,
802 gIOhw16o16i = mkldnn_gIOhw16o16i,
803 gOhwi16o = mkldnn_gOhwi16o,
804 gOhwi4o = mkldnn_gOhwi4o,
805 gOhwi8o = mkldnn_gOhwi8o,
806 Goihw16g = mkldnn_Goihw16g,
807 gOIhw16i16o = mkldnn_gOIhw16i16o,
808 gOIhw16o16i = mkldnn_gOIhw16o16i,
809 gOihw16o = mkldnn_gOihw16o,
810 gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
811 gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
812 gOIhw4i4o = mkldnn_gOIhw4i4o,
813 gOIhw4o4i = mkldnn_gOIhw4o4i,
814 gOihw4o = mkldnn_gOihw4o,
815 Goihw8g = mkldnn_Goihw8g,
816 gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
817 gOIhw8i8o = mkldnn_gOIhw8i8o,
818 gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
819 gOIhw8o8i = mkldnn_gOIhw8o8i,
820 gOdhwi16o = mkldnn_gOdhwi16o,
821 gOdhwi4o = mkldnn_gOdhwi4o,
822 gOdhwi8o = mkldnn_gOdhwi8o,
823 gOIdhw16i16o = mkldnn_gOIdhw16i16o,
824 gOIdhw16o16i = mkldnn_gOIdhw16o16i,
825 gOidhw16o = mkldnn_gOidhw16o,
826 gOIdhw4i4o = mkldnn_gOIdhw4i4o,
827 gOidhw4o = mkldnn_gOidhw4o,
828 gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
829 gOIdhw8i8o = mkldnn_gOIdhw8i8o,
830 gOIdhw8o8i = mkldnn_gOIdhw8o8i,
831 };
832
833 /// A memory descriptor.
834 struct desc {
835 friend struct memory;
836 /// The underlying C API data structure.
837 mkldnn_memory_desc_t data;
838
839 /// Constructs a zero memory descriptor
840 desc(): data() {}
841
842 /// Constructs a memory descriptor.
843 ///
844 /// @param adims Data dimensions
845 /// @param adata_type Data precision/type.
846 /// @param aformat Data layout format tag.
847 desc(const dims &adims, data_type adata_type,
848 format_tag aformat) {
849 validate_dims(adims);
850 error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(),
851 adims.size() == 0 ? nullptr : &adims[0],
852 convert_to_c(adata_type), convert_to_c(aformat)),
853 "could not initialize a memory descriptor");
854 }
855
856 /// Constructs a memory descriptor from a C API data structure.
857 ///
858 /// @param adata A C API #mkldnn_memory_desc_t structure.
859 desc(const mkldnn_memory_desc_t &adata): data(adata) {}
860
861 /// Constructs a sub-memory descriptor
862 //
863 /// @param adims Sizes of a sub-memory
864 /// @param offsets Offsets of a sub-memory
865 desc submemory_desc(const dims &adims, const dims &offsets) {
866 mkldnn_memory_desc_t sub_md;
867 error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md,
868 &data, &adims[0], &offsets[0]),
869 "could not initialize a sub-memory");
870 return desc(sub_md);
871 }
872
873 /// Returns the number of bytes required to allocate the memory described
874 /// including the padding area.
875 size_t get_size() const { return mkldnn_memory_desc_get_size(&data); }
876
877 bool operator==(const desc &other) const {
878 return mkldnn_memory_desc_equal(&data, &other.data) != 0;
879 }
880
881 bool operator!=(const desc &other) const { return !operator==(other); }
882 };
883
884 /// Constructs a memory.
885 ///
886 /// @param md Memory descriptor.
887 /// @param aengine Engine.
888 /// @param ahandle Native handle.
889 memory(const desc &md, const engine &aengine, void *ahandle) {
890 mkldnn_memory_t result;
891 error::wrap_c_api(mkldnn_memory_create(&result, &md.data,
892 aengine.get(), ahandle), "could not create a memory");
893 reset(result);
894 }
895
896 /// Constructs a memory.
897 ///
898 /// @param md Memory descriptor.
899 /// @param aengine Engine.
900 memory(const desc &md, const engine &aengine)
901 : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {}
902
903 /// Returns the descriptor of the memory.
904 desc get_desc() const {
905 const mkldnn_memory_desc_t *cdesc;
906 error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc),
907 "could not get memory descriptor from a memory");
908 return desc(*cdesc);
909 }
910
911 /// Returns the engine of the memory.
912 engine get_engine() const {
913 mkldnn_engine_t engine_q;
914 error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q),
915 "could not get engine from a memory");
916 return engine(engine_q);
917 }
918
919 /// Returns a handle of the data contained in the memory.
920 ///
921 /// On the CPU engine, this is a pointer to the allocated memory.
922 void *get_data_handle() const {
923 void *handle;
924 error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
925 "could not get native handle");
926 return handle;
927 }
928
929 void set_data_handle(void *handle) const {
930 error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
931 "could not set native handle");
932 }
933
934 // Must go away or be private:
935 static mkldnn_data_type_t convert_to_c(data_type adata_type) {
936 return static_cast<mkldnn_data_type_t>(adata_type);
937 }
938 static mkldnn_format_tag_t convert_to_c(format_tag aformat) {
939 return static_cast<mkldnn_format_tag_t>(aformat);
940 }
941};
942
943inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
944 return a == memory::convert_to_c(b);
945}
946inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
947 return !(a == b);
948}
949inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
950 return b == a;
951}
952inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
953 return !(a == b);
954}
955
956inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) {
957 return a == memory::convert_to_c(b);
958}
959inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) {
960 return !(a == b);
961}
962inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) {
963 return b == a;
964}
965inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) {
966 return !(a == b);
967}
968
969/// @}
970
971/// @addtogroup cpp_api_reorder Reorder
972/// A primitive to copy data between memory formats.
973///
974/// @sa @ref c_api_reorder in @ref c_api
975/// @{
976
977struct reorder : public primitive {
978 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
979 primitive_desc(const engine &src_engine, const memory::desc &src_md,
980 const engine &dst_engine, const memory::desc &dst_md,
981 const primitive_attr &aattr) {
982 mkldnn_primitive_desc_t result;
983 error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
984 src_engine.get(), &src_md.data,
985 dst_engine.get(), &dst_md.data, aattr.get()),
986 "could not create a reorder primitive descriptor");
987 reset(result);
988 }
989
990 primitive_desc(const engine &src_engine, const memory::desc &src_md,
991 const engine &dst_engine, const memory::desc &dst_md) {
992 mkldnn_primitive_desc_t result;
993 error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
994 src_engine.get(), &src_md.data,
995 dst_engine.get(), &dst_md.data, nullptr),
996 "could not create a reorder primitive descriptor");
997 reset(result);
998 }
999
1000 primitive_desc(const memory &src, const memory &dst,
1001 const primitive_attr &aattr) {
1002 mkldnn_primitive_desc_t result;
1003 auto src_md = src.get_desc();
1004 auto dst_md = dst.get_desc();
1005 error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
1006 src.get_engine().get(), &src_md.data,
1007 dst.get_engine().get(), &dst_md.data, aattr.get()),
1008 "could not create a reorder primitive descriptor");
1009 reset(result);
1010 }
1011
1012 primitive_desc(const memory &src, const memory &dst) {
1013 mkldnn_primitive_desc_t result;
1014 auto src_md = src.get_desc();
1015 auto dst_md = dst.get_desc();
1016 error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
1017 src.get_engine().get(), &src_md.data,
1018 dst.get_engine().get(), &dst_md.data, nullptr),
1019 "could not create a reorder primitive descriptor");
1020 reset(result);
1021 }
1022
1023 memory::desc scratchpad_desc() const {
1024 const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
1025 get(), mkldnn::convert_to_c(scratchpad_md), 0);
1026 if (cdesc == nullptr)
1027 return memory::desc();
1028 return memory::desc(*cdesc);
1029 }
1030
1031 engine scratchpad_engine() {
1032 mkldnn_engine_t engine_q;
1033 error::wrap_c_api(
1034 mkldnn_primitive_desc_query(get(),
1035 mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q),
1036 "could not get scratchpad engine from reorder primitive_desc");
1037
1038 return engine(engine_q);
1039 }
1040
1041 engine get_engine() { return engine::query(*this); }
1042 };
1043
1044 reorder(const primitive_desc &pd): primitive(pd.get()) {}
1045
1046 reorder(const memory &src, const memory &dst):
1047 primitive(primitive_desc(src, dst).get()) {}
1048
1049 void execute(stream astream, memory &src, memory &dst) {
1050 primitive::execute(astream,
1051 {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}});
1052 }
1053};
1054
1055/// @}
1056
1057/// @addtogroup cpp_api_concat Concat
1058/// A primitive to concatenate data by arbitrary dimension.
1059///
1060/// @sa @ref c_api_concat in @ref c_api
1061/// @{
1062
1063struct concat : public primitive {
1064 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1065 std::vector<mkldnn_memory_desc_t> cpp_to_c(
1066 const std::vector<memory::desc> &srcs) {
1067 std::vector<mkldnn_memory_desc_t> c_api_srcs;
1068 c_api_srcs.reserve(srcs.size());
1069 for (const auto &s : srcs) c_api_srcs.push_back(s.data);
1070 return c_api_srcs;
1071 }
1072
1073 primitive_desc(const memory::desc &dst, int concat_dimension,
1074 const std::vector<memory::desc> &srcs, const engine &aengine) {
1075 auto c_api_srcs = cpp_to_c(srcs);
1076
1077 mkldnn_primitive_desc_t result;
1078 error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1079 &result, &dst.data, (int)c_api_srcs.size(),
1080 concat_dimension, &c_api_srcs[0], nullptr, aengine.get()),
1081 "could not create a concat primitive descriptor");
1082 reset(result);
1083 }
1084
1085 primitive_desc(int concat_dimension,
1086 const std::vector<memory::desc> &srcs, const engine &aengine) {
1087 auto c_api_srcs = cpp_to_c(srcs);
1088
1089 mkldnn_primitive_desc_t result;
1090 error::wrap_c_api(mkldnn_concat_primitive_desc_create(
1091 &result, nullptr, (int)c_api_srcs.size(),
1092 concat_dimension, &c_api_srcs[0], nullptr, aengine.get()),
1093 "could not create a concat primitive descriptor");
1094 reset(result);
1095 }
1096
1097 memory::desc dst_desc() const {
1098 const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
1099 get(), mkldnn::convert_to_c(dst_md), 0);
1100 error::wrap_c_api(
1101 cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success,
1102 "could not get a dst memory descriptor");
1103 return memory::desc(*cdesc);
1104 }
1105
1106 memory::desc scratchpad_desc() const {
1107 const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
1108 get(), mkldnn::convert_to_c(scratchpad_md), 0);
1109 if (cdesc == nullptr)
1110 return memory::desc();
1111 return memory::desc(*cdesc);
1112 }
1113
1114 engine get_engine() { return engine::query(*this); }
1115 };
1116
1117 concat(const primitive_desc &pd): primitive(pd.get()) {}
1118};
1119
1120/// @}
1121
1122/// @addtogroup cpp_api_sum Sum
1123/// A primitive to sum data.
1124///
1125/// @sa @ref c_api_sum in @ref c_api
1126/// @{
1127
1128struct sum : public primitive {
1129 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1130 std::vector<mkldnn_memory_desc_t> cpp_to_c(
1131 const std::vector<memory::desc> &srcs) {
1132 std::vector<mkldnn_memory_desc_t> c_api_srcs;
1133 c_api_srcs.reserve(srcs.size());
1134 for (const auto &s : srcs) c_api_srcs.push_back(s.data);
1135 return c_api_srcs;
1136 }
1137
1138 primitive_desc(const memory::desc &dst,
1139 const std::vector<float> &scales,
1140 const std::vector<memory::desc> &srcs, const engine &aengine) {
1141 error::wrap_c_api(scales.size() == srcs.size()
1142 ? mkldnn_success : mkldnn_invalid_arguments,
1143 "number of scales not equal to number of srcs");
1144
1145 auto c_api_srcs = cpp_to_c(srcs);
1146
1147 mkldnn_primitive_desc_t result;
1148 error::wrap_c_api(mkldnn_sum_primitive_desc_create(
1149 &result, &dst.data, (int)c_api_srcs.size(),
1150 &scales[0], &c_api_srcs[0], nullptr, aengine.get()),
1151 "could not create a sum primitive descriptor");
1152 reset(result);
1153 }
1154
1155 primitive_desc(const std::vector<float> &scales,
1156 const std::vector<memory::desc> &srcs, const engine &aengine) {
1157 error::wrap_c_api(scales.size() == srcs.size()
1158 ? mkldnn_success : mkldnn_invalid_arguments,
1159 "number of scales not equal to number of srcs");
1160
1161 auto c_api_srcs = cpp_to_c(srcs);
1162 mkldnn_primitive_desc_t result;
1163 error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result,
1164 nullptr, (int)c_api_srcs.size(), &scales[0],
1165 &c_api_srcs[0], nullptr, aengine.get()),
1166 "could not create a sum primitive descriptor");
1167 reset(result);
1168 }
1169
1170 memory::desc dst_desc() const {
1171 const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
1172 get(), mkldnn::convert_to_c(dst_md), 0);
1173 error::wrap_c_api(
1174 cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success,
1175 "could not get a dst memory descriptor");
1176 return memory::desc(*cdesc);
1177 }
1178
1179 memory::desc scratchpad_desc() const {
1180 const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
1181 get(), mkldnn::convert_to_c(scratchpad_md), 0);
1182 if (cdesc == nullptr)
1183 return memory::desc();
1184 return memory::desc(*cdesc);
1185 }
1186
1187 engine get_engine() { return engine::query(*this); }
1188 };
1189
1190 sum(const primitive_desc &pd): primitive(pd.get()) {}
1191};
1192
1193/// @}
1194
1195/// @}
1196
1197/// @addtogroup cpp_api_primitives Primitives
1198/// @{
1199
1200/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
1201/// @{
1202
1203/// A base class for all primitive descriptors.
1204struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1205 primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
1206 const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1207 mkldnn_primitive_desc_iterator_t iterator = nullptr;
1208 mkldnn_status_t status = mkldnn_primitive_desc_iterator_create(
1209 &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1210 hint_fwd_pd);
1211 error::wrap_c_api(status,
1212 "could not create a primitive descriptor iterator");
1213 pd_iterator.reset(iterator);
1214 fetch_impl();
1215 }
1216
1217 engine get_engine() { return engine::query(*this); }
1218
1219 primitive_attr get_primitive_attr() const {
1220 const_mkldnn_primitive_attr_t const_cattr;
1221 error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr),
1222 "could not get attributes");
1223 mkldnn_primitive_attr_t cattr;
1224 error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1225 "could not clone attributes");
1226
1227 primitive_attr attr;
1228 attr.reset(cattr);
1229 return attr;
1230 }
1231
1232 /// Returns implementation name
1233 const char *impl_info_str() const {
1234 const char *res;
1235 error::wrap_c_api(mkldnn_primitive_desc_query(get(),
1236 mkldnn_query_impl_info_str, 0, &res),
1237 "could not query implementation info string");
1238 return res;
1239 }
1240
1241 /// Queries the memory::dim value (same as int64_t)
1242 memory::dim query_s64(query q) const {
1243 memory::dim res;
1244 mkldnn_status_t status = mkldnn_primitive_desc_query(get(),
1245 mkldnn::convert_to_c(q), 0, &res);
1246 return status == mkldnn_success ? res : 0;
1247 }
1248
1249 /// Advances the next implementation for the given op descriptor.
1250 ///
1251 /// Returns:
1252 /// - @c true on success
1253 /// - @c false if the last implementation reached, and
1254 /// the primitive descriptor itself is kept unchanged
1255 bool next_impl() {
1256 mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(
1257 pd_iterator.get());
1258 if (status == mkldnn_iterator_ends) return false;
1259 error::wrap_c_api(status, "primitive descriptor iterator next failed");
1260
1261 fetch_impl();
1262 return true;
1263 }
1264
1265 /// Queries and returns requested memory descriptor.
1266 memory::desc query_md(query what, int idx = 0) const {
1267 std::vector<query> valid_q{src_md, diff_src_md, weights_md,
1268 diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md};
1269 if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
1270 [=](query q) { return what == q; }))
1271 throw error(mkldnn_invalid_arguments, "invalid memory query");
1272
1273 const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
1274 get(), mkldnn::convert_to_c(what), idx);
1275 if (cdesc == nullptr) return memory::desc();
1276
1277 return memory::desc(*cdesc);
1278 }
1279
1280 // register specialized queries, e.g. src_desc()
1281# define REG_QUERY_MD(name, what, idx) \
1282 memory::desc name ## _desc() const { return query_md(what ## _md, idx); }
1283
1284 private:
1285 handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1286 void fetch_impl() {
1287 mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1288 pd_iterator.get());
1289 error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error,
1290 "could not fetch a primitive descriptor from the iterator");
1291 reset(pd);
1292 }
1293};
1294
1295/// @}
1296
1297/// @addtogroup cpp_api_convolution Convolution
1298/// A primitive to compute convolution using different algorithms.
1299///
1300/// @sa @ref c_api_convolution in @ref c_api
1301/// @{
1302
1303struct convolution_forward: public primitive {
1304 struct desc {
1305 mkldnn_convolution_desc_t data;
1306 desc(prop_kind aprop_kind, algorithm aalgorithm,
1307 const memory::desc &src_desc,
1308 const memory::desc &weights_desc,
1309 const memory::desc &bias_desc,
1310 const memory::desc &dst_desc,
1311 const memory::dims strides,
1312 const memory::dims padding_l,
1313 const memory::dims padding_r,
1314 const padding_kind apadding_kind) {
1315 memory::validate_dims(strides);
1316 memory::validate_dims(padding_l);
1317 memory::validate_dims(padding_r);
1318 error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1319 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1320 &src_desc.data, &weights_desc.data, &bias_desc.data,
1321 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1322 mkldnn::convert_to_c(apadding_kind)),
1323 "could not create a convolution forward descriptor");
1324 }
1325 desc(prop_kind aprop_kind, algorithm aalgorithm,
1326 const memory::desc &src_desc,
1327 const memory::desc &weights_desc,
1328 const memory::desc &dst_desc,
1329 const memory::dims strides,
1330 const memory::dims padding_l,
1331 const memory::dims padding_r,
1332 const padding_kind apadding_kind) {
1333 memory::validate_dims(strides);
1334 memory::validate_dims(padding_l);
1335 memory::validate_dims(padding_r);
1336 error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
1337 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1338 &src_desc.data, &weights_desc.data, nullptr,
1339 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1340 mkldnn::convert_to_c(apadding_kind)),
1341 "could not create a convolution forward descriptor");
1342 }
1343 desc(prop_kind aprop_kind, algorithm aalgorithm,
1344 const memory::desc &src_desc,
1345 const memory::desc &weights_desc,
1346 const memory::desc &bias_desc,
1347 const memory::desc &dst_desc,
1348 const memory::dims strides,
1349 const memory::dims dilates,
1350 const memory::dims padding_l,
1351 const memory::dims padding_r,
1352 const padding_kind apadding_kind) {
1353 memory::validate_dims(strides);
1354 memory::validate_dims(dilates);
1355 memory::validate_dims(padding_l);
1356 memory::validate_dims(padding_r);
1357 error::wrap_c_api(
1358 mkldnn_dilated_convolution_forward_desc_init(&data,
1359 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1360 &src_desc.data, &weights_desc.data, &bias_desc.data,
1361 &dst_desc.data, &strides[0], &dilates[0],
1362 &padding_l[0], &padding_r[0],
1363 mkldnn::convert_to_c(apadding_kind)),
1364 "could not create a dilated convolution forward descriptor");
1365 }
1366 desc(prop_kind aprop_kind, algorithm aalgorithm,
1367 const memory::desc &src_desc,
1368 const memory::desc &weights_desc,
1369 const memory::desc &dst_desc,
1370 const memory::dims strides,
1371 const memory::dims dilates,
1372 const memory::dims padding_l,
1373 const memory::dims padding_r,
1374 const padding_kind apadding_kind) {
1375 memory::validate_dims(strides);
1376 memory::validate_dims(dilates);
1377 memory::validate_dims(padding_l);
1378 memory::validate_dims(padding_r);
1379 error::wrap_c_api(
1380 mkldnn_dilated_convolution_forward_desc_init(&data,
1381 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1382 &src_desc.data, &weights_desc.data, nullptr,
1383 &dst_desc.data, &strides[0], &dilates[0],
1384 &padding_l[0], &padding_r[0],
1385 mkldnn::convert_to_c(apadding_kind)),
1386 "could not create a dilated convolution forward descriptor");
1387 }
1388 };
1389
1390 struct primitive_desc : public mkldnn::primitive_desc {
1391 primitive_desc(const desc &desc, const engine &e)
1392 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1393
1394 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1395 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1396
1397 REG_QUERY_MD(src, src, 0);
1398 REG_QUERY_MD(weights, weights, 0);
1399 REG_QUERY_MD(bias, weights, 1);
1400 REG_QUERY_MD(dst, dst, 0);
1401 REG_QUERY_MD(scratchpad, scratchpad, 0);
1402 };
1403
1404 convolution_forward(const primitive_desc &pd): primitive(pd) {}
1405};
1406
1407struct convolution_backward_data : public primitive {
1408 struct desc {
1409 mkldnn_convolution_desc_t data;
1410 desc(algorithm aalgorithm,
1411 const memory::desc &diff_src_desc,
1412 const memory::desc &weights_desc,
1413 const memory::desc &diff_dst_desc,
1414 const memory::dims strides,
1415 const memory::dims padding_l,
1416 const memory::dims padding_r,
1417 const padding_kind apadding_kind) {
1418 memory::validate_dims(strides);
1419 memory::validate_dims(padding_l);
1420 memory::validate_dims(padding_r);
1421 error::wrap_c_api(mkldnn_convolution_backward_data_desc_init(
1422 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1423 &weights_desc.data, &diff_dst_desc.data,
1424 &strides[0], &padding_l[0], &padding_r[0],
1425 mkldnn::convert_to_c(apadding_kind)),
1426 "could not create a convolution backward data descriptor");
1427 }
1428 desc(algorithm aalgorithm,
1429 const memory::desc &diff_src_desc,
1430 const memory::desc &weights_desc,
1431 const memory::desc &diff_dst_desc,
1432 const memory::dims strides,
1433 const memory::dims dilates,
1434 const memory::dims padding_l,
1435 const memory::dims padding_r,
1436 const padding_kind apadding_kind) {
1437 memory::validate_dims(strides);
1438 memory::validate_dims(dilates);
1439 memory::validate_dims(padding_l);
1440 memory::validate_dims(padding_r);
1441 error::wrap_c_api(
1442 mkldnn_dilated_convolution_backward_data_desc_init(
1443 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1444 &weights_desc.data, &diff_dst_desc.data,
1445 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1446 mkldnn::convert_to_c(apadding_kind)),
1447 "could not create a convolution backward data descriptor");
1448 }
1449 };
1450
1451 struct primitive_desc : public mkldnn::primitive_desc {
1452 primitive_desc(const desc &desc, const engine &e,
1453 const convolution_forward::primitive_desc &hint_fwd_pd)
1454 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1455
1456 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1457 const convolution_forward::primitive_desc &hint_fwd_pd)
1458 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1459
1460 REG_QUERY_MD(diff_src, diff_src, 0);
1461 REG_QUERY_MD(weights, weights, 0);
1462 REG_QUERY_MD(diff_dst, diff_dst, 0);
1463 REG_QUERY_MD(scratchpad, scratchpad, 0);
1464 };
1465
1466 convolution_backward_data(const primitive_desc &pd): primitive(pd) {}
1467};
1468
1469struct convolution_backward_weights : public primitive {
1470 struct desc {
1471 mkldnn_convolution_desc_t data;
1472 desc(algorithm aalgorithm,
1473 const memory::desc &src_desc,
1474 const memory::desc &diff_weights_desc,
1475 const memory::desc &diff_bias_desc,
1476 const memory::desc &diff_dst_desc,
1477 const memory::dims strides,
1478 const memory::dims padding_l,
1479 const memory::dims padding_r,
1480 const padding_kind apadding_kind) {
1481 memory::validate_dims(strides);
1482 memory::validate_dims(padding_l);
1483 memory::validate_dims(padding_r);
1484 error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1485 &data, convert_to_c(aalgorithm), &src_desc.data,
1486 &diff_weights_desc.data, &diff_bias_desc.data,
1487 &diff_dst_desc.data,
1488 &strides[0], &padding_l[0], &padding_r[0],
1489 mkldnn::convert_to_c(apadding_kind)),
1490 "could not create a convolution backward weights descriptor");
1491 }
1492 desc(algorithm aalgorithm,
1493 const memory::desc &src_desc,
1494 const memory::desc &diff_weights_desc,
1495 const memory::desc &diff_dst_desc,
1496 const memory::dims strides,
1497 const memory::dims padding_l,
1498 const memory::dims padding_r,
1499 const padding_kind apadding_kind) {
1500 memory::validate_dims(strides);
1501 memory::validate_dims(padding_l);
1502 memory::validate_dims(padding_r);
1503 error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
1504 &data, convert_to_c(aalgorithm), &src_desc.data,
1505 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1506 &strides[0], &padding_l[0], &padding_r[0],
1507 mkldnn::convert_to_c(apadding_kind)),
1508 "could not create a convolution backward weights descriptor");
1509 }
1510 desc(algorithm aalgorithm,
1511 const memory::desc &src_desc,
1512 const memory::desc &diff_weights_desc,
1513 const memory::desc &diff_bias_desc,
1514 const memory::desc &diff_dst_desc,
1515 const memory::dims strides,
1516 const memory::dims dilates,
1517 const memory::dims padding_l,
1518 const memory::dims padding_r,
1519 const padding_kind apadding_kind) {
1520 memory::validate_dims(strides);
1521 memory::validate_dims(dilates);
1522 memory::validate_dims(padding_l);
1523 memory::validate_dims(padding_r);
1524 error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1525 &data, convert_to_c(aalgorithm), &src_desc.data,
1526 &diff_weights_desc.data, &diff_bias_desc.data,
1527 &diff_dst_desc.data,
1528 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1529 mkldnn::convert_to_c(apadding_kind)),
1530 "could not create a convolution backward weights descriptor");
1531 }
1532 desc(algorithm aalgorithm,
1533 const memory::desc &src_desc,
1534 const memory::desc &diff_weights_desc,
1535 const memory::desc &diff_dst_desc,
1536 const memory::dims strides,
1537 const memory::dims dilates,
1538 const memory::dims padding_l,
1539 const memory::dims padding_r,
1540 const padding_kind apadding_kind) {
1541 memory::validate_dims(strides);
1542 memory::validate_dims(dilates);
1543 memory::validate_dims(padding_l);
1544 memory::validate_dims(padding_r);
1545 error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
1546 &data, convert_to_c(aalgorithm), &src_desc.data,
1547 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1548 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1549 mkldnn::convert_to_c(apadding_kind)),
1550 "could not create a convolution backward weights descriptor");
1551 }
1552
1553 };
1554
1555 struct primitive_desc : public mkldnn::primitive_desc {
1556 primitive_desc(const desc &desc, const engine &e,
1557 const convolution_forward::primitive_desc &hint_fwd_pd)
1558 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1559
1560 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1561 const convolution_forward::primitive_desc &hint_fwd_pd)
1562 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1563
1564 REG_QUERY_MD(src, src, 0);
1565 REG_QUERY_MD(diff_weights, diff_weights, 0);
1566 REG_QUERY_MD(diff_bias, diff_weights, 1);
1567 REG_QUERY_MD(diff_dst, diff_dst, 0);
1568 REG_QUERY_MD(scratchpad, scratchpad, 0);
1569 };
1570
1571 convolution_backward_weights(const primitive_desc &pd): primitive(pd) {}
1572};
1573
1574/// @}
1575//
1576/// @addtogroup cpp_api_deconvolution Deconvolution
1577/// A primitive to compute deconvolution using different algorithms.
1578///
1579/// @sa @ref c_api_deconvolution in @ref c_api
1580/// @{
1581
1582struct deconvolution_forward: public primitive {
1583 struct desc {
1584 mkldnn_deconvolution_desc_t data;
1585 desc(prop_kind aprop_kind, algorithm aalgorithm,
1586 const memory::desc &src_desc,
1587 const memory::desc &weights_desc,
1588 const memory::desc &bias_desc,
1589 const memory::desc &dst_desc,
1590 const memory::dims strides,
1591 const memory::dims padding_l,
1592 const memory::dims padding_r,
1593 const padding_kind apadding_kind) {
1594 memory::validate_dims(strides);
1595 memory::validate_dims(padding_l);
1596 memory::validate_dims(padding_r);
1597 error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1598 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1599 &src_desc.data, &weights_desc.data, &bias_desc.data,
1600 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1601 mkldnn::convert_to_c(apadding_kind)),
1602 "could not create a deconvolution forward descriptor");
1603 }
1604 desc(prop_kind aprop_kind, algorithm aalgorithm,
1605 const memory::desc &src_desc,
1606 const memory::desc &weights_desc,
1607 const memory::desc &dst_desc,
1608 const memory::dims strides,
1609 const memory::dims padding_l,
1610 const memory::dims padding_r,
1611 const padding_kind apadding_kind) {
1612 memory::validate_dims(strides);
1613 memory::validate_dims(padding_l);
1614 memory::validate_dims(padding_r);
1615 error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
1616 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1617 &src_desc.data, &weights_desc.data, nullptr,
1618 &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1619 mkldnn::convert_to_c(apadding_kind)),
1620 "could not create a deconvolution forward descriptor");
1621 }
1622 desc(prop_kind aprop_kind, algorithm aalgorithm,
1623 const memory::desc &src_desc,
1624 const memory::desc &weights_desc,
1625 const memory::desc &bias_desc,
1626 const memory::desc &dst_desc,
1627 const memory::dims strides,
1628 const memory::dims dilates,
1629 const memory::dims padding_l,
1630 const memory::dims padding_r,
1631 const padding_kind apadding_kind) {
1632 memory::validate_dims(strides);
1633 memory::validate_dims(dilates);
1634 memory::validate_dims(padding_l);
1635 memory::validate_dims(padding_r);
1636 error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1637 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1638 &src_desc.data, &weights_desc.data, &bias_desc.data,
1639 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1640 &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1641 "could not create a dilated deconvolution forward descriptor");
1642 }
1643 desc(prop_kind aprop_kind, algorithm aalgorithm,
1644 const memory::desc &src_desc,
1645 const memory::desc &weights_desc,
1646 const memory::desc &dst_desc,
1647 const memory::dims strides,
1648 const memory::dims dilates,
1649 const memory::dims padding_l,
1650 const memory::dims padding_r,
1651 const padding_kind apadding_kind) {
1652 memory::validate_dims(strides);
1653 memory::validate_dims(dilates);
1654 memory::validate_dims(padding_l);
1655 memory::validate_dims(padding_r);
1656 error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
1657 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1658 &src_desc.data, &weights_desc.data, nullptr,
1659 &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1660 &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1661 "could not create a dilated deconvolution forward descriptor");
1662 }
1663 };
1664
1665 struct primitive_desc : public mkldnn::primitive_desc {
1666 primitive_desc(const desc &desc, const engine &e)
1667 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1668
1669 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1670 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1671
1672 REG_QUERY_MD(src, src, 0);
1673 REG_QUERY_MD(weights, weights, 0);
1674 REG_QUERY_MD(bias, weights, 1);
1675 REG_QUERY_MD(dst, dst, 0);
1676 REG_QUERY_MD(scratchpad, scratchpad, 0);
1677 };
1678
1679 deconvolution_forward(const primitive_desc &pd): primitive(pd) {}
1680};
1681
1682struct deconvolution_backward_data : public primitive {
1683 struct desc {
1684 mkldnn_deconvolution_desc_t data;
1685 desc(algorithm aalgorithm,
1686 const memory::desc &diff_src_desc,
1687 const memory::desc &weights_desc,
1688 const memory::desc &diff_dst_desc,
1689 const memory::dims strides,
1690 const memory::dims padding_l,
1691 const memory::dims padding_r,
1692 const padding_kind apadding_kind) {
1693 memory::validate_dims(strides);
1694 memory::validate_dims(padding_l);
1695 memory::validate_dims(padding_r);
1696 error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init(
1697 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1698 &weights_desc.data, &diff_dst_desc.data,
1699 &strides[0], &padding_l[0], &padding_r[0],
1700 mkldnn::convert_to_c(apadding_kind)),
1701 "could not create a deconvolution backward data descriptor");
1702 }
1703 desc(algorithm aalgorithm,
1704 const memory::desc &diff_src_desc,
1705 const memory::desc &weights_desc,
1706 const memory::desc &diff_dst_desc,
1707 const memory::dims strides,
1708 const memory::dims dilates,
1709 const memory::dims padding_l,
1710 const memory::dims padding_r,
1711 const padding_kind apadding_kind) {
1712 memory::validate_dims(strides);
1713 memory::validate_dims(dilates);
1714 memory::validate_dims(padding_l);
1715 memory::validate_dims(padding_r);
1716 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init(
1717 &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1718 &weights_desc.data, &diff_dst_desc.data,
1719 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1720 mkldnn::convert_to_c(apadding_kind)),
1721 "could not create a dilated deconvolution backward data descriptor");
1722 }
1723 };
1724
1725 struct primitive_desc : public mkldnn::primitive_desc {
1726 primitive_desc(const desc &desc, const engine &e,
1727 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1728 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1729
1730 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1731 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1732 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1733
1734 REG_QUERY_MD(diff_src, diff_src, 0);
1735 REG_QUERY_MD(weights, weights, 0);
1736 REG_QUERY_MD(diff_dst, diff_dst, 0);
1737 REG_QUERY_MD(scratchpad, scratchpad, 0);
1738 };
1739
1740 deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {}
1741};
1742
1743struct deconvolution_backward_weights : public primitive {
1744 struct desc {
1745 mkldnn_deconvolution_desc_t data;
1746 desc(algorithm aalgorithm,
1747 const memory::desc &src_desc,
1748 const memory::desc &diff_weights_desc,
1749 const memory::desc &diff_bias_desc,
1750 const memory::desc &diff_dst_desc,
1751 const memory::dims strides,
1752 const memory::dims padding_l,
1753 const memory::dims padding_r,
1754 const padding_kind apadding_kind) {
1755 memory::validate_dims(strides);
1756 memory::validate_dims(padding_l);
1757 memory::validate_dims(padding_r);
1758 error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
1759 &data, convert_to_c(aalgorithm), &src_desc.data,
1760 &diff_weights_desc.data, &diff_bias_desc.data,
1761 &diff_dst_desc.data,
1762 &strides[0], &padding_l[0], &padding_r[0],
1763 mkldnn::convert_to_c(apadding_kind)),
1764 "could not create a deconvolution backward weights descriptor");
1765 }
1766 desc(algorithm aalgorithm,
1767 const memory::desc &src_desc,
1768 const memory::desc &diff_weights_desc,
1769 const memory::desc &diff_dst_desc,
1770 const memory::dims strides,
1771 const memory::dims padding_l,
1772 const memory::dims padding_r,
1773 const padding_kind apadding_kind) {
1774 memory::validate_dims(strides);
1775 memory::validate_dims(padding_l);
1776 memory::validate_dims(padding_r);
1777 error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
1778 &data, convert_to_c(aalgorithm), &src_desc.data,
1779 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1780 &strides[0], &padding_l[0], &padding_r[0],
1781 mkldnn::convert_to_c(apadding_kind)),
1782 "could not create a deconvolution backward weights descriptor");
1783 }
1784 desc(algorithm aalgorithm,
1785 const memory::desc &src_desc,
1786 const memory::desc &diff_weights_desc,
1787 const memory::desc &diff_bias_desc,
1788 const memory::desc &diff_dst_desc,
1789 const memory::dims strides,
1790 const memory::dims dilates,
1791 const memory::dims padding_l,
1792 const memory::dims padding_r,
1793 const padding_kind apadding_kind) {
1794 memory::validate_dims(strides);
1795 memory::validate_dims(dilates);
1796 memory::validate_dims(padding_l);
1797 memory::validate_dims(padding_r);
1798 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
1799 &data, convert_to_c(aalgorithm), &src_desc.data,
1800 &diff_weights_desc.data, &diff_bias_desc.data,
1801 &diff_dst_desc.data,
1802 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1803 mkldnn::convert_to_c(apadding_kind)),
1804 "could not create a dilated deconvolution backward weights descriptor");
1805 }
1806 desc(algorithm aalgorithm,
1807 const memory::desc &src_desc,
1808 const memory::desc &diff_weights_desc,
1809 const memory::desc &diff_dst_desc,
1810 const memory::dims strides,
1811 const memory::dims dilates,
1812 const memory::dims padding_l,
1813 const memory::dims padding_r,
1814 const padding_kind apadding_kind) {
1815 memory::validate_dims(strides);
1816 memory::validate_dims(dilates);
1817 memory::validate_dims(padding_l);
1818 memory::validate_dims(padding_r);
1819 error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
1820 &data, convert_to_c(aalgorithm), &src_desc.data,
1821 &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1822 &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1823 mkldnn::convert_to_c(apadding_kind)),
1824 "could not create a dilated deconvolution backward weights descriptor");
1825 }
1826 };
1827
1828 struct primitive_desc : public mkldnn::primitive_desc {
1829 primitive_desc(const desc &desc, const engine &e,
1830 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1831 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1832
1833 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1834 const deconvolution_forward::primitive_desc &hint_fwd_pd)
1835 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1836
1837 REG_QUERY_MD(src, src, 0);
1838 REG_QUERY_MD(diff_weights, diff_weights, 0);
1839 REG_QUERY_MD(diff_bias, diff_weights, 1);
1840 REG_QUERY_MD(diff_dst, diff_dst, 0);
1841 REG_QUERY_MD(scratchpad, scratchpad, 0);
1842 };
1843
1844 deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {}
1845};
1846
1847/// @}
1848
1849/// @addtogroup cpp_api_lrn LRN
1850/// A primitive to perform local response normalization (LRN) across or within
1851/// channels.
1852///
1853/// @sa @ref c_api_lrn in @ref c_api
1854/// @{
1855
1856struct lrn_forward : public primitive {
1857 struct desc {
1858 mkldnn_lrn_desc_t data;
1859
1860 desc(prop_kind aprop_kind, algorithm aalgorithm,
1861 const memory::desc &src_desc, memory::dim local_size,
1862 float alpha, float beta, float k = 1.f) {
1863 error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
1864 mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1865 &src_desc.data, local_size, alpha, beta, k),
1866 "could not create a lrn forward descriptor");
1867 }
1868 };
1869
1870 struct primitive_desc : public mkldnn::primitive_desc {
1871 primitive_desc(const desc &desc, const engine &e)
1872 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1873
1874 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1875 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1876
1877 REG_QUERY_MD(src, src, 0);
1878 REG_QUERY_MD(dst, dst, 0);
1879 REG_QUERY_MD(workspace, workspace, 0);
1880 REG_QUERY_MD(scratchpad, scratchpad, 0);
1881 };
1882
1883 lrn_forward(const primitive_desc &pd): primitive(pd) {}
1884};
1885
1886struct lrn_backward : public primitive {
1887 struct desc {
1888 mkldnn_lrn_desc_t data;
1889
1890 desc(algorithm aalgorithm, const memory::desc &data_desc,
1891 const memory::desc &diff_data_desc, memory::dim local_size,
1892 float alpha, float beta, float k = 1.f) {
1893 error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
1894 convert_to_c(aalgorithm), &diff_data_desc.data,
1895 &data_desc.data, local_size, alpha, beta, k),
1896 "could not create a lrn backward descriptor");
1897 }
1898 };
1899
1900 struct primitive_desc : public mkldnn::primitive_desc {
1901 primitive_desc(const desc &desc, const engine &e,
1902 const lrn_forward::primitive_desc &hint_fwd_pd)
1903 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1904
1905 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1906 const lrn_forward::primitive_desc &hint_fwd_pd)
1907 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1908
1909 REG_QUERY_MD(diff_src, diff_src, 0);
1910 REG_QUERY_MD(diff_dst, diff_dst, 0);
1911 REG_QUERY_MD(workspace, workspace, 0);
1912 REG_QUERY_MD(scratchpad, scratchpad, 0);
1913 };
1914
1915 lrn_backward(const primitive_desc &pd): primitive(pd) {}
1916};
1917
1918/// @}
1919
1920/// @addtogroup cpp_api_pooling Pooling
1921/// A primitive to perform max or average pooling.
1922///
1923/// @sa @ref c_api_pooling in @ref c_api
1924/// @{
1925
1926struct pooling_forward : public primitive {
1927 struct desc {
1928 mkldnn_pooling_desc_t data;
1929 desc(prop_kind aprop_kind, algorithm aalgorithm,
1930 const memory::desc &src_desc,
1931 const memory::desc &dst_desc,
1932 const memory::dims strides,
1933 const memory::dims kernel,
1934 const memory::dims padding_l,
1935 const memory::dims padding_r,
1936 const padding_kind apadding_kind) {
1937 memory::validate_dims(strides);
1938 memory::validate_dims(kernel);
1939 memory::validate_dims(padding_l);
1940 memory::validate_dims(padding_r);
1941 error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data,
1942 mkldnn::convert_to_c(aprop_kind),
1943 convert_to_c(aalgorithm),
1944 &src_desc.data, &dst_desc.data,
1945 &strides[0], &kernel[0],
1946 &padding_l[0], &padding_r[0],
1947 mkldnn::convert_to_c(apadding_kind)),
1948 "could not init a forward pooling descriptor");
1949 }
1950 };
1951
1952 struct primitive_desc : public mkldnn::primitive_desc {
1953 primitive_desc(const desc &desc, const engine &e)
1954 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1955
1956 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1957 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1958
1959 REG_QUERY_MD(src, src, 0);
1960 REG_QUERY_MD(dst, dst, 0);
1961 REG_QUERY_MD(workspace, workspace, 0);
1962 REG_QUERY_MD(scratchpad, scratchpad, 0);
1963 };
1964
1965 pooling_forward(const primitive_desc &pd): primitive(pd) {}
1966};
1967
1968struct pooling_backward : public primitive {
1969 struct desc {
1970 mkldnn_pooling_desc_t data;
1971 desc(algorithm aalgorithm,
1972 const memory::desc &diff_src_desc,
1973 const memory::desc &diff_dst_desc,
1974 const memory::dims &strides,
1975 const memory::dims &kernel,
1976 const memory::dims &padding_l,
1977 const memory::dims &padding_r,
1978 const padding_kind apadding_kind) {
1979 memory::validate_dims(strides);
1980 memory::validate_dims(kernel);
1981 memory::validate_dims(padding_l);
1982 memory::validate_dims(padding_r);
1983 error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data,
1984 convert_to_c(aalgorithm),
1985 &diff_src_desc.data, &diff_dst_desc.data,
1986 &strides[0], &kernel[0],
1987 &padding_l[0], &padding_r[0],
1988 mkldnn::convert_to_c(apadding_kind)),
1989 "could not init a backward pooling descriptor");
1990 }
1991 };
1992
1993 struct primitive_desc : public mkldnn::primitive_desc {
1994 primitive_desc(const desc &desc, const engine &e,
1995 const pooling_forward::primitive_desc &hint_fwd_pd)
1996 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1997
1998 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1999 const pooling_forward::primitive_desc &hint_fwd_pd)
2000 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2001
2002 REG_QUERY_MD(diff_src, diff_src, 0);
2003 REG_QUERY_MD(diff_dst, diff_dst, 0);
2004 REG_QUERY_MD(workspace, workspace, 0);
2005 REG_QUERY_MD(scratchpad, scratchpad, 0);
2006 };
2007
2008 pooling_backward(const primitive_desc &pd): primitive(pd) {}
2009};
2010
2011/// @}
2012
2013/// @addtogroup cpp_api_eltwise Eltwise
2014/// A primitive to compute element-wise operations like parametric rectifier
2015/// linear unit (ReLU).
2016///
2017/// @sa @ref c_api_eltwise in @ref c_api
2018/// @{
2019
2020struct eltwise_forward : public primitive {
2021 struct desc {
2022 mkldnn_eltwise_desc_t data;
2023 template <typename T>
2024 desc(prop_kind aprop_kind, algorithm alg_kind,
2025 const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2026 error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data,
2027 mkldnn::convert_to_c(aprop_kind),
2028 mkldnn::convert_to_c(alg_kind), &src_desc.data,
2029 static_cast<float>(alpha), static_cast<float>(beta)),
2030 "could not create a eltwise forward descriptor");
2031 }
2032 };
2033
2034 struct primitive_desc : public mkldnn::primitive_desc {
2035 primitive_desc(const desc &desc, const engine &e)
2036 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2037
2038 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2039 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2040
2041 REG_QUERY_MD(src, src, 0);
2042 REG_QUERY_MD(dst, dst, 0);
2043 REG_QUERY_MD(scratchpad, scratchpad, 0);
2044 };
2045
2046 eltwise_forward(const primitive_desc &pd): primitive(pd) {}
2047};
2048
2049struct eltwise_backward : public primitive {
2050 struct desc {
2051 mkldnn_eltwise_desc_t data;
2052
2053 template <typename T>
2054 desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2055 const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2056 error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data,
2057 mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2058 &data_desc.data, static_cast<float>(alpha),
2059 static_cast<float>(beta)),
2060 "could not create a eltwise backward descriptor");
2061 }
2062 };
2063
2064 struct primitive_desc : public mkldnn::primitive_desc {
2065 primitive_desc(const desc &desc, const engine &e,
2066 const eltwise_forward::primitive_desc &hint_fwd_pd)
2067 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2068
2069 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2070 const eltwise_forward::primitive_desc &hint_fwd_pd)
2071 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2072
2073 REG_QUERY_MD(src, src, 0);
2074 REG_QUERY_MD(diff_src, diff_src, 0);
2075 REG_QUERY_MD(diff_dst, diff_dst, 0);
2076 REG_QUERY_MD(scratchpad, scratchpad, 0);
2077 };
2078
2079 eltwise_backward(const primitive_desc &pd): primitive(pd) {}
2080};
2081
2082/// @}
2083
2084/// @addtogroup cpp_api_softmax Softmax
2085/// A primitive to perform softmax.
2086///
2087/// @sa @ref c_api_softmax in @ref c_api
2088/// @{
2089
2090struct softmax_forward : public primitive {
2091 struct desc {
2092 mkldnn_softmax_desc_t data;
2093 desc(prop_kind aprop_kind, const memory::desc &data_desc,
2094 int softmax_axis) {
2095 error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data,
2096 mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2097 softmax_axis),
2098 "could not create a softmax forward descriptor");
2099 }
2100 };
2101
2102 struct primitive_desc : public mkldnn::primitive_desc {
2103 primitive_desc(const desc &desc, const engine &e)
2104 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2105
2106 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2107 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2108
2109 REG_QUERY_MD(src, src, 0);
2110 REG_QUERY_MD(dst, dst, 0);
2111 REG_QUERY_MD(scratchpad, scratchpad, 0);
2112 };
2113
2114 softmax_forward(const primitive_desc &pd): primitive(pd) {}
2115};
2116
2117struct softmax_backward : public primitive {
2118 struct desc {
2119 mkldnn_softmax_desc_t data;
2120 desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2121 int softmax_axis) {
2122 error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data,
2123 &diff_desc.data, &data_desc.data, softmax_axis),
2124 "could not init a backward softmax descriptor");
2125 }
2126 };
2127
2128 struct primitive_desc : public mkldnn::primitive_desc {
2129 primitive_desc(const desc &desc, const engine &e,
2130 const softmax_forward::primitive_desc &hint_fwd_pd)
2131 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2132
2133 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2134 const softmax_forward::primitive_desc &hint_fwd_pd)
2135 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2136
2137 REG_QUERY_MD(dst, dst, 0);
2138 REG_QUERY_MD(diff_src, diff_src, 0);
2139 REG_QUERY_MD(diff_dst, diff_dst, 0);
2140 REG_QUERY_MD(workspace, workspace, 0);
2141 REG_QUERY_MD(scratchpad, scratchpad, 0);
2142 };
2143
2144 softmax_backward(const primitive_desc &pd): primitive(pd) {}
2145};
2146
2147/// @}
2148
2149/// @addtogroup cpp_api_batch_norm Batch normalization
2150/// A primitive to perform batch normalization.
2151///
2152/// @sa @ref c_api_batch_normalization in @ref c_api
2153/// @{
2154
2155struct batch_normalization_forward : public primitive {
2156 struct desc {
2157 mkldnn_batch_normalization_desc_t data;
2158 template <typename T>
2159 desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2160 unsigned flags) {
2161 error::wrap_c_api(
2162 mkldnn_batch_normalization_forward_desc_init(&data,
2163 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2164 static_cast<float>(epsilon), flags),
2165 "could not create a batch normalization forward descriptor");
2166 }
2167 };
2168
2169 struct primitive_desc : public mkldnn::primitive_desc {
2170 primitive_desc(const desc &desc, const engine &e)
2171 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2172
2173 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2174 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2175
2176 REG_QUERY_MD(src, src, 0);
2177 REG_QUERY_MD(weights, weights, 0);
2178 REG_QUERY_MD(dst, dst, 0);
2179 REG_QUERY_MD(workspace, workspace, 0);
2180 REG_QUERY_MD(scratchpad, scratchpad, 0);
2181
2182 memory::desc mean_desc() const { return stat_desc(mean); }
2183 memory::desc variance_desc() const { return stat_desc(var); }
2184
2185 private:
2186 enum { mean = 1, var = 2, };
2187 memory::desc stat_desc(int kind) const {
2188 mkldnn_batch_normalization_desc_t *p;
2189 error::wrap_c_api(mkldnn_primitive_desc_query(
2190 get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
2191 "could not get a batch-normalization descriptor");
2192 return query_md(p->flags & use_global_stats ? src_md : dst_md, kind);
2193 }
2194 };
2195
2196 batch_normalization_forward(const primitive_desc &pd): primitive(pd) {}
2197};
2198
2199struct batch_normalization_backward : public primitive {
2200 struct desc {
2201 mkldnn_batch_normalization_desc_t data;
2202 template <typename T>
2203 desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2204 const memory::desc &data_desc, T epsilon, unsigned flags) {
2205 error::wrap_c_api(
2206 mkldnn_batch_normalization_backward_desc_init(&data,
2207 mkldnn::convert_to_c(aprop_kind),
2208 &diff_data_desc.data, &data_desc.data,
2209 static_cast<float>(epsilon), flags),
2210 "could not create a batch normalization backward descriptor");
2211 }
2212 };
2213
2214 struct primitive_desc : public mkldnn::primitive_desc {
2215 primitive_desc(const desc &desc, const engine &e,
2216 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2217 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2218
2219 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2220 const batch_normalization_forward::primitive_desc &hint_fwd_pd)
2221 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2222
2223 REG_QUERY_MD(src, src, 0);
2224 REG_QUERY_MD(mean, src, 1);
2225 REG_QUERY_MD(variance, src, 2);
2226 REG_QUERY_MD(weights, weights, 0);
2227 REG_QUERY_MD(dst, dst, 0);
2228 REG_QUERY_MD(diff_dst, diff_dst, 0);
2229 REG_QUERY_MD(workspace, workspace, 0);
2230
2231 REG_QUERY_MD(diff_src, diff_src, 0);
2232 REG_QUERY_MD(diff_weights, diff_weights, 0);
2233 REG_QUERY_MD(scratchpad, scratchpad, 0);
2234 };
2235
2236 batch_normalization_backward(const primitive_desc &pd): primitive(pd) {}
2237};
2238
2239/// @}
2240
2241/// @addtogroup cpp_api_inner_product Inner Product
2242/// A primitive to compute an inner product.
2243///
2244/// @sa @ref c_api_inner_product in @ref c_api
2245/// @{
2246
2247struct inner_product_forward: public primitive {
2248 struct desc {
2249 mkldnn_inner_product_desc_t data;
2250 desc(prop_kind aprop_kind, const memory::desc &src_desc,
2251 const memory::desc &weights_desc,
2252 const memory::desc &bias_desc,
2253 const memory::desc &dst_desc) {
2254 error::wrap_c_api(
2255 mkldnn_inner_product_forward_desc_init(&data,
2256 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2257 &weights_desc.data, &bias_desc.data, &dst_desc.data),
2258 "could not create a inner product forward descriptor");
2259 }
2260
2261 desc(prop_kind aprop_kind, const memory::desc &src_desc,
2262 const memory::desc &weights_desc,
2263 const memory::desc &dst_desc) {
2264 error::wrap_c_api(
2265 mkldnn_inner_product_forward_desc_init(&data,
2266 mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2267 &weights_desc.data, nullptr, &dst_desc.data),
2268 "could not create a inner product forward descriptor");
2269 }
2270 };
2271
2272 struct primitive_desc : public mkldnn::primitive_desc {
2273 primitive_desc(const desc &desc, const engine &e)
2274 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2275
2276 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2277 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2278
2279 REG_QUERY_MD(src, src, 0);
2280 REG_QUERY_MD(weights, weights, 0);
2281 REG_QUERY_MD(bias, weights, 1);
2282 REG_QUERY_MD(dst, dst, 0);
2283 REG_QUERY_MD(scratchpad, scratchpad, 0);
2284 };
2285
2286 inner_product_forward(const primitive_desc &pd): primitive(pd) {}
2287};
2288
2289struct inner_product_backward_data: public primitive {
2290 struct desc {
2291 mkldnn_inner_product_desc_t data;
2292 desc(const memory::desc &diff_src_desc,
2293 const memory::desc &weights_desc,
2294 const memory::desc &diff_dst_desc) {
2295 error::wrap_c_api(
2296 mkldnn_inner_product_backward_data_desc_init(&data,
2297 &diff_src_desc.data, &weights_desc.data,
2298 &diff_dst_desc.data),
2299 "could not create a inner product backward data descriptor");
2300 }
2301 };
2302
2303 struct primitive_desc : public mkldnn::primitive_desc {
2304 primitive_desc(const desc &desc, const engine &e,
2305 const inner_product_forward::primitive_desc &hint_fwd_pd)
2306 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2307
2308 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2309 const inner_product_forward::primitive_desc &hint_fwd_pd)
2310 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2311
2312 REG_QUERY_MD(diff_src, diff_src, 0);
2313 REG_QUERY_MD(weights, weights, 0);
2314 REG_QUERY_MD(diff_dst, diff_dst, 0);
2315 REG_QUERY_MD(scratchpad, scratchpad, 0);
2316 };
2317
2318 inner_product_backward_data(const primitive_desc &pd): primitive(pd) {}
2319};
2320
2321struct inner_product_backward_weights: public primitive {
2322 struct desc {
2323 mkldnn_inner_product_desc_t data;
2324 desc(const memory::desc &src_desc,
2325 const memory::desc &diff_weights_desc,
2326 const memory::desc &diff_bias_desc,
2327 const memory::desc &diff_dst_desc) {
2328 error::wrap_c_api(
2329 mkldnn_inner_product_backward_weights_desc_init(
2330 &data, &src_desc.data, &diff_weights_desc.data,
2331 &diff_bias_desc.data, &diff_dst_desc.data),
2332 "could not create a inner product backward weights descriptor");
2333 }
2334 desc(const memory::desc &src_desc,
2335 const memory::desc &diff_weights_desc,
2336 const memory::desc &diff_dst_desc) {
2337 error::wrap_c_api(
2338 mkldnn_inner_product_backward_weights_desc_init(
2339 &data, &src_desc.data, &diff_weights_desc.data,
2340 nullptr, &diff_dst_desc.data),
2341 "could not create a inner product backward weights descriptor");
2342 }
2343 };
2344
2345 struct primitive_desc : public mkldnn::primitive_desc {
2346 primitive_desc(const desc &desc, const engine &e,
2347 const inner_product_forward::primitive_desc &hint_fwd_pd)
2348 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2349
2350 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2351 const inner_product_forward::primitive_desc &hint_fwd_pd)
2352 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2353
2354 REG_QUERY_MD(src, src, 0);
2355 REG_QUERY_MD(diff_weights, diff_weights, 0);
2356 REG_QUERY_MD(diff_bias, diff_weights, 1);
2357 REG_QUERY_MD(diff_dst, diff_dst, 0);
2358 REG_QUERY_MD(scratchpad, scratchpad, 0);
2359 };
2360
2361 inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {}
2362};
2363
2364/// @}
2365
2366/// @addtogroup cpp_api_rnn RNN
2367/// A primitive to compute common recurrent layer.
2368///
2369/// @sa @ref c_api_rnn in @ref c_api
2370/// @{
2371
2372struct rnn_cell {
2373 struct desc {
2374 mkldnn_rnn_cell_desc_t c_rnn_cell_;
2375
2376 desc(algorithm kind, algorithm activation_f) {
2377 error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_,
2378 mkldnn::convert_to_c(kind),
2379 mkldnn::convert_to_c(activation_f), 0U, 0, 0),
2380 "could not init an rnn cell descriptor");
2381 }
2382 desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {}
2383
2384 operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
2385
2386 algorithm get_cell_kind() const
2387 { return algorithm(c_rnn_cell_.cell_kind); }
2388 algorithm get_activation() const
2389 { return algorithm(c_rnn_cell_.activation_kind); }
2390
2391 float get_alpha() const { return c_rnn_cell_.alpha; }
2392 void set_alpha(float alpha) {
2393 c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
2394 c_rnn_cell_.alpha = alpha;
2395 }
2396
2397 float get_clipping() const { return c_rnn_cell_.clipping; }
2398 void set_clipping(float clipping) {
2399 c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
2400 c_rnn_cell_.clipping = clipping;
2401 }
2402
2403 int get_gates_count() const {
2404 return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
2405 }
2406 int get_state_count() const {
2407 return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
2408 }
2409 };
2410};
2411
2412struct rnn_forward : public primitive {
2413 struct desc {
2414 mkldnn_rnn_desc_t data;
2415 desc(prop_kind aprop_kind, rnn_cell::desc cell,
2416 const rnn_direction direction,
2417 const memory::desc &src_layer_desc,
2418 const memory::desc &src_iter_desc,
2419 const memory::desc &weights_layer_desc,
2420 const memory::desc &weights_iter_desc,
2421 const memory::desc &bias_desc,
2422 const memory::desc &dst_layer_desc,
2423 const memory::desc &dst_iter_desc
2424 ) {
2425 error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data,
2426 mkldnn::convert_to_c(aprop_kind), cell,
2427 mkldnn::convert_to_c(direction),
2428 &src_layer_desc.data, &src_iter_desc.data,
2429 &weights_layer_desc.data, &weights_iter_desc.data,
2430 &bias_desc.data,
2431 &dst_layer_desc.data, &dst_iter_desc.data),
2432 "could not create an RNN forward descriptor");
2433 }
2434
2435 };
2436
2437 struct primitive_desc : public mkldnn::primitive_desc {
2438 primitive_desc(const desc &desc, const engine &e)
2439 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2440
2441 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2442 : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2443
2444 REG_QUERY_MD(src_layer, src, 0);
2445 REG_QUERY_MD(src_iter, src, 1);
2446 REG_QUERY_MD(weights_layer, weights, 0);
2447 REG_QUERY_MD(weights_iter, weights, 1);
2448 REG_QUERY_MD(bias, weights, 2);
2449 REG_QUERY_MD(dst_layer, dst, 0);
2450 REG_QUERY_MD(dst_iter, dst, 1);
2451 REG_QUERY_MD(workspace, workspace, 0);
2452 REG_QUERY_MD(scratchpad, scratchpad, 0);
2453 };
2454
2455 rnn_forward(const primitive_desc &pd): primitive(pd) {}
2456};
2457
2458struct rnn_backward : public primitive {
2459 struct desc {
2460 mkldnn_rnn_desc_t data;
2461 desc(prop_kind aprop_kind, rnn_cell::desc cell,
2462 const rnn_direction direction,
2463 const memory::desc &src_layer_desc,
2464 const memory::desc &src_iter_desc,
2465 const memory::desc &weights_layer_desc,
2466 const memory::desc &weights_iter_desc,
2467 const memory::desc &bias_desc,
2468 const memory::desc &dst_layer_desc,
2469 const memory::desc &dst_iter_desc,
2470 const memory::desc &diff_src_layer_desc,
2471 const memory::desc &diff_src_iter_desc,
2472 const memory::desc &diff_weights_layer_desc,
2473 const memory::desc &diff_weights_iter_desc,
2474 const memory::desc &diff_bias_desc,
2475 const memory::desc &diff_dst_layer_desc,
2476 const memory::desc &diff_dst_iter_desc) {
2477 error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data,
2478 mkldnn::convert_to_c(aprop_kind), cell,
2479 mkldnn::convert_to_c(direction),
2480 &src_layer_desc.data, &src_iter_desc.data,
2481 &weights_layer_desc.data, &weights_iter_desc.data,
2482 &bias_desc.data,
2483 &dst_layer_desc.data, &dst_iter_desc.data,
2484 &diff_src_layer_desc.data, &diff_src_iter_desc.data,
2485 &diff_weights_layer_desc.data,
2486 &diff_weights_iter_desc.data, &diff_bias_desc.data,
2487 &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
2488 "could not create an RNN backward descriptor");
2489 }
2490
2491 };
2492
2493 struct primitive_desc : public mkldnn::primitive_desc {
2494 primitive_desc(const desc &desc, const engine &e,
2495 const rnn_forward::primitive_desc &hint_fwd_pd)
2496 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2497
2498 primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2499 const rnn_forward::primitive_desc &hint_fwd_pd)
2500 : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2501
2502 REG_QUERY_MD(src_layer, src, 0);
2503 REG_QUERY_MD(src_iter, src, 1);
2504 REG_QUERY_MD(weights_layer, weights, 0);
2505 REG_QUERY_MD(weights_iter, weights, 1);
2506 REG_QUERY_MD(bias, weights, 2);
2507 REG_QUERY_MD(dst_layer, dst, 0);
2508 REG_QUERY_MD(dst_iter, dst, 1);
2509 REG_QUERY_MD(workspace, workspace, 0);
2510
2511 REG_QUERY_MD(diff_src_layer, diff_src, 0);
2512 REG_QUERY_MD(diff_src_iter, diff_src, 1);
2513 REG_QUERY_MD(diff_weights_layer, diff_weights, 0);
2514 REG_QUERY_MD(diff_weights_iter, diff_weights, 1);
2515 REG_QUERY_MD(diff_bias, diff_weights, 2);
2516 REG_QUERY_MD(diff_dst_layer, diff_dst, 0);
2517 REG_QUERY_MD(diff_dst_iter, diff_dst, 1);
2518 REG_QUERY_MD(scratchpad, scratchpad, 0);
2519 };
2520
2521 // With last iteration (with and without input src_iter)
2522 rnn_backward(const primitive_desc &pd): primitive(pd) {}
2523};
2524
2525/// @}
2526
2527/// @addtogroup cpp_api_shuffle Shuffle
2528/// A primitive to shuffle data along the axis.
2529///
2530/// @sa @ref c_api_shuffle in @ref c_api
2531/// @{
2532
2533struct shuffle_forward : public primitive {
2534 struct desc {
2535 mkldnn_shuffle_desc_t data;
2536 desc(prop_kind aprop_kind, const memory::desc &data_desc,
2537 int axis, int group_size) {
2538 error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data,
2539 mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2540 axis, group_size),
2541 "could not create a shuffle forward descriptor");
2542 }
2543 };
2544
2545 struct primitive_desc : public mkldnn::primitive_desc {
2546 primitive_desc(const desc &desc, const engine &e)
2547 : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2548
2549 REG_QUERY_MD(src, src, 0);
2550 REG_QUERY_MD(dst, dst, 0);
2551 REG_QUERY_MD(scratchpad, scratchpad, 0);
2552 };
2553
2554 shuffle_forward(const primitive_desc &pd): primitive(pd) {}
2555};
2556
2557struct shuffle_backward : public primitive {
2558 struct desc {
2559 mkldnn_shuffle_desc_t data;
2560 desc(const memory::desc &diff_data_desc, int axis, int group_size) {
2561 error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data,
2562 &diff_data_desc.data, axis, group_size),
2563 "could not create a shuffle backward descriptor");
2564 }
2565 };
2566
2567 struct primitive_desc : public mkldnn::primitive_desc {
2568 primitive_desc(const desc &desc, const engine &e,
2569 const shuffle_forward::primitive_desc &hint_fwd_pd)
2570 : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2571
2572 REG_QUERY_MD(diff_src, diff_src, 0);
2573 REG_QUERY_MD(diff_dst, diff_dst, 0);
2574 REG_QUERY_MD(scratchpad, scratchpad, 0);
2575 };
2576
2577 shuffle_backward(const primitive_desc &pd): primitive(pd) {}
2578};
2579
2580/// @}
2581
2582/// @} Primitives
2583
2584/// @} C++ API
2585
2586#undef REG_QUERY_MD
2587
2588// implementation section
2589#ifndef DOXYGEN_SHOULD_SKIP_THIS
2590
2591inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) {
2592 mkldnn_primitive_t result;
2593 error::wrap_c_api(mkldnn_primitive_create(&result, c_pd),
2594 "could not create a primitive");
2595 reset(result);
2596}
2597
2598inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {}
2599
2600inline void primitive::execute(stream &astream,
2601 const std::unordered_map<int, memory> &args) const {
2602 std::vector<mkldnn_exec_arg_t> c_args;
2603 c_args.reserve(args.size());
2604 for (const auto &a: args)
2605 c_args.push_back({a.first, a.second.get()});
2606
2607 error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(),
2608 (int)c_args.size(), c_args.data()),
2609 "primitive execution fail");
2610}
2611#endif // DOXYGEN_SHOULD_SKIP_THIS
2612
2613} // namespace mkldnn
2614
2615#endif
2616