1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3* Copyright 2018 YANDEX LLC
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#ifndef JIT_UNI_POOL_KERNEL_F32_HPP
19#define JIT_UNI_POOL_KERNEL_F32_HPP
20
21#include <cfloat>
22
23#include "c_types_map.hpp"
24#include "pooling_pd.hpp"
25#include "type_helpers.hpp"
26
27#include "jit_generator.hpp"
28#include "jit_primitive_conf.hpp"
29
30namespace mkldnn {
31namespace impl {
32namespace cpu {
33
34using namespace Xbyak;
35
36template <cpu_isa_t isa>
37struct jit_uni_pool_kernel_f32: public jit_generator {
38 jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp)
39 {
40 this->generate();
41 jit_ker = (decltype(jit_ker))this->getCode();
42 }
43
44 jit_pool_conf_t jpp;
45
46 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32)
47
48 void operator()(jit_pool_call_s *arg) { jit_ker(arg); }
49 static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd);
50
51private:
52 using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx,
53 Ymm, Zmm>::type;
54 Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); }
55 Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); }
56 Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); }
57
58 const AddressFrame &vmmword = (isa == sse42) ? xword :
59 (isa == avx) ? yword : zword;
60
61 Xmm vmm_mask = Xmm(0);
62 Xmm xmm_ker_area_h = Xmm(2);
63 Xmm xmm_one = Xmm(2);
64 Xmm xmm_tmp = Xmm(3);
65
66 Vmm vmm_ker_area_h = Vmm(2);
67 Vmm vmm_one = Vmm(2);
68 Vmm vmm_tmp = Vmm(3);
69
70 Vmm vmm_k_offset = Vmm(1);
71
72 Opmask k_index_mask = Opmask(6);
73 Opmask k_store_mask = Opmask(7);
74
75 // Here be some (tame) dragons. This kernel does not follow the regular
76 // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu
77 // instruction which has its destination hardcoded in rdi. Therefore:
78 // - all registers are hardcoded
79 // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI
80 //
81 // While this is only required by the backward pass, the quirk above
82 // is applied to the forward pass as well to keep things simpler.
83
84 using reg64_t = const Xbyak::Reg64;
85 reg64_t reg_param = rdi; // Always mimic the Unix ABI
86 reg64_t reg_input = r8;
87 reg64_t aux_reg_input = r9;
88 reg64_t reg_index = r10;
89 reg64_t reg_output = r12;
90 reg64_t reg_kd_pad_shift = r13;
91 reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu
92
93 reg64_t kj = r14;
94 reg64_t oi_iter = r15;
95 reg64_t reg_kh = rax;
96 reg64_t reg_k_shift = rbx;
97 reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above
98 reg64_t reg_ker_area_h = rdx;
99
100 reg64_t zero_size = r15;
101 reg64_t ki = r12;
102 reg64_t aux_reg_input_d = r8;
103
104 Xbyak::Reg32 reg_shuf_mask = esi;
105
106 int prev_kw;
107 void (*jit_ker)(jit_pool_call_s *);
108
109 void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r);
110 void avg_step(int ur_w, int pad_l, int pad_r);
111 void max_step_fwd(int ur_w, int pad_l, int pad_r);
112 void max_step_bwd(int ur_w, int pad_l, int pad_r);
113
114 void maybe_zero_diff_src();
115
116 void step(int ur_w, int pad_l, int pad_r) {
117 if (jpp.alg == alg_kind::pooling_max) {
118 if(jpp.is_backward)
119 max_step_bwd(ur_w, pad_l, pad_r);
120 else
121 max_step_fwd(ur_w, pad_l, pad_r);
122 }
123 else
124 avg_step(ur_w, pad_l, pad_r);
125 }
126
127 void step_high_half(int ur_w, int pad_l, int pad_r) {
128 add(reg_input, sizeof(float) * 4);
129 add(reg_output, sizeof(float) * 4);
130 if (jpp.alg == alg_kind::pooling_max &&
131 (jpp.is_training || jpp.is_backward))
132 add(reg_index, types::data_type_size(jpp.ind_dt) * 4);
133
134 step(ur_w, pad_l, pad_r);
135 }
136
137 void generate();
138
139 void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) {
140 assert(y0.getIdx() != x1.getIdx());
141 vextractf128(xtmp, y0, 0);
142 vpaddd(xtmp, xtmp, x1);
143 vinsertf128(y0, y0, xtmp, 0);
144 vextractf128(xtmp, y0, 1);
145 vpaddd(xtmp, xtmp, x1);
146 vinsertf128(y0, y0, xtmp, 1);
147 }
148
149 void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) {
150 assert(false /*function should not be used*/);
151 paddd(x0, x1);
152 }
153
154 void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) {
155 Xmm x0(y0.getIdx());
156 pshufd(xmm_tmp, x1, 1);
157 pmovzxbd(x0, x1);
158 pmovzxbd(xmm_tmp, xmm_tmp);
159 vinsertf128(y0, y0, xmm_tmp, 1);
160 }
161
162 void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) {
163 assert(false /*function should not be used*/);
164 pmovzxbd(x0, x1);
165 }
166
167 void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) {
168 assert(y0.getIdx() != y1.getIdx());
169 assert(y0.getIdx() != y2.getIdx());
170 Xmm x0(y0.getIdx());
171 Xmm x2(y2.getIdx());
172 vextractf128(x0, y1, 1);
173 vextractf128(xtmp, y2, 1);
174 pcmpeqd(xtmp, x0);
175 vextractf128(x0, y1, 0);
176 pcmpeqd(x0, x2);
177 vinsertf128(y0, y0, xtmp, 1);
178 }
179
180 void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) {
181 assert(false /*function should not be used*/);
182 pcmpeqd(x0, x1);
183 }
184};
185
186}
187}
188}
189
190#endif
191
192// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
193