| 1 | /******************************************************************************* | 
|---|
| 2 | * Copyright 2016-2018 Intel Corporation | 
|---|
| 3 | * | 
|---|
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); | 
|---|
| 5 | * you may not use this file except in compliance with the License. | 
|---|
| 6 | * You may obtain a copy of the License at | 
|---|
| 7 | * | 
|---|
| 8 | *     http://www.apache.org/licenses/LICENSE-2.0 | 
|---|
| 9 | * | 
|---|
| 10 | * Unless required by applicable law or agreed to in writing, software | 
|---|
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, | 
|---|
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|---|
| 13 | * See the License for the specific language governing permissions and | 
|---|
| 14 | * limitations under the License. | 
|---|
| 15 | *******************************************************************************/ | 
|---|
| 16 |  | 
|---|
| 17 | #ifndef PRIMITIVE_DESC_HPP | 
|---|
| 18 | #define PRIMITIVE_DESC_HPP | 
|---|
| 19 |  | 
|---|
| 20 | #include "mkldnn.h" | 
|---|
| 21 |  | 
|---|
| 22 | #include "c_types_map.hpp" | 
|---|
| 23 | #include "memory_tracking.hpp" | 
|---|
| 24 | #include "nstl.hpp" | 
|---|
| 25 | #include "type_helpers.hpp" | 
|---|
| 26 | #include "primitive_attr.hpp" | 
|---|
| 27 | #include "verbose.hpp" | 
|---|
| 28 |  | 
|---|
| 29 | struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { | 
|---|
| 30 | using md_t = mkldnn::impl::memory_desc_t; | 
|---|
| 31 |  | 
|---|
| 32 | mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, | 
|---|
| 33 | const mkldnn::impl::primitive_attr_t *attr, | 
|---|
| 34 | mkldnn::impl::primitive_kind_t kind) | 
|---|
| 35 | : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } | 
|---|
| 36 |  | 
|---|
| 37 | mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, | 
|---|
| 38 | mkldnn::impl::primitive_kind_t kind) | 
|---|
| 39 | : engine_(engine), kind_(kind) { info_[0] = '\0'; } | 
|---|
| 40 |  | 
|---|
| 41 | virtual mkldnn_primitive_desc *clone() const = 0; | 
|---|
| 42 | virtual ~mkldnn_primitive_desc() {} | 
|---|
| 43 |  | 
|---|
| 44 | const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } | 
|---|
| 45 | mkldnn::impl::engine_t *engine() const { return engine_; } | 
|---|
| 46 | mkldnn::impl::primitive_kind_t kind() const { return kind_; } | 
|---|
| 47 |  | 
|---|
| 48 | virtual void init_info() {} | 
|---|
| 49 | const char *info() const { return info_; } | 
|---|
| 50 |  | 
|---|
| 51 | mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() | 
|---|
| 52 | { return scratchpad_registry_; } | 
|---|
| 53 | const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const | 
|---|
| 54 | { return scratchpad_registry_; } | 
|---|
| 55 | virtual mkldnn::impl::engine_t *scratchpad_engine() const | 
|---|
| 56 | { return engine_; } | 
|---|
| 57 |  | 
|---|
| 58 | virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } | 
|---|
| 59 |  | 
|---|
| 60 | enum class arg_usage_t { unused, input, output }; | 
|---|
| 61 | virtual arg_usage_t arg_usage( | 
|---|
| 62 | mkldnn::impl::primitive_arg_index_t arg) const { | 
|---|
| 63 | using mkldnn::impl::types::is_zero_md; | 
|---|
| 64 | if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) | 
|---|
| 65 | return arg_usage_t::output; | 
|---|
| 66 | return arg_usage_t::unused; | 
|---|
| 67 | } | 
|---|
| 68 |  | 
|---|
| 69 | #   define DECLARE_MD_STUB(stub) \ | 
|---|
| 70 | virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ | 
|---|
| 71 | { return nullptr; } | 
|---|
| 72 |  | 
|---|
| 73 | DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); | 
|---|
| 74 | DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); | 
|---|
| 75 | DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); | 
|---|
| 76 | DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); | 
|---|
| 77 | DECLARE_MD_STUB(workspace_md); | 
|---|
| 78 | #   undef DECLARE_MD_STUB | 
|---|
| 79 |  | 
|---|
| 80 | const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { | 
|---|
| 81 | return idx == 0 ? &scratchpad_md_ : nullptr; | 
|---|
| 82 | } | 
|---|
| 83 |  | 
|---|
| 84 | virtual void init_scratchpad_md() { | 
|---|
| 85 | auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); | 
|---|
| 86 | mkldnn::impl::dims_t dims = { size }; | 
|---|
| 87 | mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, | 
|---|
| 88 | mkldnn::impl::data_type::u8, mkldnn_x); | 
|---|
| 89 | } | 
|---|
| 90 |  | 
|---|
| 91 | /** returns the scratchpad size for the given scratchpad mode. */ | 
|---|
| 92 | mkldnn::impl::dim_t scratchpad_size( | 
|---|
| 93 | mkldnn::impl::scratchpad_mode_t mode) const { | 
|---|
| 94 | if (mode != attr_.scratchpad_mode_) return 0; | 
|---|
| 95 | return scratchpad_registry().size(); | 
|---|
| 96 | } | 
|---|
| 97 |  | 
|---|
| 98 | virtual int n_inputs() const { return 0; } | 
|---|
| 99 | virtual int n_outputs() const { return 0; } | 
|---|
| 100 |  | 
|---|
| 101 | virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, | 
|---|
| 102 | void *result) const; | 
|---|
| 103 |  | 
|---|
| 104 | virtual mkldnn::impl::status_t create_primitive( | 
|---|
| 105 | mkldnn::impl::primitive_t **primitive) const = 0; | 
|---|
| 106 |  | 
|---|
| 107 | virtual const char *name() const { return "mkldnn_primitive_desc"; } | 
|---|
| 108 |  | 
|---|
| 109 | /* static magic */ | 
|---|
| 110 |  | 
|---|
| 111 | template<typename pd_t> | 
|---|
| 112 | static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, | 
|---|
| 113 | const mkldnn::impl::op_desc_t *adesc, | 
|---|
| 114 | const mkldnn::impl::primitive_attr_t *attr, | 
|---|
| 115 | mkldnn::impl::engine_t *engine, | 
|---|
| 116 | const mkldnn::impl::primitive_desc_t *hint_fwd) { | 
|---|
| 117 | using namespace mkldnn::impl; | 
|---|
| 118 | using namespace mkldnn::impl::status; | 
|---|
| 119 | using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type; | 
|---|
| 120 | if (adesc->kind != pd_t::base_pkind) return invalid_arguments; | 
|---|
| 121 | assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); | 
|---|
| 122 | auto hint = | 
|---|
| 123 | reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd); | 
|---|
| 124 | auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); | 
|---|
| 125 | if (_pd == nullptr) return out_of_memory; | 
|---|
| 126 | if (_pd->init() != success) { delete _pd; return unimplemented; } | 
|---|
| 127 | _pd->init_info(); | 
|---|
| 128 | _pd->init_scratchpad_md(); | 
|---|
| 129 | *pd = _pd; | 
|---|
| 130 | return success; | 
|---|
| 131 | } | 
|---|
| 132 |  | 
|---|
| 133 | protected: | 
|---|
| 134 | mkldnn::impl::engine_t *engine_; | 
|---|
| 135 | mkldnn::impl::primitive_attr_t attr_; | 
|---|
| 136 | mkldnn::impl::primitive_kind_t kind_; | 
|---|
| 137 |  | 
|---|
| 138 | mkldnn::impl::memory_desc_t scratchpad_md_; | 
|---|
| 139 |  | 
|---|
| 140 | char info_[MKLDNN_VERBOSE_BUF_LEN]; | 
|---|
| 141 |  | 
|---|
| 142 | mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; | 
|---|
| 143 |  | 
|---|
| 144 | protected: | 
|---|
| 145 | /** compares ws between fwd_pd and this (make sense to use for bwd_pd) | 
|---|
| 146 | * Expectation: this already set workspace, and this workspace should | 
|---|
| 147 | *              exactly match the one from fwd_pd */ | 
|---|
| 148 | bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { | 
|---|
| 149 | using namespace mkldnn::impl; | 
|---|
| 150 | if (!workspace_md()) return true; // the impl lives fine w/o workspace | 
|---|
| 151 | return fwd_pd && fwd_pd->workspace_md() | 
|---|
| 152 | && *fwd_pd->workspace_md() == *workspace_md(); | 
|---|
| 153 | } | 
|---|
| 154 | }; | 
|---|
| 155 |  | 
|---|
| 156 | #define DECLARE_COMMON_PD_t(impl_name, ...) \ | 
|---|
| 157 | virtual pd_t *clone() const override { return new pd_t(*this); } \ | 
|---|
| 158 | virtual status_t create_primitive(primitive_t **p) const override { \ | 
|---|
| 159 | double ms = get_msec(); \ | 
|---|
| 160 | auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ | 
|---|
| 161 | ms = get_msec() - ms; \ | 
|---|
| 162 | if (mkldnn_verbose()->level >= 2) { \ | 
|---|
| 163 | printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ | 
|---|
| 164 | fflush(0); \ | 
|---|
| 165 | } \ | 
|---|
| 166 | return ret; \ | 
|---|
| 167 | } \ | 
|---|
| 168 | virtual const char *name() const override { return impl_name; } | 
|---|
| 169 | #define DECLARE_COMMON_PD_T(impl_name, ...) \ | 
|---|
| 170 | DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) | 
|---|
| 171 |  | 
|---|
| 172 | #endif | 
|---|
| 173 |  | 
|---|
| 174 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s | 
|---|
| 175 |  | 
|---|