1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3* Copyright 2018 YANDEX LLC
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include "c_types_map.hpp"
19#include "nstl.hpp"
20#include "utils.hpp"
21#include "cpu_pooling_pd.hpp"
22
23#include "jit_uni_pool_kernel_f32.hpp"
24
25namespace mkldnn {
26namespace impl {
27namespace cpu {
28
29using namespace Xbyak;
30using namespace alg_kind;
31
32#define GET_OFF(field) offsetof(jit_pool_call_s, field)
33
34template <cpu_isa_t isa>
35status_t jit_uni_pool_kernel_f32<isa>::init_conf(jit_pool_conf_t &jpp,
36 const pooling_pd_t *ppd) {
37 const auto &pd = *ppd->desc();
38 const memory_desc_wrapper src_d(
39 ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md());
40 const memory_desc_wrapper dst_d(
41 ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md());
42
43 bool args_ok = true
44 && mayiuse(isa)
45 && utils::one_of(pd.alg_kind, pooling_max,
46 pooling_avg_include_padding,
47 pooling_avg_exclude_padding);
48 if (!args_ok) return status::unimplemented;
49
50 const int simd_w = isa == avx512_common ? 16 : 8;
51 const int ndims = src_d.ndims();
52
53 jpp.ndims = ndims;
54 jpp.mb = src_d.dims()[0];
55
56 jpp.c = utils::rnd_up(src_d.dims()[1], simd_w);
57 if (jpp.c > src_d.padded_dims()[1])
58 return status::unimplemented;
59
60 jpp.id = (ndims == 5) ? src_d.dims()[2] : 1;
61 jpp.ih = src_d.dims()[ndims-2];
62 jpp.iw = src_d.dims()[ndims-1];
63 jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
64 jpp.oh = dst_d.dims()[ndims-2];
65 jpp.ow = dst_d.dims()[ndims-1];
66
67 jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1;
68 jpp.stride_h = pd.strides[ndims-4];
69 jpp.stride_w = pd.strides[ndims-3];
70 jpp.kd = (ndims == 5) ? pd.kernel[0] : 1;
71 jpp.kh = pd.kernel[ndims-4];
72 jpp.kw = pd.kernel[ndims-3];
73
74 jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0;
75 jpp.t_pad = pd.padding[0][ndims-4];
76 jpp.l_pad = pd.padding[0][ndims-3];
77
78 jpp.alg = pd.alg_kind;
79
80 jpp.is_training = pd.prop_kind == prop_kind::forward_training;
81 jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
82 jpp.ind_dt = ppd->workspace_md()
83 ? ppd->workspace_md()->data_type : data_type::undef;
84
85 jpp.simple_alg = jpp.is_training
86 || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d);
87
88 jpp.c_block = simd_w;
89
90 jpp.nb_c = jpp.c / jpp.c_block;
91 if (jpp.alg == pooling_max) {
92 jpp.ur_w = isa == avx512_common ? 16 : 4;
93 if (jpp.is_training)
94 jpp.ur_w = isa == avx512_common ? 9 : 3;
95 else if (jpp.is_backward)
96 jpp.ur_w = isa == avx512_common ? 6 : 3;
97 } else {
98 if (jpp.is_backward)
99 jpp.ur_w = isa == avx512_common ? 12 : 6;
100 else
101 jpp.ur_w = isa == avx512_common ? 24 : 12;
102 }
103 if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow;
104 if (jpp.l_pad > jpp.ur_w) return status::unimplemented;
105
106 jpp.ur_w_tail = jpp.ow % jpp.ur_w;
107
108 return status::success;
109}
110
111template <cpu_isa_t isa>
112inline void jit_uni_pool_kernel_f32<isa>::maybe_recalculate_divisor(int jj,
113 int ur_w, int pad_l, int pad_r) {
114 if (jpp.alg == pooling_avg_exclude_padding) {
115 int kw = jpp.kw;
116 int stride_w = jpp.stride_w;
117
118 int non_zero_kw = kw;
119 non_zero_kw -= nstl::max(0, pad_l - jj*stride_w);
120 non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w);
121
122 if (non_zero_kw != prev_kw) {
123 mov(tmp_gpr, float2int((float)non_zero_kw));
124 movq(xmm_tmp, tmp_gpr);
125 uni_vbroadcastss(vmm_tmp, xmm_tmp);
126 uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h);
127 prev_kw = non_zero_kw;
128 }
129 }
130}
131
132template <cpu_isa_t isa>
133inline void jit_uni_pool_kernel_f32<isa>::avg_step(int ur_w, int pad_l,
134 int pad_r) {
135
136 int iw = jpp.iw;
137 int kw = jpp.kw;
138 int stride_w = jpp.stride_w;
139 int c_block = jpp.c_block;
140 Label kd_label, kh_label;
141
142 for (int jj = 0; jj < ur_w; jj++) {
143 if (jpp.is_backward) {
144 uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]);
145 maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r);
146 uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
147 } else {
148 uni_vpxor(vreg(jj), vreg(jj), vreg(jj));
149 }
150 }
151
152 if (jpp.simple_alg && jpp.ndims == 5) {
153 push(reg_input);
154 push(reg_output);
155 mov(aux_reg_input_d, reg_input);
156 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
157 L(kd_label);
158 mov(aux_reg_input, aux_reg_input_d);
159 } else {
160 mov(aux_reg_input, reg_input);
161 }
162
163 xor_(kj, kj);
164 L(kh_label);
165 {
166 for (int ki = 0; ki < kw; ki++) {
167 int jj_start = nstl::max(0, pad_l - ki);
168 int jj_end = ur_w
169 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
170 for (int jj = jj_start; jj < jj_end; jj++) {
171 int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
172 if (aux_input_offset > iw * c_block)
173 continue;
174 int input_offset = sizeof(float)*aux_input_offset;
175 if (jpp.is_backward) {
176 uni_vmovups(vreg(ur_w+jj),
177 ptr[aux_reg_input + input_offset]);
178 uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj));
179 uni_vmovups(vmmword[aux_reg_input + input_offset],
180 vreg(ur_w+jj));
181 } else {
182 uni_vaddps(vreg(jj), vreg(jj),
183 ptr[aux_reg_input + input_offset]);
184 }
185 }
186 }
187 add(aux_reg_input, sizeof(float) * iw * c_block);
188 inc(kj);
189 cmp(kj, reg_kh);
190 jl(kh_label, T_NEAR);
191 }
192
193 if (jpp.simple_alg && jpp.ndims == 5)
194 {
195 add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
196 dec(ki);
197 cmp(ki, 0);
198 jg(kd_label, T_NEAR);
199 pop(reg_output);
200 pop(reg_input);
201 }
202
203 if (!jpp.is_backward) {
204 for (int jj = 0; jj < ur_w; jj++) {
205 maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r);
206 uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
207 uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block],
208 vreg(jj));
209 }
210 }
211}
212
213template <cpu_isa_t isa>
214inline void jit_uni_pool_kernel_f32<isa>::max_step_fwd(int ur_w, int pad_l,
215 int pad_r) {
216 int iw = jpp.iw;
217 int kw = jpp.kw;
218 int stride_w = jpp.stride_w;
219 int c_block = jpp.c_block;
220 Label kd_label, kh_label;
221
222 mov(tmp_gpr, float2int(nstl::numeric_limits<float>::lowest()));
223 movq(xmm_tmp, tmp_gpr);
224 uni_vbroadcastss(vmm_tmp, xmm_tmp);
225
226 for (int jj = 0; jj < ur_w; jj++) {
227 uni_vmovups(vreg(jj), vmm_tmp);
228 if (jpp.is_training)
229 uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj));
230 }
231 if (jpp.is_training)
232 {
233 movq(xmm_tmp, reg_k_shift);
234 uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
235 }
236
237 if (jpp.ndims == 5) {
238 push(reg_input);
239 push(reg_output);
240 mov(aux_reg_input_d, reg_input);
241 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
242 L(kd_label);
243 mov(aux_reg_input, aux_reg_input_d);
244 } else {
245 mov(aux_reg_input, reg_input);
246 }
247 xor_(kj, kj);
248 L(kh_label);
249 {
250 for (int ki = 0; ki < kw; ki++) {
251 int jj_start = nstl::max(0, pad_l - ki);
252 int jj_end = ur_w
253 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
254 for (int jj = jj_start; jj < jj_end; jj++) {
255 int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
256 if (aux_input_offset > iw * c_block)
257 continue;
258 int input_offset = sizeof(float)*aux_input_offset;
259 uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]);
260 if (isa == sse42) {
261 movups(vmm_mask, vreg(jj));
262 cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os);
263 blendvps(vreg(jj), vreg(ur_w+jj));
264 if (jpp.is_training)
265 blendvps(vreg(2*ur_w+jj), vmm_k_offset);
266 } else if (isa == avx) {
267 vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj),
268 _cmp_lt_os);
269 vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj),
270 vreg(3*ur_w+jj));
271 if (jpp.is_training)
272 vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj),
273 vmm_k_offset, vreg(3*ur_w+jj));
274 } else {
275 vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os);
276 vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj));
277 if (jpp.is_training)
278 vblendmps(vreg(2*ur_w+jj) | k_store_mask,
279 vreg(2*ur_w+jj), vmm_k_offset);
280 }
281 }
282 if (jpp.is_training) {
283 if (isa == avx && !mayiuse(avx2)) {
284 avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
285 } else {
286 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
287 }
288 }
289 }
290 add(aux_reg_input, sizeof(float) * iw * c_block);
291 inc(kj);
292 cmp(kj, reg_kh);
293 jl(kh_label, T_NEAR);
294 }
295
296 if (jpp.ndims == 5)
297 {
298 add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
299 if (jpp.is_training) {
300 mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]);
301 movq(xmm_tmp, tmp_gpr);
302 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
303 if (isa == avx && !mayiuse(avx2)) {
304 Xmm t(vmm_mask.getIdx());
305 avx_vpadd1(vmm_k_offset, xmm_tmp, t);
306 } else {
307 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
308 }
309 }
310
311 dec(ki);
312 cmp(ki, 0);
313 jg(kd_label, T_NEAR);
314 pop(reg_output);
315 pop(reg_input);
316 }
317
318 for (int jj = 0; jj < ur_w; jj++) {
319 uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj));
320 if (jpp.is_training) {
321 const size_t step_index
322 = jj * c_block * types::data_type_size(jpp.ind_dt);
323
324 auto x = xreg(2 * ur_w + jj);
325 if (jpp.ind_dt == data_type::u8) {
326 if (isa == sse42) {
327 for (int i = 0; i < 4; ++i)
328 pextrb(ptr[reg_index + step_index + i], x, 4*i);
329 } else if (isa == avx) {
330 auto y = yreg(2 * ur_w + jj);
331 if (jj == 0) {
332 movd(xmm_tmp, reg_shuf_mask);
333 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
334 }
335 if (mayiuse(avx2)) {
336 vpshufb(y, y, vmm_tmp);
337 movd(ptr[reg_index + step_index], x);
338 vperm2i128(y, y, y, 0x1u);
339 movd(ptr[reg_index + step_index + 4], x);
340 } else {
341 Xmm t(vmm_mask.getIdx());
342 vextractf128(t, y, 0);
343 vpshufb(t, t, xmm_tmp);
344 movd(ptr[reg_index + step_index], t);
345 vextractf128(t, y, 1);
346 vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0]
347 movd(ptr[reg_index + step_index + 4], t);
348 }
349 } else {
350 auto v = vreg(2 * ur_w + jj);
351 vpmovusdb(x, v);
352 vmovups(ptr[reg_index + step_index], v | k_index_mask);
353 }
354 } else {
355 uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj));
356 }
357 }
358 }
359}
360
361template <cpu_isa_t isa>
362inline void jit_uni_pool_kernel_f32<isa>::max_step_bwd(int ur_w, int pad_l,
363 int pad_r) {
364
365 int iw = jpp.iw;
366 int kw = jpp.kw;
367 int stride_w = jpp.stride_w;
368 int c_block = jpp.c_block;
369 Label kd_label, kh_label;
370
371 for (int jj = 0; jj < ur_w; jj++) {
372 uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]);
373
374 const size_t step_index
375 = jj * c_block * types::data_type_size(jpp.ind_dt);
376 if (jpp.ind_dt == data_type::u8) {
377 if (isa == sse42) {
378 movd(xreg(ur_w+jj), ptr[reg_index + step_index]);
379 pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
380 } else if (isa == avx) {
381 movq(xreg(ur_w+jj), ptr[reg_index + step_index]);
382 if (!mayiuse(avx2)) {
383 avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp);
384 } else {
385 vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
386 }
387 } else {
388 vmovups(vreg(ur_w+jj) | k_index_mask,
389 ptr[reg_index + step_index]);
390 vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
391 }
392 } else {
393 uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]);
394 }
395 }
396 movq(xmm_tmp, reg_k_shift);
397 uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
398
399 if (jpp.simple_alg && jpp.ndims == 5) {
400 push(reg_input);
401 push(reg_output);
402 if (isa == sse42) {
403 // Save rdi since it is used in maskmovdqu
404 assert(dst_ptr == rdi);
405 push(dst_ptr);
406 }
407 mov(aux_reg_input_d, reg_input);
408 mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
409 mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]);
410 L(kd_label);
411 mov(aux_reg_input, aux_reg_input_d);
412 } else {
413 mov(aux_reg_input, reg_input);
414 }
415
416 xor_(kj, kj);
417 L(kh_label);
418 {
419 for (int ki = 0; ki < kw; ki++) {
420 int jj_start = nstl::max(0, pad_l - ki);
421 int jj_end = ur_w
422 - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
423 for (int jj = jj_start; jj < jj_end; jj++) {
424 int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
425 if (aux_input_offset > iw * c_block)
426 continue;
427 int input_offset = sizeof(float)*aux_input_offset;
428 uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]);
429 if (isa == sse42) {
430 mov(dst_ptr, aux_reg_input);
431 add(dst_ptr, input_offset);
432
433 movups(vreg(3*ur_w+jj), vreg(ur_w+jj));
434 pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset);
435 addps(vreg(2*ur_w+jj), vreg(jj));
436 maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj));
437 } else if (isa == avx) {
438 if (mayiuse(avx2)) {
439 vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset);
440 } else {
441 avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp);
442 }
443 vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj));
444 vmaskmovps(vmmword[aux_reg_input + input_offset],
445 vreg(3*ur_w+jj), vreg(2*ur_w+jj));
446 } else {
447 vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset);
448 vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj));
449 vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp);
450 vmovups(vmmword[aux_reg_input +
451 sizeof(float)*aux_input_offset], vreg(2*ur_w+jj));
452 }
453 }
454 if (isa == avx && !mayiuse(avx2)) {
455 avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
456 } else {
457 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
458 }
459 }
460 add(aux_reg_input, sizeof(float) * iw * c_block);
461 inc(kj);
462 cmp(kj, reg_kh);
463 jl(kh_label, T_NEAR);
464 }
465 if (jpp.simple_alg && jpp.ndims == 5)
466 {
467 add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
468
469 mov(tmp_gpr, reg_kd_pad_shift);
470 movq(xmm_tmp, tmp_gpr);
471 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
472 if (isa == avx && !mayiuse(avx2)) {
473 Xmm t(vmm_mask.getIdx());
474 avx_vpadd1(vmm_k_offset, vmm_tmp, t);
475 } else {
476 uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
477 }
478
479 dec(ki);
480 cmp(ki, 0);
481 jg(kd_label, T_NEAR);
482 if (isa == sse42) {
483 // Save rdi since it is used in maskmovdqu
484 assert(dst_ptr == rdi);
485 pop(dst_ptr);
486 }
487 pop(reg_output);
488 pop(reg_input);
489 }
490}
491
492template <cpu_isa_t isa>
493void jit_uni_pool_kernel_f32<isa>::maybe_zero_diff_src() {
494 assert(jpp.c_block * sizeof(float) % cpu_isa_traits<isa>::vlen == 0);
495 Label l_skip, l_zero;
496
497 auto reg_oh = tmp_gpr;
498 mov(reg_oh, ptr[reg_param + GET_OFF(oh)]);
499 cmp(reg_oh, 0);
500 jz(l_skip, T_NEAR);
501
502 if (jpp.ndims == 5) {
503 mov(zero_size, ptr[reg_param + GET_OFF(oh)]);
504 mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float));
505 imul(zero_size, tmp_gpr);
506 }
507
508 auto vzero = vmm_tmp;
509 uni_vpxor(vzero, vzero, vzero);
510
511 auto reg_off = tmp_gpr;
512 xor_(reg_off, reg_off);
513
514 L(l_zero);
515 {
516 const int dim = jpp.iw * jpp.c_block * sizeof(float);
517 for (int i = 0; i < dim; i += cpu_isa_traits<isa>::vlen)
518 uni_vmovups(ptr[reg_input + reg_off + i], vzero);
519 add(reg_off, dim);
520 if (jpp.ndims == 5) cmp(reg_off, zero_size);
521 else cmp(reg_off, jpp.ih * dim);
522 jl(l_zero, T_NEAR);
523 }
524
525 L(l_skip);
526}
527
528template <cpu_isa_t isa>
529void jit_uni_pool_kernel_f32<isa>::generate() {
530
531 this->preamble();
532
533 int ow = jpp.ow;
534 int iw = jpp.iw;
535 int kw = jpp.kw;
536 int kh = jpp.kh;
537 int ur_w = jpp.ur_w;
538 int c_block = jpp.c_block;
539 int stride_w = jpp.stride_w;
540 int l_pad = jpp.l_pad;
541 int ur_w_tail = jpp.ur_w_tail;
542
543 int n_oi = ow / ur_w;
544
545 prev_kw = 0;
546
547 int vlen = cpu_isa_traits<isa>::vlen;
548
549#if defined(_WIN32)
550 // Always mimic the Unix ABI (see the note about maskmovdqu in the header
551 // file).
552 xor_(rdi, rcx);
553 xor_(rcx, rdi);
554 xor_(rdi, rcx);
555#endif
556
557 mov(reg_input, ptr[reg_param + GET_OFF(src)]);
558 mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
559 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
560 mov(reg_index, ptr[reg_param + GET_OFF(indices)]);
561 mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]);
562 mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]);
563 mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]);
564
565 if (jpp.is_backward)
566 maybe_zero_diff_src();
567
568 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
569 mov(tmp_gpr, 1);
570 movq(xmm_one, tmp_gpr);
571 uni_vpbroadcastd(vmm_one, xmm_one);
572
573 if (isa == avx) {
574 mov(reg_shuf_mask, 0x0c080400);
575 } else if (isa >= avx512_common) {
576 mov(tmp_gpr.cvt32(), 0x000f);
577 kmovw(k_index_mask, tmp_gpr.cvt32());
578 }
579 }
580
581 int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1));
582 int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1);
583 if (r_pad1 > 0) n_oi--;
584
585 if (jpp.alg == pooling_avg_exclude_padding) {
586 movq(xmm_ker_area_h, reg_ker_area_h);
587 uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h);
588 }
589
590 if (jpp.alg == pooling_avg_include_padding) {
591 mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd)));
592 movq(xmm_tmp, tmp_gpr);
593 uni_vpbroadcastd(vmm_tmp, xmm_tmp);
594 }
595 if (l_pad > 0) {
596 n_oi--;
597 if (n_oi < 0 && r_pad1 > 0) {
598 step(ur_w, l_pad, r_pad1);
599 } else {
600 step(ur_w, l_pad, 0);
601 }
602
603 if (isa == sse42) {
604 if (n_oi < 0 && r_pad1 > 0) {
605 step_high_half(ur_w, l_pad, r_pad1);
606 } else {
607 step_high_half(ur_w, l_pad, 0);
608 }
609 }
610
611 if (isa == sse42) {
612 add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen);
613 add(reg_output, sizeof(float)*ur_w*c_block - vlen);
614 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
615 add(reg_index, (2 * ur_w - 1) * c_block / 2
616 * types::data_type_size(jpp.ind_dt));
617 } else {
618 add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block);
619 add(reg_output, sizeof(float)*ur_w*c_block);
620 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
621 add(reg_index, ur_w * c_block
622 * types::data_type_size(jpp.ind_dt));
623 }
624 }
625
626 xor_(oi_iter, oi_iter);
627 if (n_oi > 0) {
628 Label ow_loop;
629 L(ow_loop); {
630 step(ur_w, 0, 0);
631
632 if (isa == sse42) {
633 step_high_half(ur_w, 0, 0);
634 }
635
636 if (isa == sse42) {
637 add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen);
638 add(reg_output, sizeof(float)*ur_w*c_block - vlen);
639 if (jpp.alg == pooling_max &&
640 (jpp.is_training || jpp.is_backward))
641 add(reg_index, (2 * ur_w - 1) * c_block / 2
642 * types::data_type_size(jpp.ind_dt));
643 } else {
644 add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
645 add(reg_output, sizeof(float)*ur_w*c_block);
646 if (jpp.alg == pooling_max &&
647 (jpp.is_training || jpp.is_backward))
648 add(reg_index, ur_w * c_block
649 * types::data_type_size(jpp.ind_dt));
650 }
651
652 inc(oi_iter);
653 cmp(oi_iter, n_oi);
654 jl(ow_loop, T_NEAR);
655 }
656 }
657
658 if (r_pad1 > 0 && n_oi >= 0) {
659 step(ur_w, 0, r_pad1);
660
661 if (isa == sse42) {
662 step_high_half(ur_w, 0, r_pad1);
663 }
664
665 if (isa == sse42) {
666 add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen);
667 add(reg_output, sizeof(float)*ur_w*c_block - vlen);
668 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
669 add(reg_index, (2 * ur_w - 1) * c_block / 2
670 * types::data_type_size(jpp.ind_dt));
671 } else {
672 add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
673 add(reg_output, sizeof(float)*ur_w*c_block);
674 if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
675 add(reg_index, ur_w * c_block
676 * types::data_type_size(jpp.ind_dt));
677 }
678 }
679
680 if (ur_w_tail != 0) {
681 step(ur_w_tail, 0, r_pad);
682
683 if (isa == sse42) {
684 step_high_half(ur_w_tail, 0, r_pad);
685 }
686 }
687
688 this->postamble();
689}
690
691template struct jit_uni_pool_kernel_f32<sse42>;
692template struct jit_uni_pool_kernel_f32<avx>; // implements both <avx> and <avx2>
693template struct jit_uni_pool_kernel_f32<avx512_common>;
694
695}
696}
697}
698
699// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
700