1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP
18#define CPU_JIT_AVX2_CONVOLUTION_HPP
19
20#include "c_types_map.hpp"
21#include "memory_tracking.hpp"
22#include "mkldnn_thread.hpp"
23#include "utils.hpp"
24
25#include "cpu_convolution_pd.hpp"
26#include "cpu_reducer.hpp"
27
28#include "jit_avx2_conv_kernel_f32.hpp"
29
30namespace mkldnn {
31namespace impl {
32namespace cpu {
33
34struct jit_avx2_convolution_fwd_t: public cpu_primitive_t {
35 struct pd_t: public cpu_convolution_fwd_pd_t {
36 pd_t(engine_t *engine,
37 const convolution_desc_t *adesc,
38 const primitive_attr_t *attr,
39 const typename pd_t::base_class *hint_fwd_pd)
40 : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
41 , jcp_() {}
42
43 DECLARE_COMMON_PD_T(
44 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
45 jit_avx2_convolution_fwd_t);
46
47 status_t init() {
48 bool ok = true
49 && is_fwd()
50 && set_default_alg_kind(alg_kind::convolution_direct)
51 && expect_data_types(data_type::f32, data_type::f32,
52 data_type::f32, data_type::f32, data_type::f32)
53 && !has_zero_dim_memory()
54 && set_default_formats();
55 if (!ok) return status::unimplemented;
56
57 status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_,
58 *desc(), src_md(), weights_md(), dst_md(), *attr());
59 if (status != status::success) return status;
60
61 auto scratchpad = scratchpad_registry().registrar();
62 jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_);
63
64 return status::success;
65 }
66
67 jit_conv_conf_t jcp_;
68
69 protected:
70 bool set_default_formats() {
71 using namespace format_tag;
72
73 const bool flat = IC() < 8;
74 auto src_tag = flat
75 ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
76 : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
77 auto dst_tag =
78 utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
79 auto wei_tag = with_groups()
80 ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
81 gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
82 : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
83 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
84
85 return set_default_formats_common(src_tag, wei_tag, dst_tag);
86 }
87 };
88
89 jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
90 { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); }
91 ~jit_avx2_convolution_fwd_t() { delete kernel_; }
92
93 typedef typename prec_traits<data_type::f32>::type data_t;
94
95 virtual status_t execute(const exec_ctx_t &ctx) const override {
96 execute_forward(ctx);
97 return status::success;
98 }
99
100private:
101 void execute_forward(const exec_ctx_t &ctx) const;
102 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
103
104 jit_avx2_conv_fwd_kernel_f32 *kernel_;
105};
106
107struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t {
108 struct pd_t: public cpu_convolution_bwd_data_pd_t {
109 pd_t(engine_t *engine,
110 const convolution_desc_t *adesc,
111 const primitive_attr_t *attr,
112 const convolution_fwd_pd_t *hint_fwd_pd)
113 : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
114 , jcp_()
115 {}
116
117 DECLARE_COMMON_PD_T(
118 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
119 jit_avx2_convolution_bwd_data_t);
120
121 status_t init() {
122 bool ok = true
123 && desc()->prop_kind == prop_kind::backward_data
124 && set_default_alg_kind(alg_kind::convolution_direct)
125 && expect_data_types(data_type::f32, data_type::f32,
126 data_type::undef, data_type::f32, data_type::f32)
127 && !has_zero_dim_memory()
128 && set_default_formats();
129 if (!ok) return status::unimplemented;
130
131 status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf(
132 jcp_, *desc(), *diff_src_md(), *weights_md(),
133 *diff_dst_md());
134 if (status != status::success) return status;
135
136 auto scratchpad = scratchpad_registry().registrar();
137 jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad,
138 jcp_);
139
140 return status::success;
141 }
142
143 jit_conv_conf_t jcp_;
144
145 protected:
146 bool set_default_formats() {
147 using namespace format_tag;
148
149 auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
150 auto wei_tag = with_groups()
151 ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
152 : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
153
154 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
155 }
156 };
157
158 jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
159 { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); }
160 ~jit_avx2_convolution_bwd_data_t() { delete kernel_; }
161
162 typedef typename prec_traits<data_type::f32>::type data_t;
163
164 virtual status_t execute(const exec_ctx_t &ctx) const override {
165 execute_backward_data(ctx);
166 return status::success;
167 }
168
169private:
170 void execute_backward_data(const exec_ctx_t &ctx) const;
171 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
172
173 jit_avx2_conv_bwd_data_kernel_f32 *kernel_;
174};
175
176struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t {
177 struct pd_t: public cpu_convolution_bwd_weights_pd_t {
178 pd_t(engine_t *engine, const convolution_desc_t *adesc,
179 const primitive_attr_t *attr,
180 const convolution_fwd_pd_t *hint_fwd_pd)
181 : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
182 , jcp_() {}
183
184 DECLARE_COMMON_PD_T(
185 JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
186 jit_avx2_convolution_bwd_weights_t);
187
188 status_t init() {
189 bool ok = true
190 && desc()->prop_kind == prop_kind::backward_weights
191 && set_default_alg_kind(alg_kind::convolution_direct)
192 && expect_data_types(data_type::f32, data_type::f32,
193 data_type::f32, data_type::f32, data_type::f32)
194 && !has_zero_dim_memory()
195 && set_default_formats();
196 if (!ok) return status::unimplemented;
197
198 status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf(
199 jcp_, *desc(), *src_md(), *diff_weights_md(),
200 *diff_dst_md());
201 if (status != status::success) return status;
202
203 init_balancers();
204
205 auto scratchpad = scratchpad_registry().registrar();
206 jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad,
207 jcp_);
208
209 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
210 scratchpad, memory_tracking::names::prefix_reducer_bia);
211 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
212
213 auto reducer_wei_scratchpad = memory_tracking::registrar_t(
214 scratchpad, memory_tracking::names::prefix_reducer_wei);
215 reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
216
217 return status::success;
218 }
219
220 jit_conv_conf_t jcp_;
221 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
222 cpu_reducer_t<data_type::f32>::conf_t reducer_wei_conf_;
223
224 protected:
225 bool set_default_formats() {
226 using namespace format_tag;
227 const bool flat = IC() == 3;
228
229 auto src_tag = flat
230 ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
231 : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
232 auto dst_tag =
233 utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
234 auto wei_tag = with_groups()
235 ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
236 gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
237 : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
238 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
239
240 return set_default_formats_common(src_tag, wei_tag, dst_tag);
241 }
242
243 private:
244 void init_balancers() {
245 const int max_threads = mkldnn_get_max_threads();
246 const size_t max_buffer_size = 1<<21; /* just a heuristic */
247
248 if(with_bias()) {
249 reducer_bia_conf_.init(reduce_balancer_t(max_threads,
250 jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
251 max_buffer_size));
252 }
253
254 reducer_wei_conf_.init(reduce_balancer_t(max_threads,
255 jcp_.kd * jcp_.kh * jcp_.kw
256 * jcp_.ic_block * jcp_.oc_block,
257 jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc,
258 jcp_.mb * jcp_.od, max_buffer_size));
259 }
260 };
261
262 jit_avx2_convolution_bwd_weights_t(const pd_t *apd)
263 : cpu_primitive_t(apd)
264 , kernel_(nullptr)
265 , reducer_weights_(nullptr)
266 , reducer_bias_(nullptr)
267 {
268 kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_);
269 reducer_bias_ =
270 new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
271 reducer_weights_ =
272 new cpu_reducer_t<data_type::f32>(pd()->reducer_wei_conf_);
273 }
274
275 ~jit_avx2_convolution_bwd_weights_t() {
276 delete kernel_;
277 delete reducer_weights_;
278 delete reducer_bias_;
279 }
280
281 typedef typename prec_traits<data_type::f32>::type data_t;
282
283 virtual status_t execute(const exec_ctx_t &ctx) const override {
284 execute_backward_weights(ctx);
285 return status::success;
286 }
287
288private:
289 void execute_backward_weights(const exec_ctx_t &ctx) const;
290 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
291
292 jit_avx2_conv_bwd_weights_kernel_f32 *kernel_;
293 cpu_reducer_t<data_type::f32> *reducer_weights_, *reducer_bias_;
294};
295
296}
297}
298}
299
300#endif
301
302// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
303