1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "c_types_map.hpp"
20#include "memory_desc_wrapper.hpp"
21#include "mkldnn_debug.h"
22#include "nstl.hpp"
23#include "type_helpers.hpp"
24
25#include "cpu_primitive.hpp"
26#include "cpu_reorder_pd.hpp"
27#include "jit_uni_reorder.hpp"
28
29#include "jit_generator.hpp"
30
31// #define TR_DEBUG
32#if defined(TR_DEBUG)
33#define DEBUg(...) do { __VA_ARGS__ } while (0)
34#else
35#define DEBUg(...)
36#endif
37#define DEBUG(...) DEBUg(__VA_ARGS__)
38
39#ifdef _WIN32
40/* seems like s_addr is a reserved macro on Windows */
41#undef s_addr
42#endif
43
44using namespace Xbyak;
45using namespace mkldnn::impl::types;
46
47namespace mkldnn {
48namespace impl {
49namespace cpu {
50
51namespace tr {
52
53/** Minimal reasonable/desirable kernel size.
54 * The constant might be used to determine how a problem should be split
55 * between kernel and threading driver. */
56const size_t ker_prb_size_min = 64;
57
58/* kernel */
59struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator {
60 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
61
62 enum {
63 len_unroll_max = 256,
64 ndims_jit_loop_max = 3,
65 };
66
67 struct simple_impl_desc_t {
68 int ndims_full_unroll;
69 int len_last_dim_unroll;
70 int len_unroll;
71 };
72
73 static bool simple_impl_desc_init(const prb_t &prb,
74 simple_impl_desc_t *desc) {
75 const int ndims = prb.ndims;
76
77 int ndims_full_unroll = 0;
78 int len_last_dim_unroll = 1;
79 int len_unroll = 1;
80
81 for (int d = 0; d < ndims; ++d) {
82 auto &node = prb.nodes[d];
83 if (len_unroll * node.n <= len_unroll_max) {
84 ndims_full_unroll++;
85 len_unroll *= node.n;
86 } else {
87 len_last_dim_unroll = len_unroll_max / len_unroll;
88 while (node.n % len_last_dim_unroll)
89 --len_last_dim_unroll;
90 len_unroll *= len_last_dim_unroll;
91 break;
92 }
93 }
94
95 if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max)
96 return false;
97
98 if (desc) {
99 desc->ndims_full_unroll = ndims_full_unroll;
100 desc->len_last_dim_unroll = len_last_dim_unroll;
101 desc->len_unroll = len_unroll;
102 }
103
104 return true;
105 }
106
107 static bool applicable(const prb_t &p) {
108 using namespace data_type;
109
110 bool ok = true
111 && p.ndims > 0
112 && utils::one_of(p.itype, f32, s32, s8, u8)
113 && utils::one_of(p.otype, f32, s32, s8, u8)
114 && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
115 && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
116 && simple_impl_desc_init(p, nullptr)
117 && mayiuse(sse42)
118 && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype),
119 mayiuse(avx));
120 if (!ok) return false;
121
122 const ptrdiff_t max_stride = (1LL<<31) - 1;
123 for (int d = 0; d < p.ndims; ++d) {
124 const ptrdiff_t cms = max_stride / p.nodes[d].n;
125 bool strides_ok = true
126 && p.nodes[d].is < cms / (int)data_type_size(p.itype)
127 && p.nodes[d].os < cms / (int)data_type_size(p.otype);
128 if (!strides_ok) return false;
129 }
130
131 return true;
132 }
133
134 int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; }
135 int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; }
136 int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; }
137 int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; }
138
139 Address i_addr(int i_off)
140 { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; }
141
142 Address o_addr(int o_off)
143 { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; }
144
145 Address s_addr(int s_off)
146 { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; }
147
148 void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
149 int &i_off, int &o_off, int &s_off, int step_size = 1) {
150 i_off = prev_i_off;
151 o_off = prev_o_off;
152 s_off = prev_s_off;
153
154 if (off == 0) return;
155
156 int start_dim = 0, dims_prod = 1;
157 for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
158 dims_prod *= n(start_dim);
159 assert(start_dim < prb_.ndims);
160 off /= step_size;
161
162 for (int d = start_dim; d < prb_.ndims; ++d) {
163 i_off += is(d);
164 o_off += os(d);
165 s_off += ss(d);
166
167 if (off % n(d)) break;
168
169 i_off += - n(d) * is(d);
170 o_off += - n(d) * os(d);
171 s_off += - n(d) * ss(d);
172 off /= n(d);
173
174 if (off == 0) break; /* FIXME: is it really required? */
175 }
176 }
177
178 void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
179 int step_size = 1) {
180 int dummy = 0;
181 step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
182 step_size);
183 }
184
185 void tr8x8_avx2(int i_off, int o_off) {
186 for (int i = 0; i < 8; i++)
187 vmovups(Ymm(i), i_addr(i_off + i * 8));
188
189 for (int i = 0; i < 8 / 2; i++) {
190 vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1));
191 vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1));
192 }
193
194 const unsigned int lfloat = 0x44;
195 const unsigned int ufloat = 0xee;
196 for (int i = 0; i < 8 / 2; i++) {
197 int j = i % 2 == 0 ? 8 + i : i - 1;
198 vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat);
199 vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat);
200 }
201
202 const unsigned int lquad = 0x20;
203 for (int i = 0; i < 8 / 2; i++)
204 vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad);
205
206 const unsigned int uquad = 0x31;
207 for (int i = 8 / 2; i < 8; i++)
208 vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad);
209
210 for (int i = 0; i < 8; i++)
211 vmovups(o_addr(o_off + i * 8), Ymm(i));
212 }
213
214 bool process_unroll_tr8x8(int len) {
215 bool can_do = true
216 && mayiuse(avx2)
217 && prb_.ndims >= 2
218 && utils::everyone_is(4, itype_sz, otype_sz)
219 && utils::everyone_is(8, n(0), n(1))
220 && utils::everyone_is(1, os(0), is(1))
221 && utils::everyone_is(8, os(1), is(0))
222 && prb_.scale_type == scale_type_t::NONE
223 && prb_.beta == 0.f;
224 if (!can_do) return false;
225
226 const int step_size = n(0) * n(1);
227 int i_off = 0, o_off = 0;
228 for (int off = 0; off < len; off += step_size) {
229 step(off, i_off, o_off, i_off, o_off, step_size);
230 tr8x8_avx2(i_off, o_off);
231 }
232
233 return true;
234 }
235
236 template <cpu_isa_t isa>
237 bool process_direct_copy(int len) {
238 using namespace data_type;
239
240 using Vmm = typename cpu_isa_traits<isa>::Vmm;
241 const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz;
242
243 bool can_do = true
244 && mayiuse(isa)
245 && utils::everyone_is(1, os(0), is(0))
246 && (false
247 || prb_.itype == prb_.otype
248 || (prb_.itype == s32 && prb_.otype == f32)
249 || (prb_.itype == f32 && prb_.otype == s32)
250 )
251 && len % simd_w == 0
252 && n(0) % len == 0
253 && prb_.scale_type == scale_type_t::NONE
254 && prb_.beta == 0.f;
255 if (!can_do) return false;
256
257 for (int off = 0; off < len;) {
258 const int unroll = nstl::min(16, (len - off) / simd_w);
259
260 for (int ur = 0; ur < unroll; ++ur)
261 uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w));
262
263 if (prb_.itype != prb_.otype) {
264 for (int ur = 0; ur < unroll; ++ur) {
265 if (prb_.itype == s32 && prb_.otype == f32)
266 uni_vcvtdq2ps(Vmm(ur), Vmm(ur));
267 else if (prb_.itype == f32 && prb_.otype == s32)
268 uni_vcvtps2dq(Vmm(ur), Vmm(ur));
269 else assert(!"unreachable");
270 }
271 }
272
273 for (int ur = 0; ur < unroll; ++ur)
274 uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur));
275
276 off += unroll * simd_w;
277 }
278
279 return true;
280 }
281
282 void process_unroll_generic_step(int reg_unroll, const int *i_off,
283 const int *o_off, const int *s_off) {
284 using namespace data_type;
285
286 auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) {
287 Xmm dst_pure = Xmm(dst.getIdx());
288 switch (idt) {
289 case f32:
290 if (src.isMEM() || src.getIdx() != dst.getIdx())
291 vmovups(dst, src);
292 break;
293 case s32: vcvtdq2ps(dst, src); break;
294 case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
295 case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
296 default: assert(!"unreachable");
297 }
298 };
299
300 auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) {
301 switch (odt) {
302 case s32:
303 if (idt == f32) vcvtps2dq(xmm, xmm);
304 else if (idt == s8) vpmovsxbd(xmm, xmm);
305 else if (idt == u8) vpmovzxbd(xmm, xmm);
306 break;
307 case s8:
308 if (idt == f32) vcvtps2dq(xmm, xmm);
309 if (idt == f32 || idt == s32) {
310 if (mayiuse(avx512_core)) {
311 vpmovsdb(xmm, xmm);
312 } else {
313 vpackssdw(xmm, xmm, xmm_zero);
314 vpacksswb(xmm, xmm, xmm_zero);
315 }
316 }
317 if (idt == u8) vpminub(xmm, xmm, xmm_4x127b);
318 break;
319 case u8:
320 if (idt == f32) vcvtps2dq(xmm, xmm);
321 if (idt == f32 || idt == s32) {
322 if (mayiuse(avx512_core)) {
323 vpmaxsd(xmm, xmm, xmm_zero);
324 vpmovusdb(xmm, xmm);
325 } else {
326 vpackssdw(xmm, xmm, xmm_zero);
327 vpackuswb(xmm, xmm, xmm_zero);
328 }
329 }
330 if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero);
331 break;
332 default: assert(!"unreachable");
333 }
334 };
335
336 auto load = [=](const Xmm &xmm, const Address &addr, int size) {
337 switch (size) {
338 case 16: movups(xmm, addr); break;
339 case 4: movss(xmm, addr); break;
340 case 1: pinsrb(xmm, addr, 0x0); break;
341 default: assert(!"unreachable");
342 }
343 };
344
345 auto store = [=](const Address &addr, const Xmm &xmm, int size) {
346 switch (size) {
347 case 16: movups(addr, xmm); break;
348 case 4: movss(addr, xmm); break;
349 case 1: pextrb(addr, xmm, 0x0); break;
350 default: assert(!"unreachable");
351 }
352 };
353
354 /* check whether loading 4 values at once is possible */
355 bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0;
356 for (int ur = 1; ur < reg_unroll; ++ur)
357 if (i_off[ur] != i_off[ur - 1] + 1)
358 can_load_xmm = false;
359 const int load_step = can_load_xmm ? 4 : 1;
360
361 /* check whether storing 4 values at once is possible */
362 bool can_store_xmm = reg_unroll % 4 == 0;
363 for (int ur = 1; ur < reg_unroll; ++ur)
364 if (o_off[ur] != o_off[ur - 1] + 1)
365 can_store_xmm = false;
366 const int ur_step = can_store_xmm ? 4 : 1;
367
368 const bool interim_f32 = false
369 || utils::one_of(f32, prb_.itype, prb_.otype)
370 || prb_.scale_type != scale_type_t::NONE
371 || prb_.beta != 0.f;
372
373 if (!can_load_xmm && can_store_xmm) {
374 assert(ur_step == 4);
375 /* load with stride */
376 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
377 for (int r = 0; r < ur_step; ++r) {
378 if (itype_sz == 4)
379 pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r);
380 else
381 pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r);
382 }
383 }
384 } else {
385 for (int ur = 0; ur < reg_unroll; ur += load_step)
386 load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz);
387 }
388
389 /* xmm[:] <-- (f32)xmm[:] */
390 if (interim_f32) {
391 const int cvt_step = nstl::max(load_step, ur_step);
392 for (int ur = 0; ur < reg_unroll; ur += cvt_step)
393 cvt2ps(Xmm(ur), Xmm(ur), prb_.itype);
394 }
395
396 if (can_load_xmm && !can_store_xmm) {
397 const bool fast_return = true // transposition on the fly
398 && prb_.scale_type != scale_type_t::MANY
399 && prb_.beta == 0.f;
400 if (fast_return) {
401 for (int ur = 0; ur < reg_unroll; ur += load_step) {
402 if (prb_.scale_type == scale_type_t::COMMON)
403 mulps(Xmm(ur), xmm_scale);
404 if (prb_.otype != f32)
405 cvt2int(Xmm(ur), prb_.otype,
406 interim_f32 ? f32 : prb_.itype);
407 for (int r = 0; r < load_step; ++r) {
408 if (otype_sz == 4)
409 pextrd(o_addr(o_off[ur + r]), Xmm(ur), r);
410 else
411 pextrb(o_addr(o_off[ur + r]), Xmm(ur), r);
412 }
413 }
414 return;
415 }
416
417 /* scatter elements of xmm into 4 xmms */
418 if (itype_sz == 4 || interim_f32) {
419 for (int ur = 0; ur < reg_unroll; ur += load_step)
420 for (int r = 1; r < load_step; ++r)
421 vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
422 } else {
423 for (int ur = 0; ur < reg_unroll; ur += load_step)
424 for (int r = 1; r < load_step; ++r)
425 vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
426 }
427 }
428
429 /* scale and beta processing */
430 if (can_store_xmm) {
431 /* xmm <-- scale * xmm[:] */
432 if (prb_.scale_type == scale_type_t::COMMON) {
433 for (int ur = 0; ur < reg_unroll; ur += ur_step)
434 mulps(Xmm(ur), xmm_scale);
435 } else if (prb_.scale_type == scale_type_t::MANY) {
436 enum class scale_load_type_t { bcast, load, gather };
437
438 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
439 scale_load_type_t scale_load_type =
440 scale_load_type_t::bcast; // the best case
441
442 for (int r = ur + 1; r < ur + ur_step; ++r)
443 if (s_off[r] != s_off[r - 1] + 0)
444 scale_load_type = scale_load_type_t::load;
445
446 if (scale_load_type == scale_load_type_t::bcast) {
447 movss(xmm_scale, s_addr(s_off[ur]));
448 shufps(xmm_scale, xmm_scale, 0x0);
449 mulps(Xmm(ur), xmm_scale);
450 continue;
451 }
452
453 // bcast doesn't work, the next try -- load
454 for (int r = ur + 1; r < ur + ur_step; ++r)
455 if (s_off[r] != s_off[r - 1] + 1)
456 scale_load_type = scale_load_type_t::gather;
457
458 if (scale_load_type == scale_load_type_t::load) {
459 movups(xmm_scale, s_addr(s_off[ur]));
460 mulps(Xmm(ur), xmm_scale);
461 continue;
462 }
463
464 // load doesn't work as well
465 // so gather the scale factors one by one
466 for (int r = ur; r < ur + ur_step; ++r)
467 pinsrd(xmm_scale, s_addr(s_off[r]), r - ur);
468 mulps(Xmm(ur), xmm_scale);
469 }
470 }
471
472 /* dst <-- beta * dst + xmm[:] */
473 assert(prb_.beta == 0.f || prb_.beta == 1.f);
474 if (prb_.beta == 1.f) {
475 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
476 if (prb_.otype == f32) {
477 /* non VEX instructions do not support unaligned
478 * memory for instructions other than movups. */
479 if (mayiuse(avx)) {
480 vaddps(Xmm(ur), o_addr(o_off[ur]));
481 } else {
482 /* register xmm(1) is unused */
483 movups(Xmm(1), o_addr(o_off[ur]));
484 addps(Xmm(ur), Xmm(1));
485 }
486 } else {
487 cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype);
488 vaddps(Xmm(ur), Xmm(1));
489 }
490 }
491 }
492 } else {
493 /* xmm[0] <-- scale * xmm[0] */
494 if (prb_.scale_type == scale_type_t::COMMON) {
495 for (int ur = 0; ur < reg_unroll; ur += ur_step)
496 mulss(Xmm(ur), xmm_scale);
497 } else if (prb_.scale_type == scale_type_t::MANY) {
498 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
499 mulss(Xmm(ur), s_addr(s_off[ur]));
500 }
501 }
502
503 /* dst <-- beta * dst + xmm[0] */
504 assert(prb_.beta == 0.f || prb_.beta == 1.f);
505 if (prb_.beta == 1.f) {
506 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
507 if (prb_.otype == f32) {
508 addss(Xmm(ur), o_addr(o_off[ur]));
509 } else {
510 if (prb_.otype == s32) {
511 vmovss(xmm_tmp, o_addr(o_off[ur]));
512 } else if (utils::one_of(prb_.otype, s8, u8)) {
513 pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0);
514 } else {
515 assert(!"unsupported o_type");
516 }
517 cvt2ps(xmm_tmp, xmm_tmp, prb_.otype);
518 addps(Xmm(ur), xmm_tmp);
519 }
520 }
521 }
522 }
523
524 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
525 if (prb_.otype != f32)
526 cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype);
527 store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz);
528 }
529 }
530
531 void process_unroll_generic(int len) {
532 const int blk = 8;
533
534 int i_off[2 * blk] = {0};
535 int o_off[2 * blk] = {0};
536 int s_off[2 * blk] = {0};
537
538 int curr = 0; // will switch between 0 and 1
539
540 for (int off = 0; off < len; off += blk) {
541 const int reg_unroll = nstl::min(off + blk, len) - off;
542
543 /* compute offsets */
544 for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
545 const int ur_c = curr * blk + ur;
546 const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
547 step(off + ur,
548 i_off[ur_p], o_off[ur_p], s_off[ur_p],
549 i_off[ur_c], o_off[ur_c], s_off[ur_c]);
550 }
551
552 process_unroll_generic_step(reg_unroll, i_off + curr * blk,
553 o_off + curr * blk, s_off + curr * blk);
554
555 curr = 1 - curr;
556 }
557 }
558
559 void loop_begin(Label &l, Reg64 reg_cnt, int len) {
560 mov(reg_cnt, len);
561 L(l);
562 }
563
564 void loop_end(Label &l, Reg64 reg_cnt, int len,
565 int i_step, int o_step, int s_step) {
566 add(reg_off_in, i_step * itype_sz);
567 add(reg_off_out, o_step * otype_sz);
568 if (prb_.scale_type == scale_type_t::MANY)
569 add(reg_off_scale, s_step * stype_sz);
570 dec(reg_cnt);
571 jnz(l);
572
573 sub(reg_off_in, len * i_step * itype_sz);
574 sub(reg_off_out, len * o_step * otype_sz);
575 if (prb_.scale_type == scale_type_t::MANY)
576 sub(reg_off_scale, len * s_step * stype_sz);
577 }
578
579 bool simple_impl() {
580 simple_impl_desc_t d;
581 if (!simple_impl_desc_init(prb_, &d)) return false;
582
583 const int nfu = d.ndims_full_unroll;
584 const int ldu = d.len_last_dim_unroll;
585 const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
586 assert(n_jit_loops <= ndims_jit_loop_max);
587
588 xor_(reg_off_in, reg_off_in);
589 xor_(reg_off_out, reg_off_out);
590 if (prb_.scale_type == scale_type_t::MANY)
591 xor_(reg_off_scale, reg_off_scale);
592
593 Label l_loop[3];
594 Reg64 reg_cnt[3] = {r15, r14, r13};
595
596 if (n_jit_loops > 2)
597 loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
598
599 if (n_jit_loops > 1)
600 loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
601
602 if (n_jit_loops > 0)
603 loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
604
605 const bool optimized = false
606 || process_direct_copy<avx>(d.len_unroll)
607 || process_direct_copy<sse42>(d.len_unroll)
608 || process_unroll_tr8x8(d.len_unroll);
609 if (!optimized)
610 process_unroll_generic(d.len_unroll);
611
612 if (n_jit_loops > 0)
613 loop_end(l_loop[0], reg_cnt[0],
614 n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu,
615 ss(nfu + 0) * ldu);
616
617 if (n_jit_loops > 1)
618 loop_end(l_loop[1], reg_cnt[1],
619 n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1));
620
621 if (n_jit_loops > 2)
622 loop_end(l_loop[2], reg_cnt[2],
623 n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2));
624
625 return true;
626 }
627
628 void impl() {
629 if (simple_impl()) return;
630 assert(!"no implementation available");
631 }
632
633 jit_uni_reorder_kernel_f32(const desc_t &desc)
634 : kernel_t(desc), jit_generator() {
635 itype_sz = data_type_size(prb_.itype);
636 otype_sz = data_type_size(prb_.otype);
637 stype_sz = sizeof(float);
638
639 preamble();
640# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)]
641 if (prb_.scale_type == scale_type_t::COMMON) {
642 auto reg_ptr_scale_tmp = reg_ptr_in;
643 mov(reg_ptr_scale_tmp, PARAM(scale));
644 movups(xmm_scale, ptr[reg_ptr_scale_tmp]);
645 } else if (prb_.scale_type == scale_type_t::MANY) {
646 mov(reg_ptr_scale, PARAM(scale));
647 }
648 mov(reg_ptr_in, PARAM(in));
649 mov(reg_ptr_out, PARAM(out));
650# undef PARAM
651
652 if (mayiuse(avx)) {
653 vxorps(xmm_zero, xmm_zero, xmm_zero);
654
655 if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
656 mov(reg_tmp.cvt32(), 0x7f7f7f7f);
657 movd(xmm_4x127b, reg_tmp.cvt32());
658 }
659 }
660
661 impl();
662 postamble();
663 ker_ = (void (*)(const call_param_t *))getCode();
664 }
665
666private:
667 int itype_sz;
668 int otype_sz;
669 int stype_sz;
670
671 Reg64 reg_ptr_in = rsi;
672 Reg64 reg_ptr_out = rdx;
673 Reg64 reg_ptr_scale = abi_not_param1;
674
675 Reg64 reg_off_in = r8;
676 Reg64 reg_off_out = r9;
677 Reg64 reg_off_scale = r10;
678
679 Reg64 reg_tmp = rax;
680
681 Xmm xmm_scale = xmm15;
682 Xmm xmm_zero = xmm14;
683 Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero
684 Xmm xmm_tmp = xmm12;
685};
686
687status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb,
688 int ndims_ker_max) {
689 desc.prb = prb;
690 desc.prb.ioff = desc.prb.ooff = 0;
691
692 if (ndims_ker_max > prb.ndims)
693 return status::invalid_arguments;
694
695 auto ndims_ker_max_f = [&]() {
696 size_t cur_size = 1;
697 for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
698 if (cur_size >= ker_prb_size_min) return d;
699 return prb.ndims;
700 };
701
702 if (ndims_ker_max <= 0)
703 ndims_ker_max = ndims_ker_max_f();
704
705 /* traverse through kernel implementations */
706 /* TODO: find a better way to do that... */
707 desc.id = 0;
708 for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
709 desc.prb.ndims = ndims_ker;
710 if (jit_uni_reorder_kernel_f32::applicable(desc.prb))
711 return status::success;
712 }
713
714 return status::unimplemented;
715}
716
717kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
718 switch (desc.id) {
719 case 0: return new jit_uni_reorder_kernel_f32(desc);
720 default: assert(!"unknown kernel id"); return nullptr;
721 }
722
723 return nullptr;
724}
725
726}
727
728static void prb_block_for_cache(tr::prb_t &prb) {
729 if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) {
730 /** an attempt to use caches more efficient and
731 * address the 4K-aliasing issue */
732 /* TODO: improve the logic around here */
733 int j = 1;
734 for (; j < prb.ndims && prb.nodes[j].is != 1; ++j);
735 if (j == prb.ndims) return;
736
737 /* it makes sense to re-prioritize sequential read over
738 * sequential write if the former would not trash the
739 * cache, i.e. is == 1 and os % 2^smth != 0. Smth is
740 * set to 2 at the moment */
741 const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1;
742 if (j == move_to) return;
743
744 if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0)
745 prb_node_split(prb, j, 16);
746
747 prb_node_move(prb, j, move_to);
748 DEBUG({ printf("cache: "); prb_dump(prb); });
749 }
750}
751
752/** finds the maximum number of dimension the kernel should process and
753 * optionally splits one of the dimension to achieve better balance between
754 * parallel driver and the kernel. */
755static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) {
756 size_t sz_total = 1;
757 for (int d = 0; d < prb.ndims; ++d)
758 sz_total *= prb.nodes[d].n;
759
760 /* sz_drv_min is the minimal size for the parallel
761 * driver required for good parallelization */
762 const size_t sz_drv_min = nstl::min<size_t>(
763 16 * mkldnn_get_max_threads(),
764 utils::div_up(sz_total, 1024));
765
766 /* kdims -- # of dimensions processed by a kernel
767 * sz_ker_cur -- product of the dimension processed by a kernel
768 * sz_drv_cur -- product of the dimension processed by a driver */
769
770 int kdims = prb.ndims;
771 size_t sz_drv_cur = 1;
772 for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
773 sz_drv_cur *= prb.nodes[kdims - 1].n;
774
775 size_t sz_ker_cur = 1;
776 for (int d = 0; d < kdims; ++d)
777 sz_ker_cur *= prb.nodes[d].n;
778
779 /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
780 *
781 * It might happen that for chosen kdims the sz_ker_cur is too small
782 * (less than tr::ker_prb_size_min). In that case try to split the
783 * innermost driver dimension into two, to increase sz_ker_cur. */
784 bool want_borrow_ker_from_drv = true
785 && kdims < prb.ndims
786 && sz_ker_cur < tr::ker_prb_size_min
787 && sz_drv_cur > sz_drv_min;
788 if (want_borrow_ker_from_drv) {
789 /* sz_want_borrow is the minimal sz, so that:
790 * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
791 * o) current innermost driver dimension is divisible by
792 * sz_want_borrow (so that we can evenly split that
793 * dimension into two)
794 *
795 * In the worst case the minimal sz_want_borrow is equal
796 * to the innermost driver dimension itself. In that case
797 * we will sacrifice it in favor of kernel (is it fine?). */
798 size_t sz_want_borrow
799 = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
800 for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow);
801 if (sz_want_borrow != prb.nodes[kdims].n)
802 prb_node_split(prb, kdims, sz_want_borrow);
803 kdims += 1;
804 }
805
806 /* On the other hand it might happen that for chosen kdims
807 * the sz_drv_cur is too small (less than sz_drv_min). In that case
808 * try to split the outermost kernel dimension into two, to increase
809 * sz_drv_cur. */
810 bool want_borrow_drv_from_ker = true
811 && sz_ker_cur > tr::ker_prb_size_min
812 && sz_drv_cur < sz_drv_min;
813 if (want_borrow_drv_from_ker) {
814 size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
815 for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow);
816 if (sz_want_borrow != prb.nodes[kdims - 1].n)
817 prb_node_split(prb, kdims - 1,
818 prb.nodes[kdims - 1].n / sz_want_borrow);
819 }
820
821 ndims_ker_max = kdims;
822
823 if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
824 DEBUG({ printf("split: "); prb_dump(prb);
825 printf("ndims_ker_max = %d\n", ndims_ker_max); });
826 }
827}
828
829struct jit_uni_reorder_t : public cpu_primitive_t {
830 struct pd_t : public cpu_reorder_pd_t {
831 using cpu_reorder_pd_t::cpu_reorder_pd_t;
832
833 DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t);
834
835 static status_t create(reorder_pd_t **reorder_pd,
836 engine_t *engine, const primitive_attr_t *attr,
837 engine_t *src_engine, const memory_desc_t *src_md,
838 engine_t *dst_engine, const memory_desc_t *dst_md) {
839 auto prb = tr::prb_t();
840
841 status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
842 if (prb_init_status != status::success) return prb_init_status;
843
844 DEBUG({ printf("init : "); prb_dump(prb); });
845 prb_normalize(prb);
846 DEBUG({ printf("norm : "); prb_dump(prb); });
847 prb_simplify(prb);
848 DEBUG({ printf("smpl : "); prb_dump(prb); });
849
850 prb_block_for_cache(prb);
851
852 int ndims_ker_max;
853 prb_thread_kernel_balance(prb, ndims_ker_max);
854
855 tr::kernel_t::desc_t ker_desc;
856 status_t ker_init_status
857 = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
858 if (ker_init_status != status::success) return ker_init_status;
859
860 const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
861 if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
862 return status::unimplemented;
863
864 DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); });
865
866 auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
867 dst_md);
868 if (_pd == nullptr) return status::out_of_memory;
869 if (_pd->init() != status::success) {
870 delete _pd;
871 return status::unimplemented;
872 }
873 _pd->prb_ = prb;
874 _pd->ker_desc_ = ker_desc;
875 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
876 }
877
878 tr::prb_t prb_;
879 tr::kernel_t::desc_t ker_desc_;
880 };
881
882 jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {
883 kernel_ = tr::kernel_t::create(pd()->ker_desc_);
884 assert(kernel_);
885 }
886 ~jit_uni_reorder_t() { delete kernel_; }
887
888 void omp_driver_0d(int off, const char *in, char *out,
889 const float *scale) const {
890 tr::call_param_t c{in, out, scale};
891 (*kernel_)(&c);
892 }
893
894 void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
895 const float *scale) const {
896 const tr::node_t *ns = pd()->prb_.nodes + off;
897 for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
898 auto c = tr::call_param_t();
899 c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
900 c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
901 c.scale = scale + d0 * ns[0].ss;
902 (*kernel_)(&c);
903 });
904 }
905
906 void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
907 const float *scale) const {
908 const tr::node_t *ns = pd()->prb_.nodes + off;
909 for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
910 [&](ptrdiff_t d1, ptrdiff_t d0) {
911 auto c = tr::call_param_t();
912 c.in = in + (d0 * ns[0].is + d1 * ns[1].is)
913 * data_type_size(pd()->prb_.itype);
914 c.out = out + (d0 * ns[0].os + d1 * ns[1].os)
915 * data_type_size(pd()->prb_.otype);
916 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
917 (*kernel_)(&c);
918 });
919 }
920
921 void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
922 const float *scale) const {
923 const tr::node_t *ns = pd()->prb_.nodes + off;
924 for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
925 (ptrdiff_t)ns[0].n,
926 [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
927 auto c = tr::call_param_t();
928 c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
929 * data_type_size(pd()->prb_.itype);
930 c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
931 * data_type_size(pd()->prb_.otype);
932 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
933 (*kernel_)(&c);
934 });
935 }
936
937 void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
938 const float *scale) const {
939 const tr::node_t *ns = pd()->prb_.nodes + off;
940 for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
941 (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
942 [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
943 auto c = tr::call_param_t();
944 c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
945 + d3 * ns[3].is) * data_type_size(pd()->prb_.itype);
946 c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
947 + d3 * ns[3].os) * data_type_size(pd()->prb_.otype);
948 c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
949 + d3 * ns[3].ss;
950 (*kernel_)(&c);
951 });
952 }
953
954 void omp_driver(const char *in, char *out, const float *scale) const {
955 in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
956 out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
957
958 DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); });
959 DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); });
960
961 int ndims = pd()->prb_.ndims;
962 int ndims_ker = pd()->ker_desc_.prb.ndims;
963 assert(ndims - ndims_ker <= ndims_driver_max);
964
965 if (ndims - ndims_ker == 0) {
966 omp_driver_0d(ndims_ker, in, out, scale);
967 } else {
968 parallel(0, [&](const int ithr, const int nthr) {
969 switch (ndims - ndims_ker) {
970 case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break;
971 case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break;
972 case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break;
973 case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break;
974 default: assert(!"unimplemented");
975 }
976 });
977 }
978 }
979
980 virtual status_t execute(const exec_ctx_t &ctx) const override {
981 auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM);
982 auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO);
983
984 omp_driver(in, out, pd()->attr()->output_scales_.scales_);
985
986 return status::success;
987 }
988
989 enum { ndims_driver_max = 4 };
990
991private:
992 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
993 tr::kernel_t *kernel_;
994};
995
996status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
997 engine_t *engine, const primitive_attr_t *attr,
998 engine_t *src_engine, const memory_desc_t *src_md,
999 engine_t *dst_engine, const memory_desc_t *dst_md) {
1000 return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr,
1001 src_engine, src_md, dst_engine, dst_md);
1002}
1003
1004}
1005}
1006}
1007