1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "c_types_map.hpp" |
18 | #include "mkldnn_thread.hpp" |
19 | #include "nstl.hpp" |
20 | #include "type_helpers.hpp" |
21 | #include "utils.hpp" |
22 | #include "cpu_memory.hpp" |
23 | |
24 | #include <math.h> |
25 | |
26 | #include "jit_avx512_common_conv_winograd_kernel_f32.hpp" |
27 | |
28 | #ifndef KERNEL_SIZE_THRESHOLD |
29 | #define KERNEL_SIZE_THRESHOLD 16 |
30 | #endif |
31 | |
32 | #define MIN_REQUIRED_DIMN_REG_BLOCK 14 |
33 | |
34 | namespace mkldnn { |
35 | namespace impl { |
36 | namespace cpu { |
37 | |
38 | namespace { |
39 | |
40 | using namespace mkldnn::impl::utils; |
41 | |
42 | unsigned int L1_cache_size = get_cache_size(1, true); |
43 | unsigned int L2_cache_size = get_cache_size(2, true); |
44 | unsigned int LLC_data_size = get_cache_size(3, false); |
45 | |
46 | // the test funtion takes jcp, the candidate and the current best. |
47 | // it returns true if the new candidate is better |
48 | int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, |
49 | int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) |
50 | { |
51 | int best_divisor = default_best; |
52 | auto test_num |
53 | = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { |
54 | if (test(jcp, num, best_divisor)) { |
55 | best_divisor = num; |
56 | } |
57 | }; |
58 | |
59 | for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { |
60 | if (number % divisor == 0) { |
61 | test_num(jcp, divisor); |
62 | test_num(jcp, number / divisor); |
63 | } |
64 | } |
65 | |
66 | return best_divisor; |
67 | } |
68 | |
69 | namespace { |
70 | bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { |
71 | if (jcp.ver == ver_4fma) |
72 | return jcp.mb >= 32; |
73 | else |
74 | return jcp.mb >= 16; |
75 | } |
76 | } |
77 | |
78 | /* assumes 512 bits registers */ |
79 | /* TODO: add support for strides */ |
80 | /* TODO: handle the prefetch distance automatically */ |
81 | typedef enum cache_t_ { L1, L2, L3 } cache_t; |
82 | |
83 | template <typename data_t> |
84 | struct prefetcher_t { |
85 | prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, |
86 | cache_t cache_type, size_t block_size, /* in number of elements*/ |
87 | int nb_instructions_in_block, int fma_ipc) |
88 | : cg_(generator) |
89 | , reg_base_addr_(reg_base_addr) |
90 | , cache_type_(cache_type) |
91 | , cache_block_size_(block_size) |
92 | { |
93 | nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); |
94 | prefetch_spread_ |
95 | = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); |
96 | prefetch_blk_ |
97 | = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); |
98 | |
99 | /* assumption: when fetch in Li, data is already in L(i+1) */ |
100 | int cache_latency; |
101 | switch (cache_type_) { |
102 | case L1: cache_latency = 14; break; |
103 | case L2: |
104 | case L3: |
105 | default: cache_latency = 250; break; |
106 | } |
107 | |
108 | prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); |
109 | } |
110 | |
111 | void prefetch(int instruction_number) |
112 | { |
113 | if (instruction_number % prefetch_spread_ == 0) { |
114 | for (int i = 0; (i < prefetch_blk_) |
115 | && (prefetches_issued_ < nb_cache_lines_to_prefetch_); |
116 | i++, prefetches_issued_++) { |
117 | prefetch_inst_(cg_->EVEX_compress_addr( |
118 | reg_base_addr_, (cache_block_size_ * prefetch_distance_) |
119 | * sizeof(data_t) |
120 | + (prefetches_issued_ * 64))); |
121 | } |
122 | } |
123 | } |
124 | |
125 | private: |
126 | void prefetch_inst_(const Xbyak::Address &addr) |
127 | { |
128 | switch (cache_type_) { |
129 | case L1: cg_->prefetcht0(addr); break; |
130 | case L2: cg_->prefetcht1(addr); break; |
131 | case L3: cg_->prefetcht2(addr); break; |
132 | default: |
133 | break; // TODO: raise an exception or put an assert |
134 | } |
135 | } |
136 | |
137 | jit_generator *cg_; |
138 | Xbyak::Reg64 reg_base_addr_; |
139 | cache_t cache_type_; |
140 | int cache_block_size_ = 0; |
141 | int nb_cache_lines_to_prefetch_ = 0; |
142 | int prefetches_issued_ = 0; |
143 | int prefetch_spread_ = 0; |
144 | int prefetch_blk_ = 0; |
145 | int prefetch_distance_ = 0; |
146 | }; |
147 | |
148 | // utilities to support kernel parameter selection |
149 | bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, |
150 | int dimM_block, int dimM_simd_block, float C) |
151 | { |
152 | float lhs = (dimM_block * dimN_reg_block * dimM_simd_block |
153 | + dimM_block * dimK_block * dimK_reg_block |
154 | * dimM_simd_block |
155 | + dimK_block * dimN_reg_block * dimK_reg_block) |
156 | * (float)sizeof(float); |
157 | float rhs = C * L1_cache_size; |
158 | return (lhs < rhs); |
159 | } |
160 | |
161 | bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, |
162 | int dimM_block, int dimM_simd_block, float C) |
163 | { |
164 | float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block |
165 | + dimK_block * dimN_reg_block * dimK_reg_block) |
166 | * (float)sizeof(float); |
167 | float rhs = C * L1_cache_size; |
168 | return (lhs < rhs); |
169 | } |
170 | |
171 | bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, |
172 | int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block, |
173 | float C) |
174 | { |
175 | float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block |
176 | + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block |
177 | * dimM_simd_block |
178 | + nb_dimN_reg_block * dimK_nb_block * dimK_block |
179 | * dimN_reg_block * dimK_reg_block) |
180 | * (float)sizeof(float); |
181 | float rhs = C * L2_cache_size; |
182 | return (lhs < rhs); |
183 | } |
184 | } |
185 | |
186 | using namespace mkldnn::impl::format_tag; |
187 | using namespace mkldnn::impl::utils; |
188 | using namespace Xbyak; |
189 | |
190 | void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate( |
191 | bool is_beta_zero) |
192 | { |
193 | // const int dimK_simd_block = jcp.dimK_reg_block; |
194 | |
195 | // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) |
196 | // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) |
197 | // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; |
198 | // dimK_reg_block++) |
199 | // for (int tile =0; tile < jcp.dimN_reg_block; tile++) |
200 | // C[dimM_block][tile] += |
201 | // A[dimM_block][dimK_block][dimK_reg_block] * |
202 | // broadcast(B[dimK_block][tile][dimK_reg_block]); |
203 | // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block], |
204 | // so we load it before the loop on tile |
205 | // 2) the loop on tile must be fully unrolled. Don't know about the one on |
206 | // dimK_reg_block. I think it should be |
207 | |
208 | auto inner_loops = [=]() { |
209 | Label dimM_block_loop, dimK_block_loop; |
210 | const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1; |
211 | const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; |
212 | |
213 | prefetcher_t<float> L1_pf(this, reg_srcB, L1, |
214 | jcp.dimN_reg_block * jcp.dimK_reg_block, |
215 | jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, |
216 | fma_ipc); |
217 | prefetcher_t<float> L2_pf(this, reg_srcB, L2, |
218 | jcp.dimN_reg_block * jcp.dimK_reg_block, |
219 | jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, |
220 | fma_ipc); |
221 | |
222 | if (jcp.dimM_block > 1) { |
223 | mov(reg_dimM_block_loop_cnt, jcp.dimM_block); |
224 | L(dimM_block_loop); |
225 | } |
226 | { |
227 | // First, we zero the accumulators if first nb_ic iteration, |
228 | // otherwise we load them |
229 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
230 | Zmm zmm(jcp.zmm_start + tile); |
231 | if (is_beta_zero) |
232 | vpxord(zmm, zmm, zmm); |
233 | else |
234 | vmovups(zmm, zword[reg_dstC + 64 * tile]); |
235 | } |
236 | |
237 | if (jcp.dimK_block > 1) { |
238 | mov(reg_dimK_block_loop_cnt, jcp.dimK_block); |
239 | L(dimK_block_loop); |
240 | } |
241 | { |
242 | auto load_A = [=](int reg_idx, int offset) { |
243 | for (int i = 0; i < inc_dimK_reg_block; i++) |
244 | vmovups(Zmm(reg_idx + i), |
245 | zword[reg_srcA + 64 * (offset + i)]); |
246 | }; |
247 | |
248 | // Used when doing double buffering |
249 | int next = 0; |
250 | if (jcp.double_buffering) { |
251 | load_A(next, 0); |
252 | } |
253 | for (int dimK_reg_block = 0; |
254 | dimK_reg_block < jcp.dimK_reg_block; |
255 | dimK_reg_block += inc_dimK_reg_block) { |
256 | int current; |
257 | /* Loading the next vector from A */ |
258 | current = next; |
259 | if (jcp.double_buffering) { |
260 | next = (dimK_reg_block + inc_dimK_reg_block) |
261 | % (2 * inc_dimK_reg_block); |
262 | load_A(next, dimK_reg_block + inc_dimK_reg_block); |
263 | } else { |
264 | next = 0; |
265 | load_A(next, dimK_reg_block); |
266 | } |
267 | /* Performing the fmas */ |
268 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
269 | Zmm zmm(jcp.zmm_start + tile); |
270 | if (jcp.ver != ver_avx512_core) |
271 | L1_pf.prefetch( |
272 | dimK_reg_block * jcp.dimN_reg_block + tile); |
273 | if (jcp.ver == ver_4fma) |
274 | v4fmaddps(zmm, Zmm(current), |
275 | EVEX_compress_addr(reg_srcB, |
276 | 64 * tile + dimK_reg_block * 4)); |
277 | else |
278 | vfmadd231ps(zmm, Zmm(current), |
279 | EVEX_compress_addr(reg_srcB, |
280 | 64 * tile + dimK_reg_block * 4, |
281 | true)); |
282 | if (jcp.ver != ver_avx512_core) |
283 | L2_pf.prefetch( |
284 | dimK_reg_block * jcp.dimN_reg_block + tile); |
285 | } |
286 | } |
287 | |
288 | add(reg_srcA, jcp.dimK_reg_block * 64); |
289 | add(reg_srcB, jcp.dimN_reg_block * 64); |
290 | if (jcp.dimK_block > 1) { |
291 | sub(reg_dimK_block_loop_cnt, 1); |
292 | jnz(dimK_block_loop); |
293 | } |
294 | } |
295 | |
296 | |
297 | auto store_output = [=](bool output_is_aligned) { |
298 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
299 | Zmm zmm(jcp.zmm_start + tile); |
300 | if (output_is_aligned |
301 | && jcp.dimK_nb_block == 1 |
302 | && (jcp.dimN * jcp.dimM * alpha * alpha |
303 | * sizeof(float) > 2 * LLC_data_size)) |
304 | vmovntps(zword[reg_dstC + 64 * tile], zmm); |
305 | else |
306 | vmovups(zword[reg_dstC + 64 * tile], zmm); |
307 | } |
308 | }; |
309 | |
310 | Label unaligned_store, end_store; |
311 | test(reg_dstC, cpu_isa_traits<avx512_common>::vlen - 1); |
312 | jnz(unaligned_store, T_NEAR); |
313 | store_output(true); |
314 | jmp(end_store, T_NEAR); |
315 | L(unaligned_store); { |
316 | store_output(false); |
317 | } |
318 | L(end_store); |
319 | |
320 | if (jcp.dimM_block > 1) { |
321 | sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); |
322 | add(reg_dstC, jcp.dimN_reg_block * 64); |
323 | sub(reg_dimM_block_loop_cnt, 1); |
324 | jnz(dimM_block_loop); |
325 | } |
326 | } |
327 | }; |
328 | |
329 | /* Preamble */ |
330 | preamble(); |
331 | |
332 | /* kernel */ |
333 | inner_loops(); |
334 | |
335 | /* Postamble */ |
336 | postamble(); |
337 | ret(); |
338 | } |
339 | |
340 | status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common( |
341 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
342 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, |
343 | const memory_desc_wrapper &dst_d) |
344 | { |
345 | |
346 | if (mayiuse(avx512_core)) |
347 | return status::unimplemented; |
348 | else if (!mayiuse(avx512_common)) |
349 | return status::unimplemented; |
350 | else if (mayiuse(avx512_mic_4ops)) |
351 | jcp.ver = ver_4fma; |
352 | else |
353 | jcp.ver = ver_fma; |
354 | |
355 | jcp.nthr = mkldnn_get_max_threads(); |
356 | |
357 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
358 | |
359 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
360 | jcp.mb = src_d.dims()[0]; |
361 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
362 | jcp.oc_without_padding = jcp.oc; |
363 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
364 | jcp.ih = src_d.dims()[2]; |
365 | jcp.iw = src_d.dims()[3]; |
366 | jcp.oh = dst_d.dims()[2]; |
367 | jcp.ow = dst_d.dims()[3]; |
368 | jcp.kh = weights_d.dims()[with_groups + 2]; |
369 | jcp.kw = weights_d.dims()[with_groups + 3]; |
370 | jcp.t_pad = cd.padding[0][0]; |
371 | jcp.l_pad = cd.padding[0][1]; |
372 | jcp.stride_h = cd.strides[0]; |
373 | jcp.stride_w = cd.strides[1]; |
374 | jcp.dilate_h = cd.dilates[0]; |
375 | jcp.dilate_w = cd.dilates[1]; |
376 | jcp.r_pad = nstl::max( |
377 | 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); |
378 | jcp.b_pad = nstl::max( |
379 | 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); |
380 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
381 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
382 | jcp.ohp = jcp.oh; |
383 | jcp.owp = jcp.ow; |
384 | |
385 | bool ok_to_pad_channels = jcp.ngroups == 1; |
386 | if (ok_to_pad_channels) { |
387 | jcp.oc = rnd_up(jcp.oc, simd_w); |
388 | jcp.ic = rnd_up(jcp.ic, simd_w); |
389 | } |
390 | |
391 | if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, |
392 | is_winograd_faster_than_direct(jcp))) |
393 | return status::unimplemented; |
394 | |
395 | // Checking conditions not supported by these kernels |
396 | if (jcp.ngroups != 1) |
397 | return status::unimplemented; |
398 | if ((jcp.kh != 3) || (jcp.kw != 3)) |
399 | return status::unimplemented; |
400 | if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) |
401 | return status::unimplemented; |
402 | if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) |
403 | return status::unimplemented; |
404 | if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) |
405 | return status::unimplemented; |
406 | |
407 | format_tag_t dat_tag = nChw16c; |
408 | format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; |
409 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
410 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
411 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
412 | |
413 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
414 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
415 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
416 | |
417 | bool layout_consistency = true |
418 | && jcp.ic <= src_d.padded_dims()[1] |
419 | && jcp.oc <= dst_d.padded_dims()[1] |
420 | && jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
421 | && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; |
422 | if (!layout_consistency) return status::unimplemented; |
423 | |
424 | return status::success; |
425 | } |
426 | |
427 | |
428 | status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) { |
429 | |
430 | auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, |
431 | int dimN_reg_block, int current_best) { |
432 | return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK) |
433 | && (dimN_reg_block < jcp.nb_reg) |
434 | && (dimN_reg_block < current_best); |
435 | }; |
436 | jcp.dimN_reg_block = get_divisor_satisfying_cond( |
437 | jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block); |
438 | |
439 | if (jcp.dimN_reg_block >= jcp.nb_reg) { |
440 | auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, |
441 | int dimN_reg_block, int current_best) { |
442 | return (dimN_reg_block < jcp.nb_reg) |
443 | && (dimN_reg_block > current_best); |
444 | }; |
445 | |
446 | jcp.dimN_reg_block = get_divisor_satisfying_cond( |
447 | jcp, jcp.dimN, 1, test_cond_dimN_reg_block); |
448 | } |
449 | |
450 | //********************* Choosing dimK_block **********************// |
451 | auto test_cond1_dimK_block = []( |
452 | jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { |
453 | return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, |
454 | 1, jcp.dimM_simd_block, .75f) |
455 | && (dimK_block > current_best); |
456 | }; |
457 | |
458 | auto test_cond1_bis_dimK_block = []( |
459 | jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { |
460 | return check_cond1_bis(jcp.dimN_reg_block, dimK_block, |
461 | jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f) |
462 | && (dimK_block > current_best); |
463 | }; |
464 | |
465 | jcp.dimK_block = get_divisor_satisfying_cond( |
466 | jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); |
467 | // If we are not able to use streams, we fall back to condition [1] |
468 | if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) |
469 | jcp.dimK_block = get_divisor_satisfying_cond( |
470 | jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); |
471 | jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; |
472 | |
473 | //********************* Choosing dimM_block **********************// |
474 | jcp.dimM_simd_block = 16; |
475 | /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/ |
476 | auto test_cond1_dimM_block = []( |
477 | jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { |
478 | return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, |
479 | jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f) |
480 | && (dimM_block > current_best); |
481 | }; |
482 | |
483 | auto test_cond1_bis_dimM_block = []( |
484 | jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { |
485 | return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, |
486 | jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f) |
487 | && (dimM_block > current_best); |
488 | }; |
489 | |
490 | if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) |
491 | jcp.dimM_block = get_divisor_satisfying_cond( |
492 | jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); |
493 | else |
494 | jcp.dimM_block = get_divisor_satisfying_cond(jcp, |
495 | jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block); |
496 | jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; |
497 | |
498 | //******************* Choosing dimN_block *******************// |
499 | auto test_cond2_dimN_block = []( |
500 | jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { |
501 | return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, |
502 | jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, |
503 | jcp.dimM_simd_block, .5f) |
504 | && (dimN_block > current_best); |
505 | }; |
506 | |
507 | jcp.dimN_block = get_divisor_satisfying_cond( |
508 | jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); |
509 | jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); |
510 | jcp.sched_policy = WSCHED_DATA_W_S_G_D; |
511 | return status::success; |
512 | } |
513 | |
514 | status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel( |
515 | jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) |
516 | { |
517 | jcp.dimK_reg_block = 16; |
518 | jcp.dimM_simd_block = 16; |
519 | |
520 | // TODO: replace double buffering with nuple buffering to maximize register |
521 | // usage. |
522 | // the choice of the number of buffers will then come after choosing |
523 | // dimN_reg_block |
524 | jcp.double_buffering = true; |
525 | if (jcp.double_buffering) |
526 | jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2); |
527 | else |
528 | jcp.zmm_start = 1; |
529 | jcp.nb_reg = 32 - jcp.zmm_start; |
530 | |
531 | jcp.dimN = dimN; |
532 | jcp.dimK = dimK; |
533 | jcp.dimM = dimM; |
534 | |
535 | jcp.sched_policy = WSCHED_INVALID; |
536 | set_wsched_DATA_W_S_G_D_avx512_common(jcp); |
537 | |
538 | assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D); |
539 | return status::success; |
540 | } |
541 | |
542 | bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok( |
543 | jit_conv_conf_t &jcp, const primitive_attr_t &attr) { |
544 | const auto &p = attr.post_ops_; |
545 | |
546 | auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; |
547 | auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; |
548 | |
549 | switch (p.len_) { |
550 | case 0: return true; // no post_ops |
551 | case 1: return is_relu(0) || is_sum(0); // relu or sum |
552 | case 2: return (is_sum(0) && is_relu(1)) || |
553 | (is_relu(0) && is_sum(1)); // sum->relu or relu->sum |
554 | case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu |
555 | default: return false; |
556 | } |
557 | |
558 | return false; |
559 | } |
560 | |
561 | status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf( |
562 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
563 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, |
564 | const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { |
565 | status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d); |
566 | |
567 | if (st != status::success) |
568 | return st; |
569 | |
570 | // Winograd specific initialization |
571 | jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; |
572 | jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; |
573 | jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; |
574 | |
575 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
576 | |
577 | if (!post_ops_ok(jcp, attr)) |
578 | return status::unimplemented; |
579 | |
580 | const auto &p = attr.post_ops_; |
581 | const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); |
582 | jcp.with_eltwise = eltwise_ind != -1; |
583 | if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; |
584 | jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; |
585 | |
586 | status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); |
587 | jcp.ic_simd_block = jcp.dimK_reg_block; |
588 | jcp.ic_block = jcp.dimK_block; |
589 | jcp.nb_ic = jcp.dimK_nb_block; |
590 | jcp.oc_simd_block = jcp.dimM_simd_block; |
591 | jcp.oc_block = jcp.dimM_block; |
592 | jcp.nb_oc = jcp.dimM_nb_block; |
593 | jcp.tile_block_ur = jcp.dimN_reg_block; |
594 | jcp.nb_tile_block_ur = jcp.dimN_block; |
595 | jcp.tile_block = jcp.dimN_nb_block; |
596 | jcp.tile_4fma_padding = 0; // only relevant for backward weights |
597 | |
598 | return res; |
599 | } |
600 | |
601 | status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( |
602 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
603 | const memory_desc_wrapper &diff_src_d, |
604 | const memory_desc_wrapper &weights_d, |
605 | const memory_desc_wrapper &diff_dst_d) |
606 | { |
607 | status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); |
608 | |
609 | if (st != status::success) |
610 | return st; |
611 | |
612 | jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; |
613 | jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; |
614 | jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; |
615 | |
616 | status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); |
617 | jcp.oc_simd_block = jcp.dimK_reg_block; |
618 | jcp.oc_block = jcp.dimK_block; |
619 | jcp.nb_oc = jcp.dimK_nb_block; |
620 | jcp.ic_simd_block = jcp.dimM_simd_block; |
621 | jcp.ic_block = jcp.dimM_block; |
622 | jcp.nb_ic = jcp.dimM_nb_block; |
623 | jcp.tile_block_ur = jcp.dimN_reg_block; |
624 | jcp.nb_tile_block_ur = jcp.dimN_block; |
625 | jcp.tile_block = jcp.dimN_nb_block; |
626 | jcp.tile_4fma_padding = 0; // only relevant for backward weights |
627 | |
628 | return res; |
629 | } |
630 | |
631 | void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate() |
632 | { |
633 | auto load_B = [=](int reg_idx, int offset) { |
634 | for (int i = 0; i < 4; i++) { |
635 | vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]); |
636 | } |
637 | }; |
638 | |
639 | preamble(); |
640 | int curr = 0; |
641 | for (int j = 0; j < alpha; j++) { |
642 | for (int i = 0; i < alpha; i++) { |
643 | int origB_offset = (j * alpha + i) * jcp.dimK_4fma; |
644 | size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block * |
645 | jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block * |
646 | jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float); |
647 | mov(reg_transB_idx, transB_offset); |
648 | for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) { |
649 | /*double buffering to hide load latencies*/ |
650 | int next = (curr + 4) % 8; |
651 | if (i == 0 && tb == 0) { |
652 | load_B(0, origB_offset); |
653 | } |
654 | if (tb + 4 < (jcp.dimK_4fma -1)) { |
655 | load_B(next, origB_offset + 4); |
656 | } else if (i < alpha - 1) { |
657 | load_B(next, origB_offset + jcp.dimK_4fma); |
658 | } |
659 | |
660 | vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1)); |
661 | vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3)); |
662 | vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1)); |
663 | vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3)); |
664 | |
665 | vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9)); |
666 | vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9)); |
667 | |
668 | vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1)); |
669 | vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1)); |
670 | |
671 | vmovntps(zword[reg_transB + reg_transB_idx |
672 | + sizeof(float) * tb * jcp.dimN_reg_block], |
673 | Zmm(curr+2)); |
674 | vmovntps(zword[reg_transB + reg_transB_idx |
675 | + sizeof(float) * (tb + 1) * jcp.dimN_reg_block], |
676 | Zmm(curr+3)); |
677 | vmovntps(zword[reg_transB + reg_transB_idx |
678 | + sizeof(float) * (tb + 2) * jcp.dimN_reg_block], |
679 | Zmm(8)); |
680 | vmovntps(zword[reg_transB + reg_transB_idx |
681 | + sizeof(float) * (tb + 3) * jcp.dimN_reg_block], |
682 | Zmm(9)); |
683 | curr = next; |
684 | |
685 | } |
686 | } |
687 | } |
688 | postamble(); |
689 | ret(); |
690 | } |
691 | void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate( |
692 | bool is_first_tile) |
693 | { |
694 | // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) |
695 | // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++) |
696 | // for (int nb_tile_block_ur = 0; nb_tile_block_ur < |
697 | // jcp.nb_tile_block_ur; nb_tile_block_ur++) |
698 | // for (int tile_block_ur = 0; tile_block_ur < |
699 | // jcp.tile_block_ur; tile_block_ur++) |
700 | // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3) |
701 | // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] += |
702 | // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block] |
703 | // * |
704 | // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur]) |
705 | auto inner_loops = [=]() { |
706 | int inc_fma = jcp.ver == ver_4fma ? 4 : 1; |
707 | const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; |
708 | prefetcher_t<float> L1_pf(this, reg_srcB, L1, |
709 | jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, |
710 | jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma |
711 | / inc_fma, |
712 | fma_ipc); |
713 | prefetcher_t<float> L2_pf(this, reg_srcB, L2, |
714 | jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, |
715 | jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma |
716 | / inc_fma, |
717 | fma_ipc); |
718 | |
719 | auto load_A = [=](int reg_idx, int offset) { |
720 | for (int i = 0; i < inc_fma; i++) { |
721 | vmovups(Zmm(reg_idx + i), |
722 | zword[reg_srcA + |
723 | sizeof(float) * jcp.dimM_simd_block * (offset + i)]); |
724 | } |
725 | }; |
726 | |
727 | Label dimM_block_loop, dimK_block_loop, dimN_block_loop; |
728 | if (jcp.dimM_block > 1) { |
729 | mov(reg_dimM_block_loop_cnt, jcp.dimM_block); |
730 | L(dimM_block_loop); |
731 | } |
732 | { /************* OC_block (M) loop ***********/ |
733 | if (jcp.dimN_block > 1) { |
734 | mov(reg_dimN_block_loop_cnt, jcp.dimN_block); |
735 | L(dimN_block_loop); |
736 | } |
737 | { /*************** IC_block (N) loop *********/ |
738 | for (int dimN_reg_block = 0; |
739 | dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { |
740 | Zmm zmm(jcp.zmm_start + dimN_reg_block); |
741 | if (is_first_tile) |
742 | vpxord(zmm, zmm, zmm); |
743 | else |
744 | vmovups(zmm, zword[reg_dstC + |
745 | dimN_reg_block * jcp.dimM_simd_block * |
746 | sizeof(float)]); |
747 | } |
748 | |
749 | if (jcp.dimK_block > 1) { |
750 | mov(reg_dimK_block_loop_cnt, jcp.dimK_block); |
751 | L(dimK_block_loop); |
752 | } |
753 | { /************* nb_tile_ur(K) loop ********/ |
754 | int next = 0; |
755 | if (jcp.double_buffering) { |
756 | load_A(next, 0); |
757 | } |
758 | for (int dimK_reg_block = 0; |
759 | dimK_reg_block < jcp.dimK_reg_block; |
760 | dimK_reg_block++) { |
761 | int srcB_offset = dimK_reg_block * jcp.dimK_4fma |
762 | * jcp.dimN_reg_block; |
763 | for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma; |
764 | dimK_4fma += inc_fma) { |
765 | int current = next; |
766 | if (jcp.double_buffering) { |
767 | next = (dimK_reg_block * jcp.dimK_4fma |
768 | + dimK_4fma + inc_fma) |
769 | % (2 * inc_fma); |
770 | load_A(next, dimK_reg_block * jcp.dimK_4fma |
771 | + dimK_4fma + inc_fma); |
772 | } else { |
773 | next = 0; |
774 | load_A(next, dimK_reg_block * jcp.dimK_4fma |
775 | + dimK_4fma); |
776 | } |
777 | for (int dimN_reg_block = 0; |
778 | dimN_reg_block < jcp.dimN_reg_block; |
779 | ++dimN_reg_block) { |
780 | L1_pf.prefetch(srcB_offset / inc_fma |
781 | + dimK_4fma / inc_fma |
782 | * jcp.dimN_reg_block |
783 | + dimN_reg_block); |
784 | L2_pf.prefetch(srcB_offset / inc_fma |
785 | + dimK_4fma / inc_fma |
786 | * jcp.dimN_reg_block |
787 | + dimN_reg_block); |
788 | if (jcp.ver == ver_4fma) { |
789 | int srcB_trans_offset = (dimK_4fma / 4) * 64 |
790 | + dimK_4fma % 4; |
791 | v4fmaddps( |
792 | Zmm(jcp.zmm_start + dimN_reg_block), |
793 | Zmm(current), |
794 | EVEX_compress_addr(reg_srcB, |
795 | sizeof(float) * ( |
796 | srcB_offset + |
797 | srcB_trans_offset + |
798 | (dimN_reg_block % 4) * 16 + |
799 | (dimN_reg_block / 4) * 4))); |
800 | } else { |
801 | vfmadd231ps( |
802 | Zmm(jcp.zmm_start + dimN_reg_block), |
803 | Zmm(current), |
804 | EVEX_compress_addr(reg_srcB, |
805 | sizeof(float) * (srcB_offset + dimN_reg_block), |
806 | true)); |
807 | } |
808 | } |
809 | } |
810 | } |
811 | } |
812 | |
813 | add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma |
814 | * jcp.dimM_simd_block * sizeof(float)); |
815 | add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block |
816 | * jcp.dimK_4fma * sizeof(float)); |
817 | if (jcp.dimK_block > 1) { |
818 | sub(reg_dimK_block_loop_cnt, 1); |
819 | jnz(dimK_block_loop); |
820 | } |
821 | |
822 | /******** Write C back to memory *******/ |
823 | for (int dimN_reg_block = 0; |
824 | dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { |
825 | Zmm zmm(jcp.zmm_start + dimN_reg_block); |
826 | vmovups(zword[reg_dstC + |
827 | dimN_reg_block * jcp.dimM_simd_block * sizeof(float)], |
828 | zmm); |
829 | } |
830 | |
831 | sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block * |
832 | jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); |
833 | add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block |
834 | * sizeof(float)); |
835 | if (jcp.dimN_block > 1) { |
836 | sub(reg_dimN_block_loop_cnt, 1); |
837 | jnz(dimN_block_loop); |
838 | } |
839 | } |
840 | |
841 | if (jcp.dimM_block > 1) { |
842 | sub(reg_srcB, jcp.dimN_block * jcp.dimK_block |
843 | * jcp.dimK_reg_block * jcp.dimN_reg_block |
844 | * jcp.dimK_4fma * sizeof(float)); |
845 | add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block |
846 | * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); |
847 | sub(reg_dimM_block_loop_cnt, 1); |
848 | jnz(dimM_block_loop); |
849 | } |
850 | } |
851 | }; |
852 | |
853 | /* Preamble */ |
854 | // register used to handle long fma encoding |
855 | preamble(); |
856 | mov(reg_srcA, reg_srcA_const); |
857 | inner_loops(); |
858 | |
859 | /* Postamble */ |
860 | postamble(); |
861 | ret(); |
862 | } |
863 | |
864 | namespace { |
865 | bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block, |
866 | int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) |
867 | { |
868 | float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw; |
869 | lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw; |
870 | lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; |
871 | lhs *= sizeof(float); |
872 | float rhs = C * L1_cache_size; |
873 | return (lhs <= rhs); |
874 | } |
875 | |
876 | bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block, |
877 | int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) |
878 | { |
879 | float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma |
880 | * dimM_simdw; |
881 | lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; |
882 | lhs *= sizeof(float); |
883 | float rhs = C * L1_cache_size; |
884 | return (lhs <= rhs); |
885 | } |
886 | |
887 | bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block, |
888 | int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, |
889 | float C) |
890 | { |
891 | float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block |
892 | * dimK_4fma; |
893 | lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block |
894 | * dimN_reg_block; |
895 | lhs *= sizeof(float); |
896 | float rhs = C * L2_cache_size; |
897 | return (lhs <= rhs); |
898 | } |
899 | |
900 | bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block, |
901 | int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, |
902 | float C) |
903 | { |
904 | float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block; |
905 | lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma; |
906 | lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block |
907 | * dimN_reg_block; |
908 | lhs *= sizeof(float); |
909 | float rhs = C * L2_cache_size; |
910 | return (lhs <= rhs); |
911 | } |
912 | } // namespace |
913 | |
914 | status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp) |
915 | { |
916 | /*************** Choose dimN_reg_block (ic_simd_block) |
917 | * *******************************/ |
918 | jcp.dimN = jcp.ic; |
919 | /*Hardcoded to 16 because N = ic for bwd weights and |
920 | innermost dimension for ic is assumed 16 in src transforms. This |
921 | choice covers load latencies while maintaining simplicity of kernel |
922 | for POR topologies. FIXME in future??: Will not work for future topologies |
923 | when ic%16 != 0*/ |
924 | jcp.dimN_reg_block = jcp.ic_simd_block; |
925 | |
926 | /****************************** Choose dimK_block |
927 | * **************************/ |
928 | // No freedom for choosing dimM_simd_block because ic_simd_block |
929 | // is determined by input data format |
930 | jcp.dimM_simd_block = jcp.oc_simd_block; |
931 | |
932 | auto test_cond1bis_dimK_block = []( |
933 | jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { |
934 | return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, |
935 | jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) |
936 | && (dimK_block > current_best); |
937 | }; |
938 | |
939 | auto test_cond1_dimK_block = []( |
940 | jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { |
941 | return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1, |
942 | jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) |
943 | && (dimK_block > current_best); |
944 | }; |
945 | |
946 | auto test_cond2bis_dimK_block = []( |
947 | jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { |
948 | return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, |
949 | jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f) |
950 | && (dimK_block > current_best); |
951 | }; |
952 | |
953 | auto test_cond2_dimK_block = []( |
954 | jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { |
955 | return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1, |
956 | jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f) |
957 | && (dimK_block > current_best); |
958 | }; |
959 | |
960 | jcp.dimK_block = get_divisor_satisfying_cond( |
961 | jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block); |
962 | if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma) |
963 | jcp.dimK_block = get_divisor_satisfying_cond( |
964 | jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block); |
965 | |
966 | jcp.dimK_reg_block = get_divisor_satisfying_cond( |
967 | jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block); |
968 | if (jcp.dimK_reg_block < jcp.dimK_block) { |
969 | jcp.dimK_reg_block = get_divisor_satisfying_cond( |
970 | jcp, jcp.dimK_block, 1, test_cond1_dimK_block); |
971 | } |
972 | jcp.dimK_block /= jcp.dimK_reg_block; |
973 | jcp.dimK_nb_block |
974 | = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block; |
975 | jcp.tile_block_ur = jcp.dimK_reg_block; |
976 | jcp.nb_tile_block_ur = jcp.dimK_block; |
977 | jcp.tile_block = jcp.dimK_nb_block; |
978 | |
979 | /***************************** Chose dimN block |
980 | * ****************************/ |
981 | auto test_cond2_dimN_block = []( |
982 | jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { |
983 | return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block, |
984 | jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block, |
985 | jcp.dimN_reg_block, 0.5f) |
986 | && (dimN_block > current_best); |
987 | }; |
988 | |
989 | jcp.dimN_block = get_divisor_satisfying_cond( |
990 | jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); |
991 | jcp.ic_block = jcp.dimN_block; |
992 | jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block; |
993 | jcp.nb_ic = jcp.dimN_nb_block; |
994 | |
995 | /********************************* Choose dimM block |
996 | * ************************/ |
997 | jcp.dimM = jcp.oc; |
998 | |
999 | auto test_cond1_dimM_block = []( |
1000 | jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { |
1001 | return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1, |
1002 | jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block, |
1003 | 1.0f) |
1004 | && (dimM_block > current_best) |
1005 | && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2; |
1006 | }; |
1007 | |
1008 | jcp.dimM_block = get_divisor_satisfying_cond( |
1009 | jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); |
1010 | jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; |
1011 | |
1012 | jcp.sched_policy = WSCHED_WEI_S_D_G_W; |
1013 | return status::success; |
1014 | } |
1015 | |
1016 | status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf( |
1017 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
1018 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, |
1019 | const memory_desc_wrapper &diff_weights_d) |
1020 | { |
1021 | jcp.nthr = mkldnn_get_max_threads(); |
1022 | |
1023 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
1024 | |
1025 | jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; |
1026 | jcp.mb = src_d.dims()[0]; |
1027 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
1028 | jcp.oc_without_padding = jcp.oc; |
1029 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1030 | jcp.ih = src_d.dims()[2]; |
1031 | jcp.iw = src_d.dims()[3]; |
1032 | jcp.oh = diff_dst_d.dims()[2]; |
1033 | jcp.ow = diff_dst_d.dims()[3]; |
1034 | jcp.kh = diff_weights_d.dims()[with_groups + 2]; |
1035 | jcp.kw = diff_weights_d.dims()[with_groups + 3]; |
1036 | jcp.t_pad = cd.padding[0][0]; |
1037 | jcp.l_pad = cd.padding[0][1]; |
1038 | jcp.stride_h = cd.strides[0]; |
1039 | jcp.stride_w = cd.strides[1]; |
1040 | jcp.r_pad = nstl::max( |
1041 | 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); |
1042 | jcp.b_pad = nstl::max( |
1043 | 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); |
1044 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
1045 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
1046 | jcp.ohp = jcp.oh; |
1047 | jcp.owp = jcp.ow; |
1048 | jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); |
1049 | jcp.dilate_h = cd.dilates[0]; |
1050 | jcp.dilate_w = cd.dilates[1]; |
1051 | |
1052 | bool ok_to_pad_channels = jcp.ngroups == 1; |
1053 | if (ok_to_pad_channels) { |
1054 | jcp.oc = rnd_up(jcp.oc, simd_w); |
1055 | jcp.ic = rnd_up(jcp.ic, simd_w); |
1056 | } |
1057 | |
1058 | if (mayiuse(avx512_core)) |
1059 | return status::unimplemented; |
1060 | if (!mayiuse(avx512_common)) |
1061 | return status::unimplemented; |
1062 | else if (mayiuse(avx512_mic_4ops)) |
1063 | jcp.ver = ver_4fma; |
1064 | else |
1065 | jcp.ver = ver_fma; |
1066 | |
1067 | if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, |
1068 | is_winograd_faster_than_direct(jcp))) |
1069 | return status::unimplemented; |
1070 | // Winograd specific initialization |
1071 | jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; |
1072 | jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; |
1073 | jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; |
1074 | |
1075 | // Winograd kernel works only for 3x3 convolution with stride 1 |
1076 | if (jcp.ngroups != 1) |
1077 | return status::unimplemented; |
1078 | if ((jcp.kh != 3) || (jcp.kw != 3)) |
1079 | return status::unimplemented; |
1080 | if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) |
1081 | return status::unimplemented; |
1082 | if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) |
1083 | return status::unimplemented; |
1084 | if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) |
1085 | return status::unimplemented; |
1086 | |
1087 | format_tag_t dat_tag = nChw16c; |
1088 | format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; |
1089 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
1090 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
1091 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); |
1092 | |
1093 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
1094 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
1095 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
1096 | |
1097 | bool layout_consistency = true |
1098 | && jcp.ic <= src_d.padded_dims()[1] |
1099 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
1100 | && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] |
1101 | && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; |
1102 | if (!layout_consistency) return status::unimplemented; |
1103 | |
1104 | /*************************** New Kernel Parameters |
1105 | * *****************************/ |
1106 | jcp.ic_simd_block = simd_w; |
1107 | jcp.oc_simd_block = simd_w; |
1108 | jcp.dimK_4fma = 1; |
1109 | jcp.tile_4fma_padding = 0; |
1110 | |
1111 | #define MAX_4FMA_UR 8 |
1112 | if (jcp.ver == ver_4fma) { |
1113 | auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma, |
1114 | int current_best) { |
1115 | return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR) |
1116 | && (dimK_4fma > current_best); |
1117 | }; |
1118 | jcp.dimK_4fma = get_divisor_satisfying_cond( |
1119 | jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma); |
1120 | if (jcp.dimK_4fma == 1) |
1121 | jcp.dimK_4fma = 4; |
1122 | if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0) |
1123 | jcp.tile_4fma_padding = jcp.dimK_4fma |
1124 | - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma); |
1125 | } |
1126 | |
1127 | jcp.tile_4fma = jcp.dimK_4fma; |
1128 | /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src |
1129 | * transform |
1130 | * will not work correctly, this is solved by applying padding.*/ |
1131 | jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); |
1132 | jcp.dimN = jcp.ic; |
1133 | jcp.dimM = jcp.oc; |
1134 | |
1135 | jcp.double_buffering = true; |
1136 | if (jcp.double_buffering) |
1137 | jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2; |
1138 | else |
1139 | jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1; |
1140 | jcp.nb_reg = 32 - jcp.zmm_start; |
1141 | |
1142 | jcp.sched_policy = WSCHED_INVALID; |
1143 | status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp); |
1144 | assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W); |
1145 | |
1146 | jcp.tile_block_ur = jcp.dimK_reg_block; |
1147 | jcp.nb_tile_block_ur = jcp.dimK_block; |
1148 | jcp.tile_block = jcp.dimK_nb_block; |
1149 | |
1150 | jcp.ic_block = jcp.dimN_block; |
1151 | jcp.nb_ic = jcp.dimN_nb_block; |
1152 | |
1153 | jcp.oc_block = jcp.dimM_block; |
1154 | jcp.nb_oc = jcp.dimM_nb_block; |
1155 | |
1156 | return res; |
1157 | |
1158 | } |
1159 | } |
1160 | } |
1161 | } |
1162 | |
1163 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
1164 | |