1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "c_types_map.hpp" |
18 | #include "nstl.hpp" |
19 | #include "type_helpers.hpp" |
20 | #include "utils.hpp" |
21 | |
22 | #include "cpu_barrier.hpp" |
23 | |
24 | #include "jit_avx512_common_conv_kernel.hpp" |
25 | |
26 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
27 | #define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024) |
28 | |
29 | namespace mkldnn { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | using namespace mkldnn::impl::format_tag; |
34 | using namespace mkldnn::impl::memory_tracking::names; |
35 | using namespace mkldnn::impl::utils; |
36 | using namespace Xbyak; |
37 | |
38 | namespace { |
39 | |
40 | constexpr auto small_spatial = 14; |
41 | unsigned int L1_cache_size = get_cache_size(1, true); |
42 | |
43 | inline void pick_loop_order(jit_conv_conf_t &jcp) { |
44 | using namespace prop_kind; |
45 | assert(one_of(jcp.prop_kind, |
46 | forward_training, forward_inference, backward_data)); |
47 | auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow; |
48 | auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh; |
49 | |
50 | // ow-threading is currently implemented for forward only |
51 | // TODO: single code for fwd and bwd after ow-thr for bwd |
52 | // meaningless switch was removed |
53 | if (jcp.prop_kind == backward_data) { |
54 | jcp.loop_order = (w <= small_spatial && h <= small_spatial) |
55 | ? loop_cgn : loop_gnc; |
56 | } else { |
57 | jcp.loop_order = (w <= small_spatial && h <= small_spatial) |
58 | ? loop_cwgn : loop_gncw; |
59 | } |
60 | } |
61 | |
62 | inline bool is_1stconv(const jit_conv_conf_t &jcp) { |
63 | if (mayiuse(avx512_core)) |
64 | return (jcp.ic < 16 && jcp.ngroups == 1); |
65 | else |
66 | return one_of(jcp.ic, 1, 3); |
67 | } |
68 | |
69 | inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { |
70 | return (jcp.nb_ow > 1); |
71 | } |
72 | |
73 | inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) { |
74 | return (jcp.ver == ver_4fma && is_ow_threading_on(jcp)); |
75 | } |
76 | |
77 | } |
78 | |
79 | template<typename Vmm> |
80 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w) |
81 | { |
82 | for (int k = 0; k < jcp.nb_oc_blocking; k++) |
83 | for (int j = 0; j < ur_w; j++) { |
84 | Vmm vmm = vmm_out(j, k); |
85 | vpxord(vmm, vmm, vmm); |
86 | if (!is_owb_prefetching(jcp)) { |
87 | size_t aux_output_offset = get_output_offset(j, k); |
88 | mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf, |
89 | aux_output_offset, reg_out_long_offt)); |
90 | } |
91 | } |
92 | } |
93 | |
94 | template<typename Vmm> |
95 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w) |
96 | { |
97 | Label no_update_label, store_label, eltwise_label; |
98 | |
99 | mov(reg_channel, ptr[param1 + GET_OFF(channel)]); |
100 | if (jcp.with_bias) { |
101 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
102 | } |
103 | |
104 | if (!jcp.with_sum) { |
105 | cmp(reg_channel, 0); |
106 | je(no_update_label, T_NEAR); |
107 | } |
108 | |
109 | for (int k = 0; k < jcp.nb_oc_blocking; k++) |
110 | for (int j = 0; j < ur_w; j++) { |
111 | Vmm vmm = vmm_out(j, k); |
112 | size_t aux_output_offset = get_output_offset(j, k); |
113 | vaddps(vmm, |
114 | make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt)); |
115 | } |
116 | |
117 | if (!jcp.with_sum) { |
118 | jmp(eltwise_label, T_NEAR); |
119 | } else { |
120 | cmp(reg_channel, 0); |
121 | jne(eltwise_label, T_NEAR); |
122 | } |
123 | |
124 | L(no_update_label); |
125 | if (jcp.with_bias) { |
126 | for (int k = 0; k < jcp.nb_oc_blocking; k++) { |
127 | int bias_offset = jcp.typesize_out * k * jcp.oc_block; |
128 | for (int j = 0; j < ur_w; j++) { |
129 | Vmm vmm = vmm_out(j, k); |
130 | vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset)); |
131 | } |
132 | mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64)); |
133 | } |
134 | } |
135 | |
136 | L(eltwise_label); |
137 | if (jcp.with_eltwise) { |
138 | cmp(reg_channel, jcp.nb_ic - 1); |
139 | jl(store_label, T_NEAR); |
140 | |
141 | if (ur_w == jcp.ur_w) { |
142 | eltwise_injector_->compute_vector_range(0, |
143 | jcp.nb_oc_blocking * jcp.ur_w); |
144 | } else { |
145 | for (int k = 0; k < jcp.nb_oc_blocking; k++) |
146 | eltwise_injector_->compute_vector_range(k * jcp.ur_w, |
147 | k * jcp.ur_w + ur_w); |
148 | } |
149 | } |
150 | |
151 | L(store_label); |
152 | for (int k = 0; k < jcp.nb_oc_blocking; k++) |
153 | for (int j = 0; j < ur_w; j++) { |
154 | Vmm vmm = vmm_out(j, k); |
155 | size_t aux_output_offset = (size_t)typesize * |
156 | ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; |
157 | vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset, |
158 | reg_out_long_offt), vmm); |
159 | if (!is_owb_prefetching(jcp)) |
160 | mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf, |
161 | aux_output_offset, reg_out_long_offt)); |
162 | } |
163 | } |
164 | |
165 | template<typename Vmm> |
166 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w, |
167 | int pad_l, int pad_r) |
168 | { |
169 | } |
170 | |
171 | template<> |
172 | void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w, |
173 | int pad_l, int pad_r) |
174 | { |
175 | assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0); |
176 | |
177 | int iw = jcp.iw; |
178 | int ih = jcp.ih; |
179 | int kw = jcp.kw; |
180 | int stride_w = jcp.stride_w; |
181 | int ic_block = jcp.ic_block; |
182 | int oc_block = jcp.oc_block; |
183 | |
184 | Label kh_label, kd_label; |
185 | |
186 | if (one_of(jcp.ndims, 3, 4)) { |
187 | mov(aux_reg_inp, reg_inp); |
188 | mov(aux_reg_ker, reg_ker); |
189 | mov(aux_reg_inp_prf, reg_inp_prf); |
190 | } |
191 | |
192 | size_t max_input_offset = (size_t)jcp.typesize_in |
193 | * ((size_t)(kw + ur_w * stride_w - pad_l) |
194 | + (size_t)ic_block * iw * ih * jcp.id); |
195 | assert(reg_inp_prf == reg_long_offt); |
196 | if (max_input_offset > INT_MAX) push(reg_inp_prf); |
197 | |
198 | if (jcp.ndims == 5) { |
199 | push(reg_out_prf); |
200 | push(reg_out); |
201 | |
202 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
203 | mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); |
204 | mov(aux_reg_inp_d, reg_inp); |
205 | mov(aux_reg_inp_d_prf, reg_inp_prf); |
206 | |
207 | L(kd_label); |
208 | } |
209 | mov(reg_kj, reg_kh); |
210 | if (jcp.ndims == 5) { |
211 | mov(aux_reg_inp, aux_reg_inp_d); |
212 | mov(aux_reg_ker, aux_reg_ker_d); |
213 | mov(aux_reg_inp_prf, aux_reg_inp_d_prf); |
214 | } |
215 | |
216 | L(kh_label); |
217 | for (int ki = 0; ki < kw; ki += 4) { |
218 | for (int ic = 0; ic < ic_block; ic++) { |
219 | for (int i = 0; i < 4; i++) { |
220 | int aux_ker_offset |
221 | = jcp.typesize_in |
222 | * ((ki + i) * oc_block |
223 | + ic * kw * jcp.kh * jcp.kd * oc_block); |
224 | if (ki + i < kw) |
225 | vmovups(vmm_ker(i), |
226 | EVEX_compress_addr(aux_reg_ker, aux_ker_offset)); |
227 | else |
228 | vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i)); |
229 | } |
230 | |
231 | int j_start = get_ow_start(ki, pad_l); |
232 | int j_end = get_ow_end(ur_w, ki, pad_r); |
233 | |
234 | for (int j = j_start, prf_count=0; j < j_end; j++) { |
235 | size_t aux_input_offset = (size_t)jcp.typesize_in |
236 | * ((size_t)(ki + j * stride_w |
237 | - pad_l) + (size_t)ic * iw * ih * jcp.id); |
238 | v4fmaddps(vmm_out(j, 0), vmm_ker(0), |
239 | EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset, |
240 | reg_long_offt)); |
241 | if (ki + prf_count < kw && prf_count < 4 |
242 | && ((ki < 2 && j % 4) || j % 2)) { |
243 | int aux_ker_offset = jcp.typesize_in |
244 | * ((ki + prf_count) * oc_block |
245 | + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block); |
246 | mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, |
247 | aux_ker_offset)); |
248 | prf_count++; |
249 | } |
250 | if (ki == 0 |
251 | && j % (64 / (stride_w * jcp.typesize_in)) == 0) { |
252 | mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf, |
253 | aux_input_offset, reg_long_offt)); |
254 | } |
255 | if (ki == 1 |
256 | && j % (64 / (stride_w * jcp.typesize_in)) == 0) { |
257 | mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, |
258 | aux_input_offset+jcp.typesize_in * iw, reg_long_offt)); |
259 | } |
260 | } |
261 | } |
262 | } |
263 | add(aux_reg_ker, jcp.typesize_in * kw * oc_block); |
264 | add(aux_reg_inp, jcp.typesize_in * iw); |
265 | add(aux_reg_inp_prf, jcp.typesize_in * iw); |
266 | |
267 | dec(reg_kj); |
268 | cmp(reg_kj, 0); |
269 | jg(kh_label, T_NEAR); |
270 | |
271 | if (jcp.ndims == 5) { |
272 | add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw); |
273 | add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block); |
274 | add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw); |
275 | |
276 | dec(reg_ki); |
277 | cmp(reg_ki, 0); |
278 | jg(kd_label, T_NEAR); |
279 | |
280 | pop(reg_out); |
281 | pop(reg_out_prf); |
282 | } |
283 | |
284 | if (max_input_offset > INT_MAX) pop(reg_inp_prf); |
285 | } |
286 | |
287 | template<typename Vmm> |
288 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w, |
289 | int pad_l, int pad_r) |
290 | { |
291 | } |
292 | |
293 | template<> |
294 | void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w, |
295 | int pad_l, int pad_r) |
296 | { |
297 | int stride_w = jcp.stride_w; |
298 | int ic_block = jcp.ic_block; |
299 | int oc_block = jcp.oc_block; |
300 | Label kh_label, last_iter_label, loop_end_label, kd_label; |
301 | int ker_load_number = 4; |
302 | int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block; |
303 | int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block; |
304 | |
305 | bool check_last_kh = (jcp.kh > 3); |
306 | bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28); |
307 | |
308 | int oi_ipref_t0 = get_ow_start(0, pad_l); |
309 | int ow_end_ipref = get_ow_end(ur_w, 0, pad_r); |
310 | |
311 | assert(jcp.oc % jcp.nb_oc_blocking == 0); |
312 | |
313 | auto kernel_offset = [=](int ocb, int ic, int ki) { |
314 | int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki; |
315 | int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; |
316 | int ic_offset = ic * jcp.oc_block; |
317 | return typesize * (blk_offset + ic_offset); |
318 | }; |
319 | auto kernel_loads = [=](int ki, int ic, int kk) { |
320 | for (int ii = 0; ii < ker_load_number; ii++) { |
321 | int aux_kernel_offset = kernel_offset(kk, ic + ii, ki); |
322 | vmovups(vmm_ker(ii), |
323 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
324 | } |
325 | }; |
326 | auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { |
327 | if (cnt1 >= ker_load_number && cnt0 >= ker_load_number |
328 | && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) { |
329 | int aux_inp_offset |
330 | = typesize |
331 | * ((oi_ipref_t0 * stride_w - pad_l) * ic_block |
332 | + (jcp.dilate_h + 1) * jcp.iw * ic_block); |
333 | prefetcht0(EVEX_compress_addr(aux_reg_inp, |
334 | aux_inp_offset)); |
335 | oi_ipref_t0++; |
336 | } |
337 | }; |
338 | |
339 | if (one_of(jcp.ndims, 3, 4)) { |
340 | mov(aux_reg_inp, reg_inp); |
341 | mov(aux_reg_ker, reg_ker); |
342 | mov(aux_reg_ker_prf, reg_ker_prf); |
343 | mov(aux_reg_inp_prf, reg_inp_prf); |
344 | } |
345 | |
346 | if (jcp.ndims == 5) { |
347 | push(reg_out_prf); |
348 | push(reg_out); |
349 | |
350 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
351 | mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); |
352 | mov(aux_reg_inp_d, reg_inp); |
353 | mov(aux_reg_inp_d_prf, reg_inp_prf); |
354 | mov(aux_reg_ker_d_prf, reg_ker_prf); |
355 | L(kd_label); |
356 | mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); |
357 | } else { |
358 | mov(reg_kj, reg_kh); |
359 | } |
360 | if (jcp.ndims == 5) { |
361 | mov(aux_reg_inp, aux_reg_inp_d); |
362 | mov(aux_reg_ker, aux_reg_ker_d); |
363 | mov(aux_reg_ker_prf, aux_reg_ker_d_prf); |
364 | mov(aux_reg_inp_prf, aux_reg_inp_d_prf); |
365 | } |
366 | |
367 | align(16); |
368 | L(kh_label); |
369 | int kw = jcp.kw; |
370 | if (check_last_kh) { |
371 | for (int ki = 0; ki < kw; ki++) |
372 | for (int ic = 0; ic < ic_block; ic += 4) |
373 | for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { |
374 | bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1 |
375 | && ki == kw - 1 && (ic + 4) == ic_block); |
376 | |
377 | if (last_kernel_loads) { |
378 | cmp(reg_kj, 1); |
379 | je(last_iter_label, T_NEAR); |
380 | } |
381 | |
382 | kernel_loads(ki, ic, kk); |
383 | for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, |
384 | prf_count_t0 = 0; |
385 | oi < get_ow_end(ur_w, ki, pad_r); oi++) { |
386 | int aux_input_offset = typesize |
387 | * ((ki * (jcp.dilate_w + 1) + oi * stride_w |
388 | - pad_l) * ic_block |
389 | + ic); |
390 | v4fmaddps(vmm_out(oi, kk), vmm_ker(0), |
391 | EVEX_compress_addr(aux_reg_inp, aux_input_offset)); |
392 | |
393 | if (oi % 2) { |
394 | if (prf_count_t0 < 4) { |
395 | int aux_kernel_prf; |
396 | if (last_kernel_loads) |
397 | aux_kernel_prf= kernel_offset(0, |
398 | prf_count_t0 + ic + 4 |
399 | - ic_block, 0) + typesize * kw |
400 | * oc_block * ic_block; |
401 | else |
402 | aux_kernel_prf = kernel_offset(kk, ic + 4 |
403 | + prf_count_t0, ki); |
404 | mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, |
405 | aux_kernel_prf)); |
406 | prf_count_t0++; |
407 | } else if (prf_count_t1 < 4) { |
408 | mic_prefetcht1(EVEX_compress_addr( |
409 | aux_reg_ker_prf, kernel_offset(kk, ic |
410 | + prf_count_t1, ki))); |
411 | prf_count_t1++; |
412 | } |
413 | } else |
414 | prefetch_inp_next_kh(ki, 2, prf_count_t0, |
415 | prf_count_t1); |
416 | } |
417 | |
418 | if (last_kernel_loads) { |
419 | jmp(loop_end_label, T_NEAR); |
420 | |
421 | L(last_iter_label); |
422 | |
423 | kernel_loads(ki, ic, kk); |
424 | for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, |
425 | prf_count_t0 = 0; |
426 | oi < get_ow_end(ur_w, ki, pad_r); oi++) { |
427 | int aux_input_offset = typesize |
428 | * ((ki * (jcp.dilate_w + 1) + oi * stride_w |
429 | - pad_l) * ic_block |
430 | + ic); |
431 | v4fmaddps(vmm_out(oi, kk), vmm_ker(0), |
432 | EVEX_compress_addr(aux_reg_inp, |
433 | aux_input_offset)); |
434 | if (oi % 2) { |
435 | if (prf_count_t0 < 4) { |
436 | mic_prefetcht0(EVEX_compress_addr( |
437 | aux_reg_ker_prf, kernel_offset(0, |
438 | prf_count_t0, 0))); |
439 | prf_count_t0++; |
440 | } else if (prf_count_t1 < 4) { |
441 | mic_prefetcht1(EVEX_compress_addr( |
442 | aux_reg_ker_prf, kernel_offset(kk, |
443 | ic + prf_count_t1, ki))); |
444 | prf_count_t1++; |
445 | } |
446 | } |
447 | } |
448 | L(loop_end_label); |
449 | } |
450 | } |
451 | } else { |
452 | for (int ki = 0; ki < kw; ki++) |
453 | for (int ic = 0; ic < ic_block; ic += 4) |
454 | for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { |
455 | kernel_loads(ki, ic, kk); |
456 | for (int oi = get_ow_start(ki, pad_l), |
457 | prf_count_t1 = 0, prf_count_t0 = 0; |
458 | oi < get_ow_end(ur_w, ki, pad_r); oi++) { |
459 | int aux_input_offset = typesize |
460 | * ((ki * (jcp.dilate_w + 1) + oi * stride_w |
461 | - pad_l) * ic_block + ic); |
462 | v4fmaddps(vmm_out(oi, kk), vmm_ker(0), |
463 | EVEX_compress_addr(aux_reg_inp, |
464 | aux_input_offset)); |
465 | |
466 | if (!is_owb_prefetching(jcp)) { |
467 | if ((oi % 2) && (prf_count_t1 < 4)) { |
468 | mic_prefetcht1(EVEX_compress_addr( |
469 | aux_reg_ker_prf, kernel_offset(kk, |
470 | ic + prf_count_t1, ki))); |
471 | prf_count_t1++; |
472 | } |
473 | } else { |
474 | if (!(ki == 0 && ic == 0) |
475 | && !(ki == kw-1 && ic == 0) && |
476 | (oi % 2) && (prf_count_t1 < 4) |
477 | ) { |
478 | mic_prefetcht0(EVEX_compress_addr( |
479 | aux_reg_ker, kernel_offset(kk, |
480 | ic + 4 + prf_count_t0, ki))); |
481 | prf_count_t0++; |
482 | } |
483 | } |
484 | if (!is_owb_prefetching(jcp)) { |
485 | if (pref_current_inp) { |
486 | if (ki == 0 && ic == 0 && kk == 0) |
487 | mic_prefetcht0(EVEX_compress_addr( |
488 | aux_reg_inp, |
489 | aux_input_offset + shift_input_ptr)); |
490 | } else { |
491 | if (ki == 1 && ic == 0 && kk == 0) |
492 | mic_prefetcht1(EVEX_compress_addr( |
493 | aux_reg_inp_prf, aux_input_offset)); |
494 | } |
495 | } else { |
496 | int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; |
497 | int inp_shift |
498 | = jcp.typesize_in * ur_w * stride_w * inp_mult; |
499 | bool kk_pref_slot = kk ? oi % 2 : !(oi % 2); |
500 | if (ki == 0 && ic == 0 && kk_pref_slot) |
501 | mic_prefetcht1(EVEX_compress_addr( |
502 | aux_reg_inp, |
503 | aux_input_offset + inp_shift)); |
504 | |
505 | if (ki == kw - 1 && ic == 0 && kk_pref_slot) |
506 | mic_prefetcht0(EVEX_compress_addr( |
507 | aux_reg_inp, |
508 | aux_input_offset + inp_shift)); |
509 | } |
510 | } |
511 | } |
512 | } |
513 | |
514 | add(aux_reg_ker, shift_kernel_ptr); |
515 | add(aux_reg_inp, shift_input_ptr); |
516 | add(aux_reg_ker_prf, shift_kernel_ptr); |
517 | add(aux_reg_inp_prf, shift_input_ptr); |
518 | |
519 | dec(reg_kj); |
520 | cmp(reg_kj, 0); |
521 | jg(kh_label, T_NEAR); |
522 | |
523 | if (jcp.ndims == 5) { |
524 | add(aux_reg_inp_d, |
525 | typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); |
526 | add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block |
527 | * jcp.ic_block); |
528 | add(aux_reg_inp_d_prf, |
529 | typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); |
530 | add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block |
531 | * jcp.ic_block); |
532 | |
533 | dec(reg_ki); |
534 | cmp(reg_ki, 0); |
535 | jg(kd_label, T_NEAR); |
536 | |
537 | pop(reg_out); |
538 | pop(reg_out_prf); |
539 | } |
540 | } |
541 | |
542 | template<typename Vmm> |
543 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w, |
544 | int pad_l, int pad_r) |
545 | { |
546 | bool prf_ker = true; |
547 | bool prf_inp = true; |
548 | int ih = jcp.ih; |
549 | int stride_w = jcp.stride_w; |
550 | int id = jcp.id; |
551 | int iw = jcp.iw; |
552 | int kw = jcp.kw; |
553 | int ic_block = jcp.ic_block; |
554 | int oc_block = jcp.oc_block; |
555 | int nb_oc_block = jcp.nb_oc_blocking; |
556 | Label kh_label, kd_label; |
557 | |
558 | int ker_pipeline_depth = 4; |
559 | assert(ker_reg_base_idx + ker_pipeline_depth <= 32); |
560 | assert(oc_block >= ker_pipeline_depth); |
561 | |
562 | int num_ker_loads = ic_block * nb_oc_block * kw; |
563 | int num_ker_prfs = prf_ker ? num_ker_loads : 0; |
564 | int num_inp_prfs = prf_inp ? |
565 | ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) : |
566 | 0; |
567 | if (jcp.is_1stconv && prf_inp) { |
568 | num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block; |
569 | } |
570 | int num_prfs = num_ker_prfs + num_inp_prfs; |
571 | int num_fmas = num_ker_loads * ur_w; |
572 | int prf_inst_spacing |
573 | = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1; |
574 | int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; |
575 | int inp_mul = !jcp.is_1stconv ? ic_block : 1; |
576 | |
577 | if (one_of(jcp.ndims, 3, 4)) { |
578 | mov(aux_reg_inp, reg_inp); |
579 | mov(aux_reg_ker, reg_ker); |
580 | mov(aux_reg_inp_prf, reg_inp_prf); |
581 | mov(aux_reg_ker_prf, reg_ker_prf); |
582 | } |
583 | |
584 | size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id; |
585 | assert(reg_inp_prf == reg_long_offt); |
586 | if (max_input_offset > INT_MAX) push(reg_inp_prf); |
587 | |
588 | |
589 | if (jcp.ndims == 5) { |
590 | push(reg_out_prf); |
591 | push(reg_out); |
592 | |
593 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
594 | mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); |
595 | mov(aux_reg_inp_d, reg_inp); |
596 | mov(aux_reg_inp_d_prf, reg_inp_prf); |
597 | mov(aux_reg_ker_d_prf, reg_ker_prf); |
598 | |
599 | L(kd_label); |
600 | mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); |
601 | } else { |
602 | mov(reg_kj, reg_kh); |
603 | } |
604 | |
605 | if (jcp.ndims == 5) { |
606 | mov(aux_reg_inp, aux_reg_inp_d); |
607 | mov(aux_reg_ker, aux_reg_ker_d); |
608 | mov(aux_reg_ker_prf, aux_reg_ker_d_prf); |
609 | mov(aux_reg_inp_prf, aux_reg_inp_d_prf); |
610 | } |
611 | |
612 | align(16); |
613 | L(kh_label); |
614 | { |
615 | int step = 0; |
616 | int ker_prfs = 0; |
617 | for (int ki = 0; ki < kw; ki++) { |
618 | for (int ic = 0; ic < ic_block; ic++) { |
619 | int aux_kernel_offset = 0; |
620 | if (step == 0) { |
621 | for (int i = 0; i < ker_pipeline_depth; i++) { |
622 | aux_kernel_offset = get_kernel_offset(ki, ic, 0, i); |
623 | vmovups(vmm_ker(i), EVEX_compress_addr( |
624 | aux_reg_ker, aux_kernel_offset)); |
625 | } |
626 | } else if (step < num_ker_loads - ker_pipeline_depth + 1) { |
627 | int load_offset = ker_pipeline_depth - 1; |
628 | int ker_load_reg_idx |
629 | = (step + load_offset) % ker_pipeline_depth; |
630 | aux_kernel_offset |
631 | = get_kernel_offset(ki, ic, 0, load_offset); |
632 | vmovups(vmm_ker(ker_load_reg_idx), |
633 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
634 | } |
635 | |
636 | bool ker_prf_inserted = false; |
637 | Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth); |
638 | int j_start = get_ow_start(ki, pad_l); |
639 | int j_end = get_ow_end(ur_w, ki, pad_r); |
640 | for (int j = j_start; j < j_end; j++) { |
641 | size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l); |
642 | auto addr = EVEX_compress_addr_safe(aux_reg_inp, |
643 | aux_input_offset, reg_long_offt, true); |
644 | vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr); |
645 | int fma_idx = step * ur_w + j; |
646 | int prf_slot_idx = fma_idx / prf_inst_spacing; |
647 | if (fma_idx % prf_inst_spacing == prf_inst_trigger) { |
648 | if (prf_ker && !ker_prf_inserted |
649 | && ker_prfs < num_ker_prfs) { |
650 | int ker_prf_offset |
651 | = jcp.typesize_in * ker_prfs * jcp.oc_block; |
652 | mic_prefetcht2(EVEX_compress_addr( |
653 | aux_reg_ker_prf, ker_prf_offset)); |
654 | ker_prf_inserted = true; |
655 | ker_prfs++; |
656 | } else if (prf_inp) { |
657 | int inp_prf_idx = prf_slot_idx - ker_prfs; |
658 | if (inp_prf_idx < num_inp_prfs) { |
659 | size_t inp_prf_stride = nstl::max(kw, stride_w); |
660 | size_t inp_prf_offset; |
661 | if (!jcp.is_1stconv) { |
662 | inp_prf_offset |
663 | = ic_block * jcp.typesize_in |
664 | * ((inp_prf_idx / kw) |
665 | * inp_prf_stride |
666 | + (inp_prf_idx % kw)); |
667 | } else { |
668 | size_t ic_prf_stride = |
669 | (size_t)jcp.typesize_in * iw * ih * id; |
670 | size_t iw_prf_stride |
671 | = jcp.typesize_in * jcp.simd_w; |
672 | inp_prf_offset = ((inp_prf_idx / ic_block) |
673 | * iw_prf_stride |
674 | + (inp_prf_idx % ic_block) |
675 | * ic_prf_stride); |
676 | } |
677 | mic_prefetcht0(EVEX_compress_addr_safe( |
678 | aux_reg_inp_prf, inp_prf_offset, |
679 | reg_long_offt)); |
680 | } |
681 | } |
682 | } |
683 | } |
684 | step++; |
685 | } |
686 | } |
687 | add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block); |
688 | if (prf_ker) |
689 | add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block); |
690 | add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); |
691 | if (prf_inp) |
692 | add(aux_reg_inp_prf, |
693 | jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); |
694 | dec(reg_kj); |
695 | cmp(reg_kj, 0); |
696 | jg(kh_label, T_NEAR); |
697 | } |
698 | |
699 | |
700 | if (jcp.ndims == 5) { |
701 | add(aux_reg_inp_d, |
702 | typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); |
703 | add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block |
704 | * jcp.ic_block); |
705 | add(aux_reg_inp_d_prf, |
706 | typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); |
707 | add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block |
708 | * jcp.ic_block); |
709 | |
710 | dec(reg_ki); |
711 | cmp(reg_ki, 0); |
712 | jg(kd_label, T_NEAR); |
713 | |
714 | pop(reg_out); |
715 | pop(reg_out_prf); |
716 | } |
717 | if (max_input_offset > INT_MAX) pop(reg_inp_prf); |
718 | } |
719 | |
720 | template<typename Vmm> |
721 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w, |
722 | int pad_l, int pad_r) |
723 | { |
724 | int kw = jcp.kw; |
725 | int stride_w = jcp.stride_w; |
726 | int ic_block = jcp.ic_block; |
727 | int oc_block = jcp.oc_block; |
728 | int nb_oc_block = jcp.nb_oc_blocking; |
729 | Label kh_label, kd_label; |
730 | int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block |
731 | * jcp.ic_block; |
732 | int inp_mul = !jcp.is_1stconv ? ic_block : 1; |
733 | int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw |
734 | * inp_mul; |
735 | |
736 | |
737 | auto input_offset = [=](int oi, int ic, int ki) { |
738 | return (size_t)jcp.typesize_in |
739 | * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) |
740 | * inp_mul + (size_t)ic |
741 | * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id)); |
742 | }; |
743 | |
744 | if (one_of(jcp.ndims, 3, 4)) { |
745 | mov(aux_reg_inp, reg_inp); |
746 | mov(aux_reg_ker, reg_ker); |
747 | } |
748 | |
749 | if (jcp.ndims == 5) { |
750 | push(reg_out); |
751 | |
752 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
753 | mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); |
754 | mov(aux_reg_inp_d, reg_inp); |
755 | |
756 | L(kd_label); |
757 | mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); |
758 | } else { |
759 | mov(reg_kj, reg_kh); |
760 | } |
761 | |
762 | if (jcp.ndims == 5) { |
763 | mov(aux_reg_inp, aux_reg_inp_d); |
764 | mov(aux_reg_ker, aux_reg_ker_d); |
765 | } |
766 | |
767 | L(kh_label); |
768 | { |
769 | for (int ki = 0; ki < kw; ki++) { |
770 | int jj_start = get_ow_start(ki, pad_l); |
771 | int jj_end = get_ow_end(ur_w, ki, pad_r); |
772 | for (int ic = 0; ic < ic_block; ic++) { |
773 | if (jcp.kernel_kind == expl_bcast) { |
774 | for (int jj = jj_start; jj < jj_end; jj++) { |
775 | size_t aux_input_offset = input_offset(jj, ic, ki); |
776 | vbroadcastss(vmm_inp(jj, nb_oc_block), |
777 | EVEX_compress_addr_safe(aux_reg_inp, |
778 | aux_input_offset, reg_long_offt)); |
779 | } |
780 | } |
781 | for (int ii = 0; ii < nb_oc_block; ii++) { |
782 | int aux_kernel_offset = jcp.typesize_in |
783 | * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block |
784 | * oc_block + ki * ic_block * oc_block + ic * oc_block); |
785 | if (jj_end - jj_start > 0) |
786 | vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker, |
787 | aux_kernel_offset)); |
788 | for (int jj = jj_start; jj < jj_end; jj++) |
789 | if (jcp.kernel_kind == expl_bcast) |
790 | vfmadd231ps(vmm_out(jj, ii), |
791 | vmm_inp(jj, nb_oc_block), vmm_wei); |
792 | else { |
793 | size_t aux_input_offset = input_offset(jj, ic, ki); |
794 | vfmadd231ps(vmm_out(jj, ii), vmm_wei, |
795 | EVEX_compress_addr_safe(aux_reg_inp, |
796 | aux_input_offset, reg_long_offt, true)); |
797 | } |
798 | } |
799 | } |
800 | } |
801 | add(aux_reg_ker, shift_kernel_ptr); |
802 | add(aux_reg_inp, shift_input_ptr); |
803 | dec(reg_kj); |
804 | cmp(reg_kj, 0); |
805 | jg(kh_label, T_NEAR); |
806 | } |
807 | |
808 | if (jcp.ndims == 5) { |
809 | add(aux_reg_inp_d, |
810 | typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); |
811 | add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block |
812 | * jcp.ic_block); |
813 | |
814 | dec(reg_ki); |
815 | cmp(reg_ki, 0); |
816 | jg(kd_label, T_NEAR); |
817 | |
818 | pop(reg_out); |
819 | } |
820 | } |
821 | |
822 | template<typename Vmm> |
823 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w, |
824 | int pad_l, int pad_r) |
825 | { |
826 | if (jcp.ndims == 5) push(reg_oi); |
827 | |
828 | prepare_output(ur_w); |
829 | |
830 | Label skip_compute_loop; |
831 | if (jcp.ndims == 5) { |
832 | if ((jcp.dilate_d >= jcp.id) |
833 | || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { |
834 | mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]); |
835 | cmp(reg_kj, 0); |
836 | je(skip_compute_loop, T_NEAR); |
837 | } |
838 | } |
839 | if ((jcp.dilate_h >= jcp.ih) |
840 | || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { |
841 | mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); |
842 | cmp(reg_kj, 0); |
843 | je(skip_compute_loop, T_NEAR); |
844 | } |
845 | |
846 | if (jcp.ver == ver_4fma) |
847 | if(jcp.is_1stconv) |
848 | compute_loop_4fma_1st(ur_w, pad_l, pad_r); |
849 | else |
850 | compute_loop_4fma(ur_w, pad_l, pad_r); |
851 | else if (jcp.ver == ver_fma) |
852 | if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast) |
853 | || mayiuse(avx512_mic)) |
854 | compute_loop_fma(ur_w, pad_l, pad_r); |
855 | else |
856 | if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1) |
857 | compute_loop_fma(ur_w, pad_l, pad_r); |
858 | else |
859 | compute_loop_fma_core(ur_w, pad_l, pad_r); |
860 | else |
861 | assert(!"unknown convolution version" ); |
862 | |
863 | L(skip_compute_loop); |
864 | store_output(ur_w); |
865 | if (jcp.ndims == 5) pop(reg_oi); |
866 | } |
867 | |
868 | template<typename Vmm> |
869 | void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate() |
870 | { |
871 | int iw = jcp.iw; |
872 | int ow = jcp.ow; |
873 | int ow_block = jcp.ow_block; |
874 | int nb_ow = jcp.nb_ow; |
875 | int kw = jcp.kw; |
876 | int l_pad = jcp.l_pad; |
877 | int ur_w = jcp.ur_w; |
878 | int ur_w_tail = jcp.ur_w_tail; |
879 | int dilate_w = jcp.dilate_w + 1; |
880 | int stride_w = jcp.stride_w; |
881 | |
882 | int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; |
883 | int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; |
884 | int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; |
885 | int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult; |
886 | int out_shift = jcp.typesize_out * ur_w * jcp.oc_block; |
887 | |
888 | preamble(); |
889 | mov(reg_inp, ptr[param1 + GET_OFF(src)]); |
890 | mov(reg_out, ptr[param1 + GET_OFF(dst)]); |
891 | mov(reg_ker, ptr[param1 + GET_OFF(filt)]); |
892 | mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); |
893 | mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); |
894 | |
895 | int r_pad = nstl::max( |
896 | 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1)); |
897 | int n_oi = ow / ur_w; |
898 | int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w |
899 | - (iw + l_pad - 1); |
900 | |
901 | if (!is_ow_threading_on(jcp)) { |
902 | // ow is being processed as a whole - with left and right paddings |
903 | if (r_pad1 > 0) n_oi--; |
904 | |
905 | if (ow == ur_w) { |
906 | mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]); |
907 | mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]); |
908 | compute_loop(ur_w, l_pad, r_pad); |
909 | } else { |
910 | mov(reg_inp_prf, reg_inp); |
911 | mov(reg_out_prf, reg_out); |
912 | if (n_oi == 0) { |
913 | add(reg_inp_prf, inp_shift_pad); |
914 | add(reg_out_prf, out_shift); |
915 | compute_loop(ur_w, l_pad, r_pad1); |
916 | add(reg_inp, inp_shift_pad); |
917 | add(reg_out, out_shift); |
918 | if (ur_w_tail != 0) { |
919 | add(reg_inp_prf, inp_shift); |
920 | add(reg_out_prf, out_shift); |
921 | compute_loop(ur_w_tail, 0, r_pad); |
922 | } |
923 | } else { |
924 | xor_(reg_oi, reg_oi); |
925 | if (l_pad > 0) { |
926 | add(reg_inp_prf, inp_shift_pad); |
927 | add(reg_out_prf, out_shift); |
928 | compute_loop(ur_w, l_pad, 0); |
929 | add(reg_inp, inp_shift_pad); |
930 | add(reg_out, out_shift); |
931 | inc(reg_oi); |
932 | } |
933 | if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { |
934 | Label ow_loop_label; |
935 | L(ow_loop_label); |
936 | { |
937 | add(reg_inp_prf, inp_shift); |
938 | add(reg_out_prf, out_shift); |
939 | compute_loop(ur_w, 0, 0); |
940 | add(reg_inp, inp_shift); |
941 | add(reg_out, out_shift); |
942 | inc(reg_oi); |
943 | cmp(reg_oi, n_oi); |
944 | jl(ow_loop_label, T_NEAR); |
945 | } |
946 | } |
947 | if (r_pad1 > 0) { |
948 | add(reg_inp_prf, inp_shift); |
949 | add(reg_out_prf, out_shift); |
950 | compute_loop(ur_w, 0, r_pad1); |
951 | add(reg_inp, inp_shift); |
952 | add(reg_out, out_shift); |
953 | } |
954 | if (ur_w_tail != 0) { |
955 | add(reg_inp_prf, inp_shift); |
956 | add(reg_out_prf, out_shift); |
957 | compute_loop(ur_w_tail, 0, r_pad); |
958 | } |
959 | } |
960 | } |
961 | } else { |
962 | // ow block is only processed. |
963 | // Number of block is passed as parameter owb, |
964 | // and padding processing depends on this number. |
965 | |
966 | Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; |
967 | Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; |
968 | |
969 | assert(ow_block % ur_w == 0); |
970 | int n_oi_not_last_ow_block = ow_block / ur_w; |
971 | // to simplify code (and general regs usage), |
972 | // size of ow block must be >= 2 * ur_w |
973 | assert(n_oi_not_last_ow_block > 1); |
974 | int n_oi_next_last_ow_block = n_oi_not_last_ow_block; |
975 | int n_oi_first_ow_block = n_oi_not_last_ow_block; |
976 | |
977 | int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w; |
978 | |
979 | // prepare right padding |
980 | bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; |
981 | bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2; |
982 | bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; |
983 | |
984 | if (last_ow_block_padded) n_oi_last_ow_block--; |
985 | else if (first_ow_block_padded) n_oi_first_ow_block--; |
986 | else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; |
987 | |
988 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
989 | cmp(reg_owb, 0); // is that the first ow-block ? |
990 | jg(middle_ow_blocks_label, T_NEAR); |
991 | |
992 | // the first ow block, compute left padding |
993 | |
994 | mov(reg_oi, n_oi_first_ow_block); |
995 | mov(reg_inp_prf, reg_inp); |
996 | mov(reg_out_prf, reg_out); |
997 | |
998 | if (l_pad > 0) { |
999 | mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); |
1000 | add(reg_inp_prf, inp_shift_pad); |
1001 | add(reg_out_prf, out_shift); |
1002 | compute_loop(ur_w, l_pad, 0); |
1003 | add(reg_inp, inp_shift_pad); |
1004 | add(reg_out, out_shift); |
1005 | dec(reg_oi); |
1006 | } |
1007 | jmp(oi_loop_label, T_NEAR); |
1008 | |
1009 | // middle or last ow block entry |
1010 | |
1011 | L(middle_ow_blocks_label); |
1012 | |
1013 | if (l_pad > 0) { |
1014 | // just to consider left padding, not compute |
1015 | add(reg_inp, inp_shift_pad_second_block); |
1016 | add(reg_inp_prf, inp_shift_pad_second_block); |
1017 | } |
1018 | |
1019 | // set number of iteration for oi-loop |
1020 | cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? |
1021 | mov(reg_oi, n_oi_last_ow_block); |
1022 | je(oi_loop_label, T_NEAR); |
1023 | cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? |
1024 | mov(reg_oi, n_oi_next_last_ow_block); |
1025 | je(oi_loop_label, T_NEAR); |
1026 | mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks |
1027 | |
1028 | // oi loop w/o padding |
1029 | L(oi_loop_label); |
1030 | mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); |
1031 | L(oi_loop_start_label); |
1032 | cmp(reg_oi, 0); |
1033 | jle(oi_loop_end_label, T_NEAR); |
1034 | |
1035 | add(reg_inp_prf, inp_shift); |
1036 | add(reg_out_prf, out_shift); |
1037 | compute_loop(ur_w, 0, 0); |
1038 | add(reg_inp, inp_shift); |
1039 | add(reg_out, out_shift); |
1040 | dec(reg_oi); |
1041 | jmp(oi_loop_start_label, T_NEAR); |
1042 | L(oi_loop_end_label); |
1043 | |
1044 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1045 | |
1046 | cmp(reg_owb, 0); // first ow-block ? |
1047 | if (first_ow_block_padded) { |
1048 | je(last_oi_label, T_NEAR); |
1049 | } else { |
1050 | je(end_label, T_NEAR); |
1051 | } |
1052 | cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? |
1053 | jl(end_label, T_NEAR); |
1054 | if (next_last_ow_block_padded) { |
1055 | je(last_oi_label, T_NEAR); |
1056 | } else { |
1057 | je(end_label, T_NEAR); |
1058 | } |
1059 | // that is last block |
1060 | if (!last_ow_block_padded) { |
1061 | jmp(tail_label, T_NEAR); |
1062 | } |
1063 | |
1064 | // last oi block with right padding |
1065 | L(last_oi_label); |
1066 | mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); |
1067 | add(reg_inp_prf, inp_shift); |
1068 | add(reg_out_prf, out_shift); |
1069 | compute_loop(ur_w, 0, r_pad1); |
1070 | add(reg_inp, inp_shift); |
1071 | add(reg_out, out_shift); |
1072 | |
1073 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1074 | cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? |
1075 | jl(end_label, T_NEAR); |
1076 | |
1077 | L(tail_label); |
1078 | mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); |
1079 | if (ur_w_tail != 0) { |
1080 | add(reg_inp_prf, inp_shift); |
1081 | add(reg_out_prf, out_shift); |
1082 | compute_loop(ur_w_tail, 0, r_pad); |
1083 | } |
1084 | L(end_label); |
1085 | } |
1086 | postamble(); |
1087 | |
1088 | if (jcp.with_eltwise) |
1089 | eltwise_injector_->prepare_table(); |
1090 | } |
1091 | |
1092 | bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( |
1093 | jit_conv_conf_t &jcp, const primitive_attr_t &attr) { |
1094 | const auto &p = attr.post_ops_; |
1095 | |
1096 | auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; |
1097 | auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; |
1098 | |
1099 | switch (p.len_) { |
1100 | case 0: return true; // no post_ops |
1101 | case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise |
1102 | case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise |
1103 | default: return false; |
1104 | } |
1105 | |
1106 | return false; |
1107 | } |
1108 | |
1109 | status_t jit_avx512_common_conv_fwd_kernel::init_conf( |
1110 | jit_conv_conf_t &jcp, const convolution_desc_t &cd, |
1111 | memory_desc_t &src_md, memory_desc_t &weights_md, |
1112 | memory_desc_t &dst_md, memory_desc_t &bias_md, |
1113 | const primitive_attr_t &attr, int nthreads) |
1114 | { |
1115 | using namespace prop_kind; |
1116 | |
1117 | if (!mayiuse(avx512_common)) |
1118 | return status::unimplemented; |
1119 | |
1120 | const memory_desc_wrapper src_d(&src_md); |
1121 | const memory_desc_wrapper weights_d(&weights_md); |
1122 | const memory_desc_wrapper dst_d(&dst_md); |
1123 | const memory_desc_wrapper bias_d(&bias_md); |
1124 | |
1125 | const int regs = 28; |
1126 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1127 | int ndims = src_d.ndims(); |
1128 | |
1129 | jcp = zero<decltype(jcp)>(); |
1130 | jcp.ndims = ndims; |
1131 | jcp.prop_kind = cd.prop_kind; |
1132 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1133 | jcp.mb = src_d.dims()[0]; |
1134 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
1135 | jcp.oc_without_padding = jcp.oc; |
1136 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1137 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
1138 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; |
1139 | jcp.iw = src_d.dims()[ndims-1]; |
1140 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
1141 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2]; |
1142 | jcp.ow = dst_d.dims()[ndims-1]; |
1143 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
1144 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; |
1145 | jcp.kw = weights_d.dims()[with_groups + ndims-1]; |
1146 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1147 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; |
1148 | jcp.l_pad = cd.padding[0][ndims-3]; |
1149 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1150 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; |
1151 | jcp.stride_w = cd.strides[ndims-3]; |
1152 | |
1153 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
1154 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; |
1155 | jcp.dilate_w = cd.dilates[ndims-3]; |
1156 | |
1157 | jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) |
1158 | - (jcp.ih + jcp.t_pad - 1); |
1159 | jcp.back_pad = (jcp.od - 1) * jcp.stride_d |
1160 | + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); |
1161 | |
1162 | jcp.is_1stconv = is_1stconv(jcp); |
1163 | |
1164 | bool ok_to_pad_channels = true |
1165 | && jcp.ngroups == 1 |
1166 | && src_d.data_type() == data_type::f32; |
1167 | |
1168 | const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float); |
1169 | jcp.simd_w = full_simd_w; |
1170 | bool ok_to_try_xmm = true |
1171 | && mayiuse(avx512_core) |
1172 | && src_d.data_type() == data_type::f32 |
1173 | && !jcp.is_1stconv |
1174 | && !ok_to_pad_channels |
1175 | && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0) |
1176 | && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0); |
1177 | if (ok_to_try_xmm) |
1178 | jcp.simd_w = 4; |
1179 | |
1180 | jcp.oc_block = jcp.simd_w; |
1181 | jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; |
1182 | jcp.aligned_threads = 0; |
1183 | |
1184 | if (ok_to_pad_channels) { |
1185 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
1186 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
1187 | } |
1188 | bool args_ok = true |
1189 | && jcp.oc % jcp.oc_block == 0 |
1190 | && jcp.ic % jcp.ic_block == 0; |
1191 | if (!args_ok) |
1192 | return status::unimplemented; |
1193 | |
1194 | if (!post_ops_ok(jcp, attr)) |
1195 | return status::unimplemented; |
1196 | |
1197 | const auto &p = attr.post_ops_; |
1198 | jcp.with_sum = p.find(primitive_kind::sum) != -1; |
1199 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
1200 | jcp.with_eltwise = eltwise_ind != -1; |
1201 | if (jcp.with_eltwise) { |
1202 | jcp.eltwise = p.entry_[eltwise_ind].eltwise; |
1203 | if (dst_d.data_type() == data_type::s32) return status::unimplemented; |
1204 | } |
1205 | |
1206 | auto src_tag = jcp.is_1stconv |
1207 | ? pick(ndims - 3, ncw, nchw, ncdhw) |
1208 | : ((jcp.simd_w == 4) |
1209 | ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) |
1210 | : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c)); |
1211 | auto dst_tag = (jcp.simd_w == 4) |
1212 | ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) |
1213 | : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
1214 | auto wei_tag = with_groups |
1215 | ? ((jcp.simd_w == 4) |
1216 | ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o) |
1217 | : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)) |
1218 | : ((jcp.simd_w == 4) |
1219 | ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o) |
1220 | : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o)); |
1221 | |
1222 | if (src_d.format_kind() == format_kind::any) { |
1223 | CHECK(memory_desc_init_by_tag(src_md, src_tag)); |
1224 | jcp.src_tag = src_tag; |
1225 | } else { |
1226 | jcp.src_tag = src_d.matches_one_of_tag(src_tag); |
1227 | } |
1228 | if (jcp.src_tag != src_tag) |
1229 | return status::unimplemented; |
1230 | |
1231 | if (dst_d.format_kind() == format_kind::any) { |
1232 | CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); |
1233 | jcp.dst_tag = dst_tag; |
1234 | } else { |
1235 | jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag); |
1236 | } |
1237 | if (jcp.dst_tag != dst_tag) |
1238 | return status::unimplemented; |
1239 | |
1240 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
1241 | if (jcp.with_bias) { |
1242 | if (bias_d.format_kind() == format_kind::any) |
1243 | CHECK(memory_desc_init_by_tag(bias_md, x)); |
1244 | } |
1245 | |
1246 | if (mayiuse(avx512_common) && |
1247 | src_d.data_type() == data_type::f32 |
1248 | && weights_d.data_type() == data_type::f32 |
1249 | && dst_d.data_type() == data_type::f32) { |
1250 | jcp.ver = ver_fma; |
1251 | jcp.typesize_in = sizeof(float); |
1252 | jcp.typesize_out = sizeof(float); |
1253 | if (mayiuse(avx512_mic_4ops)) |
1254 | jcp.ver = ver_4fma; |
1255 | |
1256 | if (jcp.is_1stconv) { |
1257 | // TODO: fix & remove constraints below |
1258 | bool not_for_4fma |
1259 | = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad), |
1260 | nstl::max(jcp.kw, jcp.kh) < 7); |
1261 | bool is_dilated |
1262 | = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w); |
1263 | if (one_of(true, not_for_4fma, is_dilated)) |
1264 | jcp.ver = ver_fma; |
1265 | if (jcp.ver == ver_4fma) { |
1266 | wei_tag = with_groups |
1267 | ? ((jcp.simd_w == 4) |
1268 | ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o) |
1269 | : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)) |
1270 | : ((jcp.simd_w == 4) |
1271 | ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o) |
1272 | : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o)); |
1273 | } else { |
1274 | wei_tag = with_groups |
1275 | ? ((jcp.simd_w == 4) |
1276 | ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o) |
1277 | : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)) |
1278 | : ((jcp.simd_w == 4) |
1279 | ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o) |
1280 | : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o)); |
1281 | } |
1282 | } |
1283 | } else { |
1284 | return status::unimplemented; |
1285 | } |
1286 | |
1287 | if (weights_d.format_kind() == format_kind::any) { |
1288 | CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); |
1289 | jcp.wei_tag = wei_tag; |
1290 | } else { |
1291 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
1292 | } |
1293 | if (jcp.wei_tag != wei_tag) |
1294 | return status::unimplemented; |
1295 | |
1296 | if (jcp.is_1stconv) { |
1297 | jcp.ur_w = nstl::min(jcp.ow, regs); |
1298 | } else { |
1299 | // avx512_core guard - just to avoid possible regression for other archs |
1300 | if (jcp.ver == ver_fma && mayiuse(avx512_core)) { |
1301 | jcp.ur_w = nstl::min(jcp.ow, regs); |
1302 | } else { |
1303 | for (int ur_w = regs; ur_w > 0; --ur_w) { |
1304 | if (jcp.ow % ur_w == 0) { |
1305 | jcp.ur_w = ur_w; |
1306 | break; |
1307 | } |
1308 | } |
1309 | } |
1310 | if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) { |
1311 | jcp.ur_w = nstl::min(jcp.ow, regs); |
1312 | } |
1313 | } |
1314 | // TODO (Tanya): currently applied to Segnet convolutions only. |
1315 | // Need to try for other topologies |
1316 | if (jcp.ow > 150 && jcp.ur_w < regs/2) |
1317 | jcp.ur_w = regs; |
1318 | |
1319 | int n_oi = (jcp.ow / jcp.ur_w); |
1320 | int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w |
1321 | + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); |
1322 | if (jcp.l_pad > 0 && r_pad > 0) |
1323 | n_oi--; |
1324 | |
1325 | bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0 |
1326 | && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)); |
1327 | if (large_code_size) { |
1328 | const int max_code_size = 24 * 1024; |
1329 | const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw; |
1330 | int mult = 1; |
1331 | if (jcp.l_pad > 0) mult += 1; |
1332 | if (r_pad > 0) mult += 1; |
1333 | for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { |
1334 | if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) { |
1335 | jcp.ur_w = ur_w; |
1336 | break; |
1337 | } |
1338 | } |
1339 | } |
1340 | |
1341 | /* Grouped channel offset to support 'non-blocked data' format for |
1342 | * convolution sizes with '(input_channel / ngroups) < simd' */ |
1343 | jcp.nonblk_group_off |
1344 | = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ? |
1345 | jcp.ic : |
1346 | 1; |
1347 | |
1348 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
1349 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
1350 | jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; |
1351 | |
1352 | auto is_ow_threading_applicable = [=]() { |
1353 | return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4) |
1354 | && IMPLICATION(mayiuse(avx512_mic), |
1355 | jcp.ver == ver_4fma |
1356 | && IMPLICATION(jcp.mb != 1, |
1357 | jcp.ih == 1 && jcp.kh == 1))); |
1358 | }; |
1359 | |
1360 | if (jcp.ver == ver_4fma && !jcp.is_1stconv) { |
1361 | if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8 |
1362 | && jcp.oh <= 8 && jcp.ow == jcp.oh) |
1363 | || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) { |
1364 | if (jcp.nb_oc % 2 == 0) { |
1365 | jcp.nb_oc_blocking = 2; |
1366 | jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); |
1367 | } |
1368 | } else { |
1369 | for (int i = jcp.nb_oc; i > 0; i--) |
1370 | if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) { |
1371 | jcp.nb_oc_blocking = i; |
1372 | break; |
1373 | } |
1374 | } |
1375 | if (jcp.ver == ver_4fma && is_ow_threading_applicable()) { |
1376 | if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow |
1377 | && jcp.ow != 2 * jcp.ur_w) { |
1378 | jcp.nb_oc_blocking = 2; |
1379 | jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); |
1380 | } |
1381 | } |
1382 | } |
1383 | |
1384 | jcp.ow_block = jcp.ow; |
1385 | |
1386 | auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) { |
1387 | int nb_ow = div_up(jcp.ow, ow_block); |
1388 | int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking); |
1389 | int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow; |
1390 | float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block); |
1391 | float thr_eff = disbalance * (float)work_amount |
1392 | / rnd_up(work_amount, nthreads); |
1393 | return thr_eff; |
1394 | }; |
1395 | |
1396 | auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) { |
1397 | int res_ow_block = jcp.ow; |
1398 | eff = get_thr_eff(nb_oc_blocking, res_ow_block); |
1399 | if (!is_ow_threading_applicable()) |
1400 | return res_ow_block; |
1401 | |
1402 | int L2_part = (get_cache_size(2) * 7 / 8) / typesize; |
1403 | if (jcp.ver == ver_4fma) |
1404 | L2_part /= 2; |
1405 | int size_src_chunk = jcp.ic_block * ur_w * jcp.kh; |
1406 | int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w; |
1407 | int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block |
1408 | * jcp.kw * jcp.kh; |
1409 | int nurw_cache = (L2_part - 2 * size_wei_chunk) |
1410 | / (2 * size_dst_chunk + 2 * size_src_chunk); |
1411 | // current design of generate() requires ow_block >= 2 * ur_w |
1412 | int ow_block_cache = ur_w * nstl::max(2, nurw_cache); |
1413 | |
1414 | int ow_block_thr = ow_block_cache; |
1415 | eff = get_thr_eff(nb_oc_blocking, ow_block_thr); |
1416 | |
1417 | int max_nb_ow = div_up(jcp.ow, 2 * ur_w); |
1418 | int start_nb_ow = div_up(jcp.ow, ow_block_thr); |
1419 | for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) { |
1420 | int ow_block |
1421 | = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); |
1422 | float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f; |
1423 | if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold) |
1424 | break; |
1425 | if (div_up(jcp.ow, ow_block) != nb_ow) |
1426 | continue; |
1427 | float thr_eff = get_thr_eff(nb_oc_blocking, ow_block); |
1428 | float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f; |
1429 | if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) { |
1430 | ow_block_thr = ow_block; |
1431 | eff = thr_eff; |
1432 | } |
1433 | eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f; |
1434 | if (eff > eff_threshold) |
1435 | break; |
1436 | } |
1437 | res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr)); |
1438 | eff = get_thr_eff(nb_oc_blocking, res_ow_block); |
1439 | return res_ow_block; |
1440 | }; |
1441 | |
1442 | |
1443 | if (jcp.ver == ver_fma && mayiuse(avx512_core)) { |
1444 | int try_nb_oc_blocking = 2; |
1445 | unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w) |
1446 | * jcp.ic_block * jcp.kh * jcp.kd; |
1447 | unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block |
1448 | * try_nb_oc_blocking; |
1449 | unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block |
1450 | * jcp.oc_block * try_nb_oc_blocking * jcp.kd; |
1451 | unsigned int ker_total_size = ker_inp_size + ker_out_size |
1452 | + ker_wei_size; |
1453 | |
1454 | bool embd_bcast_condition = true |
1455 | && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size) |
1456 | && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192) |
1457 | && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512); |
1458 | |
1459 | if (jcp.mb == 1) { |
1460 | unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h) |
1461 | * div_up(jcp.iw, jcp.stride_w) * jcp.ic; |
1462 | unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; |
1463 | |
1464 | // Estimate whether we need to limit the number of threads |
1465 | // and calculate this number. Includes some heuristic. |
1466 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
1467 | int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh; |
1468 | int job_size_min = work_amount / nthreads; |
1469 | int job_size_max = div_up(work_amount, nthreads); |
1470 | int ch_max = rnd_up(jcp.oh, job_size_max); |
1471 | int ch_min = (job_size_min == 0) |
1472 | ? jcp.oh |
1473 | : rnd_up(jcp.oh, job_size_min); |
1474 | bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2 |
1475 | && (jcp.oh != 8 || ch_max / jcp.oh > 1); |
1476 | bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2 |
1477 | && (jcp.oh != 8 || ch_min / jcp.oh > 1); |
1478 | bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1) |
1479 | || nthreads > oc_chunks; |
1480 | if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1 |
1481 | && wei_size / inp_size > 24 |
1482 | && (not_aligned_max || not_aligned_min) |
1483 | && eligible_case) { |
1484 | // Try to find nthreads > mkldnn_get_max_threads() / 2 such |
1485 | // that oc_chunks is a multiple of nthreads, or nthreads is a |
1486 | // multiple of oc_chunks. Otherwise, keep default value. |
1487 | // TODO: implement a task-based alternative without throttling. |
1488 | jcp.aligned_threads = nthreads; |
1489 | for (int i = nthreads; i > nthreads / 2; i--) { |
1490 | if (oc_chunks % i == 0 || i % oc_chunks == 0) { |
1491 | jcp.aligned_threads = i; |
1492 | break; |
1493 | } |
1494 | } |
1495 | } |
1496 | } |
1497 | |
1498 | if (jcp.kw > 3 |
1499 | || (jcp.stride_w == 1 && jcp.stride_h == 1 |
1500 | && embd_bcast_condition) |
1501 | || ((jcp.stride_w != 1 || jcp.stride_h != 1) |
1502 | && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10) |
1503 | && embd_bcast_condition))) |
1504 | || (jcp.mb == 1 |
1505 | && (jcp.ur_w >= jcp.ow || jcp.is_1stconv |
1506 | || (jcp.ow <= 147 && jcp.oc <= 96)))) { |
1507 | jcp.kernel_kind = embd_bcast; |
1508 | jcp.ur_w = nstl::min(jcp.ow, regs); |
1509 | jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; |
1510 | if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3 |
1511 | && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0 |
1512 | && IMPLICATION(jcp.is_1stconv, jcp.mb == 1) |
1513 | && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) { |
1514 | jcp.nb_oc_blocking = try_nb_oc_blocking; |
1515 | jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); |
1516 | } |
1517 | } else { |
1518 | jcp.kernel_kind = expl_bcast; |
1519 | jcp.nb_ic_blocking = 1; |
1520 | if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) { |
1521 | float best_thr_eff = 0.f; |
1522 | int best_nb_oc_blocking = 1; |
1523 | for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) { |
1524 | if (jcp.nb_oc % i == 0) { |
1525 | float thr_eff; |
1526 | int ur_w = nstl::min(jcp.ow, 31 / (i + 1)); |
1527 | get_ow_block(i, ur_w, thr_eff); |
1528 | if (thr_eff > 1.05f * best_thr_eff) { |
1529 | best_nb_oc_blocking = i; |
1530 | best_thr_eff = thr_eff; |
1531 | } |
1532 | } |
1533 | } |
1534 | jcp.nb_oc_blocking = best_nb_oc_blocking; |
1535 | jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); |
1536 | } |
1537 | } |
1538 | } |
1539 | |
1540 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
1541 | |
1542 | args_ok = true |
1543 | && jcp.l_pad <= jcp.ur_w |
1544 | && jcp.ic <= src_d.padded_dims()[1] |
1545 | && jcp.oc <= dst_d.padded_dims()[1] |
1546 | && jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
1547 | && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; |
1548 | if (!args_ok) |
1549 | return status::unimplemented; |
1550 | |
1551 | int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w |
1552 | + (jcp.kw - 1) * (jcp.dilate_w + 1) |
1553 | - (jcp.iw + jcp.l_pad - 1)); |
1554 | if (r_pad_no_tail > jcp.ur_w) |
1555 | return status::unimplemented; |
1556 | |
1557 | pick_loop_order(jcp); |
1558 | |
1559 | jcp.nb_ic_L2 = jcp.nb_ic; |
1560 | |
1561 | float thr_eff; |
1562 | jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff); |
1563 | jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); |
1564 | |
1565 | const int L2_size = get_cache_size(2, true) / sizeof(float); |
1566 | // Source and output data needs to fit in L2, |
1567 | // leaving some space for weights and prefetching. |
1568 | int h_L2 = int(((0.6f * L2_size) / jcp.simd_w |
1569 | - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw) |
1570 | / (jcp.stride_h * jcp.iw + jcp.ow)); |
1571 | jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2)); |
1572 | |
1573 | if (jcp.ver == ver_4fma) { |
1574 | if (!is_ow_threading_on(jcp)) { |
1575 | for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic; |
1576 | divf++) { |
1577 | size_t l2_src |
1578 | = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; |
1579 | size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking |
1580 | * jcp.oh * jcp.od; |
1581 | size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block |
1582 | * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd; |
1583 | if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { |
1584 | if (jcp.kh == 3 && jcp.oh == 7) { |
1585 | jcp.nb_ic_L2 = 1; |
1586 | break; |
1587 | } |
1588 | temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf |
1589 | : jcp.nb_ic_L2); |
1590 | } else { |
1591 | jcp.nb_ic_L2 = temp_nb; |
1592 | break; |
1593 | } |
1594 | } |
1595 | } else if (jcp.ic > 64) { |
1596 | jcp.nb_ic_L2 = 2; /* according to performance data*/ |
1597 | } |
1598 | } |
1599 | |
1600 | return status::success; |
1601 | } |
1602 | |
1603 | void jit_avx512_common_conv_fwd_kernel::init_scratchpad( |
1604 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
1605 | if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) |
1606 | scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); |
1607 | } |
1608 | |
1609 | void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w) |
1610 | { |
1611 | for (int k = 0; k < jcp.nb_ic_blocking; k++) { |
1612 | for (int j = 0; j < ur_w; j++) { |
1613 | Zmm zmm = zmm_out(j, k); |
1614 | vpxord(zmm, zmm, zmm); |
1615 | size_t aux_src_offset |
1616 | = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) |
1617 | * jcp.ic_block; |
1618 | mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, |
1619 | reg_long_offt)); |
1620 | } |
1621 | } |
1622 | } |
1623 | |
1624 | void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) |
1625 | { |
1626 | Label no_update_label; |
1627 | |
1628 | mov(reg_channel, ptr[param + GET_OFF(channel)]); |
1629 | cmp(reg_channel, 0); |
1630 | je(no_update_label, T_NEAR); |
1631 | for (int k = 0; k < jcp.nb_ic_blocking; k++) { |
1632 | for (int j = 0; j < ur_w; j++) { |
1633 | Zmm zmm = zmm_out(j, k); |
1634 | size_t aux_src_offset = (size_t)typesize |
1635 | * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; |
1636 | vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset, |
1637 | reg_long_offt)); |
1638 | } |
1639 | } |
1640 | |
1641 | L(no_update_label); |
1642 | for (int k = 0; k < jcp.nb_ic_blocking; k++) { |
1643 | for (int j = 0; j < ur_w; j++) { |
1644 | Zmm zmm = zmm_out(j, k); |
1645 | size_t aux_src_offset = (size_t)typesize |
1646 | * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; |
1647 | vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset, |
1648 | reg_long_offt), zmm); |
1649 | mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, |
1650 | reg_long_offt)); |
1651 | } |
1652 | } |
1653 | } |
1654 | |
1655 | void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( |
1656 | int ur_w, int l_overflow, int r_overflow) |
1657 | { |
1658 | int ow = jcp.ow; |
1659 | int kw = jcp.kw; |
1660 | int ic_block = jcp.ic_block; |
1661 | int oc_block = jcp.oc_block; |
1662 | Label kh_label, last_iter_label, loop_end_label, kd_label; |
1663 | int ker_load_number = 4; |
1664 | int shift_ker_ptr = typesize * kw * oc_block * ic_block; |
1665 | int shift_dst_ptr = typesize * ow * oc_block; |
1666 | int ii_dpref_t0 = get_iw_start(0, l_overflow); |
1667 | int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow); |
1668 | |
1669 | bool check_last_kh = (jcp.kh > 3); |
1670 | auto kernel_offset = [=](int icb, int oc, int ki) { |
1671 | int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; |
1672 | int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; |
1673 | int oc_offset = oc * jcp.oc_block; |
1674 | return typesize * (blk_offset + oc_offset); |
1675 | }; |
1676 | auto kernel_loads = [=](int ki, int oc, int kk) { |
1677 | for (int ii = 0; ii < ker_load_number; ii++) { |
1678 | int aux_kernel_offset = kernel_offset(kk, oc + ii, ki); |
1679 | vmovups(zmm_ker(ii), |
1680 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
1681 | } |
1682 | }; |
1683 | auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { |
1684 | if (cnt1 >= ker_load_number && cnt0 >= ker_load_number |
1685 | && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) { |
1686 | int aux_dst_offset = typesize * ((ii_dpref_t0 |
1687 | + jcp.l_pad) * oc_block + jcp.ow * oc_block); |
1688 | prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); |
1689 | ii_dpref_t0++; |
1690 | } |
1691 | }; |
1692 | |
1693 | if (one_of(jcp.ndims, 3, 4)) { |
1694 | mov(aux_reg_dst, reg_dst); |
1695 | mov(aux_reg_ker, reg_ker); |
1696 | mov(aux_reg_dst_prf, reg_dst_prf); |
1697 | mov(aux_reg_ker_prf, reg_ker_prf); |
1698 | } |
1699 | |
1700 | if (jcp.ndims == 5) { |
1701 | push(reg_src_prf); |
1702 | push(reg_src); |
1703 | |
1704 | mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); |
1705 | mov(aux_reg_dst_d, reg_dst); |
1706 | mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); |
1707 | mov(aux_reg_dst_d_prf, reg_dst_prf); |
1708 | mov(aux_reg_ker_d_prf, reg_ker_prf); |
1709 | |
1710 | L(kd_label); |
1711 | mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); |
1712 | } else { |
1713 | mov(reg_kj, reg_kh); |
1714 | } |
1715 | |
1716 | if (jcp.ndims == 5) { |
1717 | mov(aux_reg_dst, aux_reg_dst_d); |
1718 | mov(aux_reg_ker, aux_reg_ker_d); |
1719 | mov(aux_reg_dst_prf, aux_reg_dst_d_prf); |
1720 | mov(aux_reg_ker_prf, aux_reg_ker_d_prf); |
1721 | } |
1722 | |
1723 | align(16); |
1724 | L(kh_label); |
1725 | if (check_last_kh) { |
1726 | for (int ki = 0; ki < kw; ki++) |
1727 | for (int oc = 0; oc < oc_block; oc += 4) |
1728 | for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { |
1729 | bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1 |
1730 | && ki == kw - 1 && (oc + 4) == oc_block); |
1731 | |
1732 | if (last_kernel_loads) { |
1733 | cmp(reg_kj, 1); |
1734 | je(last_iter_label, T_NEAR); |
1735 | } |
1736 | |
1737 | kernel_loads(ki, oc, kk); |
1738 | for (int ii = get_iw_start(ki, l_overflow), |
1739 | prf_count_t0 = 0, prf_count_t1 = 0; |
1740 | ii < get_iw_end(ur_w, ki, r_overflow); ii++) { |
1741 | int aux_dst_offset = typesize |
1742 | * ((ii + jcp.l_pad - ki) * oc_block + oc); |
1743 | v4fmaddps(zmm_out(ii, kk), zmm_ker(0), |
1744 | EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); |
1745 | |
1746 | if (ii % 2) { |
1747 | if (prf_count_t0 < 4) { |
1748 | int aux_kernel_prf; |
1749 | if (last_kernel_loads) |
1750 | aux_kernel_prf= kernel_offset(0, prf_count_t0 |
1751 | + oc + 4 - oc_block, 0) + typesize * kw |
1752 | * oc_block * ic_block; |
1753 | else |
1754 | aux_kernel_prf = kernel_offset(kk, oc + 4 |
1755 | + prf_count_t0, ki); |
1756 | mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, |
1757 | aux_kernel_prf)); |
1758 | prf_count_t0++; |
1759 | } else if (prf_count_t1 < 4) { |
1760 | mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, |
1761 | kernel_offset(kk, oc + prf_count_t1, ki))); |
1762 | prf_count_t1++; |
1763 | } |
1764 | } else |
1765 | prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1); |
1766 | } |
1767 | if (last_kernel_loads) { |
1768 | jmp(loop_end_label, T_NEAR); |
1769 | |
1770 | L(last_iter_label); |
1771 | |
1772 | kernel_loads(ki, oc, kk); |
1773 | for (int ii = get_iw_start(ki, l_overflow), |
1774 | prf_count_t0 = 0, prf_count_t1 = 0; |
1775 | ii < get_iw_end(ur_w, ki, r_overflow); ii++) { |
1776 | int aux_dst_offset = typesize |
1777 | * ((ii + jcp.l_pad - ki) * oc_block + oc); |
1778 | v4fmaddps(zmm_out(ii, kk), zmm_ker(0), |
1779 | EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); |
1780 | if (ii % 2) { |
1781 | if (prf_count_t0 < 4) { |
1782 | mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf, |
1783 | kernel_offset(0, prf_count_t0, 0))); |
1784 | prf_count_t0++; |
1785 | } else if (prf_count_t1 < 4) { |
1786 | mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, |
1787 | kernel_offset(kk, oc + prf_count_t1, ki))); |
1788 | prf_count_t1++; |
1789 | } |
1790 | } |
1791 | } |
1792 | L(loop_end_label); |
1793 | } |
1794 | } |
1795 | } else { |
1796 | for (int ki = 0; ki < kw; ki++) |
1797 | for (int oc = 0; oc < oc_block; oc += 4) |
1798 | for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { |
1799 | kernel_loads(ki, oc, kk); |
1800 | |
1801 | for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0; |
1802 | ii < get_iw_end(ur_w, ki, r_overflow); ii++) { |
1803 | int aux_dst_offset = typesize |
1804 | * ((ii + jcp.l_pad - ki) * oc_block + oc); |
1805 | v4fmaddps(zmm_out(ii, kk), zmm_ker(0), |
1806 | EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); |
1807 | if ((ii % 2) && (prf_count_t1 < 4)) { |
1808 | mic_prefetcht1(EVEX_compress_addr( |
1809 | aux_reg_ker_prf, kernel_offset(kk, |
1810 | oc + prf_count_t1, ki))); |
1811 | prf_count_t1++; |
1812 | } |
1813 | if ( ki == 1 && oc == 0 && kk == 0) |
1814 | mic_prefetcht1(EVEX_compress_addr( |
1815 | aux_reg_dst_prf, aux_dst_offset)); |
1816 | } |
1817 | } |
1818 | } |
1819 | |
1820 | add(aux_reg_ker, shift_ker_ptr); |
1821 | sub(aux_reg_dst, shift_dst_ptr); |
1822 | add(aux_reg_ker_prf, shift_ker_ptr); |
1823 | sub(aux_reg_dst_prf, shift_dst_ptr); |
1824 | |
1825 | dec(reg_kj); |
1826 | cmp(reg_kj, 0); |
1827 | jg(kh_label, T_NEAR); |
1828 | |
1829 | if (jcp.ndims == 5) { |
1830 | sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block); |
1831 | add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); |
1832 | sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block); |
1833 | add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block); |
1834 | |
1835 | dec(reg_ki); |
1836 | cmp(reg_ki, 0); |
1837 | jg(kd_label, T_NEAR); |
1838 | |
1839 | pop(reg_src); |
1840 | pop(reg_src_prf); |
1841 | } |
1842 | } |
1843 | |
1844 | void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( |
1845 | int ur_w, int l_overflow, int r_overflow) |
1846 | { |
1847 | Label kh_label, kd_label; |
1848 | int kw = jcp.kw; |
1849 | int ow = jcp.ow; |
1850 | |
1851 | int ic_block = jcp.ic_block; |
1852 | int oc_block = jcp.oc_block; |
1853 | int l_pad = jcp.l_pad; |
1854 | int dilate_w = jcp.dilate_w + 1; |
1855 | int stride_w = jcp.stride_w; |
1856 | int stride_h = jcp.stride_h; |
1857 | |
1858 | int ker_pipeline_depth = 4; |
1859 | assert(ker_reg_base_idx + ker_pipeline_depth <= 32); |
1860 | assert(oc_block >= ker_pipeline_depth); |
1861 | |
1862 | int num_ker_loads = oc_block * kw; |
1863 | int num_inp_prfs = ur_w * nstl::min(kw, stride_w) |
1864 | + nstl::max(0, kw - stride_w); |
1865 | int num_prfs = num_ker_loads + num_inp_prfs; |
1866 | int num_fmas = num_ker_loads * ur_w / stride_w; |
1867 | int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs); |
1868 | int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; |
1869 | |
1870 | if (one_of(jcp.ndims, 3, 4)) { |
1871 | mov(aux_reg_dst, reg_dst); |
1872 | mov(aux_reg_ker, reg_ker); |
1873 | |
1874 | mov(aux_reg_dst_prf, reg_dst_prf); |
1875 | mov(aux_reg_ker_prf, reg_ker_prf); |
1876 | } |
1877 | |
1878 | if (jcp.ndims == 5) { |
1879 | push(reg_src_prf); |
1880 | push(reg_src); |
1881 | |
1882 | mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); |
1883 | mov(aux_reg_dst_d, reg_dst); |
1884 | mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); |
1885 | mov(aux_reg_dst_d_prf, reg_dst_prf); |
1886 | mov(aux_reg_ker_d_prf, reg_ker_prf); |
1887 | |
1888 | L(kd_label); |
1889 | mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); |
1890 | } else { |
1891 | mov(reg_kj, reg_kh); |
1892 | } |
1893 | |
1894 | if (jcp.ndims == 5) { |
1895 | mov(aux_reg_dst, aux_reg_dst_d); |
1896 | mov(aux_reg_ker, aux_reg_ker_d); |
1897 | mov(aux_reg_dst_prf, aux_reg_dst_d_prf); |
1898 | mov(aux_reg_ker_prf, aux_reg_ker_d_prf); |
1899 | } |
1900 | |
1901 | L(kh_label); { |
1902 | int step = 0; |
1903 | int ker_prfs = 0; |
1904 | for (int ki = 0; ki < kw; ki++) { |
1905 | for (int oc = 0; oc < oc_block; oc++) { |
1906 | if (step == 0) { |
1907 | for (int i = 0; i < ker_pipeline_depth; i++) { |
1908 | int aux_kernel_offset = typesize * ((oc + i) * oc_block |
1909 | + ki * ic_block * oc_block); |
1910 | vmovups(zmm_ker(i), EVEX_compress_addr( |
1911 | aux_reg_ker, aux_kernel_offset)); |
1912 | } |
1913 | } else if (step < num_ker_loads - ker_pipeline_depth + 1) { |
1914 | int load_offset = ker_pipeline_depth - 1; |
1915 | int ker_load_reg_idx |
1916 | = (step + load_offset) % ker_pipeline_depth; |
1917 | int aux_kernel_offset = typesize * ((oc + load_offset) |
1918 | * oc_block + ki * ic_block * oc_block); |
1919 | vmovups(zmm_ker(ker_load_reg_idx), |
1920 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
1921 | } |
1922 | |
1923 | bool ker_prf_inserted = false; |
1924 | auto zmm_kernel = zmm_ker(step % ker_pipeline_depth); |
1925 | |
1926 | int jj_start = get_iw_start(ki, l_overflow); |
1927 | int jj_end = get_iw_end(ur_w, ki, r_overflow); |
1928 | assert(stride_w != 1 |
1929 | || jj_start == nstl::max(0, |
1930 | l_overflow - (kw - 1 - ki) * dilate_w)); |
1931 | assert(stride_w != 1 |
1932 | || jj_end == ur_w - nstl::max(0, |
1933 | r_overflow - ki * dilate_w)); |
1934 | |
1935 | for (int jj = jj_start; jj < jj_end; jj += stride_w) { |
1936 | assert((jj + l_pad - ki * dilate_w) % stride_w == 0); |
1937 | int aux_dst_offset = typesize * |
1938 | (((jj + l_pad - ki * dilate_w) |
1939 | / stride_w) * jcp.oc_block + oc); |
1940 | vfmadd231ps(zmm_out(jj, 0), zmm_kernel, |
1941 | EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true)); |
1942 | |
1943 | int fma_idx = (step * ur_w + jj) / stride_w; |
1944 | int prf_slot_idx = fma_idx / prf_inst_spacing; |
1945 | if (fma_idx % prf_inst_spacing == prf_inst_trigger) { |
1946 | if (!ker_prf_inserted && ker_prfs < num_ker_loads) { |
1947 | int ker_prf_offset = typesize |
1948 | * ker_prfs * jcp.oc_block; |
1949 | mic_prefetcht1(EVEX_compress_addr( |
1950 | aux_reg_ker_prf, ker_prf_offset)); |
1951 | ker_prf_inserted = true; |
1952 | ker_prfs++; |
1953 | } else { |
1954 | int inp_prf_idx = prf_slot_idx - ker_prfs; |
1955 | if (inp_prf_idx < num_inp_prfs) { |
1956 | int inp_prf_offset |
1957 | = ic_block * typesize |
1958 | * ((inp_prf_idx / kw) * kw |
1959 | + (inp_prf_idx % kw)); |
1960 | mic_prefetcht0(EVEX_compress_addr( |
1961 | aux_reg_dst_prf, inp_prf_offset)); |
1962 | } |
1963 | } |
1964 | } |
1965 | } |
1966 | step++; |
1967 | } |
1968 | } |
1969 | |
1970 | add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block); |
1971 | sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block); |
1972 | add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block); |
1973 | sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block); |
1974 | |
1975 | dec(reg_kj); |
1976 | cmp(reg_kj, 0); |
1977 | jg(kh_label, T_NEAR); |
1978 | } |
1979 | if (jcp.ndims == 5) { |
1980 | sub(aux_reg_dst_d, |
1981 | typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); |
1982 | add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh |
1983 | * oc_block * ic_block); |
1984 | sub(aux_reg_dst_d_prf, |
1985 | typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); |
1986 | add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh |
1987 | * oc_block * ic_block); |
1988 | |
1989 | dec(reg_ki); |
1990 | cmp(reg_ki, 0); |
1991 | jg(kd_label, T_NEAR); |
1992 | } |
1993 | |
1994 | if (jcp.ndims == 5) |
1995 | { |
1996 | pop(reg_src); |
1997 | pop(reg_src_prf); |
1998 | } |
1999 | } |
2000 | |
2001 | void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( |
2002 | int ur_w, int l_overflow, int r_overflow) |
2003 | { |
2004 | int kw = jcp.kw; |
2005 | int ow = jcp.ow; |
2006 | int dilate_w = jcp.dilate_w + 1; |
2007 | int stride_w = jcp.stride_w; |
2008 | int ic_block = jcp.ic_block; |
2009 | int oc_block = jcp.oc_block; |
2010 | int nb_ic_block = jcp.nb_ic_blocking; |
2011 | Label kh_label, kd_label; |
2012 | |
2013 | int shift_ker_ptr = typesize * kw * oc_block * ic_block; |
2014 | int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block; |
2015 | |
2016 | auto output_offset = [=](int oi, int oc, int ki) { |
2017 | return typesize * |
2018 | (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc); |
2019 | }; |
2020 | auto kernel_offset = [=](int icb, int oc, int ki) { |
2021 | int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; |
2022 | int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; |
2023 | int oc_offset = oc * jcp.oc_block; |
2024 | return typesize * (blk_offset + oc_offset); |
2025 | }; |
2026 | |
2027 | if (one_of(jcp.ndims, 3, 4)) { |
2028 | mov(aux_reg_dst, reg_dst); |
2029 | mov(aux_reg_ker, reg_ker); |
2030 | } |
2031 | |
2032 | if (jcp.ndims == 5) { |
2033 | push(reg_src_prf); |
2034 | push(reg_src); |
2035 | |
2036 | mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); |
2037 | mov(aux_reg_dst_d, reg_dst); |
2038 | mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); |
2039 | |
2040 | L(kd_label); |
2041 | mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); |
2042 | } else { |
2043 | mov(reg_kj, reg_kh); |
2044 | } |
2045 | |
2046 | if (jcp.ndims == 5) { |
2047 | mov(aux_reg_dst, aux_reg_dst_d); |
2048 | mov(aux_reg_ker, aux_reg_ker_d); |
2049 | } |
2050 | |
2051 | L(kh_label); |
2052 | { |
2053 | for (int ki = 0; ki < kw; ki++) { |
2054 | int jj_start = get_iw_start(ki, l_overflow); |
2055 | int jj_end = get_iw_end(ur_w, ki, r_overflow); |
2056 | for (int oc = 0; oc < oc_block; oc++) { |
2057 | if (jcp.kernel_kind == expl_bcast) { |
2058 | for (int jj = jj_start; jj < jj_end; jj++) { |
2059 | int aux_output_offset = output_offset(jj, oc, ki); |
2060 | vbroadcastss(zmm_inp(jj, nb_ic_block), |
2061 | ptr[aux_reg_dst + aux_output_offset]); |
2062 | } |
2063 | } |
2064 | for (int ii = 0; ii < nb_ic_block; ii++) { |
2065 | int aux_kernel_offset = kernel_offset(ii, oc, ki); |
2066 | if (jj_end - jj_start > 0) |
2067 | vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, |
2068 | aux_kernel_offset)); |
2069 | for (int jj = jj_start; jj < jj_end; jj += stride_w) |
2070 | if (jcp.kernel_kind == expl_bcast) |
2071 | vfmadd231ps(zmm_out(jj, ii), |
2072 | zmm_inp(jj, nb_ic_block), zmm_wei); |
2073 | else |
2074 | vfmadd231ps(zmm_out(jj, ii), zmm_wei, |
2075 | EVEX_compress_addr(aux_reg_dst, |
2076 | output_offset(jj, oc, ki), true)); |
2077 | } |
2078 | } |
2079 | } |
2080 | add(aux_reg_ker, shift_ker_ptr); |
2081 | sub(aux_reg_dst, shift_dst_ptr); |
2082 | dec(reg_kj); |
2083 | cmp(reg_kj, 0); |
2084 | jg(kh_label, T_NEAR); |
2085 | } |
2086 | |
2087 | if (jcp.ndims == 5) { |
2088 | sub(aux_reg_dst_d, |
2089 | typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); |
2090 | add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); |
2091 | |
2092 | dec(reg_ki); |
2093 | cmp(reg_ki, 0); |
2094 | jg(kd_label, T_NEAR); |
2095 | |
2096 | pop(reg_src); |
2097 | pop(reg_src_prf); |
2098 | } |
2099 | } |
2100 | |
2101 | inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( |
2102 | int ur_w, int l_overflow, int r_overflow) |
2103 | { |
2104 | if (jcp.ndims == 5) push(reg_oi); |
2105 | |
2106 | prepare_output(ur_w); |
2107 | |
2108 | Label skip_compute_loop; |
2109 | if (jcp.ndims == 5) { |
2110 | mov(reg_kj, ptr[param + GET_OFF(kd_padding)]); |
2111 | cmp(reg_kj, 0); |
2112 | je(skip_compute_loop, T_NEAR); |
2113 | } |
2114 | mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); |
2115 | cmp(reg_kj, 0); |
2116 | je(skip_compute_loop, T_NEAR); |
2117 | |
2118 | if (jcp.ver == ver_4fma) |
2119 | compute_loop_4fma(ur_w, l_overflow, r_overflow); |
2120 | else if (jcp.ver == ver_fma) |
2121 | if (mayiuse(avx512_mic)) |
2122 | compute_loop_fma(ur_w, l_overflow, r_overflow); |
2123 | else |
2124 | if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1) |
2125 | compute_loop_fma(ur_w, l_overflow, r_overflow); |
2126 | else |
2127 | compute_loop_fma_core(ur_w, l_overflow, r_overflow); |
2128 | else |
2129 | assert("!unknown convolution version" ); |
2130 | |
2131 | L(skip_compute_loop); |
2132 | store_output(ur_w); |
2133 | if (jcp.ndims == 5) pop(reg_oi); |
2134 | } |
2135 | |
2136 | void jit_avx512_common_conv_bwd_data_kernel_f32::generate() |
2137 | { |
2138 | int iw = jcp.iw; |
2139 | int kw = jcp.kw; |
2140 | int ur_w = jcp.ur_w; |
2141 | int ic_block = jcp.ic_block; |
2142 | int oc_block = jcp.oc_block; |
2143 | int ur_w_tail = jcp.ur_w_tail; |
2144 | int dilate_w = jcp.dilate_w + 1; |
2145 | int stride_w = jcp.stride_w; |
2146 | |
2147 | int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block; |
2148 | int src_shift = jcp.typesize_out * ur_w * oc_block; |
2149 | |
2150 | preamble(); |
2151 | |
2152 | mov(reg_src, ptr[param + GET_OFF(src)]); |
2153 | mov(reg_dst, ptr[param + GET_OFF(dst)]); |
2154 | mov(reg_ker, ptr[param + GET_OFF(filt)]); |
2155 | |
2156 | mov(reg_kh, ptr[param + GET_OFF(kh_padding)]); |
2157 | mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]); |
2158 | mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]); |
2159 | mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]); |
2160 | |
2161 | int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w); |
2162 | int r_overflow = nstl::max(0, ((kw - 1) * dilate_w |
2163 | - nstl::max(0, jcp.r_pad)) / stride_w); |
2164 | int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w |
2165 | - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w); |
2166 | |
2167 | int n_oi = iw / ur_w; |
2168 | if (r_overflow1 > 0) n_oi--; |
2169 | |
2170 | if (ur_w == iw) { |
2171 | compute_loop(ur_w, l_overflow, r_overflow); |
2172 | } else if (n_oi == 0) { |
2173 | compute_loop(ur_w, l_overflow, r_overflow1); |
2174 | add(reg_src, src_shift); |
2175 | add(reg_dst, dst_shift); |
2176 | add(reg_src_prf, src_shift); |
2177 | add(reg_dst_prf, dst_shift); |
2178 | if (ur_w_tail != 0) |
2179 | compute_loop(ur_w_tail, 0, r_overflow); |
2180 | } else { |
2181 | xor_(reg_oi, reg_oi); |
2182 | if (l_overflow > 0) { |
2183 | compute_loop(ur_w, l_overflow, 0); |
2184 | add(reg_src, src_shift); |
2185 | add(reg_dst, dst_shift); |
2186 | add(reg_src_prf, src_shift); |
2187 | add(reg_dst_prf, dst_shift); |
2188 | |
2189 | inc(reg_oi); |
2190 | } |
2191 | if ((l_overflow <= 0 && n_oi > 0) |
2192 | || (l_overflow > 0 && n_oi > 1)) { |
2193 | Label ow_loop_label; |
2194 | L(ow_loop_label); { |
2195 | compute_loop(ur_w, 0, 0); |
2196 | add(reg_src, src_shift); |
2197 | add(reg_dst, dst_shift); |
2198 | add(reg_src_prf, src_shift); |
2199 | add(reg_dst_prf, dst_shift); |
2200 | |
2201 | inc(reg_oi); |
2202 | cmp(reg_oi, n_oi); |
2203 | jl(ow_loop_label, T_NEAR); |
2204 | } |
2205 | } |
2206 | if (r_overflow1 > 0) { |
2207 | compute_loop(ur_w, 0, r_overflow1); |
2208 | add(reg_src, src_shift); |
2209 | add(reg_dst, dst_shift); |
2210 | add(reg_src_prf, src_shift); |
2211 | add(reg_dst_prf, dst_shift); |
2212 | } |
2213 | if (ur_w_tail != 0) { |
2214 | compute_loop(ur_w_tail, 0, r_overflow); |
2215 | } |
2216 | } |
2217 | |
2218 | postamble(); |
2219 | } |
2220 | |
2221 | status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( |
2222 | jit_conv_conf_t &jcp, |
2223 | const convolution_desc_t &cd, |
2224 | const memory_desc_wrapper &diff_src_d, |
2225 | const memory_desc_wrapper &weights_d, |
2226 | const memory_desc_wrapper &diff_dst_d) |
2227 | { |
2228 | if (!mayiuse(avx512_common)) return status::unimplemented; |
2229 | |
2230 | jcp = zero<decltype(jcp)>(); |
2231 | |
2232 | jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float); |
2233 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
2234 | int ndims = diff_src_d.ndims(); |
2235 | |
2236 | jcp.ndims = ndims; |
2237 | jcp.prop_kind = cd.prop_kind; |
2238 | |
2239 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
2240 | jcp.mb = diff_src_d.dims()[0]; |
2241 | |
2242 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
2243 | jcp.oc_without_padding = jcp.oc; |
2244 | jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; |
2245 | |
2246 | jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; |
2247 | jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; |
2248 | jcp.iw = diff_src_d.dims()[ndims-1]; |
2249 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
2250 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; |
2251 | jcp.ow = diff_dst_d.dims()[ndims-1]; |
2252 | |
2253 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
2254 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
2255 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
2256 | |
2257 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
2258 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; |
2259 | jcp.l_pad = cd.padding[0][ndims-3]; |
2260 | |
2261 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
2262 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; |
2263 | jcp.stride_w = cd.strides[ndims-3]; |
2264 | |
2265 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
2266 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; |
2267 | jcp.dilate_w = cd.dilates[ndims-3]; |
2268 | if ((jcp.dilate_w != 0 && jcp.stride_w != 1) |
2269 | || (jcp.dilate_d != 0 && jcp.stride_d != 1) |
2270 | || (jcp.dilate_h != 0 && jcp.stride_h != 1)) |
2271 | return status::unimplemented; |
2272 | |
2273 | jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) |
2274 | - (jcp.iw + jcp.l_pad - 1); |
2275 | jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) |
2276 | - (jcp.ih + jcp.t_pad - 1); |
2277 | jcp.back_pad = (jcp.od - 1) * jcp.stride_d |
2278 | + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); |
2279 | |
2280 | jcp.aligned_threads = 0; |
2281 | |
2282 | jcp.is_1stconv = false; |
2283 | |
2284 | jcp.oc_block = jcp.simd_w; |
2285 | jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; |
2286 | |
2287 | bool ok_to_pad_channels = true |
2288 | && jcp.ngroups == 1 |
2289 | && diff_src_d.data_type() == data_type::f32; |
2290 | |
2291 | if (ok_to_pad_channels) { |
2292 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
2293 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
2294 | } |
2295 | |
2296 | auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
2297 | auto wei_tag = with_groups |
2298 | ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) |
2299 | : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i); |
2300 | jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); |
2301 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); |
2302 | |
2303 | bool args_ok = true |
2304 | && jcp.oc % jcp.oc_block == 0 |
2305 | && jcp.ic % jcp.ic_block == 0 |
2306 | && jcp.src_tag == dat_tag |
2307 | && jcp.dst_tag == dat_tag; |
2308 | if (!args_ok) |
2309 | return status::unimplemented; |
2310 | |
2311 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
2312 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
2313 | |
2314 | jcp.ur_w = jcp.stride_w; |
2315 | |
2316 | int regs = 28; |
2317 | if (jcp.iw <= regs) |
2318 | jcp.ur_w = jcp.iw; |
2319 | else { |
2320 | for (int ur_w = regs; ur_w > 0; --ur_w) |
2321 | if (ur_w % jcp.stride_w == 0) { |
2322 | jcp.ur_w = ur_w; |
2323 | break; |
2324 | } |
2325 | } |
2326 | int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) |
2327 | - jcp.l_pad) / jcp.stride_w); |
2328 | int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) |
2329 | - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w); |
2330 | int n_oi = jcp.iw / jcp.ur_w; |
2331 | if (r_overflow1 > 0) n_oi--; |
2332 | |
2333 | if (mayiuse(avx512_common) |
2334 | && diff_dst_d.data_type() == data_type::f32 |
2335 | && weights_d.data_type() == data_type::f32 |
2336 | && diff_src_d.data_type() == data_type::f32) { |
2337 | jcp.ver = ver_fma; |
2338 | jcp.typesize_in = sizeof(float); |
2339 | jcp.typesize_out = sizeof(float); |
2340 | if (mayiuse(avx512_mic_4ops) |
2341 | && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) { |
2342 | jcp.ver = ver_4fma; |
2343 | } |
2344 | } else { |
2345 | return status::unimplemented; |
2346 | } |
2347 | |
2348 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
2349 | if (jcp.wei_tag != wei_tag) |
2350 | return status::unimplemented; |
2351 | |
2352 | if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) |
2353 | && jcp.ver != ver_fma) |
2354 | return status::unimplemented; |
2355 | |
2356 | jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; |
2357 | if (jcp.ver == ver_4fma) { |
2358 | if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) { |
2359 | jcp.nb_ic_blocking = 2; |
2360 | } else { |
2361 | for (int i = jcp.nb_ic; i > 0; i--) |
2362 | if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) { |
2363 | jcp.nb_ic_blocking = i; |
2364 | break; |
2365 | } |
2366 | } |
2367 | } |
2368 | |
2369 | jcp.loop_order = loop_gnc; |
2370 | |
2371 | bool large_code_size = (jcp.ur_w != jcp.ow) |
2372 | && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1)) |
2373 | && (r_overflow1 > 0) && (l_overflow > 0); |
2374 | if (large_code_size) { |
2375 | const int max_code_size = 24 * 1024; |
2376 | const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw; |
2377 | int mult = 1; |
2378 | if (l_overflow > 0) mult += 1; |
2379 | if (r_overflow1 > 0) mult += 1; |
2380 | for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { |
2381 | if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2 |
2382 | < max_code_size) { |
2383 | if (ur_w % jcp.stride_w == 0) { |
2384 | jcp.ur_w = ur_w; |
2385 | break; |
2386 | } |
2387 | } |
2388 | } |
2389 | } |
2390 | |
2391 | if (jcp.ver == ver_fma && mayiuse(avx512_core)) { |
2392 | int try_nb_ic_blocking = 2; |
2393 | unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block |
2394 | * try_nb_ic_blocking * jcp.kh; |
2395 | unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block; |
2396 | unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block |
2397 | * jcp.oc_block * try_nb_ic_blocking; |
2398 | unsigned int ker_total_size = ker_inp_size + ker_out_size |
2399 | + ker_wei_size; |
2400 | if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8) |
2401 | || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13)) |
2402 | || ker_total_size > L1_cache_size ))) |
2403 | || jcp.stride_h > 1 || jcp.stride_d > 1) { |
2404 | jcp.kernel_kind = embd_bcast; |
2405 | jcp.ur_w = nstl::min(jcp.iw, regs); |
2406 | jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; |
2407 | if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size |
2408 | && jcp.ow > 8)) && jcp.stride_h == 1) |
2409 | if (jcp.nb_ic % try_nb_ic_blocking == 0) { |
2410 | jcp.nb_ic_blocking = try_nb_ic_blocking; |
2411 | jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); |
2412 | if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; |
2413 | } |
2414 | } else { |
2415 | jcp.kernel_kind = expl_bcast; |
2416 | jcp.nb_oc_blocking = 1; |
2417 | jcp.nb_ic_blocking = 4; |
2418 | if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic; |
2419 | if (jcp.nb_ic % jcp.nb_ic_blocking != 0) |
2420 | for (int i = jcp.nb_ic_blocking; i > 0; i--) |
2421 | if (jcp.nb_ic % i == 0) { |
2422 | jcp.nb_ic_blocking = i; |
2423 | break; |
2424 | } |
2425 | jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); |
2426 | if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; |
2427 | } |
2428 | } |
2429 | jcp.ur_w_tail = jcp.iw % jcp.ur_w; |
2430 | |
2431 | if (l_overflow * jcp.stride_w > jcp.ur_w) |
2432 | return status::unimplemented; |
2433 | int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) |
2434 | - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); |
2435 | if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) |
2436 | return status::unimplemented; |
2437 | if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) |
2438 | return status::unimplemented; |
2439 | |
2440 | pick_loop_order(jcp); |
2441 | |
2442 | jcp.nb_oc_L2 = jcp.nb_oc; |
2443 | if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) { |
2444 | for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc; |
2445 | divf++) { |
2446 | size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih |
2447 | * jcp.id; |
2448 | size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; |
2449 | size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh |
2450 | * jcp.kd * jcp.nb_ic_blocking * temp_nb; |
2451 | if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { |
2452 | if (jcp.kh == 3 && jcp.ih == 7) { |
2453 | jcp.nb_oc_L2 = 1; |
2454 | break; |
2455 | } |
2456 | temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf |
2457 | : jcp.nb_oc_L2); |
2458 | } else { |
2459 | jcp.nb_oc_L2 = temp_nb; |
2460 | break; |
2461 | } |
2462 | } |
2463 | } |
2464 | |
2465 | args_ok = true |
2466 | && jcp.ic <= diff_src_d.padded_dims()[1] |
2467 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
2468 | && jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
2469 | && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; |
2470 | if (!args_ok) return status::unimplemented; |
2471 | |
2472 | return status::success; |
2473 | } |
2474 | |
2475 | void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( |
2476 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
2477 | UNUSED(scratchpad); |
2478 | UNUSED(jcp); |
2479 | } |
2480 | |
2481 | const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28; |
2482 | |
2483 | void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() |
2484 | { |
2485 | Label kd_comeback_label; |
2486 | |
2487 | /* 'depth' loop count bound by 'kd_work_size' */ |
2488 | mov(kj, reg_kd_count); |
2489 | L(kd_comeback_label); { |
2490 | int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; |
2491 | int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; |
2492 | sub(reg_input, |
2493 | jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult); |
2494 | sub(reg_kernel, |
2495 | jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block); |
2496 | dec(kj); |
2497 | cmp(kj, 0); |
2498 | jg(kd_comeback_label, T_NEAR); |
2499 | } |
2500 | } |
2501 | |
2502 | void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() |
2503 | { |
2504 | Label kh_comeback_label, kd_comeback_label; |
2505 | mov(kj, reg_kh); |
2506 | L(kh_comeback_label); { |
2507 | int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; |
2508 | int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; |
2509 | sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult); |
2510 | sub(reg_kernel, |
2511 | jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block); |
2512 | dec(kj); |
2513 | cmp(kj, 0); |
2514 | jg(kh_comeback_label, T_NEAR); |
2515 | } |
2516 | } |
2517 | |
2518 | void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma( |
2519 | int ur_w, int pad_l, int pad_r, |
2520 | int ic_block_step, int input_offset, int kernel_offset, |
2521 | int output_offset, bool input_wraparound) |
2522 | { |
2523 | |
2524 | int kw = jcp.kw; |
2525 | int ic_block = jcp.ic_block; |
2526 | int oc_block = jcp.oc_block; |
2527 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2528 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) |
2529 | vmovups(Zmm(i_kw * ic_block_step + i_ic), |
2530 | EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block |
2531 | + i_ic) * jcp.oc_block + kernel_offset)); |
2532 | |
2533 | for (int i_ur = 0; i_ur < ur_w; i_ur++) { |
2534 | if (i_ur == 0) { |
2535 | vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4), |
2536 | EVEX_compress_addr(reg_output, typesize * (i_ur + 0) |
2537 | * oc_block + output_offset)); |
2538 | if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4), |
2539 | EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block |
2540 | + output_offset)); |
2541 | if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4), |
2542 | EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block |
2543 | + output_offset)); |
2544 | if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), |
2545 | EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block |
2546 | + output_offset)); |
2547 | } else if (i_ur + 3 < ur_w) |
2548 | vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), |
2549 | EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block |
2550 | + output_offset)); |
2551 | |
2552 | for (int i_kw = 0; i_kw < kw; i_kw++) { |
2553 | int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1); |
2554 | if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + |
2555 | (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue; |
2556 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2557 | const size_t i_offset = (size_t)input_offset |
2558 | + (size_t)typesize * (jcp.ver == ver_4fma |
2559 | ? (i_iw - pad_l + i_ic * jcp.tr_iw) |
2560 | : (jcp.is_1stconv |
2561 | ? (i_iw - pad_l) + (size_t)i_ic |
2562 | * ((size_t)jcp.ih*jcp.iw*jcp.id) |
2563 | : (i_iw - pad_l) * ic_block + i_ic)); |
2564 | vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic), |
2565 | Zmm(kw * ic_block_step + i_ur % 4), |
2566 | EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt, |
2567 | true)); |
2568 | } |
2569 | } |
2570 | } |
2571 | |
2572 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2573 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) |
2574 | vmovups(EVEX_compress_addr(reg_kernel, typesize |
2575 | * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset), |
2576 | Zmm(i_kw * ic_block_step + i_ic)); |
2577 | } |
2578 | |
2579 | void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma( |
2580 | int ur_w, int pad_l, int pad_r, |
2581 | int ic_block_step, int input_offset, int kernel_offset, |
2582 | int output_offset, bool input_wraparound) |
2583 | { |
2584 | // TODO: add prefetches to fma version as well |
2585 | |
2586 | assert(jcp.ver == ver_4fma); |
2587 | |
2588 | int kw = jcp.kw; |
2589 | int ic_block = jcp.ic_block; |
2590 | int oc_block = jcp.oc_block; |
2591 | |
2592 | auto zmm_ker = [=](int i_kw, int i_ic) { |
2593 | return Zmm(i_kw * ic_block_step + i_ic); |
2594 | }; |
2595 | |
2596 | auto ker_addr = [=](int i_kw, int i_ic) { |
2597 | size_t local_offset |
2598 | = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block; |
2599 | return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); |
2600 | }; |
2601 | |
2602 | auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t = 0) { |
2603 | int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1); |
2604 | int local_offset = jcp.typesize_in * (i_iw + i_ic * stride); |
2605 | return EVEX_compress_addr(reg_input, |
2606 | local_offset + input_offset + extra_offset); |
2607 | }; |
2608 | |
2609 | auto zmm_out = [=](int i_iw) { |
2610 | // TODO: move reg calc to global member funcs |
2611 | const int out_zmm_base_idx = 28; |
2612 | return Zmm(out_zmm_base_idx + i_iw % 4); |
2613 | }; |
2614 | |
2615 | auto out_addr = [=](int i_ur) { |
2616 | return EVEX_compress_addr(reg_output, |
2617 | jcp.typesize_in * i_ur * oc_block + output_offset); |
2618 | }; |
2619 | |
2620 | auto pf_callback = [=](int i_ur, int i_kw, int i_ic) { |
2621 | assert(i_ur % 4 == 0); |
2622 | if (i_ur == 0) |
2623 | prefetcht1(ker_addr(i_kw, i_ic)); |
2624 | if (i_ur + 4 >= ur_w) |
2625 | prefetcht0(ker_addr(i_kw, i_ic)); |
2626 | |
2627 | const ptrdiff_t next_input_block_offset |
2628 | = jcp.typesize_in * ic_block_step * jcp.tr_iw; |
2629 | if (i_ur % 16 == 4 && i_kw == 0) { |
2630 | if (i_ur + 16 < ur_w) |
2631 | prefetcht0(inp_addr(i_ur + 16, i_ic)); |
2632 | else |
2633 | prefetcht0(inp_addr(0, i_ic, next_input_block_offset)); |
2634 | } |
2635 | if (i_ur % 16 == 4 && i_kw == 1) { |
2636 | if (input_wraparound) |
2637 | prefetcht1(inp_addr(i_ur, i_ic, -input_offset)); |
2638 | else |
2639 | prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset)); |
2640 | } |
2641 | }; |
2642 | |
2643 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2644 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2645 | auto zmm = zmm_ker(i_kw, i_ic); |
2646 | vpxord(zmm, zmm, zmm); |
2647 | } |
2648 | |
2649 | for (int i_ur = 0; i_ur < ur_w; i_ur += 4) { |
2650 | |
2651 | for (int i = 0; i < 4; i++) { |
2652 | auto zmm = zmm_out(i_ur + i); |
2653 | if (i_ur + i < ur_w) |
2654 | vmovups(zmm, out_addr(i_ur + i)); |
2655 | else |
2656 | vpxord(zmm, zmm, zmm); |
2657 | prefetcht0(out_addr(i_ur + i + 4)); |
2658 | } |
2659 | |
2660 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2661 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2662 | int i_iw = i_ur + i_kw; |
2663 | v4fmaddps(zmm_ker(i_kw, i_ic), |
2664 | zmm_out(i_ur), inp_addr(i_iw, i_ic)); |
2665 | pf_callback(i_ur, i_kw, i_ic); |
2666 | } |
2667 | } |
2668 | |
2669 | for (int i_kw = 0; i_kw < kw; i_kw++) |
2670 | for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { |
2671 | auto addr = ker_addr(i_kw, i_ic); |
2672 | auto zmm = zmm_ker(i_kw, i_ic); |
2673 | vaddps(zmm, zmm, addr); |
2674 | vmovups(addr, zmm); |
2675 | } |
2676 | } |
2677 | |
2678 | void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step( |
2679 | int ur_w, int pad_l, int pad_r, |
2680 | int ic_block_step, int input_offset, int kernel_offset, |
2681 | int output_offset, bool input_wraparound) |
2682 | { |
2683 | if (jcp.ver == ver_4fma) |
2684 | compute_ic_block_step_4fma(ur_w, pad_l, pad_r, |
2685 | ic_block_step, input_offset, kernel_offset, output_offset, |
2686 | input_wraparound); |
2687 | else if (jcp.ver == ver_fma) |
2688 | compute_ic_block_step_fma(ur_w, pad_l, pad_r, |
2689 | ic_block_step, input_offset, kernel_offset, output_offset, |
2690 | input_wraparound); |
2691 | else |
2692 | assert(!"unknown convolution version" ); |
2693 | } |
2694 | |
2695 | void jit_avx512_common_conv_bwd_weights_kernel_f32 |
2696 | ::compute_oh_step_unroll_ow_icblock( |
2697 | int ic_block_step, int max_ur_w) |
2698 | { |
2699 | UNUSED(max_ur_w); |
2700 | |
2701 | Label kh_label, kd_label; |
2702 | |
2703 | int ic_block = jcp.ic_block; |
2704 | int oc_block = jcp.oc_block; |
2705 | int inp_mul = !jcp.is_1stconv ? ic_block : 1; |
2706 | int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; |
2707 | int ow = jcp.ow; |
2708 | |
2709 | int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w |
2710 | + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); |
2711 | int l_pad = jcp.l_pad; |
2712 | |
2713 | if (jcp.ndims == 5) { |
2714 | L(kd_label); |
2715 | mov(reg_input, aux_reg_input); |
2716 | mov(reg_kernel, aux_reg_kernel); |
2717 | } |
2718 | |
2719 | mov(kj, reg_kh); |
2720 | L(kh_label); |
2721 | { |
2722 | for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) { |
2723 | const int input_offset = jcp.typesize_in |
2724 | * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic); |
2725 | compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step, |
2726 | input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0, |
2727 | i_b_ic + ic_block_step >= jcp.ic_block); |
2728 | } |
2729 | add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); |
2730 | add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block); |
2731 | dec(kj); |
2732 | cmp(kj, 0); |
2733 | jg(kh_label, T_NEAR); |
2734 | } |
2735 | |
2736 | if (jcp.ndims == 5) { |
2737 | add(aux_reg_input, |
2738 | jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul); |
2739 | add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block |
2740 | * oc_block); |
2741 | dec(ki); |
2742 | cmp(ki, 0); |
2743 | jg(kd_label, T_NEAR); |
2744 | } |
2745 | } |
2746 | |
2747 | void jit_avx512_common_conv_bwd_weights_kernel_f32 |
2748 | ::compute_oh_step_unroll_ow( |
2749 | int ic_block_step, int max_ur_w) |
2750 | { |
2751 | Label kh_label, ic_block_label, kd_label; |
2752 | |
2753 | UNUSED(max_ur_w); |
2754 | |
2755 | int ic_block = jcp.ic_block; |
2756 | int oc_block = jcp.oc_block; |
2757 | |
2758 | int ow = jcp.ow; |
2759 | |
2760 | int r_pad = nstl::max(0, |
2761 | (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) |
2762 | - (jcp.iw + jcp.l_pad - 1)); |
2763 | int l_pad = jcp.l_pad; |
2764 | |
2765 | if (jcp.ndims == 5) { |
2766 | L(kd_label); |
2767 | mov(reg_input, aux_reg_input); |
2768 | mov(reg_kernel, aux_reg_kernel); |
2769 | } |
2770 | |
2771 | mov(kj, reg_kh); |
2772 | L(kh_label); |
2773 | { |
2774 | xor_(b_ic, b_ic); |
2775 | L(ic_block_label); { |
2776 | compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, |
2777 | 0, 0, 0); |
2778 | size_t inp_icblk_stride = jcp.is_1stconv |
2779 | ? (size_t)jcp.ih * jcp.iw * jcp.id |
2780 | : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); |
2781 | size_t input_offset |
2782 | = inp_icblk_stride * jcp.typesize_in * ic_block_step; |
2783 | safe_add(reg_input, input_offset, reg_long_offt); |
2784 | add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); |
2785 | add(b_ic, ic_block_step); |
2786 | cmp(b_ic, jcp.ic_block); |
2787 | jl(ic_block_label, T_NEAR); |
2788 | } |
2789 | |
2790 | if (jcp.is_1stconv) { |
2791 | size_t input_offset |
2792 | = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; |
2793 | safe_sub(reg_input, input_offset, reg_long_offt); |
2794 | add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); |
2795 | } else if (jcp.ver != ver_4fma) { |
2796 | add(reg_input, jcp.typesize_in |
2797 | * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block); |
2798 | } |
2799 | add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); |
2800 | dec(kj); |
2801 | cmp(kj, 0); |
2802 | jg(kh_label, T_NEAR); |
2803 | } |
2804 | if (jcp.ndims == 5) { |
2805 | add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih |
2806 | * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); |
2807 | add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block |
2808 | * oc_block); |
2809 | dec(ki); |
2810 | cmp(ki, 0); |
2811 | jg(kd_label, T_NEAR); |
2812 | } |
2813 | } |
2814 | |
2815 | void jit_avx512_common_conv_bwd_weights_kernel_f32 |
2816 | ::compute_oh_step_common( |
2817 | int ic_block_step, int max_ur_w) |
2818 | { |
2819 | Label kh_label, ic_block_label, ow_block_label, kd_label; |
2820 | |
2821 | int ic_block = jcp.ic_block; |
2822 | int oc_block = jcp.oc_block; |
2823 | |
2824 | int ow = jcp.ow; |
2825 | int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w |
2826 | + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); |
2827 | int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad; |
2828 | |
2829 | int ur_w = nstl::min(ow, max_ur_w); |
2830 | int ur_w_trips = ow / ur_w; |
2831 | int ur_w_tail = ow % ur_w; |
2832 | if ((ur_w_tail == 0 && r_pad != 0) |
2833 | || r_pad >= ur_w_tail) { |
2834 | if (ur_w_trips > 1) { |
2835 | ur_w_tail += ur_w; |
2836 | ur_w_trips--; |
2837 | } else { |
2838 | ur_w_tail += (ur_w - ur_w / 2); |
2839 | ur_w = ur_w / 2; |
2840 | } |
2841 | } |
2842 | |
2843 | int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block; |
2844 | int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult; |
2845 | int output_comeback = ur_w_trips * ur_w * oc_block; |
2846 | |
2847 | if (jcp.ndims == 5) { |
2848 | L(kd_label); |
2849 | mov(reg_input, aux_reg_input); |
2850 | mov(reg_kernel, aux_reg_kernel); |
2851 | } |
2852 | |
2853 | mov(kj, reg_kh); |
2854 | L(kh_label); { |
2855 | xor_(b_ic, b_ic); |
2856 | L(ic_block_label); { |
2857 | if (l_pad != 0) { |
2858 | ur_w_trips--; |
2859 | compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0); |
2860 | add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad) |
2861 | * inp_mult); |
2862 | add(reg_output, jcp.typesize_in * ur_w * oc_block); |
2863 | } |
2864 | |
2865 | if (ur_w_trips > 0) { |
2866 | xor_(reg_ur_w_trips, reg_ur_w_trips); |
2867 | L(ow_block_label); { |
2868 | compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); |
2869 | add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w |
2870 | * inp_mult); |
2871 | add(reg_output, jcp.typesize_in * ur_w * oc_block); |
2872 | |
2873 | inc(reg_ur_w_trips); |
2874 | cmp(reg_ur_w_trips, ur_w_trips); |
2875 | jl(ow_block_label, T_NEAR); |
2876 | } |
2877 | } |
2878 | |
2879 | if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad, |
2880 | ic_block_step, 0, 0, 0); |
2881 | |
2882 | sub(reg_input, jcp.typesize_in * input_comeback); |
2883 | sub(reg_output, jcp.typesize_in * output_comeback); |
2884 | int inp_icblk_stride = jcp.is_1stconv |
2885 | ? jcp.ih * jcp.iw * jcp.id |
2886 | : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); |
2887 | size_t input_offset |
2888 | = inp_icblk_stride * jcp.typesize_in * ic_block_step; |
2889 | safe_add(reg_input, input_offset, reg_long_offt); |
2890 | add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); |
2891 | |
2892 | add(b_ic, ic_block_step); |
2893 | cmp(b_ic, jcp.ic_block); |
2894 | jl(ic_block_label, T_NEAR); |
2895 | } |
2896 | if (jcp.is_1stconv) { |
2897 | size_t input_offset |
2898 | = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; |
2899 | safe_sub(reg_input, input_offset, reg_long_offt); |
2900 | add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); |
2901 | } else if (jcp.ver != ver_4fma) { |
2902 | add(reg_input, jcp.typesize_in |
2903 | * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block); |
2904 | } |
2905 | add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); |
2906 | dec(kj); |
2907 | cmp(kj, 0); |
2908 | jg(kh_label, T_NEAR); |
2909 | } |
2910 | if (jcp.ndims == 5) { |
2911 | add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih |
2912 | * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); |
2913 | add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block |
2914 | * oc_block); |
2915 | dec(ki); |
2916 | cmp(ki, 0); |
2917 | jg(kd_label, T_NEAR); |
2918 | } |
2919 | } |
2920 | |
2921 | void jit_avx512_common_conv_bwd_weights_kernel_f32 |
2922 | ::compute_oh_step_disp() |
2923 | { |
2924 | int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2); |
2925 | if (jcp.is_1stconv) { |
2926 | bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0); |
2927 | ic_block_step |
2928 | = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1; |
2929 | } |
2930 | |
2931 | bool too_large_to_unroll |
2932 | = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1) |
2933 | && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1); |
2934 | |
2935 | int ow = jcp.ow; |
2936 | if (jcp.ndims == 5) { |
2937 | /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of |
2938 | * 'movs' must be guaranteed. */ |
2939 | mov(ki, reg_kd_count); |
2940 | push(reg_kd_count); |
2941 | mov(aux_reg_input, reg_input); |
2942 | mov(aux_reg_kernel, reg_kernel); |
2943 | } |
2944 | |
2945 | if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) |
2946 | compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w); |
2947 | else if (ow <= max_ur_w) |
2948 | compute_oh_step_unroll_ow(ic_block_step, max_ur_w); |
2949 | else |
2950 | compute_oh_step_common(ic_block_step, max_ur_w); |
2951 | |
2952 | if (jcp.ndims == 5) { |
2953 | mov(reg_input, aux_reg_input); |
2954 | mov(reg_kernel, aux_reg_kernel); |
2955 | pop(reg_kd_count); |
2956 | od_step_comeback_pointers(); |
2957 | } else { |
2958 | oh_step_comeback_pointers(); |
2959 | } |
2960 | } |
2961 | |
2962 | void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel() |
2963 | { |
2964 | Label skip_zeroing, zeroing_loop; |
2965 | |
2966 | mov(reg_tmp, ptr[param + GET_OFF(channel)]); |
2967 | cmp(reg_tmp, 0); |
2968 | jz(skip_zeroing, T_NEAR); |
2969 | |
2970 | Zmm zero = Zmm(0); |
2971 | vpxord(zero, zero, zero); |
2972 | xor_(reg_tmp, reg_tmp); |
2973 | L(zeroing_loop); { |
2974 | assert(jcp.oc_block * jcp.typesize_out |
2975 | == cpu_isa_traits<avx512_common>::vlen); |
2976 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) |
2977 | vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block |
2978 | * jcp.typesize_out], zero); |
2979 | add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out); |
2980 | cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd |
2981 | * jcp.typesize_out); |
2982 | jnz(zeroing_loop); |
2983 | } |
2984 | |
2985 | L(skip_zeroing); |
2986 | } |
2987 | |
2988 | void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel() |
2989 | { |
2990 | Label skip_bias, bias_loop, skip_load_bias; |
2991 | |
2992 | mov(reg_tmp, ptr[param + GET_OFF(flags)]); |
2993 | test(reg_tmp,reg_tmp); |
2994 | jne(skip_bias, T_NEAR); |
2995 | |
2996 | mov(reg_bias, ptr[param + GET_OFF(bias)]); |
2997 | mov(reg_output, ptr[param + GET_OFF(dst)]); |
2998 | vpxord(Zmm(1), Zmm(1), Zmm(1)); |
2999 | |
3000 | mov(reg_tmp, ptr[param + GET_OFF(channel)]); |
3001 | cmp(reg_tmp, 0); |
3002 | jne(skip_load_bias, T_NEAR); |
3003 | vmovups(Zmm(1), ptr[reg_bias]); |
3004 | |
3005 | L(skip_load_bias); |
3006 | |
3007 | mov(reg_oi, ptr[param + GET_OFF(d_worksize)]); |
3008 | sub(reg_oi, ptr[param + GET_OFF(d_index)]); |
3009 | mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out); |
3010 | imul(reg_oi, reg_tmp); |
3011 | |
3012 | xor_(reg_tmp, reg_tmp); |
3013 | L(bias_loop); { |
3014 | vmovups(Zmm(0), ptr[reg_output + reg_tmp]); |
3015 | vaddps(Zmm(1), Zmm(1), Zmm(0)); |
3016 | add(reg_tmp, jcp.oc_block * jcp.typesize_out); |
3017 | cmp(reg_tmp, reg_oi); |
3018 | jl(bias_loop); |
3019 | } |
3020 | vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1)); |
3021 | |
3022 | L(skip_bias); |
3023 | } |
3024 | |
3025 | void jit_avx512_common_conv_bwd_weights_kernel_f32 |
3026 | ::compute_oh_loop_common() |
3027 | { |
3028 | int b_pad = jcp.b_pad; |
3029 | int t_pad = jcp.t_pad; |
3030 | bool is_dilated = jcp.dilate_h != 0; |
3031 | int dilate_h = jcp.dilate_h + 1; |
3032 | int stride_h = jcp.stride_h; |
3033 | const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; |
3034 | int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; |
3035 | Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label, |
3036 | oh_bpad_label, oh_bpad_label_end, od_label, od_label_end, |
3037 | oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end; |
3038 | |
3039 | int ow = jcp.ow; |
3040 | |
3041 | mov(reg_kh, jcp.kh); |
3042 | xor_(reg_ih_count, reg_ih_count); |
3043 | xor_(reg_oj, reg_oj); |
3044 | /* Compute 'top' edge */ |
3045 | if (t_pad > 0) { |
3046 | const int kh_range = 1 + (jcp.kh - 1) * dilate_h; |
3047 | const int overflow |
3048 | = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); |
3049 | const int underflow = div_up(t_pad, dilate_h); |
3050 | const int initial_inp_ker_overlap = jcp.kh - overflow - underflow; |
3051 | mov(reg_kh, initial_inp_ker_overlap); |
3052 | add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block |
3053 | * jcp.oc_block); |
3054 | // generate loop to process kernel while it remains within t_pad + ih |
3055 | if (kh_range < t_pad + jcp.ih) { |
3056 | if (is_dilated) { |
3057 | const int tail = t_pad % dilate_h; |
3058 | const int shift = tail == 0 ? 0 : dilate_h - tail; |
3059 | mov(reg_tmp, shift); |
3060 | if (tail != 0) |
3061 | add(reg_input, jcp.typesize_in * shift * iw * inp_mult); |
3062 | } |
3063 | L(oh_tpad_label); { |
3064 | compute_oh_step_disp(); |
3065 | add(reg_output, jcp.typesize_in * ow * jcp.oc_block); |
3066 | if (is_dilated) { |
3067 | inc(reg_tmp); |
3068 | cmp(reg_tmp, dilate_h); |
3069 | jl(oh_dilate_label_shift, T_NEAR); |
3070 | // unshift input as new kernel element enters |
3071 | sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult); |
3072 | xor_(reg_tmp, reg_tmp); |
3073 | } |
3074 | // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 |
3075 | sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw |
3076 | * jcp.ic_block * jcp.oc_block); |
3077 | add(reg_kh, stride_h); |
3078 | if (is_dilated) { |
3079 | jmp(oh_dilate_label_noshift, T_NEAR); |
3080 | L(oh_dilate_label_shift); |
3081 | // shift input as old kernel element progresses |
3082 | add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); |
3083 | L(oh_dilate_label_noshift); |
3084 | } |
3085 | inc(reg_oj); |
3086 | add(reg_ih_count, stride_h); |
3087 | |
3088 | // final number of kernel elements that overlap with input |
3089 | const int final_inp_ker_overlap |
3090 | = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h)); |
3091 | cmp(reg_kh, final_inp_ker_overlap); |
3092 | jl(oh_tpad_label, T_NEAR); |
3093 | } |
3094 | } |
3095 | // need second loop to process kernel if it is larger than the input |
3096 | // (does not apply to dilations as they must have unit stride) |
3097 | if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : |
3098 | t_pad % stride_h)) { |
3099 | assert(!is_dilated); |
3100 | mov(reg_kh, jcp.ih); |
3101 | L(oh_tpad_tail_label); { |
3102 | compute_oh_step_disp(); |
3103 | add(reg_output, jcp.typesize_in * ow * jcp.oc_block); |
3104 | sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw |
3105 | * jcp.ic_block * jcp.oc_block); |
3106 | |
3107 | inc(reg_oj); |
3108 | add(reg_ih_count, stride_h); |
3109 | |
3110 | cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h)); |
3111 | jl(oh_tpad_tail_label, T_NEAR); |
3112 | } |
3113 | } |
3114 | // correct any excess shifts to kernel and input |
3115 | // (does not apply to dilations as they must have unit stride, |
3116 | // kernel must fit inside input, and padding is smaller than input) |
3117 | if (t_pad <= jcp.oh * stride_h) { |
3118 | // kernel has moved beyond padding (adjust for stride effects) |
3119 | if (t_pad % stride_h != 0) { |
3120 | assert(!is_dilated); |
3121 | int inp_corr = stride_h - t_pad % stride_h; |
3122 | add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw |
3123 | * jcp.ic_block * jcp.oc_block); |
3124 | add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult); |
3125 | } |
3126 | } else { |
3127 | // kernel still overlaps padding (complete reset) |
3128 | assert(!is_dilated); |
3129 | sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h) |
3130 | * jcp.kw * jcp.ic_block * jcp.oc_block); |
3131 | } |
3132 | } |
3133 | |
3134 | cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); |
3135 | jge(oh_label_end, T_NEAR); |
3136 | cmp(reg_oj, jcp.oh); |
3137 | jge(oh_label, T_NEAR); |
3138 | |
3139 | /* Compute middle block(s) */ |
3140 | mov(reg_kh, jcp.kh); |
3141 | L(oh_label); { |
3142 | compute_oh_step_disp(); |
3143 | add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); |
3144 | add(reg_output, jcp.typesize_in * ow * jcp.oc_block); |
3145 | |
3146 | inc(reg_oj); |
3147 | add(reg_ih_count, stride_h); |
3148 | |
3149 | cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); |
3150 | jge(oh_label_end, T_NEAR); |
3151 | |
3152 | cmp(reg_oj, jcp.oh); |
3153 | jl(oh_label, T_NEAR); |
3154 | } |
3155 | L(oh_label_end); |
3156 | |
3157 | /* Compute bottom edge */ |
3158 | if (b_pad > 0) { |
3159 | cmp(reg_oj, jcp.oh); |
3160 | jge(oh_bpad_label_end, T_NEAR); |
3161 | |
3162 | if (is_dilated) { |
3163 | mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations |
3164 | mov(reg_tmp, 0); |
3165 | } else { |
3166 | mov(reg_kh, jcp.ihp - b_pad); |
3167 | sub(reg_kh, reg_ih_count); |
3168 | } |
3169 | L(oh_bpad_label); |
3170 | { |
3171 | compute_oh_step_disp(); |
3172 | add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); |
3173 | add(reg_output, jcp.typesize_in * ow * jcp.oc_block); |
3174 | if (is_dilated) { |
3175 | inc(reg_tmp); |
3176 | cmp(reg_tmp, dilate_h); |
3177 | jl(oh_dilate_label_end, T_NEAR); |
3178 | xor_(reg_tmp, reg_tmp); |
3179 | } |
3180 | sub(reg_kh, stride_h); |
3181 | cmp(reg_kh, 0); |
3182 | jle(oh_bpad_label_end, T_NEAR); |
3183 | if (is_dilated) |
3184 | L(oh_dilate_label_end); |
3185 | |
3186 | inc(reg_oj); |
3187 | cmp(reg_oj, jcp.oh); |
3188 | jl(oh_bpad_label, T_NEAR); |
3189 | } |
3190 | L(oh_bpad_label_end); |
3191 | } |
3192 | } |
3193 | |
3194 | void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() { |
3195 | int ic_block = jcp.ic_block; |
3196 | int oc_block = jcp.oc_block; |
3197 | const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; |
3198 | int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; |
3199 | int ow = jcp.ow; |
3200 | const int input_backpad_overlap |
3201 | = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); |
3202 | |
3203 | const size_t filter_shift |
3204 | = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block; |
3205 | const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult; |
3206 | const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block; |
3207 | |
3208 | Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, |
3209 | backpad_end_label, backpad_label; |
3210 | |
3211 | if (jcp.with_bias) bias_kernel(); |
3212 | |
3213 | /* initially offset 'kd' by f_pad */ |
3214 | add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); |
3215 | |
3216 | mov(reg_input_d, ptr[param + GET_OFF(src)]); |
3217 | mov(reg_output_d, ptr[param + GET_OFF(dst)]); |
3218 | mov(reg_d_index, ptr[param + GET_OFF(d_index)]); |
3219 | mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); |
3220 | |
3221 | cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); |
3222 | jge(loop_end_label, T_NEAR); |
3223 | |
3224 | L(d_loop_label); |
3225 | |
3226 | mov(reg_input, reg_input_d); |
3227 | mov(reg_output, reg_output_d); |
3228 | |
3229 | push(reg_input_d); |
3230 | push(reg_output_d); |
3231 | push(reg_d_index); |
3232 | |
3233 | compute_oh_loop_common(); |
3234 | |
3235 | pop(reg_d_index); |
3236 | pop(reg_output_d); |
3237 | pop(reg_input_d); |
3238 | |
3239 | /* Compute 'front' edge */ |
3240 | if (jcp.f_pad > 0) { |
3241 | |
3242 | /* Check if within fpad region */ |
3243 | cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); |
3244 | jge(fpad_end_label, T_NEAR); |
3245 | |
3246 | /* Fpad steps */ |
3247 | sub(reg_kernel, filter_shift * jcp.stride_d); |
3248 | add(reg_kd_count, jcp.stride_d); |
3249 | |
3250 | /* Final number of kernel elements that overlap with input */ |
3251 | const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id); |
3252 | cmp(reg_kd_count, inp_ker_overlap); |
3253 | jl(common_block_label, T_NEAR); |
3254 | |
3255 | /* Correct any excess shifts to kernel and input */ |
3256 | if (jcp.f_pad <= jcp.od * jcp.stride_d) { |
3257 | /* Filter has moved beyond padding (adjust for stride effects) */ |
3258 | if (jcp.f_pad % jcp.stride_d != 0) { |
3259 | int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; |
3260 | add(reg_kernel, filter_shift * inp_corr); |
3261 | add(reg_input_d, input_shift * inp_corr); |
3262 | } |
3263 | } else { |
3264 | /* Filter still overlaps padding (complete reset) */ |
3265 | sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); |
3266 | } |
3267 | |
3268 | /* Apply correction */ |
3269 | mov(reg_kd_count, jcp.kd); |
3270 | jmp(common_block_label); |
3271 | |
3272 | L(fpad_end_label); |
3273 | } |
3274 | |
3275 | /* Compute bottom edge */ |
3276 | if (jcp.back_pad > 0) { |
3277 | |
3278 | /* Check if within back_pad region */ |
3279 | cmp(reg_d_index, input_backpad_overlap - 1); |
3280 | jl(backpad_end_label, T_NEAR); |
3281 | jg(backpad_label, T_NEAR); |
3282 | |
3283 | /* Execute overlap correction between the filter and the initial |
3284 | * back_pad region. */ |
3285 | mov(reg_kd_count, |
3286 | jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d); |
3287 | jmp(backpad_end_label, T_NEAR); |
3288 | |
3289 | L(backpad_label); |
3290 | sub(reg_kd_count, jcp.stride_d); |
3291 | cmp(reg_kd_count, 0); |
3292 | jle(loop_end_label, T_NEAR); |
3293 | |
3294 | L(backpad_end_label); |
3295 | } |
3296 | |
3297 | /* Compute middle block */ |
3298 | add(reg_input_d, input_shift * jcp.stride_d); |
3299 | |
3300 | /* Execute common block and loop */ |
3301 | L(common_block_label); |
3302 | add(reg_output_d, output_shift); |
3303 | inc(reg_d_index); |
3304 | cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); |
3305 | jl(d_loop_label, T_NEAR); |
3306 | |
3307 | L(loop_end_label); |
3308 | } |
3309 | |
3310 | bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() { |
3311 | // FIXME: use register mapping from the class declaration |
3312 | bool ok = jcp.ver == ver_4fma |
3313 | && everyone_is(0, jcp.dilate_h, jcp.dilate_w) |
3314 | && everyone_is(1, jcp.stride_h, jcp.stride_w); |
3315 | if (!ok) return false; |
3316 | if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2) |
3317 | return false; |
3318 | |
3319 | // General code layout: |
3320 | // |
3321 | // Blocking over OH -- top level |
3322 | // (Reduces L2 pressure; not very useful right now) |
3323 | // Loop over all KHxKW kernel -- emit_kh_kw_loop() |
3324 | // Loop over OH block -- emit_h_loop() |
3325 | // Loop over OW blocks -- emit_fma_block() |
3326 | // (Supports both fully unrolled and partially unrolled versions to |
3327 | // reduce code size) |
3328 | // Loop over OW block -- emit_fma_step() |
3329 | |
3330 | int max_working_set_size = 128 * 1024; |
3331 | int pad_ow = jcp.ow; |
3332 | |
3333 | int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in; |
3334 | int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in; |
3335 | int row_size = inp_row_size + out_row_size; |
3336 | |
3337 | int h_block_size = jcp.oh; |
3338 | int working_set_size = row_size * h_block_size; |
3339 | |
3340 | if (working_set_size > max_working_set_size) { |
3341 | int opt_working_set_size = 48 * 1024; |
3342 | assert(opt_working_set_size < max_working_set_size); |
3343 | |
3344 | while (working_set_size > opt_working_set_size) { |
3345 | for (int i = 2; i <= h_block_size; i++) |
3346 | if (i == h_block_size) |
3347 | h_block_size = h_block_size / 2; |
3348 | else if (h_block_size % i == 0) { |
3349 | h_block_size = h_block_size / i; |
3350 | break; |
3351 | } |
3352 | working_set_size = row_size * h_block_size; |
3353 | |
3354 | if (h_block_size == 1 && working_set_size > opt_working_set_size) |
3355 | return false; |
3356 | } |
3357 | } |
3358 | |
3359 | // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below) |
3360 | if (h_block_size < nstl::max(1, jcp.t_pad) |
3361 | || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size |
3362 | : jcp.oh % h_block_size)) |
3363 | return false; |
3364 | |
3365 | // check that we can use simple arithmetic for prefetch address |
3366 | // calculations |
3367 | // TODO: we need some traits for this check (Roma) |
3368 | int cache_line_size = 64; |
3369 | assert(jcp.ic_block * typesize == 64); |
3370 | assert(jcp.oc_block * typesize == 64); |
3371 | |
3372 | int num_inp_l2_pfs = jcp.tr_iw * h_block_size; |
3373 | int avg_h_loop_len = h_block_size; |
3374 | int num_inp_l2_pfs_per_fma_block |
3375 | = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); |
3376 | int num_out_l2_pfs = pad_ow * h_block_size; |
3377 | int num_out_l2_pfs_per_fma_block |
3378 | = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); |
3379 | |
3380 | Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors |
3381 | Reg64 reg_kh = rax; |
3382 | Reg64 reg_kw = rbx; |
3383 | Reg64 reg_tmp = abi_not_param1; |
3384 | Reg32 reg_tmp_w = reg_tmp.cvt32(); |
3385 | Reg64 reg_ohs = rdx; |
3386 | Reg64 reg_ihs = rsi; |
3387 | Reg64 reg_h = r8; |
3388 | Reg64 reg_i = r9; |
3389 | Reg64 reg_j = r10; |
3390 | |
3391 | Reg64 reg_inp = r13; |
3392 | Reg64 reg_out = r14; |
3393 | Reg64 reg_ker = r15; |
3394 | |
3395 | Reg64 reg_inp_pf_l1 = rbp; |
3396 | |
3397 | Reg64 reg_inp_pf_l2 = r11; |
3398 | Reg64 reg_out_pf_l2 = r12; |
3399 | |
3400 | Xmm reg_inp_pf_save = xmm17; |
3401 | Xmm reg_out_pf_save = xmm18; |
3402 | |
3403 | Reg64 reg_inp_save = abi_param1; |
3404 | Reg64 reg_out_save = reg_tmp; |
3405 | |
3406 | auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); }; |
3407 | auto zmm_ker = [&](int ic1) { return Zmm(ic1); }; |
3408 | auto inp_addr = [&](int oi, int ic1) { |
3409 | return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in]; |
3410 | }; |
3411 | auto out_addr = [&](int oi, int oj = 0) { |
3412 | assert(jcp.ver == ver_4fma); |
3413 | return ptr[reg_out |
3414 | + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in]; |
3415 | }; |
3416 | auto ker_addr = [&](int ic1) { |
3417 | return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out]; |
3418 | }; |
3419 | |
3420 | auto emit_block = [&](int h_block_size, |
3421 | bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row) |
3422 | { |
3423 | // TODO: add an fma version (Roma) |
3424 | auto pad_ow = jcp.ow; |
3425 | |
3426 | int ow4u = rnd_up(pad_ow, 4); |
3427 | int def_step_size = 16; |
3428 | |
3429 | bool has_w_tail = (pad_ow % def_step_size != 0 |
3430 | || pad_ow % 4 != 0); |
3431 | bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail; |
3432 | |
3433 | auto emit_step = [&](int ur_ow, |
3434 | int num_inp_l1_pfs_per_fma_step, |
3435 | int num_inp_l2_pfs_per_fma_step, |
3436 | int num_out_l2_pfs_per_fma_step, bool is_w_tail) |
3437 | { |
3438 | bool block_wraparound = is_w_tail && is_last_row; |
3439 | |
3440 | assert(ur_ow % 4 == 0); |
3441 | int tail_size = ow4u % ur_ow; |
3442 | int this_ur_ow |
3443 | = (is_w_tail && tail_size) ? tail_size : ur_ow; |
3444 | int ow_last_chunk4 = pad_ow % 4; |
3445 | int ow_zero_tail4 = ow_last_chunk4 |
3446 | ? 4 - ow_last_chunk4 : 0; |
3447 | |
3448 | auto emit_out_pf = [&](int oi) { |
3449 | #if 1 |
3450 | if (oi + def_step_size < ur_ow || !block_wraparound) |
3451 | mic_prefetcht0(ptr[reg_out |
3452 | + ((def_step_size + oi) |
3453 | * jcp.oc_block * jcp.typesize_in)]); |
3454 | else { |
3455 | assert(block_wraparound); |
3456 | assert(oi + def_step_size >= ur_ow); |
3457 | mic_prefetcht0(ptr[reg_out_save |
3458 | + ((oi + def_step_size - ur_ow) |
3459 | * jcp.oc_block * jcp.typesize_in)]); |
3460 | } |
3461 | #else |
3462 | // XXX: This is an alternative prefetching strategy that |
3463 | // always prefetches the next row. Keeping it here for |
3464 | // future experiments (Roma) |
3465 | if (!block_wraparound) |
3466 | mic_prefetcht0(ptr[reg_out |
3467 | + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]); |
3468 | else |
3469 | mic_prefetcht0(ptr[reg_out + reg_ohs |
3470 | - ((h_block_size - 1) * jcp.ow |
3471 | - oi) * jcp.oc_block * jcp.typesize_in]); |
3472 | #endif |
3473 | if (oi < num_out_l2_pfs_per_fma_step) |
3474 | mic_prefetcht1(ptr[reg_out_pf_l2 |
3475 | + oi * jcp.oc_block * jcp.typesize_in]); |
3476 | }; |
3477 | |
3478 | auto emit_inp_pf = [&](int oi4, int ic1) { |
3479 | int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block; |
3480 | int num_pf_slots = jcp.ic_block * ur_ow / 4; |
3481 | |
3482 | int num_pfs = num_inp_l1_pfs_per_fma_step |
3483 | + num_inp_l2_pfs_per_fma_step; |
3484 | int pf_freq = nstl::max(1, num_pf_slots / num_pfs); |
3485 | |
3486 | if (pf_slot_idx % pf_freq) |
3487 | return; |
3488 | |
3489 | int pf_idx = pf_slot_idx / pf_freq; |
3490 | |
3491 | if (pf_idx < num_inp_l2_pfs_per_fma_step) |
3492 | mic_prefetcht1(ptr[reg_inp_pf_l2 |
3493 | + pf_idx * jcp.ic_block * jcp.typesize_in]); |
3494 | else { |
3495 | pf_idx -= num_inp_l2_pfs_per_fma_step; |
3496 | // prefetch the 'tail' of the cache line because most of |
3497 | // the accesses are not aligned |
3498 | mic_prefetcht0(ptr[reg_inp_pf_l1 |
3499 | + pf_idx * jcp.ic_block * jcp.typesize_in |
3500 | + cache_line_size - jcp.typesize_in]); |
3501 | } |
3502 | }; |
3503 | |
3504 | auto numloads = 4; |
3505 | |
3506 | int steps = this_ur_ow; |
3507 | for (int oi4 = 0; oi4 < steps; oi4 += numloads) { |
3508 | for (int oi1 = 0; oi1 < numloads; oi1++) { |
3509 | int oi = oi4 + oi1; |
3510 | if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) { |
3511 | vmovups(zmm_out(oi), out_addr(oi)); |
3512 | emit_out_pf(oi); |
3513 | } else { |
3514 | auto zmm = zmm_out(oi); |
3515 | vpxord(zmm, zmm, zmm); |
3516 | } |
3517 | } |
3518 | |
3519 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3520 | if (jcp.ver == ver_4fma) { |
3521 | v4fmaddps(zmm_ker(ic1), |
3522 | zmm_out(oi4), inp_addr(oi4, ic1)); |
3523 | } else { |
3524 | assert(!"unknown convolution version" ); |
3525 | } |
3526 | emit_inp_pf(oi4, ic1); |
3527 | } |
3528 | } |
3529 | }; |
3530 | |
3531 | // Input is transposed and padded but we only access about jcp.iw |
3532 | // elements so use that to compute the # of cache lines in each 'row' |
3533 | int num_inp_l1_pfs |
3534 | = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block; |
3535 | |
3536 | if (full_w_unroll) { |
3537 | emit_step(ow4u, num_inp_l1_pfs, |
3538 | num_inp_l2_pfs_per_fma_block, |
3539 | num_out_l2_pfs_per_fma_block, true); |
3540 | add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size); |
3541 | add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size); |
3542 | } else { |
3543 | Label w_loop; |
3544 | int num_w_iters = pad_ow / def_step_size; |
3545 | int num_w_iters_full = num_w_iters + has_w_tail; |
3546 | int num_inp_l1_pfs_per_fma_step |
3547 | = div_up(num_inp_l1_pfs, num_w_iters_full); |
3548 | int num_inp_l2_pfs_per_fma_step |
3549 | = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full); |
3550 | int num_out_l2_pfs_per_fma_step |
3551 | = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full); |
3552 | mov(reg_i, num_w_iters); |
3553 | L(w_loop); { |
3554 | emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, |
3555 | num_inp_l2_pfs_per_fma_step, |
3556 | num_out_l2_pfs_per_fma_step, false); |
3557 | add(reg_inp, def_step_size * jcp.typesize_in); |
3558 | add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in); |
3559 | add(reg_inp_pf_l1, |
3560 | num_inp_l1_pfs_per_fma_step * cache_line_size); |
3561 | add(reg_inp_pf_l2, |
3562 | num_inp_l2_pfs_per_fma_step * cache_line_size); |
3563 | add(reg_out_pf_l2, |
3564 | num_out_l2_pfs_per_fma_step * cache_line_size); |
3565 | sub(reg_i, 1); |
3566 | jnz(w_loop); |
3567 | } |
3568 | if (has_w_tail) { |
3569 | emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, |
3570 | num_inp_l2_pfs_per_fma_step, |
3571 | num_out_l2_pfs_per_fma_step, true); |
3572 | add(reg_inp_pf_l2, |
3573 | num_inp_l2_pfs_per_fma_step * cache_line_size); |
3574 | add(reg_out_pf_l2, |
3575 | num_out_l2_pfs_per_fma_step * cache_line_size); |
3576 | } |
3577 | // reset reg_inp and reg_out because emit_h_loop expects |
3578 | // unmodified pointers |
3579 | int w_offset = num_w_iters * def_step_size; |
3580 | sub(reg_inp, w_offset * jcp.typesize_in); |
3581 | sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in); |
3582 | } |
3583 | }; |
3584 | |
3585 | auto emit_h_loop = [&](int h_block_size, |
3586 | bool is_last_block, bool is_last_kh_kw_iter) |
3587 | { |
3588 | Label h_loop, skip_h_loop; |
3589 | mov(reg_j, 1); |
3590 | cmp(reg_j, reg_h); |
3591 | je(skip_h_loop, T_NEAR); |
3592 | L(h_loop); { |
3593 | |
3594 | lea(reg_inp_pf_l1, |
3595 | ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]); |
3596 | emit_block(h_block_size, |
3597 | is_last_block, is_last_kh_kw_iter, false); |
3598 | |
3599 | add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in); |
3600 | add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in); |
3601 | add(reg_j, 1); |
3602 | cmp(reg_j, reg_h); |
3603 | jb(h_loop); |
3604 | } |
3605 | |
3606 | L(skip_h_loop); |
3607 | |
3608 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) |
3609 | mic_prefetcht0(ker_addr(ic1)); |
3610 | |
3611 | lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]); |
3612 | emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true); |
3613 | }; |
3614 | |
3615 | auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block, |
3616 | int h_block_size) |
3617 | { |
3618 | xor_(reg_kh, reg_kh); |
3619 | Label kh_loop, kh_loop_end; |
3620 | |
3621 | int last_oh_block_size |
3622 | = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size); |
3623 | int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size; |
3624 | // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size |
3625 | int ih_block_size = oh_block_size - 1 + jcp.kh |
3626 | - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad; |
3627 | |
3628 | L(kh_loop); { |
3629 | // determine starting indices for this block |
3630 | if (is_first_block) { |
3631 | xor_(reg_tmp, reg_tmp); |
3632 | mov(reg_ohs, jcp.t_pad); |
3633 | sub(reg_ohs, reg_kh); |
3634 | cmovb(reg_ohs, reg_tmp); |
3635 | |
3636 | mov(reg_ihs, reg_ohs); |
3637 | sub(reg_ihs, jcp.t_pad); |
3638 | add(reg_ihs, reg_kh); |
3639 | } else { |
3640 | xor_(reg_ohs, reg_ohs); |
3641 | mov(reg_ihs, reg_kh); |
3642 | } |
3643 | |
3644 | // determine effective size of block based on padding |
3645 | mov(reg_tmp, oh_block_size); |
3646 | sub(reg_tmp, reg_ohs); |
3647 | mov(reg_h, ih_block_size); |
3648 | sub(reg_h, reg_ihs); |
3649 | cmp(reg_tmp, reg_h); |
3650 | cmovb(reg_h, reg_tmp); |
3651 | |
3652 | Label kh_loop_work; |
3653 | cmp(reg_h, 0); |
3654 | jg(kh_loop_work, T_NEAR); |
3655 | |
3656 | // empty h loop for this jcp.kh: |
3657 | // - set the output to 0 if necessary |
3658 | // - move ker pt |
3659 | // - jump to the end |
3660 | sub(reg_h, 1); |
3661 | Label skip_ker_zeroing; |
3662 | |
3663 | // The reg_ker ptr has highest bit set if the output needs to be |
3664 | // zeroed. Those who have byte-aligned their data will suffer the |
3665 | // consiquences :( |
3666 | // TODO: move the flag to a mask register? (Roma) |
3667 | test(reg_ker, 1); |
3668 | jz(skip_ker_zeroing, T_NEAR); |
3669 | |
3670 | Label zeroing_loop; |
3671 | vpxord(zmm0, zmm0, zmm0); |
3672 | and_(reg_ker, ~1); // temporarily clear the zeroing flag |
3673 | mov(reg_tmp, jcp.kw); |
3674 | L(zeroing_loop); { |
3675 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) |
3676 | vmovups(ker_addr(ic1), zmm0); |
3677 | add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out); |
3678 | sub(reg_tmp, 1); |
3679 | jnz(zeroing_loop, T_NEAR); |
3680 | } |
3681 | // restore the zeroing flag (it will be cleared after the end of |
3682 | // emit_kh_kw_loop, but we may need it until then) |
3683 | or_(reg_ker, 1); |
3684 | jmp(kh_loop_end, T_NEAR); |
3685 | |
3686 | L(skip_ker_zeroing); |
3687 | add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw |
3688 | * jcp.typesize_out); |
3689 | jmp(kh_loop_end, T_NEAR); |
3690 | |
3691 | L(kh_loop_work); |
3692 | |
3693 | mul_by_const(reg_ihs, reg_tmp, |
3694 | jcp.tr_iw * jcp.ic_block * jcp.typesize_in); |
3695 | mul_by_const(reg_ohs, reg_tmp, |
3696 | pad_ow * jcp.oc_block * jcp.typesize_in); |
3697 | |
3698 | add(reg_inp, reg_ihs); |
3699 | add(reg_out, reg_ohs); |
3700 | |
3701 | Label kw_loop; |
3702 | xor_(reg_kw, reg_kw); |
3703 | L(kw_loop); { |
3704 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3705 | auto zmm = zmm_ker(ic1); |
3706 | vpxord(zmm, zmm, zmm); |
3707 | mic_prefetcht1(ker_addr(ic1)); |
3708 | } |
3709 | |
3710 | mov(reg_out_save, reg_out); |
3711 | mov(reg_inp_save, reg_inp); |
3712 | lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]); |
3713 | |
3714 | #if 0 |
3715 | // XXX: Generate code with special prefetches when switching |
3716 | // blocks or at the end of the last block. Disabled to reduce |
3717 | // code size and because there's no performance benefit (Roma) |
3718 | Label regular_h_loop, end_h_loop; |
3719 | cmp(reg_kw, jcp.kw - 1); |
3720 | jne(regular_h_loop, T_NEAR); |
3721 | cmp(reg_kh, jcp.kh - 1); |
3722 | jne(regular_h_loop, T_NEAR); |
3723 | |
3724 | emit_h_loop(oh_block_size, is_last_block, true); |
3725 | jmp(end_h_loop, T_NEAR); |
3726 | |
3727 | L(regular_h_loop); |
3728 | emit_h_loop(oh_block_size, is_last_block, false); |
3729 | |
3730 | L(end_h_loop); |
3731 | #else |
3732 | emit_h_loop(oh_block_size, is_last_block, false); |
3733 | #endif |
3734 | |
3735 | mov(reg_out, reg_out_save); |
3736 | mov(reg_inp, reg_inp_save); |
3737 | |
3738 | Label do_store; |
3739 | // The reg_ker ptr has highest bit set if the output needs to |
3740 | // be zeroed. Those who have byte-aligned their data will |
3741 | // suffer the consiquences :( |
3742 | mov(reg_tmp, reg_ker); |
3743 | and_(reg_ker, ~1); |
3744 | test(reg_tmp, 1); |
3745 | jnz(do_store, T_NEAR); |
3746 | |
3747 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3748 | auto zmm = zmm_ker(ic1); |
3749 | if (jcp.ver == ver_4fma) { |
3750 | vaddps(zmm, ker_addr(ic1)); |
3751 | } else { |
3752 | assert(!"unknown convolution version" ); |
3753 | } |
3754 | } |
3755 | |
3756 | L(do_store); |
3757 | for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { |
3758 | auto zmm = zmm_ker(ic1); |
3759 | vmovups(ker_addr(ic1), zmm); |
3760 | } |
3761 | |
3762 | mov(reg_ker, reg_tmp); |
3763 | add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out); |
3764 | add(reg_kw, 1); |
3765 | cmp(reg_kw, jcp.kw); |
3766 | jl(kw_loop); |
3767 | } |
3768 | |
3769 | sub(reg_inp, reg_ihs); |
3770 | sub(reg_out, reg_ohs); |
3771 | |
3772 | |
3773 | L(kh_loop_end); |
3774 | add(reg_kh, 1); |
3775 | cmp(reg_kh, jcp.kh); |
3776 | jl(kh_loop); |
3777 | } |
3778 | }; |
3779 | |
3780 | mov(reg_inp, ptr[param + GET_OFF(src)]); |
3781 | mov(reg_out, ptr[param + GET_OFF(dst)]); |
3782 | mov(reg_ker, ptr[param + GET_OFF(filt)]); |
3783 | mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]); |
3784 | mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]); |
3785 | mov(reg_tmp, ptr[param + GET_OFF(channel)]); |
3786 | or_(reg_ker, reg_tmp); |
3787 | |
3788 | bool single_kh_kw_loop = (h_block_size == jcp.oh); |
3789 | |
3790 | size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in; |
3791 | size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad); |
3792 | size_t inp_block_step = inp_row_step * h_block_size; |
3793 | size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in |
3794 | * h_block_size; |
3795 | |
3796 | if (!single_kh_kw_loop) { |
3797 | // Save the original prefetch pointers from the OpenMP driver |
3798 | vmovq(reg_inp_pf_save, reg_inp_pf_l2); |
3799 | vmovq(reg_out_pf_save, reg_out_pf_l2); |
3800 | mov(reg_inp_pf_l2, reg_inp); |
3801 | add(reg_inp_pf_l2, first_inp_block_step); |
3802 | mov(reg_out_pf_l2, reg_out); |
3803 | add(reg_out_pf_l2, out_block_step); |
3804 | } |
3805 | emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size); |
3806 | |
3807 | if (!single_kh_kw_loop) { |
3808 | size_t ker_reset_offset |
3809 | = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh; |
3810 | sub(reg_ker, ker_reset_offset); |
3811 | and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates |
3812 | |
3813 | add(reg_inp, first_inp_block_step); |
3814 | add(reg_out, out_block_step); |
3815 | mov(reg_inp_pf_l2, reg_inp); |
3816 | add(reg_inp_pf_l2, inp_block_step); |
3817 | mov(reg_out_pf_l2, reg_out); |
3818 | add(reg_out_pf_l2, out_block_step); |
3819 | |
3820 | int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2; |
3821 | if (num_innermost_iters > 0) { |
3822 | Label h_block_loop; |
3823 | |
3824 | mov(reg_tmp_w, num_innermost_iters); |
3825 | kmovw(reg_h_block, reg_tmp_w); |
3826 | L(h_block_loop); { |
3827 | emit_kh_kw_loop(false, false, h_block_size); |
3828 | sub(reg_ker, ker_reset_offset); |
3829 | add(reg_inp, inp_row_step * h_block_size); |
3830 | add(reg_out, out_block_step); |
3831 | mov(reg_inp_pf_l2, reg_inp); |
3832 | add(reg_inp_pf_l2, inp_block_step); |
3833 | mov(reg_out_pf_l2, reg_out); |
3834 | add(reg_out_pf_l2, out_block_step); |
3835 | kmovw(reg_tmp_w, reg_h_block); |
3836 | sub(reg_tmp_w, 1); |
3837 | kmovw(reg_h_block, reg_tmp_w); |
3838 | jnz(h_block_loop); |
3839 | } |
3840 | } |
3841 | |
3842 | // Restore the original prefetch pointers that came from the OpenMP |
3843 | // driver |
3844 | vmovq(reg_inp_pf_l2, reg_inp_pf_save); |
3845 | vmovq(reg_out_pf_l2, reg_out_pf_save); |
3846 | emit_kh_kw_loop(false, true, h_block_size); |
3847 | } |
3848 | |
3849 | return true; |
3850 | } |
3851 | |
3852 | bool jit_avx512_common_conv_bwd_weights_kernel_f32 |
3853 | ::flat_4ops_compute() { |
3854 | const auto &j = jcp; |
3855 | const bool ok = j.ver == ver_4fma && j.is_1stconv |
3856 | && everyone_is(0, j.dilate_h, j.dilate_w); |
3857 | if (!ok) return false; |
3858 | |
3859 | Reg64 reg_ptr_tr_src = r8; |
3860 | Reg64 reg_ptr_dst = r9; |
3861 | Reg64 reg_ptr_wei = r10; |
3862 | Reg64 reg_ptr_bia = r11; |
3863 | |
3864 | Reg64 reg_kh_step = rax; |
3865 | Reg64 reg_oh = abi_not_param1; |
3866 | Reg64 reg_kh = rdx; |
3867 | |
3868 | Reg32 reg_flag_save = ebx; |
3869 | Reg32 reg_flag = esi; |
3870 | |
3871 | Zmm vbia(31); |
3872 | |
3873 | auto zmm_wei = [&](int kh, int kw) { |
3874 | return Zmm(8 + kh * j.kw + kw); |
3875 | }; |
3876 | auto zmm_dst = [&](int ow) { |
3877 | return Zmm(ow % 8); |
3878 | }; |
3879 | |
3880 | auto addr_tr_src = [&](int kh, int iw) { |
3881 | return ptr[reg_ptr_tr_src |
3882 | + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in]; |
3883 | }; |
3884 | auto addr_dst = [&](int ow) { |
3885 | return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in]; |
3886 | }; |
3887 | auto addr_wei = [&](int kh, int kw) { |
3888 | return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block |
3889 | * jcp.typesize_out]; |
3890 | }; |
3891 | |
3892 | auto emit_fma_block = [&](int kh_step) { |
3893 | for (int kh = 0; kh < kh_step; ++kh) { |
3894 | for (int kw = 0; kw < j.kw; ++kw) { |
3895 | auto vwei = zmm_wei(kh, kw); |
3896 | vpxord(vwei, vwei, vwei); |
3897 | } |
3898 | } |
3899 | |
3900 | for (int ow = 0; ow < j.ow; ow += 4) { |
3901 | for (int _ow = ow; _ow < ow + 4; ++_ow) { |
3902 | auto vdst = zmm_dst(_ow); |
3903 | if (_ow < j.ow) |
3904 | vmovups(vdst, addr_dst(_ow)); |
3905 | else |
3906 | vpxord(vdst, vdst, vdst); |
3907 | } |
3908 | |
3909 | for (int kh = 0; kh < kh_step; ++kh) { |
3910 | for (int kw = 0; kw < j.kw; ++kw) { |
3911 | const int iw = ow + (kw % j.stride_w) * j.tr_ld |
3912 | + (kw / j.stride_w); |
3913 | v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow), |
3914 | addr_tr_src(kh, iw)); |
3915 | if (1 && kh == 0 && kw < 4) { |
3916 | prefetcht1(ptr[reg_ptr_dst |
3917 | + (j.ow + ow + kw) * jcp.oc_block |
3918 | * jcp.typesize_in]); |
3919 | } |
3920 | if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */ |
3921 | const int off = kw + 4 - j.kw; |
3922 | if (off >= 0 && ow + off < j.ow) |
3923 | vaddps(vbia, vbia, zmm_dst(ow + off)); |
3924 | } |
3925 | } |
3926 | } |
3927 | } |
3928 | |
3929 | Label l_store; |
3930 | test(reg_flag, FLAG_MB_FIRST); |
3931 | jnz(l_store, T_NEAR); |
3932 | for (int kh = 0; kh < kh_step; ++kh) { |
3933 | for (int kw = 0; kw < j.kw; ++kw) |
3934 | vaddps(zmm_wei(kh, kw), addr_wei(kh, kw)); |
3935 | } |
3936 | L(l_store); |
3937 | for (int kh = 0; kh < kh_step; ++kh) { |
3938 | for (int kw = 0; kw < j.kw; ++kw) |
3939 | vmovups(addr_wei(kh, kw), zmm_wei(kh, kw)); |
3940 | } |
3941 | }; |
3942 | |
3943 | auto emit_kh_loop = [&]() { |
3944 | const int kh_step_rem = j.kh % j.kh_step; |
3945 | xor_(reg_kh, reg_kh); |
3946 | mov(reg_kh_step, j.kh_step); |
3947 | |
3948 | Label l_kh_loop; |
3949 | L(l_kh_loop); { |
3950 | Label l_done; |
3951 | |
3952 | if (kh_step_rem != 0) { |
3953 | Label l_keep_kh_step; |
3954 | cmp(reg_kh, j.kh - j.kh_step); |
3955 | jle(l_keep_kh_step, T_NEAR); |
3956 | |
3957 | mov(reg_kh_step, kh_step_rem); |
3958 | emit_fma_block(kh_step_rem); |
3959 | jmp(l_done, T_NEAR); |
3960 | |
3961 | L(l_keep_kh_step); |
3962 | } |
3963 | |
3964 | emit_fma_block(j.kh_step); |
3965 | |
3966 | L(l_done); |
3967 | |
3968 | add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld |
3969 | * jcp.typesize_in); |
3970 | add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out); |
3971 | add(reg_kh, j.kh_step); |
3972 | |
3973 | cmp(reg_kh, j.kh); |
3974 | jl(l_kh_loop, T_NEAR); |
3975 | } |
3976 | |
3977 | const int kh_steps = rnd_up(j.kh, j.kh_step); |
3978 | sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in); |
3979 | sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out); |
3980 | }; |
3981 | |
3982 | auto emit_oh_loop = [&]() { |
3983 | mov(reg_oh, j.oh); |
3984 | |
3985 | Label l_oh_loop; |
3986 | L(l_oh_loop); { |
3987 | Label l_restore_mb_flag, l_jump; |
3988 | |
3989 | cmp(reg_oh, j.oh); |
3990 | je(l_restore_mb_flag, T_NEAR); |
3991 | |
3992 | and_(reg_flag, ~FLAG_MB_FIRST); |
3993 | jmp(l_jump, T_NEAR); |
3994 | |
3995 | L(l_restore_mb_flag); |
3996 | mov(reg_flag, reg_flag_save); |
3997 | |
3998 | L(l_jump); |
3999 | |
4000 | emit_kh_loop(); |
4001 | |
4002 | add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld |
4003 | * jcp.typesize_in); |
4004 | add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in); |
4005 | |
4006 | dec(reg_oh); |
4007 | jnz(l_oh_loop, T_NEAR); |
4008 | } |
4009 | }; |
4010 | |
4011 | auto emit_bia_store = [&]() { |
4012 | if (!j.with_bias) return; |
4013 | |
4014 | Label l_bia_store, l_bia_skip; |
4015 | test(reg_flag, FLAG_IC_FIRST); |
4016 | jz(l_bia_skip); |
4017 | |
4018 | test(reg_flag, FLAG_MB_FIRST); |
4019 | jnz(l_bia_store, T_NEAR); |
4020 | vaddps(vbia, ptr[reg_ptr_bia]); |
4021 | L(l_bia_store); |
4022 | vmovups(ptr[reg_ptr_bia], vbia); |
4023 | L(l_bia_skip); |
4024 | }; |
4025 | |
4026 | mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]); |
4027 | mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]); |
4028 | mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]); |
4029 | mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]); |
4030 | mov(reg_flag_save, ptr[param + GET_OFF(flags)]); |
4031 | |
4032 | vpxord(vbia, vbia, vbia); |
4033 | emit_oh_loop(); |
4034 | emit_bia_store(); |
4035 | |
4036 | return true; |
4037 | } |
4038 | |
4039 | void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop() |
4040 | { |
4041 | if (flat_4ops_compute()) |
4042 | return; |
4043 | if (compute_full_spat_loop()) |
4044 | return; |
4045 | |
4046 | maybe_zero_kernel(); |
4047 | |
4048 | if (jcp.ndims == 5) compute_d_loop_common(); |
4049 | else compute_oh_loop_common(); |
4050 | } |
4051 | |
4052 | void jit_avx512_common_conv_bwd_weights_kernel_f32::generate() |
4053 | { |
4054 | preamble(); |
4055 | |
4056 | mov(reg_input, ptr[param + GET_OFF(src)]); |
4057 | mov(reg_output, ptr[param + GET_OFF(dst)]); |
4058 | mov(reg_kernel, ptr[param + GET_OFF(filt)]); |
4059 | |
4060 | compute_loop(); |
4061 | |
4062 | postamble(); |
4063 | } |
4064 | |
4065 | status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( |
4066 | jit_conv_conf_t &jcp, const convolution_desc_t &cd, |
4067 | memory_desc_t &src_md, memory_desc_t &diff_weights_md, |
4068 | memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) { |
4069 | if (!mayiuse(avx512_common)) |
4070 | return status::unimplemented; |
4071 | |
4072 | const memory_desc_wrapper src_d(&src_md); |
4073 | const memory_desc_wrapper diff_weights_d(&diff_weights_md); |
4074 | const memory_desc_wrapper diff_bias_d(&diff_bias_md); |
4075 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
4076 | |
4077 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
4078 | int ndims = src_d.ndims(); |
4079 | |
4080 | jcp = zero<decltype(jcp)>(); |
4081 | |
4082 | jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float); |
4083 | jcp.ndims = ndims; |
4084 | jcp.prop_kind = cd.prop_kind; |
4085 | |
4086 | jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; |
4087 | jcp.mb = src_d.dims()[0]; |
4088 | |
4089 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
4090 | jcp.oc_without_padding = jcp.oc; |
4091 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
4092 | |
4093 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
4094 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; |
4095 | jcp.iw = src_d.dims()[ndims-1]; |
4096 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
4097 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; |
4098 | jcp.ow = diff_dst_d.dims()[ndims-1]; |
4099 | |
4100 | jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; |
4101 | jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; |
4102 | jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; |
4103 | |
4104 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
4105 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; |
4106 | jcp.l_pad = cd.padding[0][ndims-3]; |
4107 | |
4108 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
4109 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; |
4110 | jcp.stride_w = cd.strides[ndims-3]; |
4111 | |
4112 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
4113 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; |
4114 | jcp.dilate_w = cd.dilates[ndims-3]; |
4115 | |
4116 | const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1); |
4117 | bool ok = true |
4118 | // general condition to simplify dilations |
4119 | && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) |
4120 | && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) |
4121 | && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) |
4122 | // special condition to simplify dilations in compute_oh_loop_common |
4123 | && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih); |
4124 | if (!ok) |
4125 | return status::unimplemented; |
4126 | |
4127 | jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w |
4128 | + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); |
4129 | jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h |
4130 | + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1)); |
4131 | jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d |
4132 | + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1)); |
4133 | |
4134 | /* XXX: currently, does not support dilation_d > 0 */ |
4135 | if (ndims == 5) |
4136 | if (jcp.dilate_d > 0) |
4137 | return status::unimplemented; |
4138 | |
4139 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
4140 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
4141 | jcp.ohp = jcp.oh; |
4142 | jcp.owp = jcp.ow; |
4143 | jcp.aligned_threads = 0; |
4144 | |
4145 | /* check for the 1st convolution */ |
4146 | jcp.is_1stconv = is_1stconv(jcp); |
4147 | |
4148 | jcp.oc_block = jcp.simd_w; |
4149 | |
4150 | bool ok_to_pad_channels = true |
4151 | && jcp.ngroups == 1 |
4152 | && src_d.data_type() == data_type::f32; |
4153 | |
4154 | if (ok_to_pad_channels) |
4155 | jcp.oc = rnd_up(jcp.oc, jcp.simd_w); |
4156 | |
4157 | if (jcp.oc % jcp.oc_block) |
4158 | return status::unimplemented; |
4159 | |
4160 | auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
4161 | auto wei_tag = with_groups |
4162 | ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) |
4163 | : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); |
4164 | |
4165 | if (diff_dst_d.format_kind() == format_kind::any) { |
4166 | CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); |
4167 | jcp.dst_tag = dst_tag; |
4168 | } else { |
4169 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag); |
4170 | } |
4171 | if (jcp.dst_tag != dst_tag) |
4172 | return status::unimplemented; |
4173 | |
4174 | /* conditions on bias memory */ |
4175 | jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; |
4176 | if (jcp.with_bias) { |
4177 | if (diff_bias_d.format_kind() == format_kind::any) |
4178 | CHECK(memory_desc_init_by_tag(diff_bias_md, x)); |
4179 | } |
4180 | |
4181 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
4182 | |
4183 | /* kernel applicability check wrt boundaries |
4184 | * the conditions are quite general across the kernels we have, |
4185 | * but ideally the check should belong to a specific kernel... */ |
4186 | const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2; |
4187 | const bool boundaries_ok = true |
4188 | && jcp.t_pad <= max_pad |
4189 | && jcp.b_pad <= max_pad |
4190 | && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad) |
4191 | && jcp.f_pad < jcp.kd; |
4192 | if (!boundaries_ok) |
4193 | return status::unimplemented; |
4194 | |
4195 | /* yet another common check */ |
4196 | if (jcp.kw > 14) |
4197 | return status::unimplemented; |
4198 | |
4199 | /* setting register strategy */ |
4200 | for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) { |
4201 | if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; } |
4202 | } |
4203 | |
4204 | if (jcp.is_1stconv) { |
4205 | auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw); |
4206 | if (src_d.format_kind() == format_kind::any) { |
4207 | CHECK(memory_desc_init_by_tag(src_md, src_tag)); |
4208 | jcp.src_tag = src_tag; |
4209 | } else { |
4210 | jcp.src_tag = src_d.matches_one_of_tag(src_tag); |
4211 | if (jcp.ic == 1 && jcp.src_tag != src_tag) |
4212 | jcp.src_tag = src_d.matches_one_of_tag( |
4213 | pick(ndims - 3, nwc, nhwc, ndhwc)); |
4214 | } |
4215 | if (jcp.src_tag == format_tag::undef) |
4216 | return status::unimplemented; |
4217 | |
4218 | const bool src_ok = true |
4219 | && utils::everyone_is(data_type::f32, |
4220 | src_d.data_type(), diff_weights_d.data_type(), |
4221 | diff_dst_d.data_type()) |
4222 | && one_of(jcp.ic, 1, 2, 3) |
4223 | && jcp.ngroups == 1; |
4224 | if (!src_ok) |
4225 | return status::unimplemented; |
4226 | |
4227 | const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, |
4228 | jcp.stride_w), 16); |
4229 | const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1); |
4230 | const int kh_step_rem = jcp.kh % kh_step; |
4231 | |
4232 | const auto wei_4fma_tag = with_groups |
4233 | ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o) |
4234 | : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o); |
4235 | |
4236 | auto current_wei_tag = format_tag::undef; |
4237 | if (diff_weights_d.format_kind() != format_kind::any) |
4238 | current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag); |
4239 | |
4240 | const bool use_4fma = true |
4241 | && one_of(ndims, 3, 4) |
4242 | && mayiuse(avx512_mic_4ops) |
4243 | && mkldnn_thr_syncable() |
4244 | && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) |
4245 | && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad) |
4246 | && jcp.kw <= 28 - jcp.with_bias |
4247 | && jcp.stride_w == 4 |
4248 | && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */ |
4249 | && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */ |
4250 | && IMPLICATION(diff_weights_d.format_kind() != format_kind::any, |
4251 | current_wei_tag == wei_4fma_tag); |
4252 | |
4253 | if (use_4fma) { |
4254 | jcp.ver = ver_4fma; |
4255 | jcp.kh_step = kh_step; |
4256 | jcp.tr_ld = tr_ld; |
4257 | jcp.ic_block = 1; |
4258 | if (diff_weights_d.format_kind() == format_kind::any) |
4259 | CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag)); |
4260 | jcp.wei_tag = wei_4fma_tag; |
4261 | } else { |
4262 | jcp.ver = ver_fma; |
4263 | jcp.ic_block = jcp.ic; |
4264 | |
4265 | wei_tag = with_groups |
4266 | ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) |
4267 | : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); |
4268 | |
4269 | if (diff_weights_d.format_kind() == format_kind::any) { |
4270 | CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); |
4271 | jcp.wei_tag = wei_tag; |
4272 | } else { |
4273 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
4274 | } |
4275 | if (jcp.wei_tag != wei_tag) |
4276 | return status::unimplemented; |
4277 | } |
4278 | |
4279 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
4280 | } else { |
4281 | auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
4282 | if (src_d.format_kind() == format_kind::any) { |
4283 | CHECK(memory_desc_init_by_tag(src_md, src_tag)); |
4284 | jcp.src_tag = src_tag; |
4285 | } else { |
4286 | jcp.src_tag = src_d.matches_one_of_tag(src_tag); |
4287 | } |
4288 | if (jcp.src_tag != src_tag) |
4289 | return status::unimplemented; |
4290 | |
4291 | if (diff_weights_d.format_kind() == format_kind::any) { |
4292 | CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); |
4293 | jcp.wei_tag = wei_tag; |
4294 | } else { |
4295 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
4296 | } |
4297 | if (jcp.wei_tag != wei_tag) |
4298 | return status::unimplemented; |
4299 | |
4300 | jcp.ic_block = jcp.simd_w; |
4301 | if (ok_to_pad_channels) |
4302 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
4303 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
4304 | if ((mayiuse(avx512_mic) || mayiuse(avx512_core)) |
4305 | && utils::everyone_is(data_type::f32, |
4306 | src_d.data_type(), diff_weights_d.data_type(), |
4307 | diff_dst_d.data_type())) { |
4308 | jcp.ver = ver_fma; |
4309 | if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 && |
4310 | everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && |
4311 | mkldnn_thr_syncable()) { |
4312 | jcp.ver = ver_4fma; |
4313 | } |
4314 | } else { |
4315 | return status::unimplemented; |
4316 | } |
4317 | if (jcp.ver == ver_4fma) { |
4318 | jcp.ur_w = jcp.ow; |
4319 | // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to |
4320 | // cross the right boundary. The only requirement is not to have |
4321 | // NaNs there because another multiplicand is always guaranteed to |
4322 | // be zero. This also may require the top-level driver to allocate |
4323 | // four extra guarding elements at the very end of the buffer. |
4324 | // I'm not proud of this hack, but it improves performance by |
4325 | // about 5-10% depending on the dimensions (Roma) |
4326 | |
4327 | const int tr_round = 4; |
4328 | |
4329 | jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round); |
4330 | jcp.tr_src_num_guard_elems = tr_round; // upper bound |
4331 | } |
4332 | } |
4333 | |
4334 | if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) { |
4335 | jcp.typesize_in = sizeof(float); |
4336 | jcp.typesize_out = sizeof(float); |
4337 | } else |
4338 | return status::unimplemented; |
4339 | |
4340 | bool args_ok = true |
4341 | && jcp.ic % jcp.ic_block == 0 |
4342 | && jcp.oc % jcp.oc_block == 0 |
4343 | && jcp.ic <= src_d.padded_dims()[1] |
4344 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
4345 | && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] |
4346 | && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; |
4347 | if (!args_ok) return status::unimplemented; |
4348 | |
4349 | { // balancing |
4350 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
4351 | balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); |
4352 | jcp.nthr = nthr; |
4353 | jcp.nthr_mb = nthr_mb; |
4354 | jcp.nthr_g = nthr_g; |
4355 | jcp.nthr_oc_b = nthr_oc_b; |
4356 | jcp.nthr_ic_b = nthr_ic_b; |
4357 | } |
4358 | |
4359 | return status::success; |
4360 | } |
4361 | |
4362 | void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( |
4363 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { |
4364 | if (jcp.ver == ver_4fma) { |
4365 | if (jcp.is_1stconv) { |
4366 | const size_t tr_src_size = |
4367 | jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld; |
4368 | scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); |
4369 | } else { |
4370 | // XXX: See the comment about tr_iw and guarding elements in |
4371 | // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() |
4372 | const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic; |
4373 | const size_t min_tr_src_size_per_thr |
4374 | = jcp.ih * jcp.ic_block * jcp.tr_iw; |
4375 | const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr |
4376 | + jcp.tr_src_num_guard_elems; |
4377 | scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); |
4378 | } |
4379 | |
4380 | /* prepare synchronization contexts */ |
4381 | if (jcp.nthr_oc_b > 1) { |
4382 | const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; |
4383 | scratchpad.book(key_conv_tr_src_bctx, |
4384 | sizeof(simple_barrier::ctx_t) * tr_src_bctx_size); |
4385 | } |
4386 | } |
4387 | |
4388 | if (jcp.nthr_mb > 1) { |
4389 | const int wei_size = jcp.ngroups * jcp.oc * jcp.ic |
4390 | * jcp.kh * jcp.kw * jcp.kd; |
4391 | const int bia_size = jcp.ngroups * jcp.oc; |
4392 | const size_t wei_bia_reduction_size = wei_size + bia_size; |
4393 | |
4394 | scratchpad.book(key_conv_wei_bia_reduction, |
4395 | jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1)); |
4396 | scratchpad.book(key_conv_wei_bia_reduction_bctx, |
4397 | sizeof(simple_barrier::ctx_t)); |
4398 | } |
4399 | |
4400 | if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) |
4401 | scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); |
4402 | } |
4403 | |
4404 | void jit_avx512_common_conv_bwd_weights_kernel_f32::balance( |
4405 | const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_, |
4406 | int &nthr_oc_b_, int &nthr_ic_b_) |
4407 | { |
4408 | nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; |
4409 | |
4410 | const int max_threads = mkldnn_get_max_threads(); |
4411 | |
4412 | if (max_threads < j.ngroups) { |
4413 | /* simplification... fortunately it doesn't hurt much */ |
4414 | return; |
4415 | } |
4416 | |
4417 | if (!mkldnn_thr_syncable() && j.ver == ver_4fma) { |
4418 | // should not happen -- the driver is not ready |
4419 | // for TBB-like non-synchronous threading yet |
4420 | return; |
4421 | } |
4422 | |
4423 | if (j.ver == ver_4fma && j.is_1stconv) { |
4424 | nthr_g_ = 1; |
4425 | nthr_oc_b_ = 1; |
4426 | nthr_ic_b_ = nstl::min(j.nb_ic, max_threads); |
4427 | nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb); |
4428 | nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_; |
4429 | return; |
4430 | } |
4431 | |
4432 | nthr_g_ = j.ngroups; |
4433 | const int nthr = max_threads / nthr_g_; |
4434 | |
4435 | auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { |
4436 | /* calculate per thread memory cost (read/write). high level optimizer |
4437 | * tries to minimize memory consumption. few notes: |
4438 | * (n1) unclear why, but that essentially helps first convolution... |
4439 | * (n2) assuming the reduction over minibatch is always there: |
4440 | * - instead of 8 it should be 5 here (write ~= 2 read): |
4441 | * kernel: temporal workspace 1 write |
4442 | * reduction: 1 read from workspace and 1 write to the diff_wei |
4443 | * - but experiments showed 8 works better than 5 or 6... */ |
4444 | |
4445 | const int src_coef = j.ver == ver_4fma ? 4 : 1; |
4446 | const int dst_coef = 1; |
4447 | const int wei_coef = 8; |
4448 | |
4449 | return 0 |
4450 | + src_coef |
4451 | * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) |
4452 | * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id |
4453 | / j.stride_d / j.stride_h / j.stride_w /* (n1) */ |
4454 | + dst_coef |
4455 | * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) |
4456 | * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od |
4457 | + wei_coef /* (n2) */ |
4458 | * div_up(j.ngroups, nthr_g_) |
4459 | * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) |
4460 | * j.kh * j.kw * j.kd * j.ic_block * j.oc_block; |
4461 | }; |
4462 | |
4463 | int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); |
4464 | |
4465 | /* step 1: find the best thread distribution with lowest memory cost */ |
4466 | const int nthr_mb_max = nstl::min(nthr, j.mb * j.od); |
4467 | for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
4468 | const int nthr_par = nthr / nthr_mb; |
4469 | const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); |
4470 | for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { |
4471 | int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); |
4472 | |
4473 | int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
4474 | if (mem_cost <= best_mem_cost) { |
4475 | best_mem_cost = mem_cost; |
4476 | nthr_mb_ = nthr_mb; |
4477 | nthr_oc_b_ = nthr_oc_b; |
4478 | nthr_ic_b_ = nthr_ic_b; |
4479 | } |
4480 | } |
4481 | |
4482 | if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } |
4483 | } |
4484 | |
4485 | if (!mayiuse(avx512_mic)) { |
4486 | auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { |
4487 | return 1 |
4488 | * div_up(j.mb, nthr_mb) |
4489 | * div_up(j.ngroups, nthr_g_) |
4490 | * div_up(j.nb_oc, nthr_oc_b) |
4491 | * div_up(j.nb_ic, nthr_ic_b); |
4492 | }; |
4493 | |
4494 | /* step 2: search for a thread distribution with lower compute cost. |
4495 | * the constrains: |
4496 | * - memory cost cannot exceed 110% of the best found in the step 1 |
4497 | * - unless compute cost is 133% lower than the current best case |
4498 | * note: both constants were found empirically */ |
4499 | int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); |
4500 | for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
4501 | const int nthr_par = nthr / nthr_mb; |
4502 | const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); |
4503 | for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { |
4504 | int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); |
4505 | int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
4506 | int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
4507 | |
4508 | const bool opt1 = comp_cost <= best_comp_cost |
4509 | && mem_cost < 1.1 * best_mem_cost; |
4510 | const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost; |
4511 | |
4512 | if (opt1 || opt2) { |
4513 | best_comp_cost = comp_cost; |
4514 | nthr_mb_ = nthr_mb; |
4515 | nthr_oc_b_ = nthr_oc_b; |
4516 | nthr_ic_b_ = nthr_ic_b; |
4517 | } |
4518 | } |
4519 | |
4520 | if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } |
4521 | } |
4522 | } |
4523 | |
4524 | if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads) |
4525 | nthr_mb_ = nstl::min(j.mb * j.od, max_threads); |
4526 | nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; |
4527 | |
4528 | assert(nthr_ <= max_threads); |
4529 | assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1)); |
4530 | } |
4531 | |
4532 | template struct _jit_avx512_common_conv_fwd_kernel<Zmm>; |
4533 | template struct _jit_avx512_common_conv_fwd_kernel<Xmm>; |
4534 | |
4535 | } |
4536 | } |
4537 | } |
4538 | |
4539 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
4540 | |