1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "cpu_engine.hpp" |
20 | #include "cpu_primitive.hpp" |
21 | #include "cpu_reorder_pd.hpp" |
22 | #include "cpu_memory.hpp" |
23 | #include "type_helpers.hpp" |
24 | |
25 | #include "cpu/jit_uni_reorder.hpp" |
26 | #include "cpu/simple_reorder.hpp" |
27 | #include "cpu/wino_reorder.hpp" |
28 | #include "cpu/rnn/rnn_reorders.hpp" |
29 | |
30 | namespace mkldnn { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f; |
35 | |
36 | namespace { |
37 | using namespace mkldnn::impl::data_type; |
38 | using namespace mkldnn::impl::format_tag; |
39 | |
40 | #define REG_SR(idt, ifmt, odt, ofmt, ...) \ |
41 | simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create |
42 | |
43 | #define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ |
44 | REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \ |
45 | REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) |
46 | |
47 | #define REG_SR_DIRECT_COPY(idt, odt) \ |
48 | REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \ |
49 | REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) |
50 | |
51 | static const rpd_create_f cpu_reorder_impl_list[] = { |
52 | /* winograd */ |
53 | wino_reorder_t<f32, f32>::pd_t::create, |
54 | //wino_reorder_t<f32, s8>::pd_t::create, |
55 | |
56 | /* rnn reorders */ |
57 | rnn_data_reorder_t<f32, u8>::pd_t::create, |
58 | rnn_weights_reorder_t<f32, f32>::pd_t::create, |
59 | rnn_weights_reorder_t<f32, s8>::pd_t::create, |
60 | |
61 | /* conv reorders w/ compensation */ |
62 | REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), |
63 | REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), |
64 | REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), |
65 | REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), |
66 | |
67 | REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
68 | REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
69 | REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
70 | REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
71 | |
72 | REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
73 | REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
74 | REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
75 | REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
76 | |
77 | REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), |
78 | REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), |
79 | |
80 | REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), |
81 | REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), |
82 | |
83 | REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), |
84 | REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), |
85 | REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), |
86 | REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), |
87 | |
88 | /* regular reorders */ |
89 | |
90 | #if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) |
91 | /* Direct copy for icc which is faster than jitted code; |
92 | * Direct copy for gcc which might or might not be faster than jitted |
93 | * code, but still worth it because doesn't require jitting, i.e. much |
94 | * faster creation time. This is tentative solution and should be removed |
95 | * later (when we will cache jitted code?...). */ |
96 | REG_SR_DIRECT_COPY(f32, f32), |
97 | #endif |
98 | |
99 | #ifdef __INTEL_COMPILER |
100 | /* direct copy for icc, which is faster than jitted code */ |
101 | /* |
102 | REG_SR_DIRECT_COPY(f32, s32), |
103 | REG_SR_DIRECT_COPY(f32, s8), |
104 | REG_SR_DIRECT_COPY(f32, u8), |
105 | REG_SR_DIRECT_COPY(s32, f32), |
106 | REG_SR_DIRECT_COPY(s32, s32), |
107 | REG_SR_DIRECT_COPY(s32, s8), |
108 | REG_SR_DIRECT_COPY(s32, u8), |
109 | REG_SR_DIRECT_COPY(s8, f32), |
110 | REG_SR_DIRECT_COPY(s8, s32), |
111 | REG_SR_DIRECT_COPY(s8, s8), |
112 | REG_SR_DIRECT_COPY(s8, u8), |
113 | REG_SR_DIRECT_COPY(u8, f32), |
114 | REG_SR_DIRECT_COPY(u8, s32), |
115 | REG_SR_DIRECT_COPY(u8, s8), |
116 | REG_SR_DIRECT_COPY(u8, u8), |
117 | */ |
118 | #endif |
119 | |
120 | /* jit */ |
121 | jit_uni_reorder_create, |
122 | |
123 | /* fp32: flat <-> blocked with tail */ |
124 | /* |
125 | REG_SR_BIDIR(f32, any, f32, nCw4c), |
126 | REG_SR_BIDIR(f32, any, f32, nCw8c), |
127 | REG_SR_BIDIR(f32, any, f32, OIw4i4o), |
128 | REG_SR_BIDIR(f32, any, f32, OIw8i8o), |
129 | REG_SR_BIDIR(f32, any, f32, OIw8o8i), |
130 | REG_SR_BIDIR(f32, any, f32, gOIw4i4o), |
131 | REG_SR_BIDIR(f32, any, f32, gOIw8i8o), |
132 | REG_SR_BIDIR(f32, any, f32, gOIw8o8i), |
133 | |
134 | REG_SR_BIDIR(f32, any, f32, nCw16c), |
135 | REG_SR_BIDIR(f32, any, f32, OIw16o16i), |
136 | REG_SR_BIDIR(f32, any, f32, OIw16i16o), |
137 | REG_SR_BIDIR(f32, any, f32, IOw16o16i), |
138 | REG_SR_BIDIR(f32, any, f32, gOIw16o16i), |
139 | REG_SR_BIDIR(f32, any, f32, gOIw16i16o), |
140 | REG_SR_BIDIR(f32, any, f32, gIOw16o16i), |
141 | |
142 | REG_SR_BIDIR(f32, any, f32, nChw4c), |
143 | REG_SR_BIDIR(f32, any, f32, nChw8c), |
144 | REG_SR_BIDIR(f32, any, f32, OIhw4i4o), |
145 | REG_SR_BIDIR(f32, any, f32, Ohwi8o), |
146 | |
147 | REG_SR_BIDIR(f32, any, f32, OIhw8i8o), |
148 | REG_SR_BIDIR(f32, any, f32, OIhw8o8i), |
149 | REG_SR_BIDIR(f32, any, f32, gOIhw4i4o), |
150 | REG_SR_BIDIR(f32, any, f32, gOIhw4o4i), |
151 | REG_SR_BIDIR(f32, any, f32, gOhwi8o), |
152 | REG_SR_BIDIR(f32, any, f32, gOIhw8i8o), |
153 | REG_SR_BIDIR(f32, any, f32, gOIhw8o8i), |
154 | |
155 | REG_SR_BIDIR(f32, any, f32, nChw16c), |
156 | REG_SR_BIDIR(f32, any, f32, Oihw4o), |
157 | REG_SR_BIDIR(f32, any, f32, Oihw16o), |
158 | REG_SR_BIDIR(f32, any, f32, Ohwi4o), |
159 | REG_SR_BIDIR(f32, any, f32, Ohwi16o), |
160 | REG_SR_BIDIR(f32, any, f32, OIhw16o16i), |
161 | REG_SR_BIDIR(f32, any, f32, OIhw16i16o), |
162 | REG_SR_BIDIR(f32, any, f32, IOhw16o16i), |
163 | REG_SR_BIDIR(f32, any, f32, gOihw4o), |
164 | REG_SR_BIDIR(f32, any, f32, gOihw16o), |
165 | REG_SR_BIDIR(f32, any, f32, gOhwi4o), |
166 | REG_SR_BIDIR(f32, any, f32, gOhwi16o), |
167 | REG_SR_BIDIR(f32, any, f32, gOIhw16o16i), |
168 | REG_SR_BIDIR(f32, any, f32, gOIhw16i16o), |
169 | REG_SR_BIDIR(f32, any, f32, gIOhw16o16i), |
170 | |
171 | REG_SR_BIDIR(f32, any, f32, nCdhw4c), |
172 | REG_SR_BIDIR(f32, any, f32, nCdhw8c), |
173 | REG_SR_BIDIR(f32, any, f32, OIdhw4i4o), |
174 | REG_SR_BIDIR(f32, any, f32, Odhwi8o), |
175 | REG_SR_BIDIR(f32, any, f32, OIdhw8i8o), |
176 | REG_SR_BIDIR(f32, any, f32, OIdhw8o8i), |
177 | REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o), |
178 | REG_SR_BIDIR(f32, any, f32, gOdhwi8o), |
179 | REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o), |
180 | REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i), |
181 | |
182 | REG_SR_BIDIR(f32, any, f32, nCdhw16c), |
183 | REG_SR_BIDIR(f32, any, f32, Oidhw4o), |
184 | REG_SR_BIDIR(f32, any, f32, Oidhw16o), |
185 | REG_SR_BIDIR(f32, any, f32, Odhwi16o), |
186 | REG_SR_BIDIR(f32, any, f32, OIdhw16o16i), |
187 | REG_SR_BIDIR(f32, any, f32, OIdhw16i16o), |
188 | REG_SR_BIDIR(f32, any, f32, gOidhw4o), |
189 | REG_SR_BIDIR(f32, any, f32, gOidhw16o), |
190 | REG_SR_BIDIR(f32, any, f32, gOdhwi16o), |
191 | REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i), |
192 | REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o), |
193 | */ |
194 | |
195 | /* fp32: blocked <-> blocked with tail */ |
196 | REG_SR_BIDIR(f32, nCw8c, f32, nCw16c), |
197 | REG_SR_BIDIR(f32, nChw8c, f32, nChw16c), |
198 | REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c), |
199 | |
200 | /* int: flat <-> blocked with tail */ |
201 | /* |
202 | REG_SR_BIDIR(f32, any, s32, nChw16c), |
203 | REG_SR_BIDIR(f32, any, s8, nChw16c), |
204 | REG_SR_BIDIR(f32, any, u8, nChw16c), |
205 | REG_SR_BIDIR(s32, any, f32, nChw16c), |
206 | REG_SR_BIDIR(s32, any, s32, nChw16c), |
207 | REG_SR_BIDIR(s32, any, s8, nChw16c), |
208 | REG_SR_BIDIR(s32, any, u8, nChw16c), |
209 | REG_SR_BIDIR(s8, any, f32, nChw16c), |
210 | REG_SR_BIDIR(s8, any, s32, nChw16c), |
211 | REG_SR_BIDIR(s8, any, s8, nChw16c), |
212 | REG_SR_BIDIR(s8, any, u8, nChw16c), |
213 | REG_SR_BIDIR(u8, any, f32, nChw16c), |
214 | REG_SR_BIDIR(u8, any, s32, nChw16c), |
215 | REG_SR_BIDIR(u8, any, s8, nChw16c), |
216 | REG_SR_BIDIR(u8, any, u8, nChw16c), |
217 | |
218 | REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i), |
219 | REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i), |
220 | REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i), |
221 | REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i), |
222 | REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i), |
223 | REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i), |
224 | REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i), |
225 | REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i), |
226 | */ |
227 | |
228 | /* reference: the last line of defence */ |
229 | /* |
230 | REG_SR(f32, any, f32, any, fmt_order::any, spec::reference), |
231 | REG_SR(f32, any, s32, any, fmt_order::any, spec::reference), |
232 | REG_SR(f32, any, s8, any, fmt_order::any, spec::reference), |
233 | REG_SR(f32, any, u8, any, fmt_order::any, spec::reference), |
234 | |
235 | REG_SR(s32, any, f32, any, fmt_order::any, spec::reference), |
236 | REG_SR(s32, any, s32, any, fmt_order::any, spec::reference), |
237 | REG_SR(s32, any, s8, any, fmt_order::any, spec::reference), |
238 | REG_SR(s32, any, u8, any, fmt_order::any, spec::reference), |
239 | |
240 | REG_SR(s8, any, f32, any, fmt_order::any, spec::reference), |
241 | REG_SR(s8, any, s32, any, fmt_order::any, spec::reference), |
242 | REG_SR(s8, any, s8, any, fmt_order::any, spec::reference), |
243 | REG_SR(s8, any, u8, any, fmt_order::any, spec::reference), |
244 | |
245 | REG_SR(u8, any, f32, any, fmt_order::any, spec::reference), |
246 | REG_SR(u8, any, s32, any, fmt_order::any, spec::reference), |
247 | REG_SR(u8, any, u8, any, fmt_order::any, spec::reference), |
248 | REG_SR(u8, any, s8, any, fmt_order::any, spec::reference), |
249 | */ |
250 | |
251 | /* eol */ |
252 | nullptr, |
253 | }; |
254 | } |
255 | |
256 | const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const { |
257 | return cpu_reorder_impl_list; |
258 | } |
259 | |
260 | } |
261 | } |
262 | } |
263 | |