1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef PRIMITIVE_DESC_HPP
18#define PRIMITIVE_DESC_HPP
19
20#include "mkldnn.h"
21
22#include "c_types_map.hpp"
23#include "memory_tracking.hpp"
24#include "nstl.hpp"
25#include "type_helpers.hpp"
26#include "primitive_attr.hpp"
27#include "verbose.hpp"
28
29struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
30 using md_t = mkldnn::impl::memory_desc_t;
31
32 mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
33 const mkldnn::impl::primitive_attr_t *attr,
34 mkldnn::impl::primitive_kind_t kind)
35 : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
36
37 mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
38 mkldnn::impl::primitive_kind_t kind)
39 : engine_(engine), kind_(kind) { info_[0] = '\0'; }
40
41 virtual mkldnn_primitive_desc *clone() const = 0;
42 virtual ~mkldnn_primitive_desc() {}
43
44 const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
45 mkldnn::impl::engine_t *engine() const { return engine_; }
46 mkldnn::impl::primitive_kind_t kind() const { return kind_; }
47
48 virtual void init_info() {}
49 const char *info() const { return info_; }
50
51 mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
52 { return scratchpad_registry_; }
53 const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
54 { return scratchpad_registry_; }
55 virtual mkldnn::impl::engine_t *scratchpad_engine() const
56 { return engine_; }
57
58 virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
59
60 enum class arg_usage_t { unused, input, output };
61 virtual arg_usage_t arg_usage(
62 mkldnn::impl::primitive_arg_index_t arg) const {
63 using mkldnn::impl::types::is_zero_md;
64 if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
65 return arg_usage_t::output;
66 return arg_usage_t::unused;
67 }
68
69# define DECLARE_MD_STUB(stub) \
70 virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
71 { return nullptr; }
72
73 DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
74 DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
75 DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
76 DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
77 DECLARE_MD_STUB(workspace_md);
78# undef DECLARE_MD_STUB
79
80 const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
81 return idx == 0 ? &scratchpad_md_ : nullptr;
82 }
83
84 virtual void init_scratchpad_md() {
85 auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
86 mkldnn::impl::dims_t dims = { size };
87 mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
88 mkldnn::impl::data_type::u8, mkldnn_x);
89 }
90
91 /** returns the scratchpad size for the given scratchpad mode. */
92 mkldnn::impl::dim_t scratchpad_size(
93 mkldnn::impl::scratchpad_mode_t mode) const {
94 if (mode != attr_.scratchpad_mode_) return 0;
95 return scratchpad_registry().size();
96 }
97
98 virtual int n_inputs() const { return 0; }
99 virtual int n_outputs() const { return 0; }
100
101 virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
102 void *result) const;
103
104 virtual mkldnn::impl::status_t create_primitive(
105 mkldnn::impl::primitive_t **primitive) const = 0;
106
107 virtual const char *name() const { return "mkldnn_primitive_desc"; }
108
109 /* static magic */
110
111 template<typename pd_t>
112 static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
113 const mkldnn::impl::op_desc_t *adesc,
114 const mkldnn::impl::primitive_attr_t *attr,
115 mkldnn::impl::engine_t *engine,
116 const mkldnn::impl::primitive_desc_t *hint_fwd) {
117 using namespace mkldnn::impl;
118 using namespace mkldnn::impl::status;
119 using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
120 if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
121 assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
122 auto hint =
123 reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
124 auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
125 if (_pd == nullptr) return out_of_memory;
126 if (_pd->init() != success) { delete _pd; return unimplemented; }
127 _pd->init_info();
128 _pd->init_scratchpad_md();
129 *pd = _pd;
130 return success;
131 }
132
133protected:
134 mkldnn::impl::engine_t *engine_;
135 mkldnn::impl::primitive_attr_t attr_;
136 mkldnn::impl::primitive_kind_t kind_;
137
138 mkldnn::impl::memory_desc_t scratchpad_md_;
139
140 char info_[MKLDNN_VERBOSE_BUF_LEN];
141
142 mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
143
144protected:
145 /** compares ws between fwd_pd and this (make sense to use for bwd_pd)
146 * Expectation: this already set workspace, and this workspace should
147 * exactly match the one from fwd_pd */
148 bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
149 using namespace mkldnn::impl;
150 if (!workspace_md()) return true; // the impl lives fine w/o workspace
151 return fwd_pd && fwd_pd->workspace_md()
152 && *fwd_pd->workspace_md() == *workspace_md();
153 }
154};
155
156#define DECLARE_COMMON_PD_t(impl_name, ...) \
157 virtual pd_t *clone() const override { return new pd_t(*this); } \
158 virtual status_t create_primitive(primitive_t **p) const override { \
159 double ms = get_msec(); \
160 auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
161 ms = get_msec() - ms; \
162 if (mkldnn_verbose()->level >= 2) { \
163 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
164 fflush(0); \
165 } \
166 return ret; \
167 } \
168 virtual const char *name() const override { return impl_name; }
169#define DECLARE_COMMON_PD_T(impl_name, ...) \
170 DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
171
172#endif
173
174// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
175