1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
18#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
19
20#include "c_types_map.hpp"
21#include "memory_tracking.hpp"
22#include "mkldnn_thread.hpp"
23#include "utils.hpp"
24
25#include "cpu_barrier.hpp"
26#include "cpu_convolution_pd.hpp"
27#include "cpu_primitive.hpp"
28#include "cpu_reducer.hpp"
29
30#include "jit_transpose_src_utils.hpp"
31#include "jit_avx512_common_conv_kernel.hpp"
32
33namespace mkldnn {
34namespace impl {
35namespace cpu {
36
37template <impl::data_type_t src_type,
38 impl::data_type_t wei_type = src_type,
39 impl::data_type_t dst_type = src_type>
40struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t {
41 struct pd_t : public cpu_convolution_fwd_pd_t {
42 pd_t(engine_t *engine, const convolution_desc_t *adesc,
43 const primitive_attr_t *attr,
44 const typename pd_t::base_class *hint_fwd_pd)
45 : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
46 , jcp_()
47 {}
48
49 DECLARE_COMMON_PD_T(
50 JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
51 jit_avx512_common_convolution_fwd_t);
52
53 status_t init() {
54 bool ok = true
55 && is_fwd()
56 && set_default_alg_kind(alg_kind::convolution_direct)
57 && expect_data_types(src_type, wei_type, dst_type, dst_type,
58 data_type::undef)
59 && !has_zero_dim_memory();
60 if (!ok) return status::unimplemented;
61
62 status_t status = jit_avx512_common_conv_fwd_kernel::init_conf(
63 jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_,
64 *attr(), mkldnn_get_max_threads());
65 if (status != status::success) return status;
66
67 auto scratchpad = scratchpad_registry().registrar();
68 jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad,
69 jcp_);
70
71 return status;
72 }
73
74 jit_conv_conf_t jcp_;
75 };
76
77 jit_avx512_common_convolution_fwd_t(const pd_t *apd)
78 : cpu_primitive_t(apd)
79 {
80 kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_,
81 *pd()->attr());
82 }
83 ~jit_avx512_common_convolution_fwd_t() { delete kernel_; }
84
85 typedef typename prec_traits<src_type>::type src_data_t;
86 typedef typename prec_traits<wei_type>::type wei_data_t;
87 typedef typename prec_traits<dst_type>::type dst_data_t;
88
89 virtual status_t execute(const exec_ctx_t &ctx) const override {
90 if (pd()->ndims() == 3)
91 execute_forward_1d(ctx);
92 else if (pd()->ndims() == 4)
93 execute_forward_2d(ctx);
94 else if (pd()->ndims() == 5)
95 execute_forward_3d(ctx);
96 else
97 assert(false);
98
99 if (pd()->wants_zero_pad_dst())
100 ctx.memory(MKLDNN_ARG_DST)->zero_pad();
101
102 return status::success;
103 }
104
105private:
106 void prepare_padded_bias(const dst_data_t *&bias,
107 const memory_tracking::grantor_t &scratchpad) const;
108 void execute_forward_1d(const exec_ctx_t &ctx) const;
109 void execute_forward_2d(const exec_ctx_t &ctx) const;
110 void execute_forward_3d(const exec_ctx_t &ctx) const;
111 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
112
113 jit_avx512_common_conv_fwd_kernel *kernel_;
114};
115
116template <impl::data_type_t diff_dst_type,
117 impl::data_type_t wei_type = diff_dst_type,
118 impl::data_type_t diff_src_type = diff_dst_type>
119struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t {
120 struct pd_t: public cpu_convolution_bwd_data_pd_t {
121 pd_t(engine_t *engine,
122 const convolution_desc_t *adesc,
123 const primitive_attr_t *attr,
124 const convolution_fwd_pd_t *hint_fwd_pd)
125 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
126 , jcp_()
127 {}
128
129 DECLARE_COMMON_PD_T(
130 JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
131 jit_avx512_common_convolution_bwd_data_t);
132
133 status_t init() {
134 bool ok = true
135 && desc()->prop_kind == prop_kind::backward_data
136 && set_default_alg_kind(alg_kind::convolution_direct)
137 && expect_data_types(diff_src_type, wei_type,
138 data_type::undef, diff_dst_type, data_type::undef)
139 && !has_zero_dim_memory()
140 && set_default_formats();
141 if (!ok) return status::unimplemented;
142
143 status_t status =
144 jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_,
145 *desc(), *diff_src_md(), *weights_md(), *diff_dst_md());
146 if (status != status::success) return status;
147
148 auto scratchpad = scratchpad_registry().registrar();
149 jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
150 scratchpad, jcp_);
151
152 return status::success;
153 }
154
155 jit_conv_conf_t jcp_;
156
157 protected:
158 bool set_default_formats() {
159 using namespace format_tag;
160
161 auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
162 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
163 OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i,
164 OIdhw16o16i, gOIdhw16o16i);
165
166 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
167 }
168 };
169
170 jit_avx512_common_convolution_bwd_data_t(const pd_t *apd)
171 : cpu_primitive_t(apd)
172 { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); }
173 ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; };
174
175 typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
176 typedef typename prec_traits<wei_type>::type wei_data_t;
177 typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
178
179 virtual status_t execute(const exec_ctx_t &ctx) const override {
180 if (pd()->ndims() == 3)
181 execute_backward_data_1d(ctx);
182 else if (pd()->ndims() == 4)
183 execute_backward_data_2d(ctx);
184 else if (pd()->ndims() == 5)
185 execute_backward_data_3d(ctx);
186 else
187 assert(false);
188 return status::success;
189 }
190
191private:
192 void execute_backward_data_1d(const exec_ctx_t &ctx) const;
193 void execute_backward_data_2d(const exec_ctx_t &ctx) const;
194 void execute_backward_data_3d(const exec_ctx_t &ctx) const;
195 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
196
197 jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_;
198};
199
200template <impl::data_type_t src_type,
201 impl::data_type_t diff_dst_type = src_type,
202 impl::data_type_t diff_weights_type = src_type>
203struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t {
204 struct pd_t: public cpu_convolution_bwd_weights_pd_t {
205 pd_t(engine_t *engine, const convolution_desc_t *adesc,
206 const primitive_attr_t *attr,
207 const convolution_fwd_pd_t *hint_fwd_pd)
208 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
209 , jcp_() {}
210
211 DECLARE_COMMON_PD_T(
212 JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
213 jit_avx512_common_convolution_bwd_weights_t);
214
215 status_t init() {
216 bool ok = true
217 && desc()->prop_kind == prop_kind::backward_weights
218 && set_default_alg_kind(alg_kind::convolution_direct)
219 && expect_data_types(src_type, diff_weights_type,
220 diff_weights_type, diff_dst_type, data_type::undef)
221 && !has_zero_dim_memory();
222 if (!ok) return status::unimplemented;
223
224 status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32::
225 init_conf(jcp_, *desc(), src_md_, diff_weights_md_,
226 diff_bias_md_, diff_dst_md_);
227 if (status != status::success) return status;
228
229 init_balancers();
230
231 auto scratchpad = scratchpad_registry().registrar();
232 jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
233 scratchpad, jcp_);
234
235 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
236 scratchpad, memory_tracking::names::prefix_reducer_bia);
237 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
238
239 return status;
240 }
241
242 jit_conv_conf_t jcp_;
243 typename cpu_reducer_t<diff_weights_type>::conf_t reducer_bia_conf_;
244
245 private:
246 void init_balancers() {
247 const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
248 if (with_bias()) {
249 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
250 jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
251 max_buffer_size));
252 }
253 }
254 };
255
256 jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd);
257 ~jit_avx512_common_convolution_bwd_weights_t() {
258 delete kernel_;
259 if (trans_kernel_)
260 delete trans_kernel_;
261 if (acc_ker_)
262 delete acc_ker_;
263 delete reducer_bias_;
264 }
265
266 typedef typename prec_traits<src_type>::type src_data_t;
267 typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
268 typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
269
270 virtual status_t execute(const exec_ctx_t &ctx) const override {
271 execute_backward_weights(ctx);
272 return status::success;
273 }
274
275private:
276 void execute_backward_weights(const exec_ctx_t &ctx) const;
277 void prepare_scratchpad_data(const exec_ctx_t &ctx) const;
278 struct thread_info_t;
279 void compute_diff_weights(const thread_info_t *) const;
280 void compute_diff_weights_3d(const thread_info_t *) const;
281 void reduce_diff_weights(const thread_info_t *) const;
282 void reduce_diff_weights_3d(const thread_info_t *) const;
283 void compute_diff_bias(const thread_info_t *) const;
284 void compute_diff_bias_3d(const thread_info_t *) const;
285
286 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
287
288 int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_;
289
290 jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_;
291 jit_trans_src_t *trans_kernel_;
292 cpu_accumulator_1d_t<diff_weights_type> *acc_ker_;
293 cpu_reducer_t<diff_weights_type> *reducer_bias_;
294};
295
296}
297}
298}
299
300#endif
301
302// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
303