1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_SIMPLE_REORDER_HPP
18#define CPU_SIMPLE_REORDER_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "math_utils.hpp"
25#include "mkldnn_thread.hpp"
26#include "utils.hpp"
27
28#include "tag_traits.hpp"
29#include "cpu_reorder_pd.hpp"
30#include "cpu_primitive.hpp"
31
32#include "simple_q10n.hpp"
33#include "cpu_isa_traits.hpp"
34
35namespace mkldnn {
36namespace impl {
37namespace cpu {
38
39using namespace mkldnn::impl::status;
40using namespace mkldnn::impl::format_tag;
41using namespace mkldnn::impl::data_type;
42
43using bd = block_dim_t;
44using ib = inner_blk_t;
45
46using namespace mkldnn::impl::utils;
47using math::saturate;
48
49template<impl::data_type_t type>
50using data_t = typename prec_traits<type>::type;
51
52template<impl::data_type_t type_i, impl::data_type_t type_o>
53using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
54
55template<impl::data_type_t type_i, impl::data_type_t type_o>
56using _qz = qz<data_t<type_i>, data_t<type_o>>;
57
58namespace fmt_order {
59 const bool keep = true;
60 const bool reverse = false;
61 const bool any = keep;
62}
63
64namespace spec {
65struct direct_copy {};
66struct direct_copy_except_dim_0 {};
67struct reference {};
68struct conv_s8s8 {};
69}
70
71#define SIMPLE_REORDER_TEMPL_DECL \
72 impl::data_type_t type_i, impl::format_tag_t tag_i, \
73 impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep
74#define SIMPLE_REORDER_TEMPL_CALL \
75 type_i, tag_i, type_o, tag_o, order_keep
76
77#define DECLARE_COMMON_PARAMS() \
78 const memory_desc_wrapper &input_d = pd->src_md(); \
79 const memory_desc_wrapper &output_d = pd->dst_md(); \
80 const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \
81 const float beta = pd->beta(); MAYBE_UNUSED(beta);
82
83/* specific reorders: common template */
84template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
85struct simple_reorder_impl {};
86
87namespace {
88inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i,
89 impl::format_tag_t tag_o, const memory_desc_wrapper &input_d,
90 const memory_desc_wrapper &output_d) {
91 return input_d.matches_tag(order_keep ? tag_i : tag_o)
92 && output_d.matches_tag(order_keep ? tag_o : tag_i);
93}
94inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) {
95 if (many_scales_support)
96 return true;
97 return IMPLICATION(attr, attr->output_scales_.mask_ == 0);
98}
99}
100
101/* specific reorders: implementation */
102template <SIMPLE_REORDER_TEMPL_DECL>
103struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
104typename utils::enable_if<tag_i == any && (false
105 || tag_o == hwio
106 || tag_o == hwigo)
107 , spec::conv_s8s8>::type>
108{
109 static bool is_applicable(const memory_desc_wrapper &input_d,
110 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
111 {
112 const size_t D_mask = utils::array_product(input_d.dims(),
113 math::ilog2q(attr->output_scales_.mask_ + 1));
114 const int oc = (input_d.dims()[tag_o == hwigo + 0]);
115 const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1;
116
117 return output_d.matches_tag(tag_o)
118 && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
119 && (input_d.data_type() == f32 || input_d.data_type() == s8)
120 && output_d.data_type() == s8
121 && (D_mask == 1 || D_mask == (size_t)g * oc);
122 }
123
124 static status_t execute(const cpu_reorder_pd_t *pd,
125 const data_t<type_i> *input, data_t<type_o> *output) {
126 DECLARE_COMMON_PARAMS();
127
128 static constexpr bool w_groups = tag_o == hwigo;
129
130 const auto &dims = input_d.dims();
131 const auto &pdims = output_d.padded_dims();
132
133 const int G = w_groups ? dims[0] : 1;
134 const int OC = dims[w_groups + 0];
135 const int IC = dims[w_groups + 1];
136 const int H = dims[w_groups + 2];
137 const int W = dims[w_groups + 3];
138
139 const float *scales = pd->attr()->output_scales_.scales_;
140 const size_t D_mask = utils::array_product(input_d.dims(),
141 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
142
143 assert(output_d.extra().flags
144 & memory_extra_flags::compensation_conv_s8s8);
145 float adj_scale =
146 (output_d.extra().flags & memory_extra_flags::scale_adjust)
147 ? output_d.extra().scale_adjust : 1.f;
148
149 size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W;
150 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
151
152 parallel_nd(G, OC, [&](int g, int oc) {
153 cp[g * OC + oc] = 0;
154 for (int ic = 0; ic < IC; ic++)
155 for (int h = 0; h < H; h++)
156 for (int w = 0; w < W; w++) {
157 auto i = input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)];
158 auto &o = output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)];
159 const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
160
161 o = qz_b0<data_t<type_i>, data_t<type_o>>()(
162 i, s * adj_scale);
163 cp[g * OC + oc] -= (int32_t)o;
164 }
165 cp [g * OC + oc] *= 128;
166 });
167 return success;
168 }
169};
170
171template <SIMPLE_REORDER_TEMPL_DECL>
172struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
173 typename utils::enable_if<
174 (tag_i == goiw && tag_o == gOIw4i16o4i)
175 || (tag_i == oiw && tag_o == OIw4i16o4i)
176 || (tag_i == goihw && tag_o == gOIhw4i16o4i)
177 || (tag_i == oihw && tag_o == OIhw4i16o4i)
178 || (tag_i == goihw && tag_o == gOIhw2i8o4i)
179 || (tag_i == goihw && tag_o == gOIhw4o4i)
180 , spec::conv_s8s8>::type>
181{
182 static bool is_applicable(const memory_desc_wrapper &input_d,
183 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
184 {
185 const size_t D_mask = utils::array_product(input_d.dims(),
186 math::ilog2q(attr->output_scales_.mask_ + 1));
187 const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i);
188 const int oc = (input_d.dims()[w_groups ? 1 : 0]);
189 const int g = w_groups ? input_d.dims()[0] : 1;
190
191 return input_d.matches_tag(tag_i)
192 && output_d.matches_tag(tag_o)
193 && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
194 && (input_d.data_type() == f32 || input_d.data_type() == s8)
195 && output_d.data_type() == s8
196 && (D_mask == 1 || D_mask == (size_t)g * oc);
197 }
198
199 static status_t execute(const cpu_reorder_pd_t *pd,
200 const data_t<type_i> *input, data_t<type_o> *output) {
201 DECLARE_COMMON_PARAMS();
202
203 static constexpr bool w_groups =
204 !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i);
205 constexpr int is_1d =
206 utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i);
207 constexpr int blksize = tag_traits<tag_o>::inner_blks == ib::_4b4c
208 ? 4
209 : tag_traits<tag_o>::inner_blks == ib::_2c8b4c
210 ? 8
211 : 16;
212
213 const auto &_g_oihw_d = order_keep ? input_d : output_d;
214 const auto &dims = input_d.dims();
215 const auto &pdims = order_keep
216 ? output_d.padded_dims()
217 : input_d.padded_dims();
218
219 const int G = w_groups ? dims[0] : 1;
220 const int OC = dims[w_groups + 0];
221 const int NB_OC = pdims[w_groups + 0] / blksize;
222 const int IC = dims[w_groups + 1];
223 const int NB_IC = pdims[w_groups + 1] / blksize;
224 const int H = is_1d ? 1 : dims[w_groups + 2];
225 const int W = dims[w_groups + 3 - is_1d];
226
227 const float *scales = pd->attr()->output_scales_.scales_;
228 const size_t D_mask = utils::array_product(input_d.dims(),
229 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
230
231 assert(output_d.extra().flags
232 & memory_extra_flags::compensation_conv_s8s8);
233 float adj_scale =
234 (output_d.extra().flags & memory_extra_flags::scale_adjust)
235 ? output_d.extra().scale_adjust : 1.f;
236
237 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
238 int32_t *c, const float *s, const int oc_block, const int ic_block) {
239# define index AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
240
241 for (int ic = 0; ic < ic_block; ++ic) {
242 for (int oc = 0; oc < oc_block; ++oc) {
243 const auto _g_oihw_off =
244 oc * _g_oihw_d.blocking_desc().strides[w_groups + 0]
245 + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1];
246 out[index(oc, ic)]
247 = qz_b0<data_t<type_i>, data_t<type_o>>()(
248 inp[_g_oihw_off], s[oc] * adj_scale);
249 c[oc] -= (128 * (int32_t)(out[index(oc, ic)]));
250 }
251 }
252# undef index
253 };
254
255 constexpr int i_mult = blksize;
256 constexpr int o_mult = 1;
257
258 size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
259 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
260 parallel_nd(G * NB_OC * blksize, [&](int i) {
261 cp[i] = 0;
262 });
263
264# define wei_blk_off(md, g, o, i, h, w) \
265 (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
266 : (md).blk_off<!w_groups>(g, o, i, h, w))
267
268 parallel_nd(G, NB_OC, [&](int g, int O) {
269 for (int I = 0; I < NB_IC; I++)
270 for (int h = 0; h < H; h++)
271 for (int w = 0; w < W; w++) {
272 auto i = &input[wei_blk_off(
273 input_d, g, i_mult * O, i_mult * I, h, w)];
274 auto o = &output[wei_blk_off(
275 output_d, g, o_mult * O, o_mult * I, h, w)];
276 const int oc_block = nstl::min(blksize, OC - O * blksize);
277 const int ic_block = nstl::min(blksize, IC - I * blksize);
278
279 int _offset = (g * NB_OC + O) * blksize;
280 ker(i, o, (order_keep) ? &cp[_offset] : nullptr,
281 &scales[(D_mask == 1) ? 0 : _offset],
282 oc_block, ic_block);
283 }
284 });
285
286# undef wei_blk_off
287
288 return success;
289 }
290};
291
292template <SIMPLE_REORDER_TEMPL_DECL>
293struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
294 typename utils::enable_if<false
295 ||(tag_i == goiw && tag_o == Goiw16g)
296 ||(tag_i == goihw && tag_o == Goihw16g)
297 , spec::conv_s8s8>::type>
298{
299 static bool is_applicable(const memory_desc_wrapper &input_d,
300 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
301 const size_t D_mask = utils::array_product(input_d.dims(),
302 math::ilog2q(attr->output_scales_.mask_ + 1));
303 const int oc = input_d.dims()[1];
304 const int g = input_d.dims()[0];
305
306 return true
307 && order_keep
308 && input_d.matches_tag(tag_i)
309 && output_d.matches_tag(tag_o)
310 && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
311 && (input_d.data_type() == f32 || input_d.data_type() == s8)
312 && output_d.data_type() == s8
313 && (D_mask == 1 || D_mask == (size_t)g * oc);
314 }
315
316 static status_t execute(const cpu_reorder_pd_t *pd,
317 const data_t<type_i> *input, data_t<type_o> *output) {
318 DECLARE_COMMON_PARAMS();
319
320 constexpr bool is_1d = tag_i == goiw;
321 constexpr int blksize = 16;
322
323 const auto &dims = input_d.dims();
324 const auto &pdims = output_d.padded_dims();
325 const int G = dims[0];
326 const int Gp = pdims[0];
327 const int OC = dims[1];
328 const int IC = dims[2];
329 const int H = is_1d ? 1 : dims[3];
330 const int W = dims[4 - is_1d];
331
332 const size_t D_mask = utils::array_product(input_d.dims(),
333 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
334 const float *scales = pd->attr()->output_scales_.scales_;
335
336 assert(output_d.extra().flags
337 & memory_extra_flags::compensation_conv_s8s8);
338 float adj_scale =
339 (output_d.extra().flags & memory_extra_flags::scale_adjust)
340 ? output_d.extra().scale_adjust : 1.f;
341
342 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
343 int32_t *cp, const float *s, const int g_block) {
344 PRAGMA_OMP_SIMD()
345 for (int g = 0; g < g_block; g++) {
346 const auto i_off = g * input_d.blocking_desc().strides[0];
347 out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
348 inp[i_off], s[g * OC] * adj_scale);
349 cp[g * OC] -= 128 * (int32_t)(out[g]);
350 }
351 };
352
353 size_t cp_offset = output_d.size() - output_d.additional_buffer_size();
354 int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset);
355 parallel_nd((Gp/blksize) * OC, [&](int ib) {
356 PRAGMA_OMP_SIMD()
357 for (int i = 0; i < blksize; i++)
358 cp[ib * blksize + i] = 0;
359 });
360
361# define wei_blk_off(md, g, o, i, h, w) \
362 (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w))
363
364 parallel_nd(Gp/blksize, OC, [&](int gb, int O) {
365 for (int I = 0; I < IC; I++) {
366 for (int h = 0; h < H; h++)
367 for (int w = 0; w < W; w++)
368 {
369 const int g_block = nstl::min(G - gb * blksize, blksize);
370 const auto inp = &input[wei_blk_off(
371 input_d, gb * blksize, O, I, h, w)];
372 const auto out = &output[wei_blk_off(
373 output_d, gb, O, I, h, w)];
374 int offset = gb * blksize + O;
375 ker(inp, out, &cp[offset],
376 &scales[(D_mask == 1) ? 0 : offset], g_block);
377 }
378 }
379 });
380
381# undef wei_blk_off
382
383 return success;
384 }
385};
386
387/* reorders with tail support */
388
389template <SIMPLE_REORDER_TEMPL_DECL>
390struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
391typename utils::enable_if<false
392 || (tag_i == nCdhw8c && tag_o == nCdhw16c)
393 || (tag_i == nChw8c && tag_o == nChw16c)
394 || (tag_i == nCw8c && tag_o == nCw16c)
395 >::type>
396{
397 static bool is_applicable(const memory_desc_wrapper &input_d,
398 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
399 {
400 return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d)
401 && simple_attr_check(attr, false);
402 }
403
404 static status_t execute(const cpu_reorder_pd_t *pd,
405 const data_t<type_i> *input, data_t<type_o> *output) {
406 DECLARE_COMMON_PARAMS();
407
408 constexpr int is_1d = tag_i == nCw8c;
409 constexpr int is_3d = tag_i == nCdhw8c;
410 constexpr int blksize_16 = 16;
411 constexpr int blksize_8 = 8;
412 constexpr int ic_mult = order_keep ? 2 : 1;
413 constexpr int oc_mult = order_keep ? 1 : 2;
414
415 const auto &dims = input_d.dims();
416 const auto &pdims = order_keep ? output_d.padded_dims()
417 : input_d.padded_dims();
418
419 const int C = dims[1];
420 const int D = is_3d ? dims[2] : 1;
421 const int H = is_1d ? 1 : dims[2 + is_3d];
422 const int W = dims[3 + is_3d - is_1d];
423
424 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
425 const int block_16) {
426 const int nb = (block_16 - 1) / blksize_8 + 1;
427 if (alpha == 1.0 && beta == 0.0) {
428 for (int b = 0; b < nb; ++b) {
429 const ptrdiff_t i_off = order_keep ? b : b * blksize_8;
430 const ptrdiff_t o_off = order_keep ? b * blksize_8 : b;
431 const int block_8 = nstl::min(blksize_8,
432 block_16 - b * blksize_8);
433 for (int c = 0; c < block_8; ++c) {
434 o[o_off + c] = _qz_a1b0<type_i, type_o>()(
435 i[i_off + c]);
436 }
437 }
438 } else {
439 for (int b = 0; b < nb; ++b) {
440 const ptrdiff_t i_off = order_keep ? b : b * blksize_8;
441 const ptrdiff_t o_off = order_keep ? b * blksize_8 : b;
442 const int block_8 = nstl::min(blksize_8,
443 block_16 - b * blksize_8);
444 for (int c = 0; c < block_8; ++c) {
445 o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
446 o[o_off + c], alpha, beta);
447 }
448 }
449 }
450 };
451
452# define data_blk_off(md, n, c, d, h, w) \
453 ( is_1d ? (md).blk_off(n, c, w) \
454 : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
455
456 parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
457 [&](int n, int nb_c, int d, int h, int w) {
458 auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
459 auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
460 const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16);
461 ker(i, o, block_16);
462 });
463
464# undef data_blk_off
465
466 return success;
467 }
468};
469
470#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
471 static bool is_applicable(const memory_desc_wrapper &input_d, \
472 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \
473 return simple_attr_check(attr, false) && (order_keep \
474 ? output_d.matches_tag(tag_o) && input_d.is_plain() \
475 : input_d.matches_tag(tag_o) && output_d.is_plain()); \
476 }
477
478template <SIMPLE_REORDER_TEMPL_DECL>
479struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
480typename utils::enable_if<tag_i == any
481 && (tag_traits<tag_o>::block_dims == bd::_A
482 || tag_traits<tag_o>::block_dims == bd::_B)
483 && tag_traits<tag_o>::ndims >= 3
484 && tag_traits<tag_o>::ndims <= 6
485 >::type>
486{
487 PLAIN_TO_BLOCKED_IS_APPLICABLE();
488
489 static status_t execute(const cpu_reorder_pd_t *pd,
490 const data_t<type_i> *input, data_t<type_o> *output) {
491 DECLARE_COMMON_PARAMS();
492
493 const auto &flat_d = order_keep ? input_d : output_d;
494 const auto &block_d = order_keep ? output_d : input_d;
495 const auto &dims = input_d.dims();
496 const auto &pdims = block_d.padded_dims();
497
498 constexpr int ndims = tag_traits<tag_o>::ndims;
499 constexpr int blk_idx = tag_traits<tag_o>::block_dims == bd::_A ? 0 : 1;
500
501 const dim_t H0 = dims[0];
502 const dim_t H1 = dims[1];
503 const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1;
504 const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1;
505 const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1;
506 const dim_t L = dims[ndims - 1];
507 const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1];
508
509 constexpr int blksize = false ? 0
510 : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_4a, ib::_4b) ? 4
511 : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_8a, ib::_8b) ? 8
512 : 16;
513
514 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, int block) {
515 if (alpha == 1.0 && beta == 0.0) {
516 for (int l = 0; l < L; ++l)
517 for (int blk = 0; blk < block; ++blk) {
518 const dim_t flat_off = 0
519 + blk * flat_d.blocking_desc().strides[blk_idx]
520 + l * flat_d.blocking_desc().strides[ndims - 1];
521 if (order_keep) {
522 o[l * l_blk_stride + blk] = _qz_a1b0<type_i, type_o>()(
523 i[flat_off]);
524 } else {
525 o[flat_off] = _qz_a1b0<type_i, type_o>()(
526 i[l * l_blk_stride + blk]);
527 }
528 }
529 } else {
530 for (int l = 0; l < L; ++l)
531 for (int blk = 0; blk < block; ++blk) {
532 const dim_t flat_off = 0
533 + blk * flat_d.blocking_desc().strides[blk_idx]
534 + l * flat_d.blocking_desc().strides[ndims - 1];
535 if (order_keep) {
536 o[l * l_blk_stride + blk] = _qz<type_i, type_o>()(
537 i[flat_off], o[l * blksize + blk],
538 alpha, beta);
539 } else {
540 o[flat_off] = _qz<type_i, type_o>()(
541 i[l * l_blk_stride + blk], o[flat_off],
542 alpha, beta);
543 }
544 }
545 }
546 };
547
548# define off(md, h0, h1, m0, m1, m2) \
549 (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \
550 : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \
551 : ndims >= 4 ? (md).blk_off(h0, h1, m2) \
552 : /* ndims >= 3 ? */ (md).blk_off(h0, h1))
553
554 constexpr int i_mult = order_keep ? blksize : 1;
555 constexpr int o_mult = order_keep ? 1 : blksize;
556
557 if (blk_idx == 0) {
558 const dim_t BH0 = pdims[0] / blksize;
559 parallel_nd(BH0, H1, M0, M1, M2,
560 [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) {
561 auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)];
562 auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)];
563 const int block = nstl::min<int>(blksize, H0 - bh0 * blksize);
564 ker(i, o, block);
565 });
566 } else if (blk_idx == 1) {
567 const dim_t BH1 = pdims[1] / blksize;
568 parallel_nd(H0, BH1, M0, M1, M2,
569 [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) {
570 auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)];
571 auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)];
572 const int block = nstl::min<int>(blksize, H1 - bh1 * blksize);
573 ker(i, o, block);
574 });
575 } else {
576 assert(!"unimplemented");
577 }
578
579# undef off
580
581 return success;
582 }
583};
584
585template <SIMPLE_REORDER_TEMPL_DECL>
586struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
587typename utils::enable_if<tag_i == any
588 && (tag_traits<tag_o>::block_dims == bd::_AB
589 || tag_traits<tag_o>::block_dims == bd::_BC)
590 && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_AB,
591 tag_traits<tag_o>::ndims >= 3 && tag_traits<tag_o>::ndims <= 5)
592 && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_BC,
593 tag_traits<tag_o>::ndims >= 4 && tag_traits<tag_o>::ndims <= 6)
594 >::type>
595{
596 PLAIN_TO_BLOCKED_IS_APPLICABLE();
597
598 static status_t execute(const cpu_reorder_pd_t *pd,
599 const data_t<type_i> *input, data_t<type_o> *output) {
600 DECLARE_COMMON_PARAMS();
601
602 const auto &flat_d = order_keep ? input_d : output_d;
603 const auto &dims = input_d.dims();
604 const auto &pdims = order_keep
605 ? output_d.padded_dims()
606 : input_d.padded_dims();
607
608 constexpr int ndims = tag_traits<tag_o>::ndims;
609
610 static constexpr bool with_g = tag_traits<tag_o>::block_dims == bd::_BC;
611 const dim_t G = with_g ? dims[0] : 1;
612
613 const dim_t H0 = dims[0 + with_g];
614 const dim_t H1 = dims[1 + with_g];
615
616 const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1;
617 const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1;
618 const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1;
619
620 constexpr int blksize_0 = false ? 0
621 : utils::one_of(tag_traits<tag_o>::inner_blks,
622 ib::_4b4a, ib::_4b4c, ib::_4c4b)
623 ? 4
624 : utils::one_of(tag_traits<tag_o>::inner_blks,
625 ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
626 ? 8
627 : utils::one_of(tag_traits<tag_o>::inner_blks,
628 ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c,
629 ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b,
630 ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c)
631 ? 16 : INT_MIN;
632
633 constexpr int blksize_1 = utils::one_of(tag_traits<tag_o>::inner_blks,
634 ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
635 ? 8
636 : utils::one_of(tag_traits<tag_o>::inner_blks,
637 ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b,
638 ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b,
639 ib::_4c16b4c, ib::_8c16b2c)
640 ? 16
641 : utils::one_of(tag_traits<tag_o>::inner_blks,
642 ib::_4b4a, ib::_4b4c, ib::_4c4b,
643 ib::_16a4b, ib::_16b4c)
644 ? 4
645 : INT_MIN;
646
647 const dim_t NB_H0 = pdims[0 + with_g] / blksize_0;
648 const dim_t NB_H1 = pdims[1 + with_g] / blksize_1;
649
650 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
651 const int block_h0, const int block_h1) {
652# define blk_off AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
653
654 if (alpha == 1.0 && beta == 0.0) {
655 for (int h0 = 0; h0 < block_h0; ++h0)
656 for (int h1 = 0; h1 < block_h1; ++h1) {
657 const dim_t flat_off = 0
658 + h0 * flat_d.blocking_desc().strides[with_g + 0]
659 + h1 * flat_d.blocking_desc().strides[with_g + 1];
660 if (order_keep) {
661 o[blk_off(h0, h1)] = _qz_a1b0<type_i, type_o>()(
662 i[flat_off]);
663 } else {
664 o[flat_off] = _qz_a1b0<type_i, type_o>()(
665 i[blk_off(h0, h1)]);
666 }
667 }
668 } else {
669 for (int h0 = 0; h0 < block_h0; ++h0)
670 for (int h1 = 0; h1 < block_h1; ++h1) {
671 const dim_t flat_off = 0
672 + h0 * flat_d.blocking_desc().strides[with_g + 0]
673 + h1 * flat_d.blocking_desc().strides[with_g + 1];
674 if (order_keep) {
675 o[blk_off(h0, h1)] = _qz<type_i, type_o>()(i[flat_off],
676 o[blk_off(h0, h1)], alpha, beta);
677 } else {
678 o[flat_off] = _qz<type_i, type_o>()(i[blk_off(h0, h1)],
679 o[flat_off], alpha, beta);
680 }
681 }
682 }
683
684# undef blk_off
685 };
686
687 constexpr int i_mult_0 = order_keep ? blksize_0 : 1;
688 constexpr int o_mult_0 = order_keep ? 1 : blksize_0;
689
690 constexpr int i_mult_1 = order_keep ? blksize_1 : 1;
691 constexpr int o_mult_1 = order_keep ? 1 : blksize_1;
692
693# define off(md, g, h0, h1, m0, m1, m2) \
694 (ndims >= 5 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m0, m1, m2) \
695 : ndims >= 4 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m1, m2) \
696 : /* ndims >= 3 + with_g ? */ (md).blk_off<!with_g>(g, h0, h1, m2))
697
698 parallel_nd(G, NB_H0, NB_H1, M0, M1, M2,
699 [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) {
700 auto i = &input[off(input_d,
701 g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)];
702 auto o = &output[off(output_d,
703 g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)];
704 const int block_h0 = nstl::min<int>(blksize_0, H0 - nb_h0 * blksize_0);
705 const int block_h1 = nstl::min<int>(blksize_1, H1 - nb_h1 * blksize_1);
706 ker(i, o, block_h0, block_h1);
707 });
708
709# undef off
710
711 return success;
712 }
713};
714
715/* generic and direct-copy reorders */
716
717template <SIMPLE_REORDER_TEMPL_DECL>
718struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
719 typename utils::enable_if<
720 tag_i == any && tag_o == any && order_keep == fmt_order::any,
721 spec::direct_copy>::type>
722{
723 static bool is_applicable(const memory_desc_wrapper &input_d,
724 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
725 /* FIXME: is the formula correct? */
726 return input_d.similar_to(output_d, true, false, 0)
727 && input_d.is_dense() && output_d.is_dense()
728 && simple_attr_check(attr, false);
729 }
730
731 static status_t execute(const cpu_reorder_pd_t *pd,
732 const data_t<type_i> *input, data_t<type_o> *output) {
733 DECLARE_COMMON_PARAMS();
734
735 assert(input_d.is_dense());
736
737 input += input_d.blk_off(0);
738 output += output_d.blk_off(0);
739
740 const size_t nelems = input_d.nelems();
741
742 constexpr int block_size = 16;
743 const auto num_blocks = nelems / block_size;
744 const auto rem_elems = nelems % block_size;
745
746 parallel(0, [&](const int ithr, const int nthr) {
747 size_t start{0}, end{0};
748 balance211(num_blocks, nthr, ithr, start, end);
749 start = start * block_size;
750 end = end * block_size;
751
752 if (alpha == 1.0 && beta == 0.0) {
753 PRAGMA_OMP_SIMD()
754 for (size_t e = start; e < end; ++e) {
755 output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()
756 (input[e]);
757 }
758 } else if (alpha == 1.0) {
759 PRAGMA_OMP_SIMD()
760 for (size_t e = start; e < end; ++e) {
761 output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()
762 (input[e], output[e], beta);
763 }
764 } else if (beta == 0.0) {
765 PRAGMA_OMP_SIMD()
766 for (size_t e = start; e < end; ++e) {
767 output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()
768 (input[e], alpha);
769 }
770 } else {
771 PRAGMA_OMP_SIMD()
772 for (size_t e = start; e < end; ++e) {
773 output[e] = qz<data_t<type_i>, data_t<type_o>>()
774 (input[e], output[e], alpha, beta);
775 }
776 }
777
778 if (rem_elems != 0 && ithr == nthr - 1){
779 if (alpha == 1.0 && beta == 0.0) {
780 PRAGMA_OMP_SIMD()
781 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
782 output[e] = qz_a1b0<data_t<type_i>,
783 data_t<type_o>>()(input[e]);
784 }
785 } else if (alpha == 1.0) {
786 PRAGMA_OMP_SIMD()
787 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
788 output[e] = qz_a1<data_t<type_i>,
789 data_t<type_o>>()(input[e], output[e], beta);
790 }
791 } else if (beta == 0.0) {
792 PRAGMA_OMP_SIMD()
793 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
794 output[e] = qz_b0<data_t<type_i>,
795 data_t<type_o>>()(input[e], alpha);
796 }
797 } else {
798 PRAGMA_OMP_SIMD()
799 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
800 output[e] = qz<data_t<type_i>, data_t<type_o>>()
801 (input[e], output[e], alpha, beta);
802 }
803 }
804 }
805 });
806 return success;
807 }
808};
809
810template <SIMPLE_REORDER_TEMPL_DECL>
811struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
812 typename utils::enable_if<
813 tag_i == any && tag_o == any && order_keep == fmt_order::any,
814 spec::direct_copy_except_dim_0>::type>
815{
816 static bool is_applicable(const memory_desc_wrapper &input_d,
817 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
818 auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
819 return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
820 };
821 /* FIXME: is the formula correct? */
822 return input_d.similar_to(output_d, true, false, 1)
823 && is_dense_no_0(input_d) && is_dense_no_0(output_d)
824 && simple_attr_check(attr, false);
825 }
826
827 static status_t execute(const cpu_reorder_pd_t *pd,
828 const data_t<type_i> *input, data_t<type_o> *output) {
829 DECLARE_COMMON_PARAMS();
830
831 input += input_d.blk_off(0);
832 output += output_d.blk_off(0);
833
834 const int N = input_d.dims()[0];
835 const dim_t is = input_d.blocking_desc().strides[0];
836 const dim_t os = output_d.blocking_desc().strides[0];
837 const dim_t nelems_no_d0 = nelems_no_dim_0(input_d);
838 const dim_t work_amount = N * nelems_no_d0;
839
840 if (alpha == 1.0 && beta == 0.0) {
841 parallel(0, [&](const int ithr, const int nthr) {
842 dim_t n{0}, dim1_s{0};
843 dim_t start{0}, end{0};
844 balance211(work_amount, nthr, ithr, start, end);
845 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
846 while(start < end) {
847 dim_t work_rem = end - start;
848 dim_t dim1_e = dim1_s + work_rem > nelems_no_d0
849 ? nelems_no_d0 : dim1_s + work_rem;
850 PRAGMA_OMP_SIMD()
851 for (dim_t e = dim1_s; e < dim1_e; ++e) {
852 output[os * n + e] = _qz_a1b0<type_i, type_o>()(
853 input[is * n + e]);
854 }
855 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
856 }
857 });
858 } else {
859 parallel(0, [&](const int ithr, const int nthr) {
860 dim_t n{0}, dim1_s{0};
861 dim_t start{0}, end{0};
862 balance211(work_amount, nthr, ithr, start, end);
863 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
864 while(start < end) {
865 dim_t work_rem = end - start;
866 dim_t dim1_e =
867 dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0
868 : dim1_s + work_rem;
869 PRAGMA_OMP_SIMD()
870 for (dim_t e = dim1_s; e < dim1_e; ++e){
871 output[os * n + e] = _qz<type_i, type_o>()(
872 input[is * n + e], output[os * n + e], alpha,
873 beta);
874 }
875 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
876 }
877 });
878 }
879
880 return success;
881 }
882
883private:
884 static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
885 const int ndims = data_d.ndims();
886 if (ndims <= 1) return 1;
887 return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
888 }
889
890 static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
891 dims_t blocks;
892 data_d.compute_blocks(blocks);
893
894 const auto &blk = data_d.blocking_desc();
895
896 dim_t blk_size = 1;
897 for (int iblk = 0; iblk < blk.inner_nblks; ++iblk)
898 blk_size *= blk.inner_blks[iblk];
899
900 dim_t max_size = blk_size;
901 for (int d = 1; d < data_d.ndims(); ++d) {
902 max_size = nstl::max(max_size,
903 data_d.padded_dims()[d] / blocks[d] * blk.strides[d]);
904 }
905
906 return max_size;
907 }
908};
909
910template <SIMPLE_REORDER_TEMPL_DECL>
911struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
912 typename utils::enable_if<
913 tag_i == any && tag_o == any && order_keep == fmt_order::any,
914 spec::reference>::type>
915{
916 static bool is_applicable(const memory_desc_wrapper &input_d,
917 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
918 /* supported smask: 0x0...011..10...0,
919 * i.e. 1 should be contiguous */
920 int smask = attr ? attr->output_scales_.mask_ : 0;
921 for (; smask > 0 && !(smask & 0x1); smask >>= 1);
922 for (; smask > 0 && smask & 0x1; smask >>= 1);
923 return true
924 && input_d.is_blocking_desc()
925 && output_d.is_blocking_desc()
926 && !output_d.is_additional_buffer()
927 && !input_d.is_additional_buffer()
928 && smask == 0;
929 }
930
931 static status_t execute(const cpu_reorder_pd_t *pd,
932 const data_t<type_i> *input, data_t<type_o> *output) {
933 DECLARE_COMMON_PARAMS();
934
935 const size_t nelems = input_d.nelems();
936
937 int ndims_start = 0, ndims_mask = 0;
938 int smask = pd->attr()->output_scales_.mask_;
939 for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start;
940 for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask;
941 assert(smask == 0);
942
943 const ptrdiff_t D_start
944 = utils::array_product(input_d.dims(), ndims_start);
945 const ptrdiff_t D_mask
946 = utils::array_product(input_d.dims() + ndims_start, ndims_mask);
947 const ptrdiff_t D_rest = nelems / D_start / D_mask;
948
949 const float *scales = pd->attr()->output_scales_.scales_;
950
951 parallel_nd(D_start, D_mask, D_rest,
952 [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
953 const float scale = scales[dm];
954
955 const size_t e = (ds * D_mask + dm) * D_rest + dr;
956 const auto &i = input[input_d.off_l(e)];
957 auto &o = output[output_d.off_l(e)];
958
959 o = _qz<type_i, type_o>()(i, o, scale, beta);
960 });
961
962 return success;
963 }
964};
965
966
967/* high level class declaration */
968
969template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
970struct simple_reorder_t: public cpu_primitive_t {
971 struct pd_t: public cpu_reorder_pd_t {
972 using cpu_reorder_pd_t::cpu_reorder_pd_t;
973
974 DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
975
976 static status_t create(reorder_pd_t **reorder_pd,
977 engine_t *engine, const primitive_attr_t *attr,
978 engine_t *src_engine, const memory_desc_t *src_md,
979 engine_t *dst_engine, const memory_desc_t *dst_md) {
980 bool args_ok = true
981 && src_md->data_type == type_i
982 && dst_md->data_type == type_o
983 && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
984 is_applicable(src_md, dst_md, attr);
985 if (!args_ok)
986 return status::invalid_arguments;
987
988 auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
989 dst_md);
990 if (_pd == nullptr) return status::out_of_memory;
991 if (_pd->init() != status::success) {
992 delete _pd;
993 return status::unimplemented;
994 }
995 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
996 }
997 };
998
999 simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
1000
1001 virtual status_t execute(const exec_ctx_t &ctx) const override {
1002 auto input = CTX_IN_MEM(const data_t<type_i> *, MKLDNN_ARG_FROM);
1003 auto output = CTX_OUT_MEM(data_t<type_o> *, MKLDNN_ARG_TO);
1004 simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
1005 pd(), input, output);
1006 return status::success;
1007 }
1008
1009private:
1010 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
1011};
1012
1013#undef SIMPLE_REORDER_TEMPL_DECL
1014#undef SIMPLE_REORDER_TEMPL_CALL
1015
1016}
1017}
1018}
1019
1020#endif
1021
1022// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
1023