1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef POOLING_PD_HPP |
18 | #define POOLING_PD_HPP |
19 | |
20 | #include "mkldnn.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "primitive_desc.hpp" |
24 | #include "type_helpers.hpp" |
25 | |
26 | namespace mkldnn { |
27 | namespace impl { |
28 | |
29 | struct pooling_fwd_pd_t; |
30 | |
31 | struct pooling_pd_t: public primitive_desc_t { |
32 | static constexpr auto base_pkind = primitive_kind::pooling; |
33 | |
34 | pooling_pd_t(engine_t *engine, |
35 | const pooling_desc_t *adesc, |
36 | const primitive_attr_t *attr, |
37 | const pooling_fwd_pd_t *hint_fwd_pd) |
38 | : primitive_desc_t(engine, attr, base_pkind) |
39 | , desc_(*adesc) |
40 | , hint_fwd_pd_(hint_fwd_pd) |
41 | , ws_md_() |
42 | {} |
43 | |
44 | const pooling_desc_t *desc() const { return &desc_; } |
45 | virtual const op_desc_t *op_desc() const override |
46 | { return reinterpret_cast<const op_desc_t *>(this->desc()); } |
47 | virtual void init_info() override { impl::init_info(this, this->info_); } |
48 | |
49 | virtual status_t query(query_t what, int idx, void *result) const override { |
50 | switch (what) { |
51 | case query::pooling_d: |
52 | *(const pooling_desc_t**)result = desc(); break; |
53 | default: return primitive_desc_t::query(what, idx, result); |
54 | } |
55 | return status::success; |
56 | } |
57 | |
58 | /* common pooling aux functions */ |
59 | |
60 | dim_t MB() const { return src_desc().dims[0]; } |
61 | dim_t C() const { return src_desc().dims[1]; } |
62 | |
63 | dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; } |
64 | dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; } |
65 | dim_t IW() const { return src_desc().dims[ndims() - 1]; } |
66 | |
67 | dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; } |
68 | dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; } |
69 | dim_t OW() const { return dst_desc().dims[ndims() - 1]; } |
70 | |
71 | dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; } |
72 | dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; } |
73 | dim_t KW() const { return desc_.kernel[ndims() - 3]; } |
74 | |
75 | dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } |
76 | dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } |
77 | dim_t KSW() const { return desc_.strides[ndims() - 3]; } |
78 | |
79 | dim_t padFront() const |
80 | { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } |
81 | dim_t padBack() const |
82 | { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } |
83 | dim_t padT() const |
84 | { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } |
85 | dim_t padB() const |
86 | { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } |
87 | dim_t padL() const { return desc_.padding[0][ndims() - 3]; } |
88 | dim_t padR() const { return desc_.padding[1][ndims() - 3]; } |
89 | |
90 | int ndims() const { return src_desc().ndims; } |
91 | bool is_3d() const { return ndims() == 5; } |
92 | |
93 | bool has_zero_dim_memory() const |
94 | { return memory_desc_wrapper(src_desc()).has_zero_dim(); } |
95 | |
96 | bool is_fwd() const { |
97 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
98 | prop_kind::forward_inference); |
99 | } |
100 | |
101 | protected: |
102 | pooling_desc_t desc_; |
103 | const pooling_fwd_pd_t *hint_fwd_pd_; |
104 | |
105 | memory_desc_t ws_md_; |
106 | |
107 | void init_default_ws() { |
108 | ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); |
109 | ws_md_.data_type = indices_data_type(); |
110 | } |
111 | |
112 | data_type_t indices_data_type() const { |
113 | /* the simplest way to express 256... */ |
114 | const int u8_max = nstl::numeric_limits< |
115 | typename prec_traits<data_type::u8>::type>::max(); |
116 | return utils::array_product(desc()->kernel, ndims()) <= u8_max |
117 | ? data_type::u8 : data_type::s32; |
118 | } |
119 | |
120 | private: |
121 | const memory_desc_t &src_desc() const |
122 | { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; } |
123 | const memory_desc_t &dst_desc() const |
124 | { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; } |
125 | }; |
126 | |
127 | struct pooling_fwd_pd_t: public pooling_pd_t { |
128 | typedef pooling_fwd_pd_t base_class; |
129 | typedef pooling_fwd_pd_t hint_class; |
130 | |
131 | pooling_fwd_pd_t(engine_t *engine, |
132 | const pooling_desc_t *adesc, |
133 | const primitive_attr_t *attr, |
134 | const pooling_fwd_pd_t *hint_fwd_pd) |
135 | : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) |
136 | , src_md_(desc_.src_desc) |
137 | , dst_md_(desc_.dst_desc) |
138 | {} |
139 | |
140 | virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { |
141 | if (arg == MKLDNN_ARG_SRC) |
142 | return arg_usage_t::input; |
143 | |
144 | if (arg == MKLDNN_ARG_DST) |
145 | return arg_usage_t::output; |
146 | |
147 | if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) |
148 | return arg_usage_t::output; |
149 | |
150 | return primitive_desc_t::arg_usage(arg); |
151 | } |
152 | |
153 | virtual const memory_desc_t *src_md(int index = 0) const override |
154 | { return index == 0 ? &src_md_ : nullptr; } |
155 | virtual const memory_desc_t *dst_md(int index = 0) const override |
156 | { return index == 0 ? &dst_md_ : nullptr; } |
157 | virtual const memory_desc_t *workspace_md(int index = 0) const override |
158 | { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } |
159 | |
160 | virtual int n_inputs() const override { return 1; } |
161 | virtual int n_outputs() const override |
162 | { return 1 + (workspace_md() != nullptr); } |
163 | |
164 | protected: |
165 | memory_desc_t src_md_; |
166 | memory_desc_t dst_md_; |
167 | |
168 | virtual status_t set_default_params() { |
169 | if (dst_md()->format_kind != format_kind::any) |
170 | return status::success; |
171 | |
172 | if (src_md()->format_kind != format_kind::blocked) |
173 | return status::unimplemented; |
174 | |
175 | return memory_desc_init_by_blocking_desc(dst_md_, |
176 | src_md_.format_desc.blocking); |
177 | } |
178 | }; |
179 | |
180 | struct pooling_bwd_pd_t: public pooling_pd_t { |
181 | typedef pooling_bwd_pd_t base_class; |
182 | typedef pooling_fwd_pd_t hint_class; |
183 | |
184 | pooling_bwd_pd_t(engine_t *engine, |
185 | const pooling_desc_t *adesc, |
186 | const primitive_attr_t *attr, |
187 | const pooling_fwd_pd_t *hint_fwd_pd) |
188 | : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) |
189 | , diff_src_md_(desc_.diff_src_desc) |
190 | , diff_dst_md_(desc_.diff_dst_desc) |
191 | {} |
192 | |
193 | virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { |
194 | if (arg == MKLDNN_ARG_DIFF_DST) |
195 | return arg_usage_t::input; |
196 | |
197 | if (arg == MKLDNN_ARG_DIFF_SRC) |
198 | return arg_usage_t::output; |
199 | |
200 | if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) |
201 | return arg_usage_t::input; |
202 | |
203 | return primitive_desc_t::arg_usage(arg); |
204 | } |
205 | |
206 | virtual const memory_desc_t *diff_src_md(int index = 0) const override |
207 | { return index == 0 ? &diff_src_md_ : nullptr; } |
208 | virtual const memory_desc_t *diff_dst_md(int index = 0) const override |
209 | { return index == 0 ? &diff_dst_md_ : nullptr; } |
210 | virtual const memory_desc_t *workspace_md(int index = 0) const override |
211 | { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } |
212 | |
213 | virtual int n_inputs() const override |
214 | { return 1 + (workspace_md() != nullptr); } |
215 | virtual int n_outputs() const override { return 1; } |
216 | |
217 | protected: |
218 | memory_desc_t diff_src_md_; |
219 | memory_desc_t diff_dst_md_; |
220 | |
221 | virtual status_t set_default_params() { |
222 | if (diff_src_md()->format_kind != format_kind::any) |
223 | return status::success; |
224 | |
225 | if (diff_dst_md()->format_kind != format_kind::blocked) |
226 | return status::unimplemented; |
227 | |
228 | return memory_desc_init_by_blocking_desc(diff_src_md_, |
229 | diff_dst_md_.format_desc.blocking); |
230 | } |
231 | }; |
232 | |
233 | } |
234 | } |
235 | |
236 | #endif |
237 | |
238 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
239 | |