1 | /******************************************************************************* |
2 | * Copyright 2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "c_types_map.hpp" |
20 | #include "memory_desc_wrapper.hpp" |
21 | #include "mkldnn_debug.h" |
22 | #include "nstl.hpp" |
23 | #include "type_helpers.hpp" |
24 | |
25 | #include "cpu_primitive.hpp" |
26 | #include "cpu_reorder_pd.hpp" |
27 | #include "jit_uni_reorder.hpp" |
28 | |
29 | #include "jit_generator.hpp" |
30 | |
31 | // #define TR_DEBUG |
32 | #if defined(TR_DEBUG) |
33 | #define DEBUg(...) do { __VA_ARGS__ } while (0) |
34 | #else |
35 | #define DEBUg(...) |
36 | #endif |
37 | #define DEBUG(...) DEBUg(__VA_ARGS__) |
38 | |
39 | #ifdef _WIN32 |
40 | /* seems like s_addr is a reserved macro on Windows */ |
41 | #undef s_addr |
42 | #endif |
43 | |
44 | using namespace Xbyak; |
45 | using namespace mkldnn::impl::types; |
46 | |
47 | namespace mkldnn { |
48 | namespace impl { |
49 | namespace cpu { |
50 | |
51 | namespace tr { |
52 | |
53 | /** Minimal reasonable/desirable kernel size. |
54 | * The constant might be used to determine how a problem should be split |
55 | * between kernel and threading driver. */ |
56 | const size_t ker_prb_size_min = 64; |
57 | |
58 | /* kernel */ |
59 | struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator { |
60 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) |
61 | |
62 | enum { |
63 | len_unroll_max = 256, |
64 | ndims_jit_loop_max = 3, |
65 | }; |
66 | |
67 | struct simple_impl_desc_t { |
68 | int ndims_full_unroll; |
69 | int len_last_dim_unroll; |
70 | int len_unroll; |
71 | }; |
72 | |
73 | static bool simple_impl_desc_init(const prb_t &prb, |
74 | simple_impl_desc_t *desc) { |
75 | const int ndims = prb.ndims; |
76 | |
77 | int ndims_full_unroll = 0; |
78 | int len_last_dim_unroll = 1; |
79 | int len_unroll = 1; |
80 | |
81 | for (int d = 0; d < ndims; ++d) { |
82 | auto &node = prb.nodes[d]; |
83 | if (len_unroll * node.n <= len_unroll_max) { |
84 | ndims_full_unroll++; |
85 | len_unroll *= node.n; |
86 | } else { |
87 | len_last_dim_unroll = len_unroll_max / len_unroll; |
88 | while (node.n % len_last_dim_unroll) |
89 | --len_last_dim_unroll; |
90 | len_unroll *= len_last_dim_unroll; |
91 | break; |
92 | } |
93 | } |
94 | |
95 | if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) |
96 | return false; |
97 | |
98 | if (desc) { |
99 | desc->ndims_full_unroll = ndims_full_unroll; |
100 | desc->len_last_dim_unroll = len_last_dim_unroll; |
101 | desc->len_unroll = len_unroll; |
102 | } |
103 | |
104 | return true; |
105 | } |
106 | |
107 | static bool applicable(const prb_t &p) { |
108 | using namespace data_type; |
109 | |
110 | bool ok = true |
111 | && p.ndims > 0 |
112 | && utils::one_of(p.itype, f32, s32, s8, u8) |
113 | && utils::one_of(p.otype, f32, s32, s8, u8) |
114 | && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ |
115 | && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ |
116 | && simple_impl_desc_init(p, nullptr) |
117 | && mayiuse(sse42) |
118 | && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype), |
119 | mayiuse(avx)); |
120 | if (!ok) return false; |
121 | |
122 | const ptrdiff_t max_stride = (1LL<<31) - 1; |
123 | for (int d = 0; d < p.ndims; ++d) { |
124 | const ptrdiff_t cms = max_stride / p.nodes[d].n; |
125 | bool strides_ok = true |
126 | && p.nodes[d].is < cms / (int)data_type_size(p.itype) |
127 | && p.nodes[d].os < cms / (int)data_type_size(p.otype); |
128 | if (!strides_ok) return false; |
129 | } |
130 | |
131 | return true; |
132 | } |
133 | |
134 | int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; } |
135 | int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; } |
136 | int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; } |
137 | int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; } |
138 | |
139 | Address i_addr(int i_off) |
140 | { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; } |
141 | |
142 | Address o_addr(int o_off) |
143 | { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; } |
144 | |
145 | Address s_addr(int s_off) |
146 | { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; } |
147 | |
148 | void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, |
149 | int &i_off, int &o_off, int &s_off, int step_size = 1) { |
150 | i_off = prev_i_off; |
151 | o_off = prev_o_off; |
152 | s_off = prev_s_off; |
153 | |
154 | if (off == 0) return; |
155 | |
156 | int start_dim = 0, dims_prod = 1; |
157 | for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) |
158 | dims_prod *= n(start_dim); |
159 | assert(start_dim < prb_.ndims); |
160 | off /= step_size; |
161 | |
162 | for (int d = start_dim; d < prb_.ndims; ++d) { |
163 | i_off += is(d); |
164 | o_off += os(d); |
165 | s_off += ss(d); |
166 | |
167 | if (off % n(d)) break; |
168 | |
169 | i_off += - n(d) * is(d); |
170 | o_off += - n(d) * os(d); |
171 | s_off += - n(d) * ss(d); |
172 | off /= n(d); |
173 | |
174 | if (off == 0) break; /* FIXME: is it really required? */ |
175 | } |
176 | } |
177 | |
178 | void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, |
179 | int step_size = 1) { |
180 | int dummy = 0; |
181 | step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, |
182 | step_size); |
183 | } |
184 | |
185 | void tr8x8_avx2(int i_off, int o_off) { |
186 | for (int i = 0; i < 8; i++) |
187 | vmovups(Ymm(i), i_addr(i_off + i * 8)); |
188 | |
189 | for (int i = 0; i < 8 / 2; i++) { |
190 | vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1)); |
191 | vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); |
192 | } |
193 | |
194 | const unsigned int lfloat = 0x44; |
195 | const unsigned int ufloat = 0xee; |
196 | for (int i = 0; i < 8 / 2; i++) { |
197 | int j = i % 2 == 0 ? 8 + i : i - 1; |
198 | vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); |
199 | vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); |
200 | } |
201 | |
202 | const unsigned int lquad = 0x20; |
203 | for (int i = 0; i < 8 / 2; i++) |
204 | vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad); |
205 | |
206 | const unsigned int uquad = 0x31; |
207 | for (int i = 8 / 2; i < 8; i++) |
208 | vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad); |
209 | |
210 | for (int i = 0; i < 8; i++) |
211 | vmovups(o_addr(o_off + i * 8), Ymm(i)); |
212 | } |
213 | |
214 | bool process_unroll_tr8x8(int len) { |
215 | bool can_do = true |
216 | && mayiuse(avx2) |
217 | && prb_.ndims >= 2 |
218 | && utils::everyone_is(4, itype_sz, otype_sz) |
219 | && utils::everyone_is(8, n(0), n(1)) |
220 | && utils::everyone_is(1, os(0), is(1)) |
221 | && utils::everyone_is(8, os(1), is(0)) |
222 | && prb_.scale_type == scale_type_t::NONE |
223 | && prb_.beta == 0.f; |
224 | if (!can_do) return false; |
225 | |
226 | const int step_size = n(0) * n(1); |
227 | int i_off = 0, o_off = 0; |
228 | for (int off = 0; off < len; off += step_size) { |
229 | step(off, i_off, o_off, i_off, o_off, step_size); |
230 | tr8x8_avx2(i_off, o_off); |
231 | } |
232 | |
233 | return true; |
234 | } |
235 | |
236 | template <cpu_isa_t isa> |
237 | bool process_direct_copy(int len) { |
238 | using namespace data_type; |
239 | |
240 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
241 | const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz; |
242 | |
243 | bool can_do = true |
244 | && mayiuse(isa) |
245 | && utils::everyone_is(1, os(0), is(0)) |
246 | && (false |
247 | || prb_.itype == prb_.otype |
248 | || (prb_.itype == s32 && prb_.otype == f32) |
249 | || (prb_.itype == f32 && prb_.otype == s32) |
250 | ) |
251 | && len % simd_w == 0 |
252 | && n(0) % len == 0 |
253 | && prb_.scale_type == scale_type_t::NONE |
254 | && prb_.beta == 0.f; |
255 | if (!can_do) return false; |
256 | |
257 | for (int off = 0; off < len;) { |
258 | const int unroll = nstl::min(16, (len - off) / simd_w); |
259 | |
260 | for (int ur = 0; ur < unroll; ++ur) |
261 | uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); |
262 | |
263 | if (prb_.itype != prb_.otype) { |
264 | for (int ur = 0; ur < unroll; ++ur) { |
265 | if (prb_.itype == s32 && prb_.otype == f32) |
266 | uni_vcvtdq2ps(Vmm(ur), Vmm(ur)); |
267 | else if (prb_.itype == f32 && prb_.otype == s32) |
268 | uni_vcvtps2dq(Vmm(ur), Vmm(ur)); |
269 | else assert(!"unreachable" ); |
270 | } |
271 | } |
272 | |
273 | for (int ur = 0; ur < unroll; ++ur) |
274 | uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur)); |
275 | |
276 | off += unroll * simd_w; |
277 | } |
278 | |
279 | return true; |
280 | } |
281 | |
282 | void process_unroll_generic_step(int reg_unroll, const int *i_off, |
283 | const int *o_off, const int *s_off) { |
284 | using namespace data_type; |
285 | |
286 | auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) { |
287 | Xmm dst_pure = Xmm(dst.getIdx()); |
288 | switch (idt) { |
289 | case f32: |
290 | if (src.isMEM() || src.getIdx() != dst.getIdx()) |
291 | vmovups(dst, src); |
292 | break; |
293 | case s32: vcvtdq2ps(dst, src); break; |
294 | case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; |
295 | case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; |
296 | default: assert(!"unreachable" ); |
297 | } |
298 | }; |
299 | |
300 | auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) { |
301 | switch (odt) { |
302 | case s32: |
303 | if (idt == f32) vcvtps2dq(xmm, xmm); |
304 | else if (idt == s8) vpmovsxbd(xmm, xmm); |
305 | else if (idt == u8) vpmovzxbd(xmm, xmm); |
306 | break; |
307 | case s8: |
308 | if (idt == f32) vcvtps2dq(xmm, xmm); |
309 | if (idt == f32 || idt == s32) { |
310 | if (mayiuse(avx512_core)) { |
311 | vpmovsdb(xmm, xmm); |
312 | } else { |
313 | vpackssdw(xmm, xmm, xmm_zero); |
314 | vpacksswb(xmm, xmm, xmm_zero); |
315 | } |
316 | } |
317 | if (idt == u8) vpminub(xmm, xmm, xmm_4x127b); |
318 | break; |
319 | case u8: |
320 | if (idt == f32) vcvtps2dq(xmm, xmm); |
321 | if (idt == f32 || idt == s32) { |
322 | if (mayiuse(avx512_core)) { |
323 | vpmaxsd(xmm, xmm, xmm_zero); |
324 | vpmovusdb(xmm, xmm); |
325 | } else { |
326 | vpackssdw(xmm, xmm, xmm_zero); |
327 | vpackuswb(xmm, xmm, xmm_zero); |
328 | } |
329 | } |
330 | if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero); |
331 | break; |
332 | default: assert(!"unreachable" ); |
333 | } |
334 | }; |
335 | |
336 | auto load = [=](const Xmm &xmm, const Address &addr, int size) { |
337 | switch (size) { |
338 | case 16: movups(xmm, addr); break; |
339 | case 4: movss(xmm, addr); break; |
340 | case 1: pinsrb(xmm, addr, 0x0); break; |
341 | default: assert(!"unreachable" ); |
342 | } |
343 | }; |
344 | |
345 | auto store = [=](const Address &addr, const Xmm &xmm, int size) { |
346 | switch (size) { |
347 | case 16: movups(addr, xmm); break; |
348 | case 4: movss(addr, xmm); break; |
349 | case 1: pextrb(addr, xmm, 0x0); break; |
350 | default: assert(!"unreachable" ); |
351 | } |
352 | }; |
353 | |
354 | /* check whether loading 4 values at once is possible */ |
355 | bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0; |
356 | for (int ur = 1; ur < reg_unroll; ++ur) |
357 | if (i_off[ur] != i_off[ur - 1] + 1) |
358 | can_load_xmm = false; |
359 | const int load_step = can_load_xmm ? 4 : 1; |
360 | |
361 | /* check whether storing 4 values at once is possible */ |
362 | bool can_store_xmm = reg_unroll % 4 == 0; |
363 | for (int ur = 1; ur < reg_unroll; ++ur) |
364 | if (o_off[ur] != o_off[ur - 1] + 1) |
365 | can_store_xmm = false; |
366 | const int ur_step = can_store_xmm ? 4 : 1; |
367 | |
368 | const bool interim_f32 = false |
369 | || utils::one_of(f32, prb_.itype, prb_.otype) |
370 | || prb_.scale_type != scale_type_t::NONE |
371 | || prb_.beta != 0.f; |
372 | |
373 | if (!can_load_xmm && can_store_xmm) { |
374 | assert(ur_step == 4); |
375 | /* load with stride */ |
376 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
377 | for (int r = 0; r < ur_step; ++r) { |
378 | if (itype_sz == 4) |
379 | pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r); |
380 | else |
381 | pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r); |
382 | } |
383 | } |
384 | } else { |
385 | for (int ur = 0; ur < reg_unroll; ur += load_step) |
386 | load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); |
387 | } |
388 | |
389 | /* xmm[:] <-- (f32)xmm[:] */ |
390 | if (interim_f32) { |
391 | const int cvt_step = nstl::max(load_step, ur_step); |
392 | for (int ur = 0; ur < reg_unroll; ur += cvt_step) |
393 | cvt2ps(Xmm(ur), Xmm(ur), prb_.itype); |
394 | } |
395 | |
396 | if (can_load_xmm && !can_store_xmm) { |
397 | const bool fast_return = true // transposition on the fly |
398 | && prb_.scale_type != scale_type_t::MANY |
399 | && prb_.beta == 0.f; |
400 | if (fast_return) { |
401 | for (int ur = 0; ur < reg_unroll; ur += load_step) { |
402 | if (prb_.scale_type == scale_type_t::COMMON) |
403 | mulps(Xmm(ur), xmm_scale); |
404 | if (prb_.otype != f32) |
405 | cvt2int(Xmm(ur), prb_.otype, |
406 | interim_f32 ? f32 : prb_.itype); |
407 | for (int r = 0; r < load_step; ++r) { |
408 | if (otype_sz == 4) |
409 | pextrd(o_addr(o_off[ur + r]), Xmm(ur), r); |
410 | else |
411 | pextrb(o_addr(o_off[ur + r]), Xmm(ur), r); |
412 | } |
413 | } |
414 | return; |
415 | } |
416 | |
417 | /* scatter elements of xmm into 4 xmms */ |
418 | if (itype_sz == 4 || interim_f32) { |
419 | for (int ur = 0; ur < reg_unroll; ur += load_step) |
420 | for (int r = 1; r < load_step; ++r) |
421 | vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r); |
422 | } else { |
423 | for (int ur = 0; ur < reg_unroll; ur += load_step) |
424 | for (int r = 1; r < load_step; ++r) |
425 | vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r); |
426 | } |
427 | } |
428 | |
429 | /* scale and beta processing */ |
430 | if (can_store_xmm) { |
431 | /* xmm <-- scale * xmm[:] */ |
432 | if (prb_.scale_type == scale_type_t::COMMON) { |
433 | for (int ur = 0; ur < reg_unroll; ur += ur_step) |
434 | mulps(Xmm(ur), xmm_scale); |
435 | } else if (prb_.scale_type == scale_type_t::MANY) { |
436 | enum class scale_load_type_t { bcast, load, gather }; |
437 | |
438 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
439 | scale_load_type_t scale_load_type = |
440 | scale_load_type_t::bcast; // the best case |
441 | |
442 | for (int r = ur + 1; r < ur + ur_step; ++r) |
443 | if (s_off[r] != s_off[r - 1] + 0) |
444 | scale_load_type = scale_load_type_t::load; |
445 | |
446 | if (scale_load_type == scale_load_type_t::bcast) { |
447 | movss(xmm_scale, s_addr(s_off[ur])); |
448 | shufps(xmm_scale, xmm_scale, 0x0); |
449 | mulps(Xmm(ur), xmm_scale); |
450 | continue; |
451 | } |
452 | |
453 | // bcast doesn't work, the next try -- load |
454 | for (int r = ur + 1; r < ur + ur_step; ++r) |
455 | if (s_off[r] != s_off[r - 1] + 1) |
456 | scale_load_type = scale_load_type_t::gather; |
457 | |
458 | if (scale_load_type == scale_load_type_t::load) { |
459 | movups(xmm_scale, s_addr(s_off[ur])); |
460 | mulps(Xmm(ur), xmm_scale); |
461 | continue; |
462 | } |
463 | |
464 | // load doesn't work as well |
465 | // so gather the scale factors one by one |
466 | for (int r = ur; r < ur + ur_step; ++r) |
467 | pinsrd(xmm_scale, s_addr(s_off[r]), r - ur); |
468 | mulps(Xmm(ur), xmm_scale); |
469 | } |
470 | } |
471 | |
472 | /* dst <-- beta * dst + xmm[:] */ |
473 | assert(prb_.beta == 0.f || prb_.beta == 1.f); |
474 | if (prb_.beta == 1.f) { |
475 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
476 | if (prb_.otype == f32) { |
477 | /* non VEX instructions do not support unaligned |
478 | * memory for instructions other than movups. */ |
479 | if (mayiuse(avx)) { |
480 | vaddps(Xmm(ur), o_addr(o_off[ur])); |
481 | } else { |
482 | /* register xmm(1) is unused */ |
483 | movups(Xmm(1), o_addr(o_off[ur])); |
484 | addps(Xmm(ur), Xmm(1)); |
485 | } |
486 | } else { |
487 | cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype); |
488 | vaddps(Xmm(ur), Xmm(1)); |
489 | } |
490 | } |
491 | } |
492 | } else { |
493 | /* xmm[0] <-- scale * xmm[0] */ |
494 | if (prb_.scale_type == scale_type_t::COMMON) { |
495 | for (int ur = 0; ur < reg_unroll; ur += ur_step) |
496 | mulss(Xmm(ur), xmm_scale); |
497 | } else if (prb_.scale_type == scale_type_t::MANY) { |
498 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
499 | mulss(Xmm(ur), s_addr(s_off[ur])); |
500 | } |
501 | } |
502 | |
503 | /* dst <-- beta * dst + xmm[0] */ |
504 | assert(prb_.beta == 0.f || prb_.beta == 1.f); |
505 | if (prb_.beta == 1.f) { |
506 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
507 | if (prb_.otype == f32) { |
508 | addss(Xmm(ur), o_addr(o_off[ur])); |
509 | } else { |
510 | if (prb_.otype == s32) { |
511 | vmovss(xmm_tmp, o_addr(o_off[ur])); |
512 | } else if (utils::one_of(prb_.otype, s8, u8)) { |
513 | pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0); |
514 | } else { |
515 | assert(!"unsupported o_type" ); |
516 | } |
517 | cvt2ps(xmm_tmp, xmm_tmp, prb_.otype); |
518 | addps(Xmm(ur), xmm_tmp); |
519 | } |
520 | } |
521 | } |
522 | } |
523 | |
524 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
525 | if (prb_.otype != f32) |
526 | cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype); |
527 | store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz); |
528 | } |
529 | } |
530 | |
531 | void process_unroll_generic(int len) { |
532 | const int blk = 8; |
533 | |
534 | int i_off[2 * blk] = {0}; |
535 | int o_off[2 * blk] = {0}; |
536 | int s_off[2 * blk] = {0}; |
537 | |
538 | int curr = 0; // will switch between 0 and 1 |
539 | |
540 | for (int off = 0; off < len; off += blk) { |
541 | const int reg_unroll = nstl::min(off + blk, len) - off; |
542 | |
543 | /* compute offsets */ |
544 | for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { |
545 | const int ur_c = curr * blk + ur; |
546 | const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur |
547 | step(off + ur, |
548 | i_off[ur_p], o_off[ur_p], s_off[ur_p], |
549 | i_off[ur_c], o_off[ur_c], s_off[ur_c]); |
550 | } |
551 | |
552 | process_unroll_generic_step(reg_unroll, i_off + curr * blk, |
553 | o_off + curr * blk, s_off + curr * blk); |
554 | |
555 | curr = 1 - curr; |
556 | } |
557 | } |
558 | |
559 | void loop_begin(Label &l, Reg64 reg_cnt, int len) { |
560 | mov(reg_cnt, len); |
561 | L(l); |
562 | } |
563 | |
564 | void loop_end(Label &l, Reg64 reg_cnt, int len, |
565 | int i_step, int o_step, int s_step) { |
566 | add(reg_off_in, i_step * itype_sz); |
567 | add(reg_off_out, o_step * otype_sz); |
568 | if (prb_.scale_type == scale_type_t::MANY) |
569 | add(reg_off_scale, s_step * stype_sz); |
570 | dec(reg_cnt); |
571 | jnz(l); |
572 | |
573 | sub(reg_off_in, len * i_step * itype_sz); |
574 | sub(reg_off_out, len * o_step * otype_sz); |
575 | if (prb_.scale_type == scale_type_t::MANY) |
576 | sub(reg_off_scale, len * s_step * stype_sz); |
577 | } |
578 | |
579 | bool simple_impl() { |
580 | simple_impl_desc_t d; |
581 | if (!simple_impl_desc_init(prb_, &d)) return false; |
582 | |
583 | const int nfu = d.ndims_full_unroll; |
584 | const int ldu = d.len_last_dim_unroll; |
585 | const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; |
586 | assert(n_jit_loops <= ndims_jit_loop_max); |
587 | |
588 | xor_(reg_off_in, reg_off_in); |
589 | xor_(reg_off_out, reg_off_out); |
590 | if (prb_.scale_type == scale_type_t::MANY) |
591 | xor_(reg_off_scale, reg_off_scale); |
592 | |
593 | Label l_loop[3]; |
594 | Reg64 reg_cnt[3] = {r15, r14, r13}; |
595 | |
596 | if (n_jit_loops > 2) |
597 | loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); |
598 | |
599 | if (n_jit_loops > 1) |
600 | loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); |
601 | |
602 | if (n_jit_loops > 0) |
603 | loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); |
604 | |
605 | const bool optimized = false |
606 | || process_direct_copy<avx>(d.len_unroll) |
607 | || process_direct_copy<sse42>(d.len_unroll) |
608 | || process_unroll_tr8x8(d.len_unroll); |
609 | if (!optimized) |
610 | process_unroll_generic(d.len_unroll); |
611 | |
612 | if (n_jit_loops > 0) |
613 | loop_end(l_loop[0], reg_cnt[0], |
614 | n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu, |
615 | ss(nfu + 0) * ldu); |
616 | |
617 | if (n_jit_loops > 1) |
618 | loop_end(l_loop[1], reg_cnt[1], |
619 | n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1)); |
620 | |
621 | if (n_jit_loops > 2) |
622 | loop_end(l_loop[2], reg_cnt[2], |
623 | n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2)); |
624 | |
625 | return true; |
626 | } |
627 | |
628 | void impl() { |
629 | if (simple_impl()) return; |
630 | assert(!"no implementation available" ); |
631 | } |
632 | |
633 | jit_uni_reorder_kernel_f32(const desc_t &desc) |
634 | : kernel_t(desc), jit_generator() { |
635 | itype_sz = data_type_size(prb_.itype); |
636 | otype_sz = data_type_size(prb_.otype); |
637 | stype_sz = sizeof(float); |
638 | |
639 | preamble(); |
640 | # define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)] |
641 | if (prb_.scale_type == scale_type_t::COMMON) { |
642 | auto reg_ptr_scale_tmp = reg_ptr_in; |
643 | mov(reg_ptr_scale_tmp, PARAM(scale)); |
644 | movups(xmm_scale, ptr[reg_ptr_scale_tmp]); |
645 | } else if (prb_.scale_type == scale_type_t::MANY) { |
646 | mov(reg_ptr_scale, PARAM(scale)); |
647 | } |
648 | mov(reg_ptr_in, PARAM(in)); |
649 | mov(reg_ptr_out, PARAM(out)); |
650 | # undef PARAM |
651 | |
652 | if (mayiuse(avx)) { |
653 | vxorps(xmm_zero, xmm_zero, xmm_zero); |
654 | |
655 | if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { |
656 | mov(reg_tmp.cvt32(), 0x7f7f7f7f); |
657 | movd(xmm_4x127b, reg_tmp.cvt32()); |
658 | } |
659 | } |
660 | |
661 | impl(); |
662 | postamble(); |
663 | ker_ = (void (*)(const call_param_t *))getCode(); |
664 | } |
665 | |
666 | private: |
667 | int itype_sz; |
668 | int otype_sz; |
669 | int stype_sz; |
670 | |
671 | Reg64 reg_ptr_in = rsi; |
672 | Reg64 reg_ptr_out = rdx; |
673 | Reg64 reg_ptr_scale = abi_not_param1; |
674 | |
675 | Reg64 reg_off_in = r8; |
676 | Reg64 reg_off_out = r9; |
677 | Reg64 reg_off_scale = r10; |
678 | |
679 | Reg64 reg_tmp = rax; |
680 | |
681 | Xmm xmm_scale = xmm15; |
682 | Xmm xmm_zero = xmm14; |
683 | Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero |
684 | Xmm xmm_tmp = xmm12; |
685 | }; |
686 | |
687 | status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb, |
688 | int ndims_ker_max) { |
689 | desc.prb = prb; |
690 | desc.prb.ioff = desc.prb.ooff = 0; |
691 | |
692 | if (ndims_ker_max > prb.ndims) |
693 | return status::invalid_arguments; |
694 | |
695 | auto ndims_ker_max_f = [&]() { |
696 | size_t cur_size = 1; |
697 | for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) |
698 | if (cur_size >= ker_prb_size_min) return d; |
699 | return prb.ndims; |
700 | }; |
701 | |
702 | if (ndims_ker_max <= 0) |
703 | ndims_ker_max = ndims_ker_max_f(); |
704 | |
705 | /* traverse through kernel implementations */ |
706 | /* TODO: find a better way to do that... */ |
707 | desc.id = 0; |
708 | for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { |
709 | desc.prb.ndims = ndims_ker; |
710 | if (jit_uni_reorder_kernel_f32::applicable(desc.prb)) |
711 | return status::success; |
712 | } |
713 | |
714 | return status::unimplemented; |
715 | } |
716 | |
717 | kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { |
718 | switch (desc.id) { |
719 | case 0: return new jit_uni_reorder_kernel_f32(desc); |
720 | default: assert(!"unknown kernel id" ); return nullptr; |
721 | } |
722 | |
723 | return nullptr; |
724 | } |
725 | |
726 | } |
727 | |
728 | static void prb_block_for_cache(tr::prb_t &prb) { |
729 | if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) { |
730 | /** an attempt to use caches more efficient and |
731 | * address the 4K-aliasing issue */ |
732 | /* TODO: improve the logic around here */ |
733 | int j = 1; |
734 | for (; j < prb.ndims && prb.nodes[j].is != 1; ++j); |
735 | if (j == prb.ndims) return; |
736 | |
737 | /* it makes sense to re-prioritize sequential read over |
738 | * sequential write if the former would not trash the |
739 | * cache, i.e. is == 1 and os % 2^smth != 0. Smth is |
740 | * set to 2 at the moment */ |
741 | const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1; |
742 | if (j == move_to) return; |
743 | |
744 | if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0) |
745 | prb_node_split(prb, j, 16); |
746 | |
747 | prb_node_move(prb, j, move_to); |
748 | DEBUG({ printf("cache: " ); prb_dump(prb); }); |
749 | } |
750 | } |
751 | |
752 | /** finds the maximum number of dimension the kernel should process and |
753 | * optionally splits one of the dimension to achieve better balance between |
754 | * parallel driver and the kernel. */ |
755 | static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) { |
756 | size_t sz_total = 1; |
757 | for (int d = 0; d < prb.ndims; ++d) |
758 | sz_total *= prb.nodes[d].n; |
759 | |
760 | /* sz_drv_min is the minimal size for the parallel |
761 | * driver required for good parallelization */ |
762 | const size_t sz_drv_min = nstl::min<size_t>( |
763 | 16 * mkldnn_get_max_threads(), |
764 | utils::div_up(sz_total, 1024)); |
765 | |
766 | /* kdims -- # of dimensions processed by a kernel |
767 | * sz_ker_cur -- product of the dimension processed by a kernel |
768 | * sz_drv_cur -- product of the dimension processed by a driver */ |
769 | |
770 | int kdims = prb.ndims; |
771 | size_t sz_drv_cur = 1; |
772 | for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) |
773 | sz_drv_cur *= prb.nodes[kdims - 1].n; |
774 | |
775 | size_t sz_ker_cur = 1; |
776 | for (int d = 0; d < kdims; ++d) |
777 | sz_ker_cur *= prb.nodes[d].n; |
778 | |
779 | /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. |
780 | * |
781 | * It might happen that for chosen kdims the sz_ker_cur is too small |
782 | * (less than tr::ker_prb_size_min). In that case try to split the |
783 | * innermost driver dimension into two, to increase sz_ker_cur. */ |
784 | bool want_borrow_ker_from_drv = true |
785 | && kdims < prb.ndims |
786 | && sz_ker_cur < tr::ker_prb_size_min |
787 | && sz_drv_cur > sz_drv_min; |
788 | if (want_borrow_ker_from_drv) { |
789 | /* sz_want_borrow is the minimal sz, so that: |
790 | * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min |
791 | * o) current innermost driver dimension is divisible by |
792 | * sz_want_borrow (so that we can evenly split that |
793 | * dimension into two) |
794 | * |
795 | * In the worst case the minimal sz_want_borrow is equal |
796 | * to the innermost driver dimension itself. In that case |
797 | * we will sacrifice it in favor of kernel (is it fine?). */ |
798 | size_t sz_want_borrow |
799 | = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); |
800 | for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow); |
801 | if (sz_want_borrow != prb.nodes[kdims].n) |
802 | prb_node_split(prb, kdims, sz_want_borrow); |
803 | kdims += 1; |
804 | } |
805 | |
806 | /* On the other hand it might happen that for chosen kdims |
807 | * the sz_drv_cur is too small (less than sz_drv_min). In that case |
808 | * try to split the outermost kernel dimension into two, to increase |
809 | * sz_drv_cur. */ |
810 | bool want_borrow_drv_from_ker = true |
811 | && sz_ker_cur > tr::ker_prb_size_min |
812 | && sz_drv_cur < sz_drv_min; |
813 | if (want_borrow_drv_from_ker) { |
814 | size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); |
815 | for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow); |
816 | if (sz_want_borrow != prb.nodes[kdims - 1].n) |
817 | prb_node_split(prb, kdims - 1, |
818 | prb.nodes[kdims - 1].n / sz_want_borrow); |
819 | } |
820 | |
821 | ndims_ker_max = kdims; |
822 | |
823 | if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { |
824 | DEBUG({ printf("split: " ); prb_dump(prb); |
825 | printf("ndims_ker_max = %d\n" , ndims_ker_max); }); |
826 | } |
827 | } |
828 | |
829 | struct jit_uni_reorder_t : public cpu_primitive_t { |
830 | struct pd_t : public cpu_reorder_pd_t { |
831 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
832 | |
833 | DECLARE_COMMON_PD_T("jit:uni" , jit_uni_reorder_t); |
834 | |
835 | static status_t create(reorder_pd_t **reorder_pd, |
836 | engine_t *engine, const primitive_attr_t *attr, |
837 | engine_t *src_engine, const memory_desc_t *src_md, |
838 | engine_t *dst_engine, const memory_desc_t *dst_md) { |
839 | auto prb = tr::prb_t(); |
840 | |
841 | status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); |
842 | if (prb_init_status != status::success) return prb_init_status; |
843 | |
844 | DEBUG({ printf("init : " ); prb_dump(prb); }); |
845 | prb_normalize(prb); |
846 | DEBUG({ printf("norm : " ); prb_dump(prb); }); |
847 | prb_simplify(prb); |
848 | DEBUG({ printf("smpl : " ); prb_dump(prb); }); |
849 | |
850 | prb_block_for_cache(prb); |
851 | |
852 | int ndims_ker_max; |
853 | prb_thread_kernel_balance(prb, ndims_ker_max); |
854 | |
855 | tr::kernel_t::desc_t ker_desc; |
856 | status_t ker_init_status |
857 | = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); |
858 | if (ker_init_status != status::success) return ker_init_status; |
859 | |
860 | const int ndims_driver = prb.ndims - ker_desc.prb.ndims; |
861 | if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) |
862 | return status::unimplemented; |
863 | |
864 | DEBUG({ printf("ker : " ); prb_dump(ker_desc.prb); }); |
865 | |
866 | auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, |
867 | dst_md); |
868 | if (_pd == nullptr) return status::out_of_memory; |
869 | if (_pd->init() != status::success) { |
870 | delete _pd; |
871 | return status::unimplemented; |
872 | } |
873 | _pd->prb_ = prb; |
874 | _pd->ker_desc_ = ker_desc; |
875 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
876 | } |
877 | |
878 | tr::prb_t prb_; |
879 | tr::kernel_t::desc_t ker_desc_; |
880 | }; |
881 | |
882 | jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { |
883 | kernel_ = tr::kernel_t::create(pd()->ker_desc_); |
884 | assert(kernel_); |
885 | } |
886 | ~jit_uni_reorder_t() { delete kernel_; } |
887 | |
888 | void omp_driver_0d(int off, const char *in, char *out, |
889 | const float *scale) const { |
890 | tr::call_param_t c{in, out, scale}; |
891 | (*kernel_)(&c); |
892 | } |
893 | |
894 | void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, |
895 | const float *scale) const { |
896 | const tr::node_t *ns = pd()->prb_.nodes + off; |
897 | for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { |
898 | auto c = tr::call_param_t(); |
899 | c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); |
900 | c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); |
901 | c.scale = scale + d0 * ns[0].ss; |
902 | (*kernel_)(&c); |
903 | }); |
904 | } |
905 | |
906 | void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, |
907 | const float *scale) const { |
908 | const tr::node_t *ns = pd()->prb_.nodes + off; |
909 | for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, |
910 | [&](ptrdiff_t d1, ptrdiff_t d0) { |
911 | auto c = tr::call_param_t(); |
912 | c.in = in + (d0 * ns[0].is + d1 * ns[1].is) |
913 | * data_type_size(pd()->prb_.itype); |
914 | c.out = out + (d0 * ns[0].os + d1 * ns[1].os) |
915 | * data_type_size(pd()->prb_.otype); |
916 | c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; |
917 | (*kernel_)(&c); |
918 | }); |
919 | } |
920 | |
921 | void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, |
922 | const float *scale) const { |
923 | const tr::node_t *ns = pd()->prb_.nodes + off; |
924 | for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, |
925 | (ptrdiff_t)ns[0].n, |
926 | [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { |
927 | auto c = tr::call_param_t(); |
928 | c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) |
929 | * data_type_size(pd()->prb_.itype); |
930 | c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) |
931 | * data_type_size(pd()->prb_.otype); |
932 | c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; |
933 | (*kernel_)(&c); |
934 | }); |
935 | } |
936 | |
937 | void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, |
938 | const float *scale) const { |
939 | const tr::node_t *ns = pd()->prb_.nodes + off; |
940 | for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, |
941 | (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, |
942 | [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { |
943 | auto c = tr::call_param_t(); |
944 | c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is |
945 | + d3 * ns[3].is) * data_type_size(pd()->prb_.itype); |
946 | c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os |
947 | + d3 * ns[3].os) * data_type_size(pd()->prb_.otype); |
948 | c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss |
949 | + d3 * ns[3].ss; |
950 | (*kernel_)(&c); |
951 | }); |
952 | } |
953 | |
954 | void omp_driver(const char *in, char *out, const float *scale) const { |
955 | in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); |
956 | out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); |
957 | |
958 | DEBUG({ printf("prb : " ); tr::prb_dump(pd()->prb_); }); |
959 | DEBUG({ printf("ker : " ); tr::prb_dump(pd()->ker_desc_.prb); }); |
960 | |
961 | int ndims = pd()->prb_.ndims; |
962 | int ndims_ker = pd()->ker_desc_.prb.ndims; |
963 | assert(ndims - ndims_ker <= ndims_driver_max); |
964 | |
965 | if (ndims - ndims_ker == 0) { |
966 | omp_driver_0d(ndims_ker, in, out, scale); |
967 | } else { |
968 | parallel(0, [&](const int ithr, const int nthr) { |
969 | switch (ndims - ndims_ker) { |
970 | case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break; |
971 | case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break; |
972 | case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break; |
973 | case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break; |
974 | default: assert(!"unimplemented" ); |
975 | } |
976 | }); |
977 | } |
978 | } |
979 | |
980 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
981 | auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM); |
982 | auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); |
983 | |
984 | omp_driver(in, out, pd()->attr()->output_scales_.scales_); |
985 | |
986 | return status::success; |
987 | } |
988 | |
989 | enum { ndims_driver_max = 4 }; |
990 | |
991 | private: |
992 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
993 | tr::kernel_t *kernel_; |
994 | }; |
995 | |
996 | status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, |
997 | engine_t *engine, const primitive_attr_t *attr, |
998 | engine_t *src_engine, const memory_desc_t *src_md, |
999 | engine_t *dst_engine, const memory_desc_t *dst_md) { |
1000 | return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr, |
1001 | src_engine, src_md, dst_engine, dst_md); |
1002 | } |
1003 | |
1004 | } |
1005 | } |
1006 | } |
1007 | |