1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "c_types_map.hpp"
18#include "mkldnn_thread.hpp"
19#include "nstl.hpp"
20#include "utils.hpp"
21
22#include "jit_uni_eltwise.hpp"
23
24#define GET_OFF(field) offsetof(jit_args, field)
25
26namespace mkldnn {
27namespace impl {
28namespace cpu {
29
30using namespace Xbyak;
31
32template <cpu_isa_t isa>
33void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
34 size_t end_idx) {
35 preserved_vecs_count = 0;
36 vecs_to_preserve = (size_t)aux_vecs_count(alg_);
37 start_idx_tail = start_idx;
38
39 // For sse42 mask register has to be Xmm(0)
40 if (isa == sse42 && vecs_to_preserve > 0) {
41 size_t idx = 0;
42 assert(idx < start_idx);
43 preserved_vec_idxs[preserved_vecs_count++] = idx;
44 }
45
46 for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
47 if (preserved_vecs_count >= vecs_to_preserve) break;
48 if (start_idx <= idx && idx < end_idx) continue;
49
50 preserved_vec_idxs[preserved_vecs_count++] = idx;
51 }
52
53 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
54 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
55 preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
56 }
57
58 assert(preserved_vecs_count == vecs_to_preserve);
59
60 if (save_state_) {
61 h->push(p_table);
62
63 if (preserved_vecs_count)
64 h->sub(h->rsp, preserved_vecs_count * vlen);
65
66 for (size_t i = 0; i < preserved_vecs_count; ++i)
67 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
68 Vmm(preserved_vec_idxs[i]));
69
70 load_table_addr();
71 }
72
73 assign_regs();
74}
75
76template <cpu_isa_t isa>
77void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
78{
79 size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
80 if (tail_vecs_to_preserve == 0) return;
81
82 const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
83
84 if (save_state_) {
85 if (idx_off)
86 h->add(h->rsp, idx_off * vlen);
87
88 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
89 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
90 h->ptr[h->rsp + i * vlen]);
91 }
92
93 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
94 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
95
96 if (save_state_) {
97 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
98 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
99 Vmm(preserved_vec_idxs[idx_off + i]));
100
101 if (idx_off)
102 h->sub(h->rsp, idx_off * vlen);
103 }
104
105 assign_regs();
106}
107
108template <cpu_isa_t isa>
109void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
110 if (!save_state_) return;
111
112 for (size_t i = 0; i < preserved_vecs_count; ++i)
113 h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
114 h->ptr[h->rsp + i * vlen]);
115
116 if (preserved_vecs_count)
117 h->add(h->rsp, preserved_vecs_count * vlen);
118
119 h->pop(p_table);
120}
121
122template <cpu_isa_t isa>
123void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
124 vmm_mask = Vmm(preserved_vec_idxs[0]);
125 vmm_aux0 = Vmm(preserved_vec_idxs[0]);
126 vmm_aux1 = Vmm(preserved_vec_idxs[1]);
127 vmm_aux2 = Vmm(preserved_vec_idxs[2]);
128 vmm_aux3 = Vmm(preserved_vec_idxs[3]);
129 vmm_aux4 = Vmm(preserved_vec_idxs[4]);
130}
131
132template <cpu_isa_t isa>
133void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
134 h->uni_vminps(vmm_src, vmm_src, table_val(10));
135 h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
136 h->uni_vmovups(vmm_aux0, vmm_src);
137 //calculate exp(x)
138 // fx = x * log2ef + 0.5
139 h->uni_vmulps(vmm_src, vmm_src, table_val(2));
140 h->uni_vaddps(vmm_src, vmm_src, table_val(1));
141
142 // tmp = floorf(fx)
143 if (isa == avx512_common) {
144 h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
145 h->vcvtdq2ps(vmm_aux1, vmm_aux1);
146
147 h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us);
148 h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
149
150 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
151 } else {
152 h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
153 }
154
155 //keep fx for further computations
156 h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
157
158 //x = x - fx * ln2
159 h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
160
161 // compute 2^n
162 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
163 h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
164 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
165
166 // y = p5
167 h->uni_vmovups(vmm_src, table_val(9));
168 // y = y * x + p4
169 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
170 // y = y * x + p3
171 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
172 // y = y * x + p2
173 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
174 // y = y * x + p1
175 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
176 // y = y * x + p0
177 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q)
178 // y = y * 2^n
179 h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
180}
181
182template <cpu_isa_t isa>
183void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
184{
185 const int alpha_off = 0, zero_off = 1;
186
187 h->uni_vmovups(vmm_aux1, vmm_src);
188 if (isa == sse42) {
189 h->movups(vmm_mask, vmm_src);
190 h->mulps(vmm_src, table_val(alpha_off));
191 h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
192 h->blendvps(vmm_src, vmm_aux1);
193 } else if (isa == avx2) {
194 h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
195 h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
196 h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
197 } else if (isa == avx512_common) {
198 h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
199 h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
200 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
201 }
202}
203
204template <cpu_isa_t isa>
205void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
206 const Vmm &vmm_src) {
207 const int zero_off = 1;
208 h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
209}
210
211template <cpu_isa_t isa>
212void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
213 const int alpha_off = 23, zero_off = 24;
214
215 // compute exponent
216 h->uni_vmovups(vmm_aux2, vmm_src);
217 exp_compute_vector(vmm_src);
218
219 // alpha * (exp(x) - 1)
220 h->uni_vsubps(vmm_src, vmm_src, table_val(0));
221 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
222
223 // combine with mask
224 if (isa == sse42) {
225 h->pxor(vmm_mask, vmm_mask);
226 h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os);
227 h->blendvps(vmm_src, vmm_aux2);
228 } else if (isa == avx2) {
229 h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
230 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
231 } else if (isa == avx512_common) {
232 h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
233 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
234 }
235}
236
237template <cpu_isa_t isa>
238void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
239{
240 // # comes from Taylor expansion error bound
241 // > linear_sat_point = single(sqrt(3) * 1b-12);
242 // # comes from the exp formula cancellation
243 // > exp_bound_point = (single(log(3)/2));
244 // # comes from rounding accuracy in float
245 // > one_sat_point = round(atanh(1 - 1b-25), single, RU);
246 // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
247 // [linear_sat_point, exp_bound_point], relative, floating);
248 // > err_bound = D(sup(supnorm(P, tanh(x),
249 // [linear_sat_point, exp_bound_point], relative, theta)));
250 // 0x1.fffd6f00b9539p-25
251 // > P;
252 // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
253 // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
254 // + x^0x1p1 * 0x1.09fa1p-6))))
255
256 // register mapping
257 // vmm_src contains input
258 // vmm_aux0 contains mask of currently valid results.
259 // 1 is need computation, 0 is already computed
260 // vmm_aux1 contains current output
261 // vmm_aux2, vmm_aux3 contains auxiliary values
262 // vmm_aux4 contains the original sign of inputs
263
264 Label end_tanh_label;
265
266 auto test_exit =[&](Xbyak::Address threshold){
267 // is not necessary for >AVX, but should not matter on perf
268 h->uni_vmovups(vmm_aux0, vmm_src);
269 if (isa == avx512_common){
270 h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
271 h->kortestw(k_mask, k_mask);
272 } else {
273 h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
274 h->uni_vtestps(vmm_aux0, vmm_aux0);
275 }
276 h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
277 };
278
279 auto blend_results=[&](Vmm vmm_partial_res){
280 if (isa == avx512_common)
281 h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
282 else
283 h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
284 };
285
286 // because tanh(x) = -tanh(-x), we extract sign to make x postive
287 // and reapply sign at the end
288 // mov is not necessary for >AVX, but should not matter for performance
289 h->uni_vmovups(vmm_aux4, vmm_src);
290 h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
291 h->uni_vandps(vmm_src, vmm_src, table_val(17));
292
293 // if x < linear_sat_point for all inputs, we just return the input
294 h->uni_vmovups(vmm_aux1, vmm_src);
295 test_exit(table_val(13));
296
297 // if one of the mask is one, we have to compute an better approx
298 h->uni_vmovups(vmm_aux2, vmm_src);
299 h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
300 h->uni_vmovups(vmm_aux3, table_val(22));
301 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
302 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
303 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
304 h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
305 h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
306
307 // we blend only the result that need update
308 blend_results(vmm_aux3);
309
310 // if x < exp_bound_point, we go to return point
311 test_exit(table_val(14));
312
313 // if not we use a better approx 1 - 2 / (1 + exp(2x))
314 // compute 2x
315 h->uni_vmovups(vmm_aux3, vmm_src);
316 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
317
318 // Compute exp(2x)
319 // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
320 // vmm_src is not more read afterwards, so we do not have to save it
321 auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
322 h->sub(h->rsp, stack_size);
323 h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
324 h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
325 h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
326 if (isa == avx512_common)
327 h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
328
329 exp_compute_vector(vmm_aux3);
330
331 h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
332 h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
333 h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
334 if (isa == avx512_common)
335 h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
336 h->add(h->rsp, stack_size);
337
338 // 1 + exp(2x)
339 h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
340
341 // 1 - 2 / (1 + exp(2x))
342 h->uni_vmovups(vmm_aux2, table_val(16));
343 h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
344 h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
345
346 // we blend only the result that need update
347 blend_results(vmm_aux2);
348
349 // finally, we saturate to 1 if needed
350 // TODO: maybe move that up if most inputs saturate in practice
351 if (isa == avx512_common)
352 h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
353 else {
354 h->uni_vmovups(vmm_aux0, vmm_src);
355 h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
356 }
357 h->uni_vmovups(vmm_aux2, table_val(0));
358 blend_results(vmm_aux2);
359
360 h->L(end_tanh_label);
361 {
362 // we apply the sign of x to the result and we are done
363 h->uni_vmovups(vmm_src, vmm_aux1);
364 h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
365 }
366}
367
368template <cpu_isa_t isa>
369void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
370 const Vmm &vmm_src) {
371 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
372}
373
374template <cpu_isa_t isa>
375void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
376 // compute abs(x) = _mm_and_ps(x, 01111..111));
377 h->uni_vandps(vmm_src, vmm_src, table_val(0));
378}
379
380template <cpu_isa_t isa>
381void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
382{
383 if (isa == avx512_common) {
384 h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
385 h->uni_vsqrtps(vmm_aux1, vmm_src);
386 h->uni_vmovups(vmm_src, table_val(0));
387 h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
388 } else {
389 h->uni_vmovups(vmm_mask, vmm_src);
390 h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
391 h->uni_vsqrtps(vmm_aux1, vmm_src);
392 h->uni_vmovups(vmm_src, table_val(0));
393 h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
394 }
395}
396
397template <cpu_isa_t isa>
398void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
399 const Vmm &vmm_src) {
400 // compute x = alpha * x + beta;
401 h->uni_vmovups(vmm_aux0, table_val(0));
402 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
403}
404
405template <cpu_isa_t isa>
406void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
407 const Vmm &vmm_src) {
408 // compute bounded relu */
409 h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
410 h->uni_vminps(vmm_src, vmm_src, table_val(0));
411}
412
413template <cpu_isa_t isa>
414void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
415 const Vmm &vmm_src) {
416 // duplicate src
417 h->uni_vmovups(vmm_aux2, vmm_src);
418
419 h->uni_vminps(vmm_src, vmm_src, table_val(24));
420 h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
421 h->uni_vmovups(vmm_aux1, vmm_src);
422 // calculate exp(x)
423 // fx = x * log2ef + 0.5
424 h->uni_vmulps(vmm_src, vmm_src, table_val(2));
425 h->uni_vaddps(vmm_src, vmm_src, table_val(1));
426
427 // tmp = floorf(fx)
428 if (isa == avx512_common) {
429 h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
430 h->vcvtdq2ps(vmm_aux0, vmm_aux0);
431
432 h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us);
433 h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
434
435 h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
436 } else {
437 h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
438 }
439
440 // keep fx for further computations
441 h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
442 // calculation fx * ln2
443 h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
444 // x = x - fx * ln2
445 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
446 // y = p5
447 h->uni_vmovups(vmm_aux3, table_val(22));
448 // y = y * x + p4
449 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
450 // y = y * x + p3
451 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
452 // y = y * x + p2
453 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
454 // y = y * x + p1
455 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
456 // y = y * x + p0
457 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
458
459 // compute 2^(-n)
460 if (isa == avx512_common) {
461 h->vmulps(vmm_aux1, vmm_src, table_val(23));
462 h->vcvtps2dq(vmm_aux1, vmm_aux1);
463 } else {
464 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
465 h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
466 }
467
468 h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
469 h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
470 // calculate ln(1 + y)
471 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
472 // x = y; y is free; keep x for further computations
473 h->uni_vmovups(vmm_src, vmm_aux3);
474 // frexp()
475 h->uni_vpsrld(vmm_src, vmm_src, 23);
476 h->uni_vcvtdq2ps(vmm_src, vmm_src);
477 // got n. where n is x = 2^n * y. y = 0.5 .. 1
478 h->uni_vsubps(vmm_src, vmm_src, table_val(5));
479
480 h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
481 // got y. (mantisa) 0.5 < y < 1
482 h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
483 // y = y - 1
484 h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
485 // y = p8
486 h->uni_vmovups(vmm_aux1, table_val(16));
487 // y = y * x + p7
488 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
489 // y = y * x + p6
490 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
491 // y = y * x + p5
492 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
493 // y = y * x + p4
494 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
495 // y = y * x + p3
496 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
497 // y = y * x + p2
498 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
499 // y = y * x + p1
500 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
501 // y = y * x + p0 ; p0 = 0
502 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
503 //calculate ln(2) * n
504 h->uni_vmulps(vmm_src, vmm_src, table_val(3));
505 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
506 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
507
508 // get vmm_mask = src > max logf
509 h->uni_vmovups(vmm_mask, vmm_aux2);
510 if (isa == avx512_common) {
511 // y = (x < max log f) ? soft_relu(x) : x
512 h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
513 h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
514 } else {
515 // y = (x < max log f) ? soft_relu(x) : x
516 h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
517 h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
518 }
519
520 h->uni_vmovups(vmm_src, vmm_aux1);
521}
522
523template <cpu_isa_t isa>
524void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
525 const Vmm &vmm_src) {
526 // we store the original sign and make x negative
527 // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
528 // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
529 h->uni_vmovups(vmm_aux2, vmm_src);
530 h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
531 h->uni_vorps(vmm_src, vmm_src, table_val(12));
532
533 exp_compute_vector(vmm_src);
534 // dup exp(x)
535 h->uni_vmovups(vmm_aux1, vmm_src);
536 // (exp(x) + 1)
537 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
538 // y = exp(x) / (exp(x) + 1)
539 h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
540
541 // Now we have to apply the "symmetry" based on original sign
542 h->uni_vmovups(vmm_aux3, table_val(0));
543 h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
544 if (isa == avx512_common) {
545 h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
546 h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
547 } else {
548 h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
549 h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
550 }
551 h->uni_vmovups(vmm_src, vmm_aux3);
552}
553
554template <cpu_isa_t isa>
555void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
556 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
557 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
558}
559
560template <cpu_isa_t isa>
561void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
562 const unsigned int cvals[] = {
563 0x3f800000, // [0] 1.0f
564 0x3f000000, // [1] 0.5f
565 0x3fb8aa3b, // [2] log2ef = 1.44269502f
566 0x3f317218, // [3] ln2f = 0.69314718f
567 0x0000007f, // [4] 0x7f
568 // exp(x) polynom
569 0x3f800001, // [5] p0 = 1.0000001f
570 0x3efffe85, // [6] p2 = 0.4999887f
571 0x3e2aaa3e, // [7] p3 = 0.16666505f
572 0x3d2bb1b1, // [8] p4 = 0.041917507f
573 0x3c091ec1, // [9] p5 = 0.008369149f
574 0x42b0c0a5, //[10] max logf = 88.3762589f
575 0xc1766666, //[11] min logf = -14.5f
576 // tanh(x) constants,
577 0x80000000, //[12] mask to extract sign
578 0x39ddb3d7, //[13] arg below which tanh(x) = x
579 0x3f0c9f54, //[14] arg below which pol approx is valid
580 0x41102cb4, //[15] arg after which tanh(x) = 1
581 0xc0000000, //[16] -2.0f
582 0x7fffffff, //[17] mask to make positive
583 // tanh pol approx
584 0x3f7fffff, //[18] p0
585 0xbeaaa9cf, //[19] p1
586 0x3e085f1f, //[20] p2
587 0xbd572bda, //[21] p3
588 0x3c84fd08, //[22] p4
589 };
590
591 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
592 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
593 }
594
595 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
596 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
597}
598
599template <cpu_isa_t isa>
600void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
601 const unsigned int cvals[] = {
602 0x3f800000, // [0] 1.0f
603 0x3f000000, // [1] 0.5f
604 0x3fb8aa3b, // [2] log2ef = 1.44269502f
605 0x3f317218, // [3] ln2f = 0.69314718f
606 0x0000007f, // [4] 0x7f
607 0x42fc0000, // [5] 126
608 0x807fffff, // [6] and with (to get 0.5 * mantissa)
609 0x3f000000, // [7] or with (to get 0.5 * mantissa)
610 // ln(1 + x) polynomial
611 0xb2b4637d, // [8] p0 = 0.0000000244f
612 0x3f7fff8e, // [9] p1 = 0.9999976971f
613 0xbf001759, //[10] p2 = -0.5002478215f
614 0x3ea70608, //[11] p3 = 0.3272714505f
615 0xbea3d7bf, //[12] p4 = -0.3153830071f
616 0xbe361d04, //[13] p5 = -0.1701777461f
617 0xbfa8f1e6, //[14] p6 = -1.3254635147f
618 0xbfe1e812, //[15] p7 = -1.7971917960f
619 0xbfc4d30e, //[16] p8 = -1.5652673123f
620 // exp(x) polynomial
621 0x3f800001, //[17] p0 = 1.0000001f
622 0x3f800000, //[18] p1 = 1.0f
623 0x3efffe85, //[19] p2 = 0.4999887f
624 0x3e2aaa3e, //[20] p3 = 0.16666505f
625 0x3d2bb1b1, //[21] p4 = 0.041917507f
626 0x3c091ec1, //[22] p5 = 0.008369149f
627 0xbf800000, //[23] is required for sign changing
628 0x42b0c0a5, //[24] max logf = 88.3762589f
629 0xc1766666 //[25] min logf = -14.5f
630 };
631
632 for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
633 for (size_t d = 0; d < vlen / sizeof(float); ++d) {
634 h->dd(cvals[i]);
635 }
636 }
637}
638
639template <cpu_isa_t isa>
640void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
641 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
642}
643
644template <cpu_isa_t isa>
645void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
646 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
647}
648
649template <cpu_isa_t isa>
650void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
651 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
652 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
653}
654
655template <cpu_isa_t isa>
656void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
657 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
658 for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
659}
660
661template <cpu_isa_t isa>
662int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
663 switch (alg_) {
664 case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
665 case alg_kind::eltwise_elu: return 4;
666 case alg_kind::eltwise_tanh: return 5;
667 case alg_kind::eltwise_square: return 0;
668 case alg_kind::eltwise_abs: return 0;
669 case alg_kind::eltwise_sqrt: return 2;
670 case alg_kind::eltwise_linear: return 1;
671 case alg_kind::eltwise_bounded_relu: return 0;
672 case alg_kind::eltwise_soft_relu: return 4;
673 case alg_kind::eltwise_logistic: return 4;
674 default: assert(!"unsupported eltwise algorithm");
675 }
676
677 return 0;
678}
679
680template <cpu_isa_t isa>
681void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
682 size_t end_idx) {
683 using namespace alg_kind;
684 for (size_t idx = start_idx; idx < end_idx; idx++) {
685 switch (alg_) {
686 case eltwise_relu:
687 if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
688 else relu_compute_vector(Vmm(idx));
689 break;
690 case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
691 case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
692 case eltwise_square: square_compute_vector(Vmm(idx)); break;
693 case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
694 case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
695 case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
696 case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
697 case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
698 case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
699 default: assert(!"unsupported eltwise algorithm");
700 }
701 }
702}
703
704template <cpu_isa_t isa>
705void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
706 size_t end_idx) {
707 assert(start_idx < end_idx && end_idx <= vecs_count);
708
709 injector_preamble(start_idx, end_idx);
710 compute_body(start_idx_tail, end_idx);
711 injector_preamble_tail(start_idx);
712 compute_body(start_idx, start_idx_tail);
713 injector_postamble();
714}
715
716template <cpu_isa_t isa>
717void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
718 using namespace alg_kind;
719
720 h->align(64);
721 h->L(l_table);
722
723 if (gen_table) {
724 switch (alg_) {
725 case eltwise_relu: relu_prepare_table(); break;
726 case eltwise_elu:
727 case eltwise_tanh:
728 case eltwise_logistic:
729 elu_prepare_table(); break;
730 case eltwise_soft_relu: soft_relu_prepare_table(); break;
731 case eltwise_abs: abs_prepare_table(); break;
732 case eltwise_sqrt: sqrt_prepare_table(); break;
733 case eltwise_linear: linear_prepare_table(); break;
734 case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
735 case eltwise_square: break;
736 default: assert(!"unsupported eltwise algorithm");
737 }
738 }
739}
740
741template struct jit_uni_eltwise_injector_f32<avx512_common>;
742template struct jit_uni_eltwise_injector_f32<avx2>;
743template struct jit_uni_eltwise_injector_f32<sse42>;
744
745
746struct jit_args {
747 const float *from;
748 const float *for_comparison;
749 const float *to;
750 size_t work_amount;
751};
752
753struct jit_uni_eltwise_kernel_f32 : public c_compatible {
754 const eltwise_desc_t &desc_;
755
756 void (*ker_)(const jit_args *);
757 void operator()(const jit_args *args) { assert(ker_); ker_(args); }
758
759 jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
760 : desc_(desc), ker_(nullptr) {}
761 virtual ~jit_uni_eltwise_kernel_f32() {}
762
763protected:
764 bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
765};
766
767/* jit kernels */
768namespace {
769
770template <cpu_isa_t isa>
771struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
772 public jit_generator
773{
774 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
775
776 void compute_step(bool vectorize, const int uf, const int shift) {
777 for (int i = 0; i < uf; i++) {
778 if (vectorize) {
779 uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]);
780 if (is_bwd())
781 uni_vmovups(Vmm(uf + i + 1),
782 ptr[reg_for_comparison + i * shift]);
783 } else {
784 movss(Xmm(i + 1), ptr[reg_from + i * shift]);
785 if (is_bwd())
786 movss(Xmm(uf + i + 1),
787 ptr[reg_for_comparison + i * shift]);
788 }
789 }
790
791 if (isa == sse42) {
792 for (int i = 0; i < uf; i++) {
793 movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
794 mulps(Vmm(2 * uf + i + 1), vmm_ns);
795
796 Vmm mask = Vmm(0);
797 if (is_bwd()) {
798 movups(mask, Vmm(uf + i + 1));
799 cmpps(mask, vmm_zero, _cmp_nle_us);
800 } else {
801 movups(mask, Vmm(i + 1));
802 cmpps(mask, vmm_zero, _cmp_nle_us);
803 }
804 blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
805 }
806 } else {
807 for (int i = 0; i < uf; i++) {
808 vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
809 if (isa == avx2) {
810 if (is_bwd())
811 vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
812 else
813 vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
814
815 vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
816 Vmm(i + 1), vmm_mask);
817
818 } else {
819 if (is_bwd())
820 vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
821 else
822 vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
823 vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
824 Vmm(i + 1));
825 }
826 }
827 }
828
829 for (int i = 0; i < uf; i++) {
830 if (vectorize) {
831 uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
832 } else {
833 movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
834 }
835 }
836 }
837
838 jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
839 : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
840 assert(desc.alg_kind == alg_kind::eltwise_relu);
841 assert(isa == sse42 || isa == avx2 || isa == avx512_common);
842
843 Reg64 param = abi_param1;
844
845 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
846 const int loop_dec[] = {simd_w, 1};
847 const int uf[] = {1, 1};
848 const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)};
849 const bool loop_vectorize[] = {true, false};
850
851 this->preamble();
852
853 mov(reg_from, ptr[param + GET_OFF(from)]);
854 if (is_bwd())
855 mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
856 mov(reg_to, ptr[param + GET_OFF(to)]);
857 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
858
859 mov(imm_addr64, float2int(desc.alpha));
860 movq(xmm_ns, imm_addr64);
861 uni_vbroadcastss(vmm_ns, xmm_ns);
862
863 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
864
865 Label loop_label[3];
866
867 for (int id = 0; id < 2; id++) {
868 L(loop_label[id]);
869 cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
870 jle(loop_label[id + 1], T_NEAR);
871
872 compute_step(loop_vectorize[id], uf[id], shift[id]);
873
874 add(reg_from, uf[id] * shift[id]);
875 add(reg_to, uf[id] * shift[id]);
876 if (is_bwd())
877 add(reg_for_comparison, uf[id] * shift[id]);
878
879 sub(reg_work_amount, uf[id] * loop_dec[id]);
880 jmp(loop_label[id]);
881 }
882
883 L(loop_label[2]);
884 this->postamble();
885
886 ker_ = (decltype(ker_))this->getCode();
887 }
888
889private:
890 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
891 isa == avx2, Ymm, Zmm>::type;
892
893 Reg64 reg_from = rax;
894 Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
895 Reg64 reg_to = r8;
896 Reg64 reg_work_amount = rsi;
897 Reg64 imm_addr64 = rbx;
898
899 Xmm xmm_ns = Xmm(14);
900
901 Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
902 Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
903
904 Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
905 Opmask k_mask = Opmask(1);
906};
907
908template <cpu_isa_t isa>
909struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
910 public jit_generator {
911 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
912
913 jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
914 : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
915
916 eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
917 desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
918
919 using namespace alg_kind;
920
921 assert(is_bwd() == false);
922 assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
923 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
924 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
925
926 preamble();
927
928 Reg64 param = abi_param1;
929 mov(reg_from, ptr[param + GET_OFF(from)]);
930 mov(reg_to, ptr[param + GET_OFF(to)]);
931 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
932 eltwise_injector_->load_table_addr();
933
934 Label reminder_loop_start, reminder_loop_end;
935 Label vectorized_loop_start, vectorized_loop_end;
936
937 cmp(reg_work_amount, simd_w);
938 jl(reminder_loop_start, T_NEAR);
939
940 L(vectorized_loop_start);
941
942 uni_vmovups(vmm_src, ptr[reg_from]);
943 eltwise_injector_->compute_vector(vmm_src.getIdx());
944 uni_vmovups(ptr[reg_to], vmm_src);
945
946 add(reg_from, vlen);
947 add(reg_to, vlen);
948
949 sub(reg_work_amount, simd_w);
950 cmp(reg_work_amount, simd_w);
951 jge(vectorized_loop_start, T_NEAR);
952
953 L(vectorized_loop_end);
954
955 L(reminder_loop_start);
956
957 cmp(reg_work_amount, 0);
958 jle(reminder_loop_end, T_NEAR);
959
960 movss(xmm_src, ptr[reg_from]);
961 eltwise_injector_->compute_vector(xmm_src.getIdx());
962 movss(ptr[reg_to], xmm_src);
963
964 add(reg_from, sizeof(float));
965 add(reg_to, sizeof(float));
966
967 dec(reg_work_amount);
968 jmp(reminder_loop_start, T_NEAR);
969
970 L(reminder_loop_end);
971
972 postamble();
973
974 eltwise_injector_->prepare_table();
975
976 ker_ = (decltype(ker_))this->getCode();
977 }
978
979 ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; }
980
981private:
982 using Vmm = typename utils::conditional3<isa == sse42, Xmm,
983 isa == avx2, Ymm, Zmm>::type;
984
985 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
986 const int vlen = cpu_isa_traits<isa>::vlen;
987
988 Reg64 reg_from = rax;
989 Reg64 reg_to = r8;
990 Reg64 reg_work_amount = rsi;
991 Reg64 imm_addr64 = rbx;
992
993 Xmm xmm_src = Xmm(1);
994 Vmm vmm_src = Vmm(1);
995
996 jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
997};
998
999} /* namespace */
1000
1001template <cpu_isa_t isa>
1002status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
1003 using namespace alg_kind;
1004
1005 bool ok = true
1006 && mayiuse(isa)
1007 && is_fwd()
1008 && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
1009 && !has_zero_dim_memory()
1010 && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
1011 eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
1012 eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
1013 eltwise_logistic)
1014 && memory_desc_wrapper(src_md()).is_dense(true)
1015 && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false),
1016 math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
1017 && attr()->has_default_values();
1018
1019 return ok ? status::success : status::unimplemented;
1020}
1021
1022template <cpu_isa_t isa>
1023jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *apd)
1024 : cpu_primitive_t(apd), kernel_(nullptr) {
1025 const auto &desc = *pd()->desc();
1026 switch (desc.alg_kind) {
1027 case alg_kind::eltwise_relu:
1028 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1029 default:
1030 kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
1031 }
1032}
1033
1034template <cpu_isa_t isa>
1035jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
1036{ delete kernel_; }
1037
1038template <cpu_isa_t isa>
1039void jit_uni_eltwise_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const {
1040 auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
1041 auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
1042
1043 const memory_desc_wrapper data_d(pd()->src_md());
1044
1045 const size_t nelems = data_d.nelems(true);
1046
1047 src += data_d.offset0();
1048 dst += data_d.offset0();
1049
1050 parallel(0, [&](const int ithr, const int nthr) {
1051 size_t start{0}, end{0};
1052
1053 const int cache_line = 16;
1054
1055 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1056 start = nstl::min(nelems, start * cache_line);
1057 end = nstl::min(nelems, end * cache_line);
1058
1059 auto arg = jit_args();
1060 arg.from = &src[start];
1061 arg.for_comparison = &src[start];
1062 arg.to = &dst[start];
1063 arg.work_amount = end - start;
1064 if (arg.work_amount)
1065 (*kernel_)(&arg);
1066 });
1067}
1068
1069template <cpu_isa_t isa>
1070status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
1071 bool ok = true
1072 && !is_fwd()
1073 && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
1074 && src_md()->data_type == data_type::f32
1075 && !has_zero_dim_memory()
1076 && mayiuse(isa)
1077 && memory_desc_wrapper(src_md()).is_dense()
1078 && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md())
1079 && attr()->has_default_values();
1080
1081 return ok ? status::success : status::unimplemented;
1082}
1083
1084template <cpu_isa_t isa>
1085jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *apd)
1086 : cpu_primitive_t(apd), kernel_(nullptr) {
1087 const auto &desc = *pd()->desc();
1088 switch (desc.alg_kind) {
1089 case alg_kind::eltwise_relu:
1090 kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
1091 default: assert(!"unknown eltwise alg_kind");
1092 }
1093}
1094
1095template <cpu_isa_t isa>
1096jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
1097{ delete kernel_; }
1098
1099template <cpu_isa_t isa>
1100void jit_uni_eltwise_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const {
1101 auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
1102 auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
1103 auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
1104
1105 const memory_desc_wrapper data_d(pd()->src_md());
1106 const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
1107
1108 const size_t nelems = data_d.nelems();
1109
1110 src += data_d.offset0();
1111 diff_dst += diff_data_d.offset0();
1112 diff_src += diff_data_d.offset0();
1113
1114 parallel(0, [&](const int ithr, const int nthr) {
1115 size_t start{0}, end{0};
1116
1117 const int cache_line = 16;
1118
1119 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
1120 start = nstl::min(nelems, start * cache_line);
1121 end = nstl::min(nelems, end * cache_line);
1122
1123 auto arg = jit_args();
1124 arg.from = &diff_dst[start];
1125 arg.to = &diff_src[start];
1126 arg.for_comparison = &src[start];
1127 arg.work_amount = end - start;
1128 if (arg.work_amount)
1129 (*kernel_)(&arg);
1130 });
1131}
1132
1133template struct jit_uni_eltwise_fwd_t<sse42>;
1134template struct jit_uni_eltwise_bwd_t<sse42>;
1135template struct jit_uni_eltwise_fwd_t<avx2>;
1136template struct jit_uni_eltwise_bwd_t<avx2>;
1137template struct jit_uni_eltwise_fwd_t<avx512_common>;
1138template struct jit_uni_eltwise_bwd_t<avx512_common>;
1139
1140}
1141}
1142}
1143