| 1 | /******************************************************************************* |
| 2 | * Copyright 2017-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #ifndef CPU_WINO_REORDER_HPP |
| 18 | #define CPU_WINO_REORDER_HPP |
| 19 | |
| 20 | #include "mkldnn_thread.hpp" |
| 21 | |
| 22 | #include "simple_q10n.hpp" |
| 23 | |
| 24 | namespace mkldnn { |
| 25 | namespace impl { |
| 26 | namespace cpu { |
| 27 | |
| 28 | template <data_type_t type_i, data_type_t type_o> |
| 29 | struct wino_reorder_t : public cpu_primitive_t { |
| 30 | struct pd_t : public cpu_reorder_pd_t { |
| 31 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
| 32 | |
| 33 | DECLARE_COMMON_PD_T("wino_reorder" , wino_reorder_t); |
| 34 | |
| 35 | static status_t create(reorder_pd_t **reorder_pd, |
| 36 | engine_t *engine, const primitive_attr_t *attr, |
| 37 | engine_t *src_engine, const memory_desc_t *src_md, |
| 38 | engine_t *dst_engine, const memory_desc_t *dst_md) { |
| 39 | const memory_desc_wrapper id(src_md), od(dst_md); |
| 40 | bool args_ok = true |
| 41 | && id.data_type() == type_i |
| 42 | && od.data_type() == type_o |
| 43 | && id.matches_tag(utils::pick(id.ndims() - 4, |
| 44 | format_tag::oihw, format_tag::goihw)) |
| 45 | && od.format_kind() == format_kind::wino |
| 46 | && utils::one_of(od.wino_desc().wino_format, |
| 47 | mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, |
| 48 | mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); |
| 49 | if (!args_ok) return status::invalid_arguments; |
| 50 | |
| 51 | auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, |
| 52 | dst_md); |
| 53 | if (_pd == nullptr) return status::out_of_memory; |
| 54 | if (_pd->init() != status::success) { |
| 55 | delete _pd; |
| 56 | return status::unimplemented; |
| 57 | } |
| 58 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
| 59 | } |
| 60 | |
| 61 | status_t init() { |
| 62 | status_t status = cpu_reorder_pd_t::init(); |
| 63 | if (status != status::success) return status; |
| 64 | |
| 65 | init_scratchpad(); |
| 66 | |
| 67 | return status::success; |
| 68 | } |
| 69 | |
| 70 | private: |
| 71 | void init_scratchpad() { |
| 72 | auto &o = memory_desc_wrapper(dst_md()).wino_desc(); |
| 73 | size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; |
| 74 | size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; |
| 75 | |
| 76 | using namespace memory_tracking::names; |
| 77 | auto scratchpad = scratchpad_registry().registrar(); |
| 78 | scratchpad.book(key_reorder_wino_transform_space, |
| 79 | sizeof(in_data_t) * transform_space_size); |
| 80 | scratchpad.book(key_reorder_wino_plain, |
| 81 | sizeof(out_data_t) * plain_size); |
| 82 | } |
| 83 | }; |
| 84 | |
| 85 | private: |
| 86 | typedef typename prec_traits<type_i>::type in_data_t; |
| 87 | typedef typename prec_traits<type_o>::type out_data_t; |
| 88 | const int unsign_val_in_wino_domain_ = 5; |
| 89 | |
| 90 | wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { |
| 91 | const memory_desc_wrapper src_d(pd()->src_md()); |
| 92 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
| 93 | |
| 94 | r_ = dst_d.wino_desc().r; |
| 95 | w_alpha_ = dst_d.wino_desc().alpha; |
| 96 | wino_format_ = dst_d.wino_desc().wino_format; |
| 97 | |
| 98 | const auto &in_dims = src_d.dims(); |
| 99 | int groups; |
| 100 | int groups_offset; |
| 101 | if (src_d.ndims() == 5) { |
| 102 | groups = in_dims[0]; |
| 103 | groups_offset = 1; |
| 104 | } else { |
| 105 | groups = 1; |
| 106 | groups_offset = 0; |
| 107 | } |
| 108 | assert(groups == 1); // groups are not supported now |
| 109 | MAYBE_UNUSED(groups); |
| 110 | |
| 111 | or_oc_ = in_dims[0 + groups_offset]; |
| 112 | or_ic_ = in_dims[1 + groups_offset]; |
| 113 | kh_ = in_dims[2 + groups_offset]; |
| 114 | kw_ = in_dims[3 + groups_offset]; |
| 115 | |
| 116 | oc_ = dst_d.wino_desc().oc; |
| 117 | ic_ = dst_d.wino_desc().ic; |
| 118 | oc_block_ = dst_d.wino_desc().oc_block; |
| 119 | ic_block_ = dst_d.wino_desc().ic_block; |
| 120 | assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); |
| 121 | nb_oc_ = oc_ / oc_block_; |
| 122 | nb_ic_ = ic_ / ic_block_; |
| 123 | ic2_block_ = 1; |
| 124 | if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) |
| 125 | ic2_block_ = dst_d.wino_desc().ic2_block; |
| 126 | oc2_block_ = dst_d.wino_desc().oc2_block; |
| 127 | assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); |
| 128 | |
| 129 | adj_scale_ = dst_d.wino_desc().adj_scale; |
| 130 | |
| 131 | size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; |
| 132 | size_wspace_ = r_ * w_alpha_ * oc_block_; |
| 133 | } |
| 134 | |
| 135 | void transform(out_data_t *__restrict tmp_wei, |
| 136 | const in_data_t *__restrict input, |
| 137 | in_data_t *__restrict wspace) const { |
| 138 | const memory_desc_wrapper src_d(pd()->src_md()); |
| 139 | |
| 140 | const int smask = pd()->attr()->output_scales_.mask_; |
| 141 | const int ndims_mask = math::ilog2q(smask + 1); |
| 142 | const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); |
| 143 | const float *__restrict scales = pd()->attr()->output_scales_.scales_; |
| 144 | assert(D_mask == 1 || D_mask == (size_t)oc_); |
| 145 | |
| 146 | /* transform weights to winograd domain */ |
| 147 | const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, |
| 148 | { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; |
| 149 | |
| 150 | const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, |
| 151 | { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, |
| 152 | { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, |
| 153 | { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, |
| 154 | { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, |
| 155 | { 0.f, 0.f, 1.f } }; |
| 156 | |
| 157 | float *__restrict g; |
| 158 | if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, |
| 159 | mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) |
| 160 | g = (float *)G_2x2_3x3; |
| 161 | else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) |
| 162 | g = (float *)G_4x4_3x3; |
| 163 | else { |
| 164 | assert("Unknown winograd weights target layout" ); |
| 165 | return; |
| 166 | } |
| 167 | |
| 168 | int Z = oc_ * ic_; |
| 169 | assert(r_ == kh_ && r_ == kw_); |
| 170 | |
| 171 | for (int iic = 0; iic < ic_; iic++) { |
| 172 | for (int ob = 0; ob < nb_oc_; ob++) { |
| 173 | const in_data_t *__restrict _inp |
| 174 | = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; |
| 175 | out_data_t *__restrict _out |
| 176 | = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; |
| 177 | |
| 178 | for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); |
| 179 | |
| 180 | for_nd(0, 1, r_, w_alpha_, oc_block_, |
| 181 | [&](int ih, int j, int ioc) { |
| 182 | for (int iw = 0; iw < r_; ++iw) { |
| 183 | int inp_oc = ob * oc_block_ + ioc; |
| 184 | int inp_ic = iic; |
| 185 | in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) |
| 186 | ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] |
| 187 | : 0.f; |
| 188 | wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] |
| 189 | += inp_v * g[j * r_ + iw]; |
| 190 | } |
| 191 | }); |
| 192 | |
| 193 | for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, |
| 194 | [&](int i, int j, int ioc) { |
| 195 | float t = 0; |
| 196 | for (int k = 0; k < r_; ++k) |
| 197 | t += g[i * r_ + k] |
| 198 | * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; |
| 199 | if (type_o == data_type::s8) { |
| 200 | const float scale = (D_mask == 1) |
| 201 | ? scales[0] |
| 202 | : scales[ob * oc_block_ + ioc]; |
| 203 | _out[(i * w_alpha_ + j) * Z + ioc] |
| 204 | = qz_b0<in_data_t, out_data_t>()( |
| 205 | (in_data_t)t, scale * adj_scale_); |
| 206 | } else { |
| 207 | _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; |
| 208 | } |
| 209 | }); |
| 210 | }} |
| 211 | } |
| 212 | |
| 213 | void reorder_to_aaOIoi(out_data_t *__restrict output, |
| 214 | const out_data_t *__restrict tmp_wei) const { |
| 215 | int32_t *__restrict dst_bias = nullptr; |
| 216 | if (type_o == data_type::s8) { |
| 217 | const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; |
| 218 | const size_t bias_size = w_alpha_ * w_alpha_ * oc_; |
| 219 | |
| 220 | dst_bias = (int32_t *)(output + bias_shift); |
| 221 | utils::array_set((int32_t *)dst_bias, 0, bias_size); |
| 222 | } |
| 223 | int index = 0; |
| 224 | for (int u_h = 0; u_h < w_alpha_; u_h++) { |
| 225 | for (int u_w = 0; u_w < w_alpha_; u_w++) { |
| 226 | for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { |
| 227 | int u_h_shift = u_h * w_alpha_ * ic_ * oc_; |
| 228 | int u_w_shift = u_w * ic_ * oc_; |
| 229 | int u_h_shift_b = u_h * w_alpha_ * oc_; |
| 230 | int u_w_shift_b = u_w * oc_; |
| 231 | int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; |
| 232 | for (int ib = 0; ib < nb_ic_; ib++) { |
| 233 | for (int i = 0; i < ic_block_; i++) { |
| 234 | int _i = ib * ic_block_; |
| 235 | int _o = ob * oc_block_; |
| 236 | int ic_shift = (_i + i) * oc_; |
| 237 | int oc_shift = (_o + o); |
| 238 | int ic_block_shift = ib * oc_block_ * ic_block_ + i; |
| 239 | int src_offset = |
| 240 | u_h_shift + u_w_shift + ic_shift + oc_shift; |
| 241 | int dst_offset = u_h_shift + u_w_shift + oc_block_shift |
| 242 | + ic_block_shift; |
| 243 | |
| 244 | output[dst_offset] = tmp_wei[src_offset]; |
| 245 | if (type_o == data_type::s8) { |
| 246 | int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; |
| 247 | if (index != unsign_val_in_wino_domain_) |
| 248 | dst_bias[bias_offset] |
| 249 | -= (128 * (int32_t)output[dst_offset]); |
| 250 | else |
| 251 | dst_bias[bias_offset] = 0; |
| 252 | } |
| 253 | }} |
| 254 | }); |
| 255 | index++; |
| 256 | }} |
| 257 | } |
| 258 | |
| 259 | void reorder_to_aaOio(out_data_t *__restrict output, |
| 260 | const out_data_t *__restrict tmp_wei) const { |
| 261 | for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, |
| 262 | [&](int u_h, int u_w, int ob) { |
| 263 | for (int ib = 0; ib < nb_ic_; ib++) { |
| 264 | for (int i = 0; i < ic_block_; i++) { |
| 265 | for (int o = 0; o < oc_block_; o++) { |
| 266 | int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ |
| 267 | + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); |
| 268 | |
| 269 | int dst_offset |
| 270 | = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ |
| 271 | + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ |
| 272 | + ob * nb_ic_ * ic_block_ * oc_block_ |
| 273 | + ib * ic_block_ * oc_block_ + i * oc_block_ + o; |
| 274 | output[dst_offset] = tmp_wei[src_offset]; |
| 275 | }}} |
| 276 | }); |
| 277 | } |
| 278 | |
| 279 | void reorder_to_aaOBiOo(out_data_t *__restrict output, |
| 280 | const out_data_t *__restrict tmp_wei) const { |
| 281 | int oc_chunks = nb_oc_ / oc2_block_; |
| 282 | |
| 283 | for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, |
| 284 | [&](int u_h, int u_w, int occ) { |
| 285 | for (int ib = 0; ib < nb_ic_; ib++) { |
| 286 | out_data_t *__restrict wei_ptr = output |
| 287 | + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) |
| 288 | * oc2_block_ * ic_block_ * oc_block_; |
| 289 | int wei_offset = 0; |
| 290 | for (int i = 0; i < ic_block_; i++) { |
| 291 | for (int ob2 = 0; ob2 < oc2_block_; ob2++) { |
| 292 | for (int o = 0; o < oc_block_; o++) { |
| 293 | int icp = ib * ic_block_ + i; |
| 294 | int ocp = |
| 295 | occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; |
| 296 | |
| 297 | int src_offset = u_h * w_alpha_ * ic_ * oc_ |
| 298 | + u_w * ic_ * oc_ + icp * oc_ + ocp; |
| 299 | wei_ptr[wei_offset + o] = tmp_wei[src_offset]; |
| 300 | } |
| 301 | wei_offset += oc_block_; |
| 302 | }} |
| 303 | } |
| 304 | }); |
| 305 | } |
| 306 | |
| 307 | void reorder_to_OBaaIBOIio(out_data_t *__restrict output, |
| 308 | const out_data_t *__restrict tmp_wei) const { |
| 309 | int ic_chunks = nb_ic_ / ic2_block_; |
| 310 | int oc_chunks = nb_oc_ / oc2_block_; |
| 311 | |
| 312 | for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, |
| 313 | [&](int occ, int u_h, int u_w) { |
| 314 | for (int icc = 0; icc < ic_chunks; icc++) { |
| 315 | for (int ob = 0; ob < oc2_block_; ob++) { |
| 316 | int ocp = (occ * oc2_block_ + ob) * oc_block_; |
| 317 | for (int ib = 0; ib < ic2_block_; ib++) { |
| 318 | for (int i = 0; i < ic_block_; i++) { |
| 319 | int icp = (icc * ic2_block_ + ib) * ic_block_ + i; |
| 320 | |
| 321 | int src_offset = u_h * w_alpha_ * ic_ * oc_ |
| 322 | + u_w * ic_ * oc_ + icp * oc_ + ocp; |
| 323 | int wei_offset |
| 324 | = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) |
| 325 | * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ |
| 326 | + ib) * ic_block_ + i) * oc_block_; |
| 327 | for (int o = 0; o < oc_block_; o++) |
| 328 | output[wei_offset + o] = tmp_wei[src_offset + o]; |
| 329 | }} |
| 330 | }} |
| 331 | }); |
| 332 | } |
| 333 | |
| 334 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
| 335 | auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); |
| 336 | auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); |
| 337 | |
| 338 | auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get<void>( |
| 339 | memory_tracking::names::key_reorder_wino_transform_space); |
| 340 | auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get<void>( |
| 341 | memory_tracking::names::key_reorder_wino_plain); |
| 342 | |
| 343 | transform(tmp_wei, input, wspace); |
| 344 | |
| 345 | /* reorder to winograd domain */ |
| 346 | switch (wino_format_) { |
| 347 | case mkldnn_wino_wei_aaOIoi: |
| 348 | reorder_to_aaOIoi(output, tmp_wei); break; |
| 349 | case mkldnn_wino_wei_aaOio: |
| 350 | reorder_to_aaOio(output, tmp_wei); break; |
| 351 | case mkldnn_wino_wei_aaOBiOo: |
| 352 | reorder_to_aaOBiOo(output, tmp_wei); break; |
| 353 | case mkldnn_wino_wei_OBaaIBOIio: |
| 354 | reorder_to_OBaaIBOIio(output, tmp_wei); break; |
| 355 | default: assert("Unknown wino format" ); break; |
| 356 | } |
| 357 | |
| 358 | return status::success; |
| 359 | } |
| 360 | |
| 361 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
| 362 | int r_, w_alpha_; |
| 363 | int ic_, oc_, or_ic_, or_oc_, kh_, kw_; |
| 364 | int oc_block_, ic_block_, oc2_block_, ic2_block_; |
| 365 | float adj_scale_; |
| 366 | int nb_oc_, nb_ic_; |
| 367 | mkldnn_wino_memory_format_t wino_format_; |
| 368 | int size_wino_wei_; |
| 369 | int size_wspace_; |
| 370 | }; |
| 371 | |
| 372 | } // namespace cpu |
| 373 | } // namespace impl |
| 374 | } // namespace mkldnn |
| 375 | |
| 376 | #endif |
| 377 | |