1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * Copyright 2018 YANDEX LLC |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #ifndef JIT_UNI_POOL_KERNEL_F32_HPP |
19 | #define JIT_UNI_POOL_KERNEL_F32_HPP |
20 | |
21 | #include <cfloat> |
22 | |
23 | #include "c_types_map.hpp" |
24 | #include "pooling_pd.hpp" |
25 | #include "type_helpers.hpp" |
26 | |
27 | #include "jit_generator.hpp" |
28 | #include "jit_primitive_conf.hpp" |
29 | |
30 | namespace mkldnn { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | using namespace Xbyak; |
35 | |
36 | template <cpu_isa_t isa> |
37 | struct jit_uni_pool_kernel_f32: public jit_generator { |
38 | jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp) |
39 | { |
40 | this->generate(); |
41 | jit_ker = (decltype(jit_ker))this->getCode(); |
42 | } |
43 | |
44 | jit_pool_conf_t jpp; |
45 | |
46 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32) |
47 | |
48 | void operator()(jit_pool_call_s *arg) { jit_ker(arg); } |
49 | static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd); |
50 | |
51 | private: |
52 | using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx, |
53 | Ymm, Zmm>::type; |
54 | Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); } |
55 | Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); } |
56 | Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); } |
57 | |
58 | const AddressFrame &vmmword = (isa == sse42) ? xword : |
59 | (isa == avx) ? yword : zword; |
60 | |
61 | Xmm vmm_mask = Xmm(0); |
62 | Xmm xmm_ker_area_h = Xmm(2); |
63 | Xmm xmm_one = Xmm(2); |
64 | Xmm xmm_tmp = Xmm(3); |
65 | |
66 | Vmm vmm_ker_area_h = Vmm(2); |
67 | Vmm vmm_one = Vmm(2); |
68 | Vmm vmm_tmp = Vmm(3); |
69 | |
70 | Vmm vmm_k_offset = Vmm(1); |
71 | |
72 | Opmask k_index_mask = Opmask(6); |
73 | Opmask k_store_mask = Opmask(7); |
74 | |
75 | // Here be some (tame) dragons. This kernel does not follow the regular |
76 | // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu |
77 | // instruction which has its destination hardcoded in rdi. Therefore: |
78 | // - all registers are hardcoded |
79 | // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI |
80 | // |
81 | // While this is only required by the backward pass, the quirk above |
82 | // is applied to the forward pass as well to keep things simpler. |
83 | |
84 | using reg64_t = const Xbyak::Reg64; |
85 | reg64_t reg_param = rdi; // Always mimic the Unix ABI |
86 | reg64_t reg_input = r8; |
87 | reg64_t aux_reg_input = r9; |
88 | reg64_t reg_index = r10; |
89 | reg64_t reg_output = r12; |
90 | reg64_t reg_kd_pad_shift = r13; |
91 | reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu |
92 | |
93 | reg64_t kj = r14; |
94 | reg64_t oi_iter = r15; |
95 | reg64_t reg_kh = rax; |
96 | reg64_t reg_k_shift = rbx; |
97 | reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above |
98 | reg64_t reg_ker_area_h = rdx; |
99 | |
100 | reg64_t zero_size = r15; |
101 | reg64_t ki = r12; |
102 | reg64_t aux_reg_input_d = r8; |
103 | |
104 | Xbyak::Reg32 reg_shuf_mask = esi; |
105 | |
106 | int prev_kw; |
107 | void (*jit_ker)(jit_pool_call_s *); |
108 | |
109 | void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r); |
110 | void avg_step(int ur_w, int pad_l, int pad_r); |
111 | void max_step_fwd(int ur_w, int pad_l, int pad_r); |
112 | void max_step_bwd(int ur_w, int pad_l, int pad_r); |
113 | |
114 | void maybe_zero_diff_src(); |
115 | |
116 | void step(int ur_w, int pad_l, int pad_r) { |
117 | if (jpp.alg == alg_kind::pooling_max) { |
118 | if(jpp.is_backward) |
119 | max_step_bwd(ur_w, pad_l, pad_r); |
120 | else |
121 | max_step_fwd(ur_w, pad_l, pad_r); |
122 | } |
123 | else |
124 | avg_step(ur_w, pad_l, pad_r); |
125 | } |
126 | |
127 | void step_high_half(int ur_w, int pad_l, int pad_r) { |
128 | add(reg_input, sizeof(float) * 4); |
129 | add(reg_output, sizeof(float) * 4); |
130 | if (jpp.alg == alg_kind::pooling_max && |
131 | (jpp.is_training || jpp.is_backward)) |
132 | add(reg_index, types::data_type_size(jpp.ind_dt) * 4); |
133 | |
134 | step(ur_w, pad_l, pad_r); |
135 | } |
136 | |
137 | void generate(); |
138 | |
139 | void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { |
140 | assert(y0.getIdx() != x1.getIdx()); |
141 | vextractf128(xtmp, y0, 0); |
142 | vpaddd(xtmp, xtmp, x1); |
143 | vinsertf128(y0, y0, xtmp, 0); |
144 | vextractf128(xtmp, y0, 1); |
145 | vpaddd(xtmp, xtmp, x1); |
146 | vinsertf128(y0, y0, xtmp, 1); |
147 | } |
148 | |
149 | void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) { |
150 | assert(false /*function should not be used*/); |
151 | paddd(x0, x1); |
152 | } |
153 | |
154 | void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { |
155 | Xmm x0(y0.getIdx()); |
156 | pshufd(xmm_tmp, x1, 1); |
157 | pmovzxbd(x0, x1); |
158 | pmovzxbd(xmm_tmp, xmm_tmp); |
159 | vinsertf128(y0, y0, xmm_tmp, 1); |
160 | } |
161 | |
162 | void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) { |
163 | assert(false /*function should not be used*/); |
164 | pmovzxbd(x0, x1); |
165 | } |
166 | |
167 | void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) { |
168 | assert(y0.getIdx() != y1.getIdx()); |
169 | assert(y0.getIdx() != y2.getIdx()); |
170 | Xmm x0(y0.getIdx()); |
171 | Xmm x2(y2.getIdx()); |
172 | vextractf128(x0, y1, 1); |
173 | vextractf128(xtmp, y2, 1); |
174 | pcmpeqd(xtmp, x0); |
175 | vextractf128(x0, y1, 0); |
176 | pcmpeqd(x0, x2); |
177 | vinsertf128(y0, y0, xtmp, 1); |
178 | } |
179 | |
180 | void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) { |
181 | assert(false /*function should not be used*/); |
182 | pcmpeqd(x0, x1); |
183 | } |
184 | }; |
185 | |
186 | } |
187 | } |
188 | } |
189 | |
190 | #endif |
191 | |
192 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
193 | |