| 1 | /******************************************************************************* |
| 2 | * Copyright 2016-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #ifndef PRIMITIVE_DESC_HPP |
| 18 | #define PRIMITIVE_DESC_HPP |
| 19 | |
| 20 | #include "mkldnn.h" |
| 21 | |
| 22 | #include "c_types_map.hpp" |
| 23 | #include "memory_tracking.hpp" |
| 24 | #include "nstl.hpp" |
| 25 | #include "type_helpers.hpp" |
| 26 | #include "primitive_attr.hpp" |
| 27 | #include "verbose.hpp" |
| 28 | |
| 29 | struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { |
| 30 | using md_t = mkldnn::impl::memory_desc_t; |
| 31 | |
| 32 | mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, |
| 33 | const mkldnn::impl::primitive_attr_t *attr, |
| 34 | mkldnn::impl::primitive_kind_t kind) |
| 35 | : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } |
| 36 | |
| 37 | mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, |
| 38 | mkldnn::impl::primitive_kind_t kind) |
| 39 | : engine_(engine), kind_(kind) { info_[0] = '\0'; } |
| 40 | |
| 41 | virtual mkldnn_primitive_desc *clone() const = 0; |
| 42 | virtual ~mkldnn_primitive_desc() {} |
| 43 | |
| 44 | const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } |
| 45 | mkldnn::impl::engine_t *engine() const { return engine_; } |
| 46 | mkldnn::impl::primitive_kind_t kind() const { return kind_; } |
| 47 | |
| 48 | virtual void init_info() {} |
| 49 | const char *info() const { return info_; } |
| 50 | |
| 51 | mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() |
| 52 | { return scratchpad_registry_; } |
| 53 | const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const |
| 54 | { return scratchpad_registry_; } |
| 55 | virtual mkldnn::impl::engine_t *scratchpad_engine() const |
| 56 | { return engine_; } |
| 57 | |
| 58 | virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } |
| 59 | |
| 60 | enum class arg_usage_t { unused, input, output }; |
| 61 | virtual arg_usage_t arg_usage( |
| 62 | mkldnn::impl::primitive_arg_index_t arg) const { |
| 63 | using mkldnn::impl::types::is_zero_md; |
| 64 | if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) |
| 65 | return arg_usage_t::output; |
| 66 | return arg_usage_t::unused; |
| 67 | } |
| 68 | |
| 69 | # define DECLARE_MD_STUB(stub) \ |
| 70 | virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ |
| 71 | { return nullptr; } |
| 72 | |
| 73 | DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); |
| 74 | DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); |
| 75 | DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); |
| 76 | DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); |
| 77 | DECLARE_MD_STUB(workspace_md); |
| 78 | # undef DECLARE_MD_STUB |
| 79 | |
| 80 | const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { |
| 81 | return idx == 0 ? &scratchpad_md_ : nullptr; |
| 82 | } |
| 83 | |
| 84 | virtual void init_scratchpad_md() { |
| 85 | auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); |
| 86 | mkldnn::impl::dims_t dims = { size }; |
| 87 | mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, |
| 88 | mkldnn::impl::data_type::u8, mkldnn_x); |
| 89 | } |
| 90 | |
| 91 | /** returns the scratchpad size for the given scratchpad mode. */ |
| 92 | mkldnn::impl::dim_t scratchpad_size( |
| 93 | mkldnn::impl::scratchpad_mode_t mode) const { |
| 94 | if (mode != attr_.scratchpad_mode_) return 0; |
| 95 | return scratchpad_registry().size(); |
| 96 | } |
| 97 | |
| 98 | virtual int n_inputs() const { return 0; } |
| 99 | virtual int n_outputs() const { return 0; } |
| 100 | |
| 101 | virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, |
| 102 | void *result) const; |
| 103 | |
| 104 | virtual mkldnn::impl::status_t create_primitive( |
| 105 | mkldnn::impl::primitive_t **primitive) const = 0; |
| 106 | |
| 107 | virtual const char *name() const { return "mkldnn_primitive_desc" ; } |
| 108 | |
| 109 | /* static magic */ |
| 110 | |
| 111 | template<typename pd_t> |
| 112 | static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, |
| 113 | const mkldnn::impl::op_desc_t *adesc, |
| 114 | const mkldnn::impl::primitive_attr_t *attr, |
| 115 | mkldnn::impl::engine_t *engine, |
| 116 | const mkldnn::impl::primitive_desc_t *hint_fwd) { |
| 117 | using namespace mkldnn::impl; |
| 118 | using namespace mkldnn::impl::status; |
| 119 | using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type; |
| 120 | if (adesc->kind != pd_t::base_pkind) return invalid_arguments; |
| 121 | assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); |
| 122 | auto hint = |
| 123 | reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd); |
| 124 | auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); |
| 125 | if (_pd == nullptr) return out_of_memory; |
| 126 | if (_pd->init() != success) { delete _pd; return unimplemented; } |
| 127 | _pd->init_info(); |
| 128 | _pd->init_scratchpad_md(); |
| 129 | *pd = _pd; |
| 130 | return success; |
| 131 | } |
| 132 | |
| 133 | protected: |
| 134 | mkldnn::impl::engine_t *engine_; |
| 135 | mkldnn::impl::primitive_attr_t attr_; |
| 136 | mkldnn::impl::primitive_kind_t kind_; |
| 137 | |
| 138 | mkldnn::impl::memory_desc_t scratchpad_md_; |
| 139 | |
| 140 | char info_[MKLDNN_VERBOSE_BUF_LEN]; |
| 141 | |
| 142 | mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; |
| 143 | |
| 144 | protected: |
| 145 | /** compares ws between fwd_pd and this (make sense to use for bwd_pd) |
| 146 | * Expectation: this already set workspace, and this workspace should |
| 147 | * exactly match the one from fwd_pd */ |
| 148 | bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { |
| 149 | using namespace mkldnn::impl; |
| 150 | if (!workspace_md()) return true; // the impl lives fine w/o workspace |
| 151 | return fwd_pd && fwd_pd->workspace_md() |
| 152 | && *fwd_pd->workspace_md() == *workspace_md(); |
| 153 | } |
| 154 | }; |
| 155 | |
| 156 | #define DECLARE_COMMON_PD_t(impl_name, ...) \ |
| 157 | virtual pd_t *clone() const override { return new pd_t(*this); } \ |
| 158 | virtual status_t create_primitive(primitive_t **p) const override { \ |
| 159 | double ms = get_msec(); \ |
| 160 | auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ |
| 161 | ms = get_msec() - ms; \ |
| 162 | if (mkldnn_verbose()->level >= 2) { \ |
| 163 | printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ |
| 164 | fflush(0); \ |
| 165 | } \ |
| 166 | return ret; \ |
| 167 | } \ |
| 168 | virtual const char *name() const override { return impl_name; } |
| 169 | #define DECLARE_COMMON_PD_T(impl_name, ...) \ |
| 170 | DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) |
| 171 | |
| 172 | #endif |
| 173 | |
| 174 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
| 175 | |