1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP
18#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP
19
20#include "c_types_map.hpp"
21#include "memory_tracking.hpp"
22#include "mkldnn_thread.hpp"
23
24#include "cpu_convolution_pd.hpp"
25#include "cpu_primitive.hpp"
26
27#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
28
29namespace mkldnn {
30namespace impl {
31namespace cpu {
32
33namespace winograd_avx512_common {
34inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
35 const jit_conv_winograd_conf_t &jcp) {
36 using namespace memory_tracking::names;
37
38 size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
39 size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic
40 * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
41 size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc
42 * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
43
44 scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M);
45 scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M);
46 scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M);
47
48 if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) {
49 const int nthr = mkldnn_get_max_threads();
50
51 size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr
52 * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block;
53 scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M);
54
55 size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0;
56 scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M);
57
58 size_t padded_bias_sz =
59 jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0;
60 scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz);
61 }
62}
63}
64
65template <bool is_fwd>
66struct _jit_avx512_common_convolution_winograd_t {
67 _jit_avx512_common_convolution_winograd_t(
68 const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
69 : kernel_(nullptr), attr_(attr) {
70 kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp);
71 }
72
73 ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; }
74
75 protected:
76 void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr,
77 float *wei_ptr, float *bias_ptr,
78 const memory_tracking::grantor_t &scratchpad) const;
79 _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_;
80 const primitive_attr_t *attr_;
81};
82
83struct jit_avx512_common_convolution_winograd_fwd_t
84 : _jit_avx512_common_convolution_winograd_t<true>
85 , public cpu_primitive_t
86 {
87 struct pd_t : public cpu_convolution_fwd_pd_t {
88 pd_t(engine_t *engine, const convolution_desc_t *adesc,
89 const primitive_attr_t *attr,
90 const typename pd_t::base_class *hint_fwd_pd)
91 : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
92 , jcp_() {}
93
94 DECLARE_COMMON_PD_T(
95 JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
96 jit_avx512_common_convolution_winograd_fwd_t);
97
98 status_t init() {
99 bool ok = true
100 && is_fwd()
101 && utils::one_of(desc()->alg_kind,
102 alg_kind::convolution_auto,
103 alg_kind::convolution_winograd)
104 && expect_data_types(data_type::f32, data_type::f32,
105 data_type::f32, data_type::f32, data_type::f32)
106 && !has_zero_dim_memory()
107 && set_default_formats();
108 if (!ok) return status::unimplemented;
109
110 status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32::
111 init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(),
112 *attr());
113 if (status != status::success) return status;
114 set_default_alg_kind(alg_kind::convolution_winograd);
115
116 auto scratchpad = scratchpad_registry().registrar();
117 winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
118
119 return status;
120 }
121
122 jit_conv_winograd_conf_t jcp_;
123
124 protected:
125 bool set_default_formats() {
126 using namespace format_tag;
127 auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
128 return set_default_formats_common(nChw16c, wei_tag, nChw16c);
129 }
130 };
131
132 jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd)
133 : _jit_avx512_common_convolution_winograd_t<true>(apd->jcp_, apd->attr())
134 , cpu_primitive_t(apd, true) {}
135
136 ~jit_avx512_common_convolution_winograd_fwd_t(){};
137
138 typedef typename prec_traits<data_type::f32>::type data_t;
139
140 virtual status_t execute(const exec_ctx_t &ctx) const override
141 {
142 auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
143 auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
144 auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
145 auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST);
146 this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights,
147 (float *)bias, this->scratchpad(ctx));
148 return status::success;
149 }
150
151private:
152 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
153};
154
155struct jit_avx512_common_convolution_winograd_bwd_data_t
156 : _jit_avx512_common_convolution_winograd_t<false>,
157 public cpu_primitive_t {
158 struct pd_t : public cpu_convolution_bwd_data_pd_t {
159 pd_t(engine_t *engine, const convolution_desc_t *adesc,
160 const primitive_attr_t *attr,
161 const convolution_fwd_pd_t *hint_fwd_pd)
162 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
163 , jcp_() {}
164
165 DECLARE_COMMON_PD_T(
166 JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
167 jit_avx512_common_convolution_winograd_bwd_data_t);
168
169 status_t init() {
170 bool ok = true
171 && desc()->prop_kind == prop_kind::backward_data
172 && expect_data_types(data_type::f32, data_type::f32,
173 data_type::undef, data_type::f32, data_type::f32)
174 && utils::one_of(desc()->alg_kind,
175 alg_kind::convolution_auto,
176 alg_kind::convolution_winograd)
177 && !has_zero_dim_memory()
178 && set_default_formats()
179 && mkldnn_thr_syncable();
180 if (!ok) return status::unimplemented;
181
182 status_t status =
183 jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
184 jcp_, *desc(), *diff_src_md(), *weights_md(),
185 *diff_dst_md());
186 if (status != status::success) return status;
187 set_default_alg_kind(alg_kind::convolution_winograd);
188
189 auto scratchpad = scratchpad_registry().registrar();
190 winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
191
192 return status;
193 }
194
195 jit_conv_winograd_conf_t jcp_;
196
197 protected:
198 bool set_default_formats() {
199 using namespace format_tag;
200 auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
201 return set_default_formats_common(nChw16c, wei_tag, nChw16c);
202 }
203 };
204
205 jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd)
206 : _jit_avx512_common_convolution_winograd_t<false>(apd->jcp_, apd->attr())
207 , cpu_primitive_t(apd, true) {}
208
209 ~jit_avx512_common_convolution_winograd_bwd_data_t(){};
210
211 typedef typename prec_traits<data_type::f32>::type data_t;
212
213 virtual status_t execute(const exec_ctx_t &ctx) const override {
214 auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
215 auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
216 auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC);
217 this->_execute_data_W_S_G_D((float *)diff_dst, diff_src,
218 (float *)weights, nullptr, this->scratchpad(ctx));
219 return status::success;
220 }
221
222private:
223 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
224};
225
226struct jit_avx512_common_convolution_winograd_bwd_weights_t
227 : public cpu_primitive_t {
228 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
229 pd_t(engine_t *engine, const convolution_desc_t *adesc,
230 const primitive_attr_t *attr,
231 const convolution_fwd_pd_t *hint_fwd_pd)
232 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr,
233 hint_fwd_pd)
234 , jcp_() {}
235
236 DECLARE_COMMON_PD_T(
237 JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
238 jit_avx512_common_convolution_winograd_bwd_weights_t);
239
240 status_t init() {
241 bool ok = true
242 && desc()->prop_kind == prop_kind::backward_weights
243 && utils::one_of(desc()->alg_kind,
244 alg_kind::convolution_auto,
245 alg_kind::convolution_winograd)
246 && expect_data_types(data_type::f32, data_type::f32,
247 data_type::f32, data_type::f32, data_type::f32)
248 && !has_zero_dim_memory()
249 && set_default_formats()
250 && mkldnn_thr_syncable();
251 if (!ok) return status::unimplemented;
252
253 status_t status =
254 jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::
255 init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(),
256 *diff_weights_md());
257 if (status != status::success) return status;
258 set_default_alg_kind(alg_kind::convolution_winograd);
259
260 auto scratchpad = scratchpad_registry().registrar();
261 winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
262
263 return status;
264 }
265
266 jit_conv_winograd_conf_t jcp_;
267
268 protected:
269 bool set_default_formats() {
270 using namespace format_tag;
271 auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
272 return set_default_formats_common(nChw16c, wei_tag, nChw16c);
273 }
274 };
275
276 jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd)
277 : cpu_primitive_t(apd, true), kernel_(nullptr)
278 {
279 kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
280 pd()->jcp_);
281 }
282
283 ~jit_avx512_common_convolution_winograd_bwd_weights_t()
284 { delete kernel_; }
285
286 typedef typename prec_traits<data_type::f32>::type data_t;
287
288 virtual status_t execute(const exec_ctx_t &ctx) const override
289 {
290 _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx));
291 return status::success;
292 }
293
294private:
295 void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx,
296 const memory_tracking::grantor_t &scratchpad) const;
297 void _maybe_execute_diff_bias_copy(float *diff_bias,
298 const memory_tracking::grantor_t &scratchpad) const;
299
300 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
301 jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_;
302};
303
304void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]);
305void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]);
306void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]);
307void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]);
308void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]);
309void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]);
310void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]);
311
312}
313}
314}
315
316#endif
317
318// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
319