1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP
18#define JIT_AVX2_CONV_KERNEL_F32_HPP
19
20#include "c_types_map.hpp"
21#include "memory_tracking.hpp"
22
23#include "cpu_memory.hpp"
24#include "jit_generator.hpp"
25#include "jit_primitive_conf.hpp"
26#include "jit_uni_eltwise.hpp"
27
28namespace mkldnn {
29namespace impl {
30namespace cpu {
31
32struct jit_avx2_conv_fwd_kernel_f32: public jit_generator {
33 jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
34 const primitive_attr_t &attr)
35 : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
36 {
37 if (jcp.with_eltwise)
38 eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx2>(this,
39 jcp.eltwise);
40
41 this->generate();
42 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
43 }
44
45 ~jit_avx2_conv_fwd_kernel_f32() {
46 delete eltwise_injector_;
47 }
48
49 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32)
50
51 static bool post_ops_ok(jit_conv_conf_t &jcp,
52 const primitive_attr_t &attr);
53 static status_t init_conf(jit_conv_conf_t &jcp,
54 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
55 const memory_desc_wrapper &weights_d,
56 const memory_desc_wrapper &dst_d,
57 const primitive_attr_t &attr);
58 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
59 const jit_conv_conf_t &jcp);
60
61 jit_conv_conf_t jcp;
62 const primitive_attr_t &attr_;
63 void (*jit_ker)(jit_conv_call_s *);
64
65private:
66 using reg64_t = const Xbyak::Reg64;
67 reg64_t reg_input = rax;
68 reg64_t aux_reg_input = r8;
69 reg64_t reg_kernel = rdx;
70 reg64_t aux_reg_kernel = r9;
71 reg64_t reg_output = rsi;
72 reg64_t reg_bias = rbx;
73
74 reg64_t aux_reg_inp_d = r11;
75 reg64_t aux_reg_ker_d = abi_not_param1;
76
77 reg64_t reg_ki = rsi;
78 reg64_t kj = r10;
79 reg64_t oi_iter = r11;
80 reg64_t ki_iter = r12;
81 reg64_t reg_kh = abi_not_param1;
82 reg64_t reg_oc_blocks = r14;
83 reg64_t imm_addr64 = r15;
84 reg64_t reg_long_offt = r15;
85 Xbyak::Reg32 reg_ci_flag = r13d;
86
87 Xbyak::Ymm ytmp = Xbyak::Ymm(14);
88
89 jit_uni_eltwise_injector_f32<avx2> *eltwise_injector_;
90
91 inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
92 int oc_blocks);
93 inline void oh_step_nopad(int ur_w, int pad_l, int pad_r,
94 char pad_label, int oc_blocks, char oc_blocks_label);
95 inline void width_blk_step(int ur_w, int pad_l, int pad_r,
96 char pad_label, int oc_blocks, char oc_blocks_label);
97 inline void solve_common(int oc_blocks, char oc_blocks_label);
98
99 void generate();
100};
101
102struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator {
103 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32)
104
105 jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
106 {
107 this->generate();
108 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
109 }
110
111 static status_t init_conf(jit_conv_conf_t &jcp,
112 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
113 const memory_desc_wrapper &weights_d,
114 const memory_desc_wrapper &diff_dst_d);
115 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
116 const jit_conv_conf_t &jcp);
117
118 jit_conv_conf_t jcp;
119 void (*jit_ker)(jit_conv_call_s *);
120
121private:
122 using reg64_t = const Xbyak::Reg64;
123
124 reg64_t reg_ddst = rax;
125 reg64_t aux_reg_ddst = r8;
126 reg64_t reg_kernel = rdx;
127 reg64_t aux_reg_kernel = r10;
128 reg64_t reg_dsrc = rsi;
129 reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only
130 reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5
131 case only */
132
133 reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only
134 reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only
135
136 reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only
137 reg64_t kj = r11;
138 reg64_t oi_iter = r12;
139 reg64_t reg_kh = r14;
140 reg64_t reg_channel = r13; // used in ndims < 5 case only
141 reg64_t reg_channel_work = r9; // used in ndims < 5 case only
142 reg64_t reg_long_offt = r15;
143
144 inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
145
146 void generate();
147
148 inline int get_iw_start(int ki, int l_overflow)
149 {
150 int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
151 + l_overflow * jcp.stride_w
152 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
153 while (res < 0)
154 res += jcp.stride_w;
155
156 return res;
157 }
158
159 inline int get_iw_end(int ur_w, int ki, int r_overflow)
160 {
161 if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
162 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
163 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
164 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
165 while (res < 0)
166 res += jcp.stride_w;
167
168 return ur_w - res;
169 }
170};
171
172struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator {
173 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32)
174
175 jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
176 {
177 this->generate();
178 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
179 }
180
181 static status_t init_conf(jit_conv_conf_t &jcp,
182 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
183 const memory_desc_wrapper &diff_weights_d,
184 const memory_desc_wrapper &diff_dst_d);
185 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
186 const jit_conv_conf_t &jcp);
187
188 jit_conv_conf_t jcp;
189 void (*jit_ker)(jit_conv_call_s *);
190
191private:
192 using reg64_t = const Xbyak::Reg64;
193 reg64_t reg_input = rax;
194 reg64_t reg_kernel = rdx;
195 reg64_t reg_output = rsi;
196 reg64_t b_ic = abi_not_param1;
197 reg64_t kj = r8;
198 reg64_t reg_kh = r9;
199 reg64_t reg_ur_w_trips = r10;
200 reg64_t reg_tmp = r11;
201 reg64_t reg_oj = r15;
202 reg64_t reg_ih_count = rbx;
203 reg64_t aux_reg_input = r12;
204 reg64_t aux_reg_kernel = r13;
205 reg64_t ki = r14;
206 reg64_t reg_long_offt = r11;
207
208 inline void od_step_comeback_pointers();
209 inline void oh_step_comeback_pointers();
210 inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r,
211 int ic_block_step, int input_offset, int kernel_offset,
212 int output_offset);
213 inline void compute_oh_step_disp();
214 inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
215 inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
216 inline void compute_oh_loop_common();
217
218 void generate();
219};
220
221}
222}
223}
224
225#endif
226