1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP |
18 | #define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP |
19 | |
20 | #include "c_types_map.hpp" |
21 | #include "memory_tracking.hpp" |
22 | #include "mkldnn_thread.hpp" |
23 | |
24 | #include "cpu_convolution_pd.hpp" |
25 | #include "cpu_primitive.hpp" |
26 | |
27 | #include "jit_avx512_common_conv_winograd_kernel_f32.hpp" |
28 | |
29 | namespace mkldnn { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | namespace winograd_avx512_common { |
34 | inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
35 | const jit_conv_winograd_conf_t &jcp) { |
36 | using namespace memory_tracking::names; |
37 | |
38 | size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; |
39 | size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic |
40 | * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); |
41 | size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc |
42 | * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); |
43 | |
44 | scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); |
45 | scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); |
46 | scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); |
47 | |
48 | if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) { |
49 | const int nthr = mkldnn_get_max_threads(); |
50 | |
51 | size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr |
52 | * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block; |
53 | scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M); |
54 | |
55 | size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0; |
56 | scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); |
57 | |
58 | size_t padded_bias_sz = |
59 | jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0; |
60 | scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz); |
61 | } |
62 | } |
63 | } |
64 | |
65 | template <bool is_fwd> |
66 | struct _jit_avx512_common_convolution_winograd_t { |
67 | _jit_avx512_common_convolution_winograd_t( |
68 | const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) |
69 | : kernel_(nullptr), attr_(attr) { |
70 | kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp); |
71 | } |
72 | |
73 | ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; } |
74 | |
75 | protected: |
76 | void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, |
77 | float *wei_ptr, float *bias_ptr, |
78 | const memory_tracking::grantor_t &scratchpad) const; |
79 | _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_; |
80 | const primitive_attr_t *attr_; |
81 | }; |
82 | |
83 | struct jit_avx512_common_convolution_winograd_fwd_t |
84 | : _jit_avx512_common_convolution_winograd_t<true> |
85 | , public cpu_primitive_t |
86 | { |
87 | struct pd_t : public cpu_convolution_fwd_pd_t { |
88 | pd_t(engine_t *engine, const convolution_desc_t *adesc, |
89 | const primitive_attr_t *attr, |
90 | const typename pd_t::base_class *hint_fwd_pd) |
91 | : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) |
92 | , jcp_() {} |
93 | |
94 | DECLARE_COMMON_PD_T( |
95 | JIT_IMPL_NAME_HELPER("jit_wino:" , avx512_common, "" ), |
96 | jit_avx512_common_convolution_winograd_fwd_t); |
97 | |
98 | status_t init() { |
99 | bool ok = true |
100 | && is_fwd() |
101 | && utils::one_of(desc()->alg_kind, |
102 | alg_kind::convolution_auto, |
103 | alg_kind::convolution_winograd) |
104 | && expect_data_types(data_type::f32, data_type::f32, |
105 | data_type::f32, data_type::f32, data_type::f32) |
106 | && !has_zero_dim_memory() |
107 | && set_default_formats(); |
108 | if (!ok) return status::unimplemented; |
109 | |
110 | status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32:: |
111 | init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), |
112 | *attr()); |
113 | if (status != status::success) return status; |
114 | set_default_alg_kind(alg_kind::convolution_winograd); |
115 | |
116 | auto scratchpad = scratchpad_registry().registrar(); |
117 | winograd_avx512_common::init_scratchpad(scratchpad, jcp_); |
118 | |
119 | return status; |
120 | } |
121 | |
122 | jit_conv_winograd_conf_t jcp_; |
123 | |
124 | protected: |
125 | bool set_default_formats() { |
126 | using namespace format_tag; |
127 | auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; |
128 | return set_default_formats_common(nChw16c, wei_tag, nChw16c); |
129 | } |
130 | }; |
131 | |
132 | jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd) |
133 | : _jit_avx512_common_convolution_winograd_t<true>(apd->jcp_, apd->attr()) |
134 | , cpu_primitive_t(apd, true) {} |
135 | |
136 | ~jit_avx512_common_convolution_winograd_fwd_t(){}; |
137 | |
138 | typedef typename prec_traits<data_type::f32>::type data_t; |
139 | |
140 | virtual status_t execute(const exec_ctx_t &ctx) const override |
141 | { |
142 | auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); |
143 | auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); |
144 | auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); |
145 | auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); |
146 | this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, |
147 | (float *)bias, this->scratchpad(ctx)); |
148 | return status::success; |
149 | } |
150 | |
151 | private: |
152 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
153 | }; |
154 | |
155 | struct jit_avx512_common_convolution_winograd_bwd_data_t |
156 | : _jit_avx512_common_convolution_winograd_t<false>, |
157 | public cpu_primitive_t { |
158 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
159 | pd_t(engine_t *engine, const convolution_desc_t *adesc, |
160 | const primitive_attr_t *attr, |
161 | const convolution_fwd_pd_t *hint_fwd_pd) |
162 | : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) |
163 | , jcp_() {} |
164 | |
165 | DECLARE_COMMON_PD_T( |
166 | JIT_IMPL_NAME_HELPER("jit_wino:" , avx512_common, "" ), |
167 | jit_avx512_common_convolution_winograd_bwd_data_t); |
168 | |
169 | status_t init() { |
170 | bool ok = true |
171 | && desc()->prop_kind == prop_kind::backward_data |
172 | && expect_data_types(data_type::f32, data_type::f32, |
173 | data_type::undef, data_type::f32, data_type::f32) |
174 | && utils::one_of(desc()->alg_kind, |
175 | alg_kind::convolution_auto, |
176 | alg_kind::convolution_winograd) |
177 | && !has_zero_dim_memory() |
178 | && set_default_formats() |
179 | && mkldnn_thr_syncable(); |
180 | if (!ok) return status::unimplemented; |
181 | |
182 | status_t status = |
183 | jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( |
184 | jcp_, *desc(), *diff_src_md(), *weights_md(), |
185 | *diff_dst_md()); |
186 | if (status != status::success) return status; |
187 | set_default_alg_kind(alg_kind::convolution_winograd); |
188 | |
189 | auto scratchpad = scratchpad_registry().registrar(); |
190 | winograd_avx512_common::init_scratchpad(scratchpad, jcp_); |
191 | |
192 | return status; |
193 | } |
194 | |
195 | jit_conv_winograd_conf_t jcp_; |
196 | |
197 | protected: |
198 | bool set_default_formats() { |
199 | using namespace format_tag; |
200 | auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; |
201 | return set_default_formats_common(nChw16c, wei_tag, nChw16c); |
202 | } |
203 | }; |
204 | |
205 | jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd) |
206 | : _jit_avx512_common_convolution_winograd_t<false>(apd->jcp_, apd->attr()) |
207 | , cpu_primitive_t(apd, true) {} |
208 | |
209 | ~jit_avx512_common_convolution_winograd_bwd_data_t(){}; |
210 | |
211 | typedef typename prec_traits<data_type::f32>::type data_t; |
212 | |
213 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
214 | auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); |
215 | auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); |
216 | auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); |
217 | this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, |
218 | (float *)weights, nullptr, this->scratchpad(ctx)); |
219 | return status::success; |
220 | } |
221 | |
222 | private: |
223 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
224 | }; |
225 | |
226 | struct jit_avx512_common_convolution_winograd_bwd_weights_t |
227 | : public cpu_primitive_t { |
228 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
229 | pd_t(engine_t *engine, const convolution_desc_t *adesc, |
230 | const primitive_attr_t *attr, |
231 | const convolution_fwd_pd_t *hint_fwd_pd) |
232 | : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, |
233 | hint_fwd_pd) |
234 | , jcp_() {} |
235 | |
236 | DECLARE_COMMON_PD_T( |
237 | JIT_IMPL_NAME_HELPER("jit_wino:" , avx512_common, "" ), |
238 | jit_avx512_common_convolution_winograd_bwd_weights_t); |
239 | |
240 | status_t init() { |
241 | bool ok = true |
242 | && desc()->prop_kind == prop_kind::backward_weights |
243 | && utils::one_of(desc()->alg_kind, |
244 | alg_kind::convolution_auto, |
245 | alg_kind::convolution_winograd) |
246 | && expect_data_types(data_type::f32, data_type::f32, |
247 | data_type::f32, data_type::f32, data_type::f32) |
248 | && !has_zero_dim_memory() |
249 | && set_default_formats() |
250 | && mkldnn_thr_syncable(); |
251 | if (!ok) return status::unimplemented; |
252 | |
253 | status_t status = |
254 | jit_avx512_common_conv_winograd_bwd_weights_kernel_f32:: |
255 | init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), |
256 | *diff_weights_md()); |
257 | if (status != status::success) return status; |
258 | set_default_alg_kind(alg_kind::convolution_winograd); |
259 | |
260 | auto scratchpad = scratchpad_registry().registrar(); |
261 | winograd_avx512_common::init_scratchpad(scratchpad, jcp_); |
262 | |
263 | return status; |
264 | } |
265 | |
266 | jit_conv_winograd_conf_t jcp_; |
267 | |
268 | protected: |
269 | bool set_default_formats() { |
270 | using namespace format_tag; |
271 | auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; |
272 | return set_default_formats_common(nChw16c, wei_tag, nChw16c); |
273 | } |
274 | }; |
275 | |
276 | jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd) |
277 | : cpu_primitive_t(apd, true), kernel_(nullptr) |
278 | { |
279 | kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( |
280 | pd()->jcp_); |
281 | } |
282 | |
283 | ~jit_avx512_common_convolution_winograd_bwd_weights_t() |
284 | { delete kernel_; } |
285 | |
286 | typedef typename prec_traits<data_type::f32>::type data_t; |
287 | |
288 | virtual status_t execute(const exec_ctx_t &ctx) const override |
289 | { |
290 | _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx)); |
291 | return status::success; |
292 | } |
293 | |
294 | private: |
295 | void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, |
296 | const memory_tracking::grantor_t &scratchpad) const; |
297 | void _maybe_execute_diff_bias_copy(float *diff_bias, |
298 | const memory_tracking::grantor_t &scratchpad) const; |
299 | |
300 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
301 | jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_; |
302 | }; |
303 | |
304 | void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]); |
305 | void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]); |
306 | void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]); |
307 | void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]); |
308 | void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]); |
309 | void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]); |
310 | void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]); |
311 | |
312 | } |
313 | } |
314 | } |
315 | |
316 | #endif |
317 | |
318 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
319 | |