1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef PRIMITIVE_DESC_HPP |
18 | #define PRIMITIVE_DESC_HPP |
19 | |
20 | #include "mkldnn.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "memory_tracking.hpp" |
24 | #include "nstl.hpp" |
25 | #include "type_helpers.hpp" |
26 | #include "primitive_attr.hpp" |
27 | #include "verbose.hpp" |
28 | |
29 | struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { |
30 | using md_t = mkldnn::impl::memory_desc_t; |
31 | |
32 | mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, |
33 | const mkldnn::impl::primitive_attr_t *attr, |
34 | mkldnn::impl::primitive_kind_t kind) |
35 | : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } |
36 | |
37 | mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, |
38 | mkldnn::impl::primitive_kind_t kind) |
39 | : engine_(engine), kind_(kind) { info_[0] = '\0'; } |
40 | |
41 | virtual mkldnn_primitive_desc *clone() const = 0; |
42 | virtual ~mkldnn_primitive_desc() {} |
43 | |
44 | const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } |
45 | mkldnn::impl::engine_t *engine() const { return engine_; } |
46 | mkldnn::impl::primitive_kind_t kind() const { return kind_; } |
47 | |
48 | virtual void init_info() {} |
49 | const char *info() const { return info_; } |
50 | |
51 | mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() |
52 | { return scratchpad_registry_; } |
53 | const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const |
54 | { return scratchpad_registry_; } |
55 | virtual mkldnn::impl::engine_t *scratchpad_engine() const |
56 | { return engine_; } |
57 | |
58 | virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } |
59 | |
60 | enum class arg_usage_t { unused, input, output }; |
61 | virtual arg_usage_t arg_usage( |
62 | mkldnn::impl::primitive_arg_index_t arg) const { |
63 | using mkldnn::impl::types::is_zero_md; |
64 | if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) |
65 | return arg_usage_t::output; |
66 | return arg_usage_t::unused; |
67 | } |
68 | |
69 | # define DECLARE_MD_STUB(stub) \ |
70 | virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ |
71 | { return nullptr; } |
72 | |
73 | DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); |
74 | DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); |
75 | DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); |
76 | DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); |
77 | DECLARE_MD_STUB(workspace_md); |
78 | # undef DECLARE_MD_STUB |
79 | |
80 | const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { |
81 | return idx == 0 ? &scratchpad_md_ : nullptr; |
82 | } |
83 | |
84 | virtual void init_scratchpad_md() { |
85 | auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); |
86 | mkldnn::impl::dims_t dims = { size }; |
87 | mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, |
88 | mkldnn::impl::data_type::u8, mkldnn_x); |
89 | } |
90 | |
91 | /** returns the scratchpad size for the given scratchpad mode. */ |
92 | mkldnn::impl::dim_t scratchpad_size( |
93 | mkldnn::impl::scratchpad_mode_t mode) const { |
94 | if (mode != attr_.scratchpad_mode_) return 0; |
95 | return scratchpad_registry().size(); |
96 | } |
97 | |
98 | virtual int n_inputs() const { return 0; } |
99 | virtual int n_outputs() const { return 0; } |
100 | |
101 | virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, |
102 | void *result) const; |
103 | |
104 | virtual mkldnn::impl::status_t create_primitive( |
105 | mkldnn::impl::primitive_t **primitive) const = 0; |
106 | |
107 | virtual const char *name() const { return "mkldnn_primitive_desc" ; } |
108 | |
109 | /* static magic */ |
110 | |
111 | template<typename pd_t> |
112 | static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, |
113 | const mkldnn::impl::op_desc_t *adesc, |
114 | const mkldnn::impl::primitive_attr_t *attr, |
115 | mkldnn::impl::engine_t *engine, |
116 | const mkldnn::impl::primitive_desc_t *hint_fwd) { |
117 | using namespace mkldnn::impl; |
118 | using namespace mkldnn::impl::status; |
119 | using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type; |
120 | if (adesc->kind != pd_t::base_pkind) return invalid_arguments; |
121 | assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); |
122 | auto hint = |
123 | reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd); |
124 | auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); |
125 | if (_pd == nullptr) return out_of_memory; |
126 | if (_pd->init() != success) { delete _pd; return unimplemented; } |
127 | _pd->init_info(); |
128 | _pd->init_scratchpad_md(); |
129 | *pd = _pd; |
130 | return success; |
131 | } |
132 | |
133 | protected: |
134 | mkldnn::impl::engine_t *engine_; |
135 | mkldnn::impl::primitive_attr_t attr_; |
136 | mkldnn::impl::primitive_kind_t kind_; |
137 | |
138 | mkldnn::impl::memory_desc_t scratchpad_md_; |
139 | |
140 | char info_[MKLDNN_VERBOSE_BUF_LEN]; |
141 | |
142 | mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; |
143 | |
144 | protected: |
145 | /** compares ws between fwd_pd and this (make sense to use for bwd_pd) |
146 | * Expectation: this already set workspace, and this workspace should |
147 | * exactly match the one from fwd_pd */ |
148 | bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { |
149 | using namespace mkldnn::impl; |
150 | if (!workspace_md()) return true; // the impl lives fine w/o workspace |
151 | return fwd_pd && fwd_pd->workspace_md() |
152 | && *fwd_pd->workspace_md() == *workspace_md(); |
153 | } |
154 | }; |
155 | |
156 | #define DECLARE_COMMON_PD_t(impl_name, ...) \ |
157 | virtual pd_t *clone() const override { return new pd_t(*this); } \ |
158 | virtual status_t create_primitive(primitive_t **p) const override { \ |
159 | double ms = get_msec(); \ |
160 | auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ |
161 | ms = get_msec() - ms; \ |
162 | if (mkldnn_verbose()->level >= 2) { \ |
163 | printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ |
164 | fflush(0); \ |
165 | } \ |
166 | return ret; \ |
167 | } \ |
168 | virtual const char *name() const override { return impl_name; } |
169 | #define DECLARE_COMMON_PD_T(impl_name, ...) \ |
170 | DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) |
171 | |
172 | #endif |
173 | |
174 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
175 | |