| 1 | /******************************************************************************* |
| 2 | * Copyright 2017-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #include <assert.h> |
| 18 | |
| 19 | #include "cpu_engine.hpp" |
| 20 | #include "cpu_primitive.hpp" |
| 21 | #include "cpu_reorder_pd.hpp" |
| 22 | #include "cpu_memory.hpp" |
| 23 | #include "type_helpers.hpp" |
| 24 | |
| 25 | #include "cpu/jit_uni_reorder.hpp" |
| 26 | #include "cpu/simple_reorder.hpp" |
| 27 | #include "cpu/wino_reorder.hpp" |
| 28 | #include "cpu/rnn/rnn_reorders.hpp" |
| 29 | |
| 30 | namespace mkldnn { |
| 31 | namespace impl { |
| 32 | namespace cpu { |
| 33 | |
| 34 | using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f; |
| 35 | |
| 36 | namespace { |
| 37 | using namespace mkldnn::impl::data_type; |
| 38 | using namespace mkldnn::impl::format_tag; |
| 39 | |
| 40 | #define REG_SR(idt, ifmt, odt, ofmt, ...) \ |
| 41 | simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create |
| 42 | |
| 43 | #define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ |
| 44 | REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \ |
| 45 | REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) |
| 46 | |
| 47 | #define REG_SR_DIRECT_COPY(idt, odt) \ |
| 48 | REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \ |
| 49 | REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) |
| 50 | |
| 51 | static const rpd_create_f cpu_reorder_impl_list[] = { |
| 52 | /* winograd */ |
| 53 | wino_reorder_t<f32, f32>::pd_t::create, |
| 54 | //wino_reorder_t<f32, s8>::pd_t::create, |
| 55 | |
| 56 | /* rnn reorders */ |
| 57 | rnn_data_reorder_t<f32, u8>::pd_t::create, |
| 58 | rnn_weights_reorder_t<f32, f32>::pd_t::create, |
| 59 | rnn_weights_reorder_t<f32, s8>::pd_t::create, |
| 60 | |
| 61 | /* conv reorders w/ compensation */ |
| 62 | REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), |
| 63 | REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), |
| 64 | REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), |
| 65 | REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), |
| 66 | |
| 67 | REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 68 | REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 69 | REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 70 | REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 71 | |
| 72 | REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 73 | REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 74 | REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 75 | REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), |
| 76 | |
| 77 | REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), |
| 78 | REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), |
| 79 | |
| 80 | REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), |
| 81 | REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), |
| 82 | |
| 83 | REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), |
| 84 | REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), |
| 85 | REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), |
| 86 | REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), |
| 87 | |
| 88 | /* regular reorders */ |
| 89 | |
| 90 | #if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) |
| 91 | /* Direct copy for icc which is faster than jitted code; |
| 92 | * Direct copy for gcc which might or might not be faster than jitted |
| 93 | * code, but still worth it because doesn't require jitting, i.e. much |
| 94 | * faster creation time. This is tentative solution and should be removed |
| 95 | * later (when we will cache jitted code?...). */ |
| 96 | REG_SR_DIRECT_COPY(f32, f32), |
| 97 | #endif |
| 98 | |
| 99 | #ifdef __INTEL_COMPILER |
| 100 | /* direct copy for icc, which is faster than jitted code */ |
| 101 | /* |
| 102 | REG_SR_DIRECT_COPY(f32, s32), |
| 103 | REG_SR_DIRECT_COPY(f32, s8), |
| 104 | REG_SR_DIRECT_COPY(f32, u8), |
| 105 | REG_SR_DIRECT_COPY(s32, f32), |
| 106 | REG_SR_DIRECT_COPY(s32, s32), |
| 107 | REG_SR_DIRECT_COPY(s32, s8), |
| 108 | REG_SR_DIRECT_COPY(s32, u8), |
| 109 | REG_SR_DIRECT_COPY(s8, f32), |
| 110 | REG_SR_DIRECT_COPY(s8, s32), |
| 111 | REG_SR_DIRECT_COPY(s8, s8), |
| 112 | REG_SR_DIRECT_COPY(s8, u8), |
| 113 | REG_SR_DIRECT_COPY(u8, f32), |
| 114 | REG_SR_DIRECT_COPY(u8, s32), |
| 115 | REG_SR_DIRECT_COPY(u8, s8), |
| 116 | REG_SR_DIRECT_COPY(u8, u8), |
| 117 | */ |
| 118 | #endif |
| 119 | |
| 120 | /* jit */ |
| 121 | jit_uni_reorder_create, |
| 122 | |
| 123 | /* fp32: flat <-> blocked with tail */ |
| 124 | /* |
| 125 | REG_SR_BIDIR(f32, any, f32, nCw4c), |
| 126 | REG_SR_BIDIR(f32, any, f32, nCw8c), |
| 127 | REG_SR_BIDIR(f32, any, f32, OIw4i4o), |
| 128 | REG_SR_BIDIR(f32, any, f32, OIw8i8o), |
| 129 | REG_SR_BIDIR(f32, any, f32, OIw8o8i), |
| 130 | REG_SR_BIDIR(f32, any, f32, gOIw4i4o), |
| 131 | REG_SR_BIDIR(f32, any, f32, gOIw8i8o), |
| 132 | REG_SR_BIDIR(f32, any, f32, gOIw8o8i), |
| 133 | |
| 134 | REG_SR_BIDIR(f32, any, f32, nCw16c), |
| 135 | REG_SR_BIDIR(f32, any, f32, OIw16o16i), |
| 136 | REG_SR_BIDIR(f32, any, f32, OIw16i16o), |
| 137 | REG_SR_BIDIR(f32, any, f32, IOw16o16i), |
| 138 | REG_SR_BIDIR(f32, any, f32, gOIw16o16i), |
| 139 | REG_SR_BIDIR(f32, any, f32, gOIw16i16o), |
| 140 | REG_SR_BIDIR(f32, any, f32, gIOw16o16i), |
| 141 | |
| 142 | REG_SR_BIDIR(f32, any, f32, nChw4c), |
| 143 | REG_SR_BIDIR(f32, any, f32, nChw8c), |
| 144 | REG_SR_BIDIR(f32, any, f32, OIhw4i4o), |
| 145 | REG_SR_BIDIR(f32, any, f32, Ohwi8o), |
| 146 | |
| 147 | REG_SR_BIDIR(f32, any, f32, OIhw8i8o), |
| 148 | REG_SR_BIDIR(f32, any, f32, OIhw8o8i), |
| 149 | REG_SR_BIDIR(f32, any, f32, gOIhw4i4o), |
| 150 | REG_SR_BIDIR(f32, any, f32, gOIhw4o4i), |
| 151 | REG_SR_BIDIR(f32, any, f32, gOhwi8o), |
| 152 | REG_SR_BIDIR(f32, any, f32, gOIhw8i8o), |
| 153 | REG_SR_BIDIR(f32, any, f32, gOIhw8o8i), |
| 154 | |
| 155 | REG_SR_BIDIR(f32, any, f32, nChw16c), |
| 156 | REG_SR_BIDIR(f32, any, f32, Oihw4o), |
| 157 | REG_SR_BIDIR(f32, any, f32, Oihw16o), |
| 158 | REG_SR_BIDIR(f32, any, f32, Ohwi4o), |
| 159 | REG_SR_BIDIR(f32, any, f32, Ohwi16o), |
| 160 | REG_SR_BIDIR(f32, any, f32, OIhw16o16i), |
| 161 | REG_SR_BIDIR(f32, any, f32, OIhw16i16o), |
| 162 | REG_SR_BIDIR(f32, any, f32, IOhw16o16i), |
| 163 | REG_SR_BIDIR(f32, any, f32, gOihw4o), |
| 164 | REG_SR_BIDIR(f32, any, f32, gOihw16o), |
| 165 | REG_SR_BIDIR(f32, any, f32, gOhwi4o), |
| 166 | REG_SR_BIDIR(f32, any, f32, gOhwi16o), |
| 167 | REG_SR_BIDIR(f32, any, f32, gOIhw16o16i), |
| 168 | REG_SR_BIDIR(f32, any, f32, gOIhw16i16o), |
| 169 | REG_SR_BIDIR(f32, any, f32, gIOhw16o16i), |
| 170 | |
| 171 | REG_SR_BIDIR(f32, any, f32, nCdhw4c), |
| 172 | REG_SR_BIDIR(f32, any, f32, nCdhw8c), |
| 173 | REG_SR_BIDIR(f32, any, f32, OIdhw4i4o), |
| 174 | REG_SR_BIDIR(f32, any, f32, Odhwi8o), |
| 175 | REG_SR_BIDIR(f32, any, f32, OIdhw8i8o), |
| 176 | REG_SR_BIDIR(f32, any, f32, OIdhw8o8i), |
| 177 | REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o), |
| 178 | REG_SR_BIDIR(f32, any, f32, gOdhwi8o), |
| 179 | REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o), |
| 180 | REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i), |
| 181 | |
| 182 | REG_SR_BIDIR(f32, any, f32, nCdhw16c), |
| 183 | REG_SR_BIDIR(f32, any, f32, Oidhw4o), |
| 184 | REG_SR_BIDIR(f32, any, f32, Oidhw16o), |
| 185 | REG_SR_BIDIR(f32, any, f32, Odhwi16o), |
| 186 | REG_SR_BIDIR(f32, any, f32, OIdhw16o16i), |
| 187 | REG_SR_BIDIR(f32, any, f32, OIdhw16i16o), |
| 188 | REG_SR_BIDIR(f32, any, f32, gOidhw4o), |
| 189 | REG_SR_BIDIR(f32, any, f32, gOidhw16o), |
| 190 | REG_SR_BIDIR(f32, any, f32, gOdhwi16o), |
| 191 | REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i), |
| 192 | REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o), |
| 193 | */ |
| 194 | |
| 195 | /* fp32: blocked <-> blocked with tail */ |
| 196 | REG_SR_BIDIR(f32, nCw8c, f32, nCw16c), |
| 197 | REG_SR_BIDIR(f32, nChw8c, f32, nChw16c), |
| 198 | REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c), |
| 199 | |
| 200 | /* int: flat <-> blocked with tail */ |
| 201 | /* |
| 202 | REG_SR_BIDIR(f32, any, s32, nChw16c), |
| 203 | REG_SR_BIDIR(f32, any, s8, nChw16c), |
| 204 | REG_SR_BIDIR(f32, any, u8, nChw16c), |
| 205 | REG_SR_BIDIR(s32, any, f32, nChw16c), |
| 206 | REG_SR_BIDIR(s32, any, s32, nChw16c), |
| 207 | REG_SR_BIDIR(s32, any, s8, nChw16c), |
| 208 | REG_SR_BIDIR(s32, any, u8, nChw16c), |
| 209 | REG_SR_BIDIR(s8, any, f32, nChw16c), |
| 210 | REG_SR_BIDIR(s8, any, s32, nChw16c), |
| 211 | REG_SR_BIDIR(s8, any, s8, nChw16c), |
| 212 | REG_SR_BIDIR(s8, any, u8, nChw16c), |
| 213 | REG_SR_BIDIR(u8, any, f32, nChw16c), |
| 214 | REG_SR_BIDIR(u8, any, s32, nChw16c), |
| 215 | REG_SR_BIDIR(u8, any, s8, nChw16c), |
| 216 | REG_SR_BIDIR(u8, any, u8, nChw16c), |
| 217 | |
| 218 | REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i), |
| 219 | REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i), |
| 220 | REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i), |
| 221 | REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i), |
| 222 | REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i), |
| 223 | REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i), |
| 224 | REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i), |
| 225 | REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i), |
| 226 | */ |
| 227 | |
| 228 | /* reference: the last line of defence */ |
| 229 | /* |
| 230 | REG_SR(f32, any, f32, any, fmt_order::any, spec::reference), |
| 231 | REG_SR(f32, any, s32, any, fmt_order::any, spec::reference), |
| 232 | REG_SR(f32, any, s8, any, fmt_order::any, spec::reference), |
| 233 | REG_SR(f32, any, u8, any, fmt_order::any, spec::reference), |
| 234 | |
| 235 | REG_SR(s32, any, f32, any, fmt_order::any, spec::reference), |
| 236 | REG_SR(s32, any, s32, any, fmt_order::any, spec::reference), |
| 237 | REG_SR(s32, any, s8, any, fmt_order::any, spec::reference), |
| 238 | REG_SR(s32, any, u8, any, fmt_order::any, spec::reference), |
| 239 | |
| 240 | REG_SR(s8, any, f32, any, fmt_order::any, spec::reference), |
| 241 | REG_SR(s8, any, s32, any, fmt_order::any, spec::reference), |
| 242 | REG_SR(s8, any, s8, any, fmt_order::any, spec::reference), |
| 243 | REG_SR(s8, any, u8, any, fmt_order::any, spec::reference), |
| 244 | |
| 245 | REG_SR(u8, any, f32, any, fmt_order::any, spec::reference), |
| 246 | REG_SR(u8, any, s32, any, fmt_order::any, spec::reference), |
| 247 | REG_SR(u8, any, u8, any, fmt_order::any, spec::reference), |
| 248 | REG_SR(u8, any, s8, any, fmt_order::any, spec::reference), |
| 249 | */ |
| 250 | |
| 251 | /* eol */ |
| 252 | nullptr, |
| 253 | }; |
| 254 | } |
| 255 | |
| 256 | const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const { |
| 257 | return cpu_reorder_impl_list; |
| 258 | } |
| 259 | |
| 260 | } |
| 261 | } |
| 262 | } |
| 263 | |