1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "cpu_engine.hpp"
20#include "cpu_primitive.hpp"
21#include "cpu_reorder_pd.hpp"
22#include "cpu_memory.hpp"
23#include "type_helpers.hpp"
24
25#include "cpu/jit_uni_reorder.hpp"
26#include "cpu/simple_reorder.hpp"
27#include "cpu/wino_reorder.hpp"
28#include "cpu/rnn/rnn_reorders.hpp"
29
30namespace mkldnn {
31namespace impl {
32namespace cpu {
33
34using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f;
35
36namespace {
37using namespace mkldnn::impl::data_type;
38using namespace mkldnn::impl::format_tag;
39
40#define REG_SR(idt, ifmt, odt, ofmt, ...) \
41 simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create
42
43#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \
44 REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \
45 REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse)
46
47#define REG_SR_DIRECT_COPY(idt, odt) \
48 REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \
49 REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0)
50
51static const rpd_create_f cpu_reorder_impl_list[] = {
52 /* winograd */
53 wino_reorder_t<f32, f32>::pd_t::create,
54 //wino_reorder_t<f32, s8>::pd_t::create,
55
56 /* rnn reorders */
57 rnn_data_reorder_t<f32, u8>::pd_t::create,
58 rnn_weights_reorder_t<f32, f32>::pd_t::create,
59 rnn_weights_reorder_t<f32, s8>::pd_t::create,
60
61 /* conv reorders w/ compensation */
62 REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8),
63 REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8),
64 REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8),
65 REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8),
66
67 REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
68 REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
69 REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
70 REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
71
72 REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
73 REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
74 REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
75 REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
76
77 REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8),
78 REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8),
79
80 REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8),
81 REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8),
82
83 REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8),
84 REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8),
85 REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8),
86 REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8),
87
88 /* regular reorders */
89
90#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__))
91 /* Direct copy for icc which is faster than jitted code;
92 * Direct copy for gcc which might or might not be faster than jitted
93 * code, but still worth it because doesn't require jitting, i.e. much
94 * faster creation time. This is tentative solution and should be removed
95 * later (when we will cache jitted code?...). */
96 REG_SR_DIRECT_COPY(f32, f32),
97#endif
98
99#ifdef __INTEL_COMPILER
100 /* direct copy for icc, which is faster than jitted code */
101 /*
102 REG_SR_DIRECT_COPY(f32, s32),
103 REG_SR_DIRECT_COPY(f32, s8),
104 REG_SR_DIRECT_COPY(f32, u8),
105 REG_SR_DIRECT_COPY(s32, f32),
106 REG_SR_DIRECT_COPY(s32, s32),
107 REG_SR_DIRECT_COPY(s32, s8),
108 REG_SR_DIRECT_COPY(s32, u8),
109 REG_SR_DIRECT_COPY(s8, f32),
110 REG_SR_DIRECT_COPY(s8, s32),
111 REG_SR_DIRECT_COPY(s8, s8),
112 REG_SR_DIRECT_COPY(s8, u8),
113 REG_SR_DIRECT_COPY(u8, f32),
114 REG_SR_DIRECT_COPY(u8, s32),
115 REG_SR_DIRECT_COPY(u8, s8),
116 REG_SR_DIRECT_COPY(u8, u8),
117 */
118#endif
119
120 /* jit */
121 jit_uni_reorder_create,
122
123 /* fp32: flat <-> blocked with tail */
124 /*
125 REG_SR_BIDIR(f32, any, f32, nCw4c),
126 REG_SR_BIDIR(f32, any, f32, nCw8c),
127 REG_SR_BIDIR(f32, any, f32, OIw4i4o),
128 REG_SR_BIDIR(f32, any, f32, OIw8i8o),
129 REG_SR_BIDIR(f32, any, f32, OIw8o8i),
130 REG_SR_BIDIR(f32, any, f32, gOIw4i4o),
131 REG_SR_BIDIR(f32, any, f32, gOIw8i8o),
132 REG_SR_BIDIR(f32, any, f32, gOIw8o8i),
133
134 REG_SR_BIDIR(f32, any, f32, nCw16c),
135 REG_SR_BIDIR(f32, any, f32, OIw16o16i),
136 REG_SR_BIDIR(f32, any, f32, OIw16i16o),
137 REG_SR_BIDIR(f32, any, f32, IOw16o16i),
138 REG_SR_BIDIR(f32, any, f32, gOIw16o16i),
139 REG_SR_BIDIR(f32, any, f32, gOIw16i16o),
140 REG_SR_BIDIR(f32, any, f32, gIOw16o16i),
141
142 REG_SR_BIDIR(f32, any, f32, nChw4c),
143 REG_SR_BIDIR(f32, any, f32, nChw8c),
144 REG_SR_BIDIR(f32, any, f32, OIhw4i4o),
145 REG_SR_BIDIR(f32, any, f32, Ohwi8o),
146
147 REG_SR_BIDIR(f32, any, f32, OIhw8i8o),
148 REG_SR_BIDIR(f32, any, f32, OIhw8o8i),
149 REG_SR_BIDIR(f32, any, f32, gOIhw4i4o),
150 REG_SR_BIDIR(f32, any, f32, gOIhw4o4i),
151 REG_SR_BIDIR(f32, any, f32, gOhwi8o),
152 REG_SR_BIDIR(f32, any, f32, gOIhw8i8o),
153 REG_SR_BIDIR(f32, any, f32, gOIhw8o8i),
154
155 REG_SR_BIDIR(f32, any, f32, nChw16c),
156 REG_SR_BIDIR(f32, any, f32, Oihw4o),
157 REG_SR_BIDIR(f32, any, f32, Oihw16o),
158 REG_SR_BIDIR(f32, any, f32, Ohwi4o),
159 REG_SR_BIDIR(f32, any, f32, Ohwi16o),
160 REG_SR_BIDIR(f32, any, f32, OIhw16o16i),
161 REG_SR_BIDIR(f32, any, f32, OIhw16i16o),
162 REG_SR_BIDIR(f32, any, f32, IOhw16o16i),
163 REG_SR_BIDIR(f32, any, f32, gOihw4o),
164 REG_SR_BIDIR(f32, any, f32, gOihw16o),
165 REG_SR_BIDIR(f32, any, f32, gOhwi4o),
166 REG_SR_BIDIR(f32, any, f32, gOhwi16o),
167 REG_SR_BIDIR(f32, any, f32, gOIhw16o16i),
168 REG_SR_BIDIR(f32, any, f32, gOIhw16i16o),
169 REG_SR_BIDIR(f32, any, f32, gIOhw16o16i),
170
171 REG_SR_BIDIR(f32, any, f32, nCdhw4c),
172 REG_SR_BIDIR(f32, any, f32, nCdhw8c),
173 REG_SR_BIDIR(f32, any, f32, OIdhw4i4o),
174 REG_SR_BIDIR(f32, any, f32, Odhwi8o),
175 REG_SR_BIDIR(f32, any, f32, OIdhw8i8o),
176 REG_SR_BIDIR(f32, any, f32, OIdhw8o8i),
177 REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o),
178 REG_SR_BIDIR(f32, any, f32, gOdhwi8o),
179 REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o),
180 REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i),
181
182 REG_SR_BIDIR(f32, any, f32, nCdhw16c),
183 REG_SR_BIDIR(f32, any, f32, Oidhw4o),
184 REG_SR_BIDIR(f32, any, f32, Oidhw16o),
185 REG_SR_BIDIR(f32, any, f32, Odhwi16o),
186 REG_SR_BIDIR(f32, any, f32, OIdhw16o16i),
187 REG_SR_BIDIR(f32, any, f32, OIdhw16i16o),
188 REG_SR_BIDIR(f32, any, f32, gOidhw4o),
189 REG_SR_BIDIR(f32, any, f32, gOidhw16o),
190 REG_SR_BIDIR(f32, any, f32, gOdhwi16o),
191 REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i),
192 REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o),
193 */
194
195 /* fp32: blocked <-> blocked with tail */
196 REG_SR_BIDIR(f32, nCw8c, f32, nCw16c),
197 REG_SR_BIDIR(f32, nChw8c, f32, nChw16c),
198 REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c),
199
200 /* int: flat <-> blocked with tail */
201 /*
202 REG_SR_BIDIR(f32, any, s32, nChw16c),
203 REG_SR_BIDIR(f32, any, s8, nChw16c),
204 REG_SR_BIDIR(f32, any, u8, nChw16c),
205 REG_SR_BIDIR(s32, any, f32, nChw16c),
206 REG_SR_BIDIR(s32, any, s32, nChw16c),
207 REG_SR_BIDIR(s32, any, s8, nChw16c),
208 REG_SR_BIDIR(s32, any, u8, nChw16c),
209 REG_SR_BIDIR(s8, any, f32, nChw16c),
210 REG_SR_BIDIR(s8, any, s32, nChw16c),
211 REG_SR_BIDIR(s8, any, s8, nChw16c),
212 REG_SR_BIDIR(s8, any, u8, nChw16c),
213 REG_SR_BIDIR(u8, any, f32, nChw16c),
214 REG_SR_BIDIR(u8, any, s32, nChw16c),
215 REG_SR_BIDIR(u8, any, s8, nChw16c),
216 REG_SR_BIDIR(u8, any, u8, nChw16c),
217
218 REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i),
219 REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i),
220 REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i),
221 REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i),
222 REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i),
223 REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i),
224 REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i),
225 REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i),
226 */
227
228 /* reference: the last line of defence */
229 /*
230 REG_SR(f32, any, f32, any, fmt_order::any, spec::reference),
231 REG_SR(f32, any, s32, any, fmt_order::any, spec::reference),
232 REG_SR(f32, any, s8, any, fmt_order::any, spec::reference),
233 REG_SR(f32, any, u8, any, fmt_order::any, spec::reference),
234
235 REG_SR(s32, any, f32, any, fmt_order::any, spec::reference),
236 REG_SR(s32, any, s32, any, fmt_order::any, spec::reference),
237 REG_SR(s32, any, s8, any, fmt_order::any, spec::reference),
238 REG_SR(s32, any, u8, any, fmt_order::any, spec::reference),
239
240 REG_SR(s8, any, f32, any, fmt_order::any, spec::reference),
241 REG_SR(s8, any, s32, any, fmt_order::any, spec::reference),
242 REG_SR(s8, any, s8, any, fmt_order::any, spec::reference),
243 REG_SR(s8, any, u8, any, fmt_order::any, spec::reference),
244
245 REG_SR(u8, any, f32, any, fmt_order::any, spec::reference),
246 REG_SR(u8, any, s32, any, fmt_order::any, spec::reference),
247 REG_SR(u8, any, u8, any, fmt_order::any, spec::reference),
248 REG_SR(u8, any, s8, any, fmt_order::any, spec::reference),
249 */
250
251 /* eol */
252 nullptr,
253};
254}
255
256const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const {
257 return cpu_reorder_impl_list;
258}
259
260}
261}
262}
263