1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "type_helpers.hpp"
20#include "verbose.hpp"
21
22#include "cpu_engine.hpp"
23#include "cpu_memory.hpp"
24
25//#include "cpu/rnn/ref_rnn.hpp"
26
27//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
28//#include "cpu/jit_avx512_common_1x1_convolution.hpp"
29#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp"
30#include "cpu/jit_avx512_common_convolution_winograd.hpp"
31//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp"
32#include "cpu/jit_avx512_common_convolution.hpp"
33//#include "cpu/jit_avx2_1x1_convolution.hpp"
34//#include "cpu/jit_sse42_1x1_convolution.hpp"
35#include "cpu/jit_avx2_convolution.hpp"
36#include "cpu/jit_sse42_convolution.hpp"
37//#include "cpu/gemm_convolution.hpp"
38//#include "cpu/gemm_x8s8s32x_convolution.hpp"
39//#include "cpu/ref_convolution.hpp"
40//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp"
41//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp"
42//#include "cpu/ref_deconvolution.hpp"
43//#include "cpu/ref_shuffle.hpp"
44//#include "cpu/jit_uni_eltwise.hpp"
45//#include "cpu/ref_eltwise.hpp"
46//#include "cpu/ref_softmax.hpp"
47#include "cpu/jit_uni_pooling.hpp"
48//#include "cpu/jit_uni_i8i8_pooling.hpp"
49//#include "cpu/ref_pooling.hpp"
50//#include "cpu/nchw_pooling.hpp"
51//#include "cpu/nhwc_pooling.hpp"
52//#include "cpu/jit_avx512_common_lrn.hpp"
53//#include "cpu/jit_uni_lrn.hpp"
54//#include "cpu/ref_lrn.hpp"
55//#include "cpu/jit_uni_batch_normalization.hpp"
56//#include "cpu/ref_batch_normalization.hpp"
57//#include "cpu/ncsp_batch_normalization.hpp"
58//#include "cpu/nspc_batch_normalization.hpp"
59//#include "cpu/ref_inner_product.hpp"
60//#include "cpu/gemm_inner_product.hpp"
61//#include "cpu/gemm_x8s8s32x_inner_product.hpp"
62//#include "cpu/jit_uni_dw_convolution.hpp"
63//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
64#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp"
65
66namespace mkldnn {
67namespace impl {
68namespace cpu {
69
70status_t cpu_engine_t::memory_create(memory_t **memory,
71 const memory_desc_t *md, void *handle) {
72 auto _memory = new cpu_memory_t(this, md, handle);
73 if (_memory == nullptr)
74 return status::out_of_memory;
75
76 status_t status = _memory->init();
77 if (status != status::success) {
78 delete _memory;
79 return status;
80 }
81
82 return safe_ptr_assign<memory_t>(*memory, _memory);
83}
84
85using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
86
87namespace {
88using namespace mkldnn::impl::data_type;
89
90#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t>
91static const pd_create_f cpu_impl_list[] = {
92 /* RNN */
93 /*
94 INSTANCE(ref_rnn_fwd_f32_t),
95 INSTANCE(ref_rnn_fwd_u8s8_t),
96 INSTANCE(ref_rnn_bwd_f32_t),
97 */
98 /* conv */
99 /*
100 INSTANCE(jit_avx512_common_dw_convolution_fwd_t),
101 INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t),
102 INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t),
103 INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t),
104 INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t),
105 INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t),
106 */
107 INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t),
108 INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t),
109 //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t),
110 //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t),
111 INSTANCE(jit_avx512_common_convolution_winograd_fwd_t),
112 //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t),
113 //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t),
114 INSTANCE(jit_avx512_common_convolution_fwd_t<f32>),
115 //INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>),
116 //INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>),
117 /*
118 INSTANCE(jit_avx2_dw_convolution_fwd_t),
119 INSTANCE(jit_avx2_dw_convolution_bwd_data_t),
120 INSTANCE(jit_avx2_dw_convolution_bwd_weights_t),
121 INSTANCE(jit_avx2_1x1_convolution_fwd_t),
122 INSTANCE(jit_avx2_1x1_convolution_bwd_data_t),
123 INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t),
124 INSTANCE(jit_sse42_dw_convolution_fwd_t),
125 INSTANCE(jit_sse42_dw_convolution_bwd_data_t),
126 INSTANCE(jit_sse42_dw_convolution_bwd_weights_t),
127 INSTANCE(jit_sse42_1x1_convolution_fwd_t),
128 */
129 INSTANCE(jit_avx2_convolution_fwd_t),
130 //INSTANCE(jit_avx2_convolution_bwd_data_t),
131 //INSTANCE(jit_avx2_convolution_bwd_weights_t),
132 INSTANCE(jit_sse42_convolution_fwd_t),
133 /*
134 INSTANCE(gemm_convolution_fwd_t),
135 INSTANCE(gemm_convolution_bwd_data_t),
136 INSTANCE(gemm_convolution_bwd_weights_t),
137 INSTANCE(ref_convolution_fwd_t<f32>),
138 INSTANCE(ref_convolution_bwd_data_t<f32, f32, f32, f32>),
139 INSTANCE(ref_convolution_bwd_weights_t<f32, f32, f32, f32>),
140 */
141 /* conv (int) */
142 /*
143 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<f32>),
144 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s32>),
145 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s8>),
146 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<u8>),
147 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,f32>),
148 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s32>),
149 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,u8>),
150 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s8>),
151 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,f32>),
152 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s32>),
153 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,u8>),
154 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s8>),
155 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,f32>),
156 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s32>),
157 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,u8>),
158 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s8>),
159 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,f32>),
160 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s32>),
161 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,u8>),
162 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s8>),
163 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>),
164 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>),
165 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>),
166 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>),
167 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>),
168 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>),
169 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>),
170 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>),
171 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>),
172 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>),
173 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>),
174 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<f32>),
175 INSTANCE(ref_convolution_fwd_t<u8, s8, f32, s32>),
176 INSTANCE(ref_convolution_fwd_t<u8, s8, s32, s32>),
177 INSTANCE(ref_convolution_fwd_t<u8, s8, s8, s32>),
178 INSTANCE(ref_convolution_fwd_t<u8, s8, u8, s32>),
179 INSTANCE(ref_convolution_bwd_data_t<f32, s8, u8, s32>),
180 INSTANCE(ref_convolution_bwd_data_t<s32, s8, u8, s32>),
181 INSTANCE(ref_convolution_bwd_data_t<s8, s8, u8, s32>),
182 INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>),
183 */
184 /* deconv */
185 /*
186 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>),
187 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>),
188 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>),
189 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>),
190 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>),
191 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>),
192 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>),
193 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>),
194 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>),
195 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>),
196 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>),
197 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>),
198 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>),
199 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>),
200 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>),
201 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>),
202 INSTANCE(ref_deconvolution_bwd_weights_t),
203 INSTANCE(ref_deconvolution_bwd_data_t),
204 INSTANCE(ref_deconvolution_fwd_t),
205 */
206 /* shuffle */
207 /*
208 INSTANCE(ref_shuffle_t<4>), // f32 or s32
209 INSTANCE(ref_shuffle_t<1>), // s8 or u8
210 */
211 /* eltwise */
212 /*
213 INSTANCE(jit_uni_eltwise_fwd_t<avx512_common>),
214 INSTANCE(jit_uni_eltwise_bwd_t<avx512_common>),
215 INSTANCE(jit_uni_eltwise_fwd_t<avx2>),
216 INSTANCE(jit_uni_eltwise_bwd_t<avx2>),
217 INSTANCE(jit_uni_eltwise_fwd_t<sse42>),
218 INSTANCE(jit_uni_eltwise_bwd_t<sse42>),
219 INSTANCE(ref_eltwise_fwd_t<f32>),
220 INSTANCE(ref_eltwise_bwd_t<f32>),
221 */
222 /* eltwise (int) */
223 /*
224 INSTANCE(ref_eltwise_fwd_t<s32>),
225 INSTANCE(ref_eltwise_fwd_t<s8>),
226 INSTANCE(ref_eltwise_fwd_t<u8>),
227 INSTANCE(ref_eltwise_bwd_t<s32>),
228 */
229 /* softmax */
230 /*
231 INSTANCE(ref_softmax_fwd_t<f32>),
232 INSTANCE(ref_softmax_bwd_t<f32>),
233 */
234 /* pool */
235 INSTANCE(jit_uni_pooling_fwd_t<avx512_common>),
236 //INSTANCE(jit_uni_pooling_bwd_t<avx512_common>),
237 INSTANCE(jit_uni_pooling_fwd_t<avx>),
238 //INSTANCE(jit_uni_pooling_bwd_t<avx>),
239 INSTANCE(jit_uni_pooling_fwd_t<sse42>),
240 //INSTANCE(jit_uni_pooling_bwd_t<sse42>),
241 /*
242 INSTANCE(nchw_pooling_fwd_t<f32>),
243 INSTANCE(nchw_pooling_bwd_t<f32>),
244 INSTANCE(nhwc_pooling_fwd_t<f32>),
245 INSTANCE(nhwc_pooling_bwd_t<f32>),
246 INSTANCE(ref_pooling_fwd_t<f32>),
247 INSTANCE(ref_pooling_bwd_t<f32>),
248 */
249 /* pool (int) */
250 /*
251 INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>),
252 INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>),
253 INSTANCE(ref_pooling_fwd_t<s32>),
254 INSTANCE(ref_pooling_fwd_t<s8, s32>),
255 INSTANCE(ref_pooling_fwd_t<u8, s32>),
256 INSTANCE(ref_pooling_bwd_t<s32>),
257 */
258 /* lrn */
259 /*
260 INSTANCE(jit_avx512_common_lrn_fwd_t),
261 INSTANCE(jit_avx512_common_lrn_bwd_t),
262 INSTANCE(jit_uni_lrn_fwd_t<avx2>),
263 INSTANCE(jit_uni_lrn_bwd_t<avx2>),
264 INSTANCE(jit_uni_lrn_fwd_t<sse42>),
265 INSTANCE(ref_lrn_fwd_t<f32>),
266 INSTANCE(ref_lrn_bwd_t<f32>),
267 */
268 /* batch normalization */
269 /*
270 INSTANCE(jit_uni_batch_normalization_fwd_t<avx512_common>),
271 INSTANCE(jit_uni_batch_normalization_bwd_t<avx512_common>),
272 INSTANCE(jit_uni_batch_normalization_fwd_t<avx2>),
273 INSTANCE(jit_uni_batch_normalization_bwd_t<avx2>),
274 INSTANCE(jit_uni_batch_normalization_fwd_t<sse42>),
275 INSTANCE(jit_uni_batch_normalization_bwd_t<sse42>),
276 INSTANCE(ncsp_batch_normalization_fwd_t),
277 INSTANCE(ncsp_batch_normalization_bwd_t),
278 INSTANCE(nspc_batch_normalization_fwd_t),
279 INSTANCE(nspc_batch_normalization_bwd_t),
280 INSTANCE(ref_batch_normalization_fwd_t<f32>),
281 INSTANCE(ref_batch_normalization_bwd_t<f32>),
282 INSTANCE(ref_batch_normalization_fwd_t<s8>),
283 */
284 /* inner product */
285 /*
286 INSTANCE(gemm_inner_product_fwd_t<f32>),
287 INSTANCE(gemm_inner_product_bwd_data_t<f32>),
288 INSTANCE(gemm_inner_product_bwd_weights_t<f32>),
289 INSTANCE(ref_inner_product_fwd_t<f32>),
290 INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>),
291 INSTANCE(ref_inner_product_bwd_weights_t<f32>),
292 */
293 /* inner product (int) */
294 /*
295 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>),
296 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>),
297 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>),
298 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>),
299 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>),
300 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>),
301 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>),
302 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>),
303 INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>),
304 INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>),
305 INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>),
306 INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>),
307 */
308 /* eol */
309 nullptr,
310};
311#undef INSTANCE
312}
313
314const pd_create_f* cpu_engine_t::get_implementation_list() const {
315 return cpu_impl_list;
316}
317
318cpu_engine_factory_t engine_factory;
319
320}
321}
322}
323
324// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
325