1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "type_helpers.hpp" |
20 | #include "verbose.hpp" |
21 | |
22 | #include "cpu_engine.hpp" |
23 | #include "cpu_memory.hpp" |
24 | |
25 | //#include "cpu/rnn/ref_rnn.hpp" |
26 | |
27 | //#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" |
28 | //#include "cpu/jit_avx512_common_1x1_convolution.hpp" |
29 | #include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp" |
30 | #include "cpu/jit_avx512_common_convolution_winograd.hpp" |
31 | //#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp" |
32 | #include "cpu/jit_avx512_common_convolution.hpp" |
33 | //#include "cpu/jit_avx2_1x1_convolution.hpp" |
34 | //#include "cpu/jit_sse42_1x1_convolution.hpp" |
35 | #include "cpu/jit_avx2_convolution.hpp" |
36 | #include "cpu/jit_sse42_convolution.hpp" |
37 | //#include "cpu/gemm_convolution.hpp" |
38 | //#include "cpu/gemm_x8s8s32x_convolution.hpp" |
39 | //#include "cpu/ref_convolution.hpp" |
40 | //#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp" |
41 | //#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp" |
42 | //#include "cpu/ref_deconvolution.hpp" |
43 | //#include "cpu/ref_shuffle.hpp" |
44 | //#include "cpu/jit_uni_eltwise.hpp" |
45 | //#include "cpu/ref_eltwise.hpp" |
46 | //#include "cpu/ref_softmax.hpp" |
47 | #include "cpu/jit_uni_pooling.hpp" |
48 | //#include "cpu/jit_uni_i8i8_pooling.hpp" |
49 | //#include "cpu/ref_pooling.hpp" |
50 | //#include "cpu/nchw_pooling.hpp" |
51 | //#include "cpu/nhwc_pooling.hpp" |
52 | //#include "cpu/jit_avx512_common_lrn.hpp" |
53 | //#include "cpu/jit_uni_lrn.hpp" |
54 | //#include "cpu/ref_lrn.hpp" |
55 | //#include "cpu/jit_uni_batch_normalization.hpp" |
56 | //#include "cpu/ref_batch_normalization.hpp" |
57 | //#include "cpu/ncsp_batch_normalization.hpp" |
58 | //#include "cpu/nspc_batch_normalization.hpp" |
59 | //#include "cpu/ref_inner_product.hpp" |
60 | //#include "cpu/gemm_inner_product.hpp" |
61 | //#include "cpu/gemm_x8s8s32x_inner_product.hpp" |
62 | //#include "cpu/jit_uni_dw_convolution.hpp" |
63 | //#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp" |
64 | #include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp" |
65 | |
66 | namespace mkldnn { |
67 | namespace impl { |
68 | namespace cpu { |
69 | |
70 | status_t cpu_engine_t::memory_create(memory_t **memory, |
71 | const memory_desc_t *md, void *handle) { |
72 | auto _memory = new cpu_memory_t(this, md, handle); |
73 | if (_memory == nullptr) |
74 | return status::out_of_memory; |
75 | |
76 | status_t status = _memory->init(); |
77 | if (status != status::success) { |
78 | delete _memory; |
79 | return status; |
80 | } |
81 | |
82 | return safe_ptr_assign<memory_t>(*memory, _memory); |
83 | } |
84 | |
85 | using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; |
86 | |
87 | namespace { |
88 | using namespace mkldnn::impl::data_type; |
89 | |
90 | #define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t> |
91 | static const pd_create_f cpu_impl_list[] = { |
92 | /* RNN */ |
93 | /* |
94 | INSTANCE(ref_rnn_fwd_f32_t), |
95 | INSTANCE(ref_rnn_fwd_u8s8_t), |
96 | INSTANCE(ref_rnn_bwd_f32_t), |
97 | */ |
98 | /* conv */ |
99 | /* |
100 | INSTANCE(jit_avx512_common_dw_convolution_fwd_t), |
101 | INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t), |
102 | INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t), |
103 | INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t), |
104 | INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t), |
105 | INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t), |
106 | */ |
107 | INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t), |
108 | INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t), |
109 | //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t), |
110 | //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t), |
111 | INSTANCE(jit_avx512_common_convolution_winograd_fwd_t), |
112 | //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t), |
113 | //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t), |
114 | INSTANCE(jit_avx512_common_convolution_fwd_t<f32>), |
115 | //INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>), |
116 | //INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>), |
117 | /* |
118 | INSTANCE(jit_avx2_dw_convolution_fwd_t), |
119 | INSTANCE(jit_avx2_dw_convolution_bwd_data_t), |
120 | INSTANCE(jit_avx2_dw_convolution_bwd_weights_t), |
121 | INSTANCE(jit_avx2_1x1_convolution_fwd_t), |
122 | INSTANCE(jit_avx2_1x1_convolution_bwd_data_t), |
123 | INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t), |
124 | INSTANCE(jit_sse42_dw_convolution_fwd_t), |
125 | INSTANCE(jit_sse42_dw_convolution_bwd_data_t), |
126 | INSTANCE(jit_sse42_dw_convolution_bwd_weights_t), |
127 | INSTANCE(jit_sse42_1x1_convolution_fwd_t), |
128 | */ |
129 | INSTANCE(jit_avx2_convolution_fwd_t), |
130 | //INSTANCE(jit_avx2_convolution_bwd_data_t), |
131 | //INSTANCE(jit_avx2_convolution_bwd_weights_t), |
132 | INSTANCE(jit_sse42_convolution_fwd_t), |
133 | /* |
134 | INSTANCE(gemm_convolution_fwd_t), |
135 | INSTANCE(gemm_convolution_bwd_data_t), |
136 | INSTANCE(gemm_convolution_bwd_weights_t), |
137 | INSTANCE(ref_convolution_fwd_t<f32>), |
138 | INSTANCE(ref_convolution_bwd_data_t<f32, f32, f32, f32>), |
139 | INSTANCE(ref_convolution_bwd_weights_t<f32, f32, f32, f32>), |
140 | */ |
141 | /* conv (int) */ |
142 | /* |
143 | INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<f32>), |
144 | INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s32>), |
145 | INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s8>), |
146 | INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<u8>), |
147 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,f32>), |
148 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s32>), |
149 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,u8>), |
150 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s8>), |
151 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,f32>), |
152 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s32>), |
153 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,u8>), |
154 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s8>), |
155 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,f32>), |
156 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s32>), |
157 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,u8>), |
158 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s8>), |
159 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,f32>), |
160 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s32>), |
161 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,u8>), |
162 | INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s8>), |
163 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>), |
164 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>), |
165 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>), |
166 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>), |
167 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>), |
168 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>), |
169 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>), |
170 | INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>), |
171 | INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>), |
172 | INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>), |
173 | INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>), |
174 | INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<f32>), |
175 | INSTANCE(ref_convolution_fwd_t<u8, s8, f32, s32>), |
176 | INSTANCE(ref_convolution_fwd_t<u8, s8, s32, s32>), |
177 | INSTANCE(ref_convolution_fwd_t<u8, s8, s8, s32>), |
178 | INSTANCE(ref_convolution_fwd_t<u8, s8, u8, s32>), |
179 | INSTANCE(ref_convolution_bwd_data_t<f32, s8, u8, s32>), |
180 | INSTANCE(ref_convolution_bwd_data_t<s32, s8, u8, s32>), |
181 | INSTANCE(ref_convolution_bwd_data_t<s8, s8, u8, s32>), |
182 | INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>), |
183 | */ |
184 | /* deconv */ |
185 | /* |
186 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>), |
187 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>), |
188 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>), |
189 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>), |
190 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>), |
191 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>), |
192 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>), |
193 | INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>), |
194 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>), |
195 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>), |
196 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>), |
197 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>), |
198 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>), |
199 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>), |
200 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>), |
201 | INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>), |
202 | INSTANCE(ref_deconvolution_bwd_weights_t), |
203 | INSTANCE(ref_deconvolution_bwd_data_t), |
204 | INSTANCE(ref_deconvolution_fwd_t), |
205 | */ |
206 | /* shuffle */ |
207 | /* |
208 | INSTANCE(ref_shuffle_t<4>), // f32 or s32 |
209 | INSTANCE(ref_shuffle_t<1>), // s8 or u8 |
210 | */ |
211 | /* eltwise */ |
212 | /* |
213 | INSTANCE(jit_uni_eltwise_fwd_t<avx512_common>), |
214 | INSTANCE(jit_uni_eltwise_bwd_t<avx512_common>), |
215 | INSTANCE(jit_uni_eltwise_fwd_t<avx2>), |
216 | INSTANCE(jit_uni_eltwise_bwd_t<avx2>), |
217 | INSTANCE(jit_uni_eltwise_fwd_t<sse42>), |
218 | INSTANCE(jit_uni_eltwise_bwd_t<sse42>), |
219 | INSTANCE(ref_eltwise_fwd_t<f32>), |
220 | INSTANCE(ref_eltwise_bwd_t<f32>), |
221 | */ |
222 | /* eltwise (int) */ |
223 | /* |
224 | INSTANCE(ref_eltwise_fwd_t<s32>), |
225 | INSTANCE(ref_eltwise_fwd_t<s8>), |
226 | INSTANCE(ref_eltwise_fwd_t<u8>), |
227 | INSTANCE(ref_eltwise_bwd_t<s32>), |
228 | */ |
229 | /* softmax */ |
230 | /* |
231 | INSTANCE(ref_softmax_fwd_t<f32>), |
232 | INSTANCE(ref_softmax_bwd_t<f32>), |
233 | */ |
234 | /* pool */ |
235 | INSTANCE(jit_uni_pooling_fwd_t<avx512_common>), |
236 | //INSTANCE(jit_uni_pooling_bwd_t<avx512_common>), |
237 | INSTANCE(jit_uni_pooling_fwd_t<avx>), |
238 | //INSTANCE(jit_uni_pooling_bwd_t<avx>), |
239 | INSTANCE(jit_uni_pooling_fwd_t<sse42>), |
240 | //INSTANCE(jit_uni_pooling_bwd_t<sse42>), |
241 | /* |
242 | INSTANCE(nchw_pooling_fwd_t<f32>), |
243 | INSTANCE(nchw_pooling_bwd_t<f32>), |
244 | INSTANCE(nhwc_pooling_fwd_t<f32>), |
245 | INSTANCE(nhwc_pooling_bwd_t<f32>), |
246 | INSTANCE(ref_pooling_fwd_t<f32>), |
247 | INSTANCE(ref_pooling_bwd_t<f32>), |
248 | */ |
249 | /* pool (int) */ |
250 | /* |
251 | INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>), |
252 | INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>), |
253 | INSTANCE(ref_pooling_fwd_t<s32>), |
254 | INSTANCE(ref_pooling_fwd_t<s8, s32>), |
255 | INSTANCE(ref_pooling_fwd_t<u8, s32>), |
256 | INSTANCE(ref_pooling_bwd_t<s32>), |
257 | */ |
258 | /* lrn */ |
259 | /* |
260 | INSTANCE(jit_avx512_common_lrn_fwd_t), |
261 | INSTANCE(jit_avx512_common_lrn_bwd_t), |
262 | INSTANCE(jit_uni_lrn_fwd_t<avx2>), |
263 | INSTANCE(jit_uni_lrn_bwd_t<avx2>), |
264 | INSTANCE(jit_uni_lrn_fwd_t<sse42>), |
265 | INSTANCE(ref_lrn_fwd_t<f32>), |
266 | INSTANCE(ref_lrn_bwd_t<f32>), |
267 | */ |
268 | /* batch normalization */ |
269 | /* |
270 | INSTANCE(jit_uni_batch_normalization_fwd_t<avx512_common>), |
271 | INSTANCE(jit_uni_batch_normalization_bwd_t<avx512_common>), |
272 | INSTANCE(jit_uni_batch_normalization_fwd_t<avx2>), |
273 | INSTANCE(jit_uni_batch_normalization_bwd_t<avx2>), |
274 | INSTANCE(jit_uni_batch_normalization_fwd_t<sse42>), |
275 | INSTANCE(jit_uni_batch_normalization_bwd_t<sse42>), |
276 | INSTANCE(ncsp_batch_normalization_fwd_t), |
277 | INSTANCE(ncsp_batch_normalization_bwd_t), |
278 | INSTANCE(nspc_batch_normalization_fwd_t), |
279 | INSTANCE(nspc_batch_normalization_bwd_t), |
280 | INSTANCE(ref_batch_normalization_fwd_t<f32>), |
281 | INSTANCE(ref_batch_normalization_bwd_t<f32>), |
282 | INSTANCE(ref_batch_normalization_fwd_t<s8>), |
283 | */ |
284 | /* inner product */ |
285 | /* |
286 | INSTANCE(gemm_inner_product_fwd_t<f32>), |
287 | INSTANCE(gemm_inner_product_bwd_data_t<f32>), |
288 | INSTANCE(gemm_inner_product_bwd_weights_t<f32>), |
289 | INSTANCE(ref_inner_product_fwd_t<f32>), |
290 | INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>), |
291 | INSTANCE(ref_inner_product_bwd_weights_t<f32>), |
292 | */ |
293 | /* inner product (int) */ |
294 | /* |
295 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>), |
296 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>), |
297 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>), |
298 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>), |
299 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>), |
300 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>), |
301 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>), |
302 | INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>), |
303 | INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>), |
304 | INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>), |
305 | INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>), |
306 | INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>), |
307 | */ |
308 | /* eol */ |
309 | nullptr, |
310 | }; |
311 | #undef INSTANCE |
312 | } |
313 | |
314 | const pd_create_f* cpu_engine_t::get_implementation_list() const { |
315 | return cpu_impl_list; |
316 | } |
317 | |
318 | cpu_engine_factory_t engine_factory; |
319 | |
320 | } |
321 | } |
322 | } |
323 | |
324 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
325 | |