1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifdef __INTEL_COMPILER
18#include <immintrin.h>
19#endif
20
21#include "mkldnn_types.h"
22
23#include "c_types_map.hpp"
24#include "mkldnn_thread.hpp"
25#include "type_helpers.hpp"
26#include "utils.hpp"
27
28#include "jit_avx512_common_convolution_winograd.hpp"
29
30#ifndef _MSC_VER
31#define pragma_unroll _Pragma("unroll")
32#else
33#define pragma_unroll
34#endif
35
36namespace mkldnn {
37namespace impl {
38namespace cpu {
39
40using namespace memory_tracking::names;
41
42namespace {
43
44unsigned int LLC_cache_size = get_cache_size(3, false);
45
46void inline load_ps(float *dest, const float *src_mem) {
47#ifdef __INTEL_COMPILER
48 __m512 *Iv512 = (__m512 *)dest;
49 Iv512[0] = _mm512_load_ps(src_mem);
50#else
51 PRAGMA_OMP_SIMD()
52 for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v];
53#endif
54}
55
56void inline store_output(float *dest, const float *data, bool streamout) {
57#ifdef __INTEL_COMPILER
58 if (streamout)
59 _mm512_stream_ps(dest, *((__m512 *)data));
60 else
61 _mm512_store_ps(dest, *((__m512 *)data));
62#else
63 PRAGMA_OMP_SIMD()
64 for (int v = 0; v < simd_w; v++)
65 dest[v] = data[v];
66#endif
67}
68
69void inline accum_output(
70 float *dest, float *data, bool streamout, bool with_relu_postsum) {
71#ifdef __INTEL_COMPILER
72 __m512 _data = _mm512_loadu_ps(data);
73 __m512 _dest = _mm512_loadu_ps(dest);
74 _data = _mm512_add_ps(_data, _dest);
75 if (with_relu_postsum)
76 _data = _mm512_max_ps(_data, _mm512_setzero_ps());
77 if (streamout)
78 _mm512_stream_ps(dest, _data);
79 else
80 _mm512_store_ps(dest, _data);
81#else
82 PRAGMA_OMP_SIMD()
83 for (int v = 0; v < simd_w; v++)
84 data[v] += dest[v];
85
86 if (with_relu_postsum) {
87 PRAGMA_OMP_SIMD()
88 for (int v = 0; v < simd_w; v++)
89 if (data[v] < 0.f)
90 data[v] = 0.f;
91 }
92
93 PRAGMA_OMP_SIMD()
94 for (int v = 0; v < simd_w; v++)
95 dest[v] = data[v];
96#endif
97}
98}
99
100using namespace mkldnn::impl::status;
101using namespace mkldnn::impl::utils;
102
103void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) {
104 float Fw[6][16];
105 float T[6][3][16];
106 float t0[16];
107 float t1[16];
108 float t2[16];
109
110 for (int j = 0; j < 16; j++) {
111#pragma unroll
112 for (int i = 0; i < 3; i++) {
113 PRAGMA_OMP_SIMD()
114 for (int k = 0; k < 16; k++) {
115 t0[k] = 0.26890756302521f * F[2][i][j][k];
116 t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k];
117 t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k];
118
119 T[0][i][k] = 1.13777777777778f * F[0][i][j][k];
120 T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k];
121 T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k];
122 T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k];
123 T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k];
124 T[5][i][k] = F[2][i][j][k];
125 }
126 }
127#pragma unroll
128 for (int i = 0; i < 6; i++) {
129 PRAGMA_OMP_SIMD()
130 for (int k = 0; k < 16; k++) {
131 t0[k] = 0.26890756302521f * T[i][2][k];
132 t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k];
133 t2[k] = t0[k] + 0.119514472455649f * T[i][0][k];
134
135 Fw[0][k] = 1.13777777777778f * T[i][0][k];
136 Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k];
137 Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k];
138 Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k];
139 Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k];
140 Fw[5][k] = T[i][2][k];
141#pragma unroll
142 for (int l = 0; l < 6; l++) {
143 Fw_[i][l][j][k] = Fw[l][k];
144 }
145 }
146 }
147 }
148}
149
150void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) {
151 float T[4][6][16];
152 float t0[16];
153 float t1[16];
154 float t2[16];
155 float t3[16];
156
157#pragma unroll
158 for (int i = 0; i < 6; i++) {
159 PRAGMA_OMP_SIMD()
160 for (int v = 0; v < 16; v++) {
161 t0[v] = Mw[1][i][v] + Mw[2][i][v];
162 t1[v] = Mw[3][i][v] + Mw[4][i][v];
163 t2[v] = Mw[1][i][v] - Mw[2][i][v];
164 t3[v] = Mw[3][i][v] - Mw[4][i][v];
165
166 T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v];
167 T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f;
168 T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
169 T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v];
170 }
171 }
172#pragma unroll
173 for (int i = 0; i < 4; i++) {
174 PRAGMA_OMP_SIMD()
175 for (int v = 0; v < 16; v++) {
176 t0[v] = T[i][1][v] + T[i][2][v];
177 t1[v] = T[i][3][v] + T[i][4][v];
178 t2[v] = T[i][1][v] - T[i][2][v];
179 t3[v] = T[i][3][v] - T[i][4][v];
180
181 O[i][0][v] = t0[v] + t1[v] + T[i][0][v];
182 O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f;
183 O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
184 O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v];
185 }
186 }
187}
188
189
190void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16])
191{
192 const float rcp3 = 1.0f / 3.0f;
193 const float rcp4 = 1.0f / 4.0f;
194 const float rcp6 = 1.0f / 6.0f;
195 const float rcp12 = 1.0f / 12.0f;
196 const float rcp24 = 1.0f / 24.0f;
197 float t0[16];
198 float t1[16];
199 float t2[16];
200 float t3[16];
201 float t4[16];
202 float T[6][4][16];
203
204pragma_unroll
205 for (int i = 0; i < 4; i++) {
206 PRAGMA_OMP_SIMD()
207 for (int j = 0; j < 16; j++) {
208 t0[j] = F[2][i][j] * rcp6;
209 t1[j] = F[0][i][j] * -rcp6 - t0[j];
210 t2[j] = F[0][i][j] * rcp24 + t0[j];
211 t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6;
212 t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3;
213
214 T[0][i][j] = F[0][i][j] * rcp4;
215 T[1][i][j] = t1[j] - t3[j];
216 T[2][i][j] = t1[j] + t3[j];
217 T[3][i][j] = t2[j] + t4[j];
218 T[4][i][j] = t2[j] - t4[j];
219 T[5][i][j] = F[3][i][j];
220 }
221 }
222pragma_unroll
223 for (int i = 0; i < 6; i++) {
224 PRAGMA_OMP_SIMD()
225 for (int j = 0; j < 16; j++) {
226 t0[j] = T[i][2][j] * rcp6;
227 t1[j] = T[i][0][j] * -rcp6 - t0[j];
228 t2[j] = T[i][0][j] * rcp24 + t0[j];
229 t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6;
230 t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3;
231
232 Fw[i][0][j] = T[i][0][j] * rcp4;
233 Fw[i][1][j] = t1[j] - t3[j];
234 Fw[i][2][j] = t1[j] + t3[j];
235 Fw[i][3][j] = t2[j] + t4[j];
236 Fw[i][4][j] = t2[j] - t4[j];
237 Fw[i][5][j] = T[i][3][j];
238 }
239 }
240}
241
242void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16])
243{
244 float T[4][6][16];
245 float M_[3][16];
246 float t0[16];
247 float t1[16];
248 float t2[16];
249
250 for (int j = 0; j < 16; j++) {
251pragma_unroll
252 for (int i = 0; i < 6; i++) {
253 PRAGMA_OMP_SIMD()
254 for (int l = 0; l < 16; l++) {
255 t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l];
256 t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l];
257 t2[l] = t1[l] * 4.0f + Mw[5][i][j][l];
258
259 T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l];
260 T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) +
261 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]);
262 T[2][i][l] = t0[l] + t2[l];
263 }
264 }
265pragma_unroll
266 for (int i = 0; i < 3; i++) {
267 PRAGMA_OMP_SIMD()
268 for (int l = 0; l < 16; l++) {
269 t0[l] = T[i][1][l] + T[i][2][l];
270 t1[l] = T[i][3][l] + T[i][4][l];
271 t2[l] = t1[l] * 4.0f + T[i][5][l];
272
273 M_[0][l] = T[i][0][l] + t0[l] + t1[l];
274 M_[1][l] = (T[i][1][l] - T[i][2][l]) +
275 2.0f * (T[i][3][l] - T[i][4][l]);
276 M_[2][l] = t0[l] + t2[l];
277
278 for (int k = 0; k < 3; k++) {
279 M[i][k][j][l] = M_[k][l];
280 }
281 }
282 }
283 }
284}
285
286void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16])
287{
288 float T[6][6][16];
289 float t0[16];
290 float t1[16];
291 float t2[16];
292 float t3[16];
293 float t4[16];
294 float t5[16];
295
296pragma_unroll
297 for (int i = 0; i < 6; i++) {
298 PRAGMA_OMP_SIMD()
299 for (int v = 0; v < 16; v++) {
300 t0[v] = I[2][i][v] * -2.25f + I[4][i][v];
301 t1[v] = I[1][i][v] * -2.25f + I[3][i][v];
302 t2[v] = I[2][i][v] * -0.390625f + I[4][i][v];
303 t3[v] = I[1][i][v] * -0.390625f + I[3][i][v];
304 t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v];
305 t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v];
306
307 T[0][i][v] = I[2][i][v] * -2.640625f + t4[v];
308 T[1][i][v] = t1[v] * 0.625f + t0[v];
309 T[2][i][v] = t1[v] * -0.625f + t0[v];
310 T[3][i][v] = t3[v] * 1.5f + t2[v];
311 T[4][i][v] = t3[v] * -1.5f + t2[v];
312 T[5][i][v] = I[3][i][v] * -2.640625f + t5[v];
313 }
314 }
315
316pragma_unroll
317 for (int i = 0; i < 6; i++) {
318 PRAGMA_OMP_SIMD()
319 for (int v = 0; v < 16; v++) {
320 t0[v] = T[i][2][v] * -2.25f + T[i][4][v];
321 t1[v] = T[i][1][v] * -2.25f + T[i][3][v];
322 t2[v] = T[i][2][v] * -0.390625f + T[i][4][v];
323 t3[v] = T[i][1][v] * -0.390625f + T[i][3][v];
324 t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v];
325 t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v];
326
327 Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v];
328 Iw[i][1][v] = t1[v] * 0.625f + t0[v];
329 Iw[i][2][v] = t1[v] * -0.625f + t0[v];
330 Iw[i][3][v] = t3[v] * 1.5f + t2[v];
331 Iw[i][4][v] = t3[v] * -1.5f + t2[v];
332 Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v];
333 }
334 }
335}
336
337void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16])
338{
339 float T[6][4][16];
340 float t0[16];
341 float t1[16];
342 float t2[16];
343 float t3[16];
344 float t4[16];
345
346pragma_unroll
347 for (int i = 0; i < 4; i++) {
348 PRAGMA_OMP_SIMD()
349 for (int v = 0; v < 16; v++) {
350 t0[v] = F[2][i][v] * 0.26890756302521f;
351 t1[v] = F[0][i][v] * -0.688403361344538f - t0[v];
352 t2[v] = F[0][i][v] * 0.119514472455649f + t0[v];
353 t3[v] = F[1][i][v] * 0.430252100840336f +
354 F[3][i][v] * 0.168067226890756f;
355 t4[v] = F[1][i][v] * 0.179271708683473f +
356 F[3][i][v] * 0.403361344537815f;
357
358 T[0][i][v] = F[0][i][v] * 1.13777777777778f;
359 T[1][i][v] = t1[v] - t3[v];
360 T[2][i][v] = t1[v] + t3[v];
361 T[3][i][v] = t2[v] + t4[v];
362 T[4][i][v] = t2[v] - t4[v];
363 T[5][i][v] = F[3][i][v];
364 }
365 }
366pragma_unroll
367 for (int i = 0; i < 6; i++) {
368 for (int v = 0; v < 16; v++) {
369 t0[v] = T[i][2][v] * 0.26890756302521f;
370 t1[v] = T[i][0][v] * -0.688403361344538f - t0[v];
371 t2[v] = T[i][0][v] * 0.119514472455649f + t0[v];
372 t3[v] = T[i][1][v] * 0.430252100840336f +
373 T[i][3][v] * 0.168067226890756f;
374 t4[v] = T[i][1][v] * 0.179271708683473f +
375 T[i][3][v] * 0.403361344537815f;
376
377 Fw[i][0][v] = T[i][0][v] * 1.13777777777778f;
378 Fw[i][1][v] = t1[v] - t3[v];
379 Fw[i][2][v] = t1[v] + t3[v];
380 Fw[i][3][v] = t2[v] + t4[v];
381 Fw[i][4][v] = t2[v] - t4[v];
382 Fw[i][5][v] = T[i][3][v];
383 }
384 }
385}
386
387void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16])
388{
389 float T[3][6][16];
390 float t0[16];
391 float t1[16];
392 float t2[16];
393 float M_[3][16];
394
395 for (int j = 0; j < 16; j++) {
396pragma_unroll
397 for (int i = 0; i < 6; i++) {
398 PRAGMA_OMP_SIMD()
399 for (int v = 0; v < 16; v++) {
400 t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v];
401 t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v];
402 t2[v] = t1[v] * 2.25f + Mw[5][i][j][v];
403
404 T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v];
405 T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) +
406 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]);
407 T[2][i][v] = t0[v] * 0.390625f + t2[v];
408 }
409 }
410pragma_unroll
411 for (int i = 0; i < 3; i++) {
412 PRAGMA_OMP_SIMD()
413 for (int v = 0; v < 16; v++) {
414 t0[v] = T[i][1][v] + T[i][2][v];
415 t1[v] = T[i][3][v] + T[i][4][v];
416 t2[v] = t1[v] * 2.25f + T[i][5][v];
417
418 M_[0][v] = T[i][0][v] + t0[v] + t1[v];
419 M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) +
420 1.5f * (T[i][3][v] - T[i][4][v]);
421 M_[2][v] = t0[v] * 0.390625f + t2[v];
422 }
423
424pragma_unroll
425 for (int k = 0; k < 3; k++) {
426 PRAGMA_OMP_SIMD()
427 for (int v = 0; v < 16; v++) {
428 M[i][k][j][v] = M_[k][v];
429 }
430 }
431 }
432 }
433}
434
435template <bool is_fwd>
436void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
437 float *inp, float *tinp, bool streamout = true)
438{
439 const int inpw = is_fwd ? jcp.iw : jcp.ow;
440 const int inph = is_fwd ? jcp.ih : jcp.oh;
441 const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
442 const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
443 const int wp_max = inpw + l_pad;
444 const int hp_max = inph + t_pad;
445 float Iw[alpha][alpha][simd_w];
446 float I[alpha][alpha][simd_w];
447
448 array_offset_calculator<float, 5> input(inp,
449 jcp.mb, jcp.dimK/simd_w, inph, inpw,
450 simd_w);
451 array_offset_calculator<float, 8> output(tinp,
452 jcp.dimN_nb_block, alpha, alpha,
453 jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
454 jcp.dimN_reg_block, jcp.dimK_reg_block);
455
456 int tile_base_index = image * jcp.itiles * jcp.jtiles;
457 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
458 int nb_tile_block_ur =
459 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
460 int tile_block =
461 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
462
463 for (int tj = 0; tj < jcp.jtiles; tj++) {
464 for (int ti = 0; ti < jcp.itiles; ti++) {
465 for (int j = 0; j < alpha; j++) {
466 int ydim = tj * tile_size + j;
467 if ((t_pad <= ydim) && (ydim < hp_max)) {
468 float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ;
469 for (int i = 0; i < alpha; i++) {
470 int xdim = ti * tile_size + i;
471 if ((l_pad <= xdim) && (xdim < wp_max)) {
472 float *pinp_i = pinp_j + (xdim - l_pad) * 16;
473 load_ps(I[j][i], pinp_i);
474 } else {
475 PRAGMA_OMP_SIMD()
476 for (int v = 0; v < simd_w; v++) {
477 I[j][i][v] = 0.0f;
478 }
479 }
480 }
481 } else {
482 for (int i = 0; i < alpha; i++) {
483 PRAGMA_OMP_SIMD()
484 for (int v = 0; v < simd_w; v++) {
485 I[j][i][v] = 0.0f;
486 }
487 }
488 }
489 }
490
491 trans_I_4x4_3x3(Iw, I);
492
493 for (int j = 0; j < alpha; j++) {
494 for (int i = 0; i < alpha; i++) {
495 store_output(&(output(tile_block, j, i,
496 nb_tile_block_ur, 0, 0,
497 tile_block_ur, 0)),
498 Iw[j][i], streamout);
499 }
500 }
501 tile_block_ur++;
502 if (tile_block_ur >= jcp.tile_block_ur) {
503 tile_block_ur = 0;
504 nb_tile_block_ur++;
505 }
506 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
507 nb_tile_block_ur = 0;
508 tile_block++;
509 }
510 }
511 }
512}
513
514template <bool is_fwd>
515void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
516 float *wp, float *twp)
517{
518 const int kh = 3;
519 const int kw = 3;
520 array_offset_calculator<float, 6> input(wp,
521 jcp.oc/jcp.oc_simd_block,
522 jcp.ic/jcp.ic_simd_block,
523 jcp.kh, jcp.kw,
524 simd_w, simd_w);
525 array_offset_calculator<float, 8> output(twp,
526 jcp.dimM_nb_block,
527 alpha, alpha,
528 jcp.dimK_nb_block,
529 jcp.dimM_block, jcp.dimK_block,
530 simd_w, simd_w);
531 float Fw[alpha][alpha][simd_w][simd_w];
532 float F[kh][kw][simd_w][simd_w];
533
534 for (int j = 0; j < kh; j++) {
535 for (int i = 0; i < kw; i++) {
536 for (int v1 = 0; v1 < simd_w; v1++) {
537 float *base_inp = is_fwd
538 ? &(input(0, 0, j, i, v1, 0))
539 : &(input(0, 0, 2 - j, 2 - i, v1, 0));
540 PRAGMA_OMP_SIMD()
541 for (int v2 = 0; v2 < simd_w; v2++) {
542 if (is_fwd)
543 F[j][i][v1][v2] = *(base_inp + v2);
544 else
545 F[j][i][v2][v1] = *(base_inp + v2);
546 }
547 }
548 }
549 }
550
551 trans_W_4x4_3x3(Fw, F);
552
553 for (int j = 0; j < alpha; j++) {
554 for (int i = 0; i < alpha; i++) {
555 for (int v1 = 0; v1 < simd_w; v1++) {
556 PRAGMA_OMP_SIMD()
557 for (int v2 = 0; v2 < simd_w; v2++) {
558 output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2];
559 }
560 }
561 }
562 }
563}
564
565template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum>
566void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
567 const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias,
568 bool streamout = true) {
569 float Ow[alpha][alpha][simd_w];
570 float O[tile_size][tile_size][simd_w];
571 int outw = is_fwd ? jcp.ow : jcp.iw;
572 int outh = is_fwd ? jcp.oh : jcp.ih;
573
574 /* Prepare for PostOps */
575 bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1;
576
577 array_offset_calculator<float, 8> input(toutp,
578 jcp.dimN_nb_block, jcp.dimM_nb_block,
579 alpha, alpha,
580 jcp.dimN_block, jcp.dimM_block,
581 jcp.dimN_reg_block, jcp.dimM_simd_block);
582
583 int tile_base_index = image * jcp.itiles * jcp.jtiles;
584 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
585 int nb_tile_block_ur =
586 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
587 int tile_block =
588 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
589
590 for (int tj = 0; tj < jcp.jtiles; tj++) {
591 for (int ti = 0; ti < jcp.itiles; ti++) {
592 for (int j = 0; j < alpha; j++) {
593 for (int i = 0; i < alpha; i++) {
594 PRAGMA_OMP_SIMD()
595 for (int v = 0; v < simd_w; v++) {
596 Ow[j][i][v] = input(tile_block, 0,
597 j, i,
598 nb_tile_block_ur, 0,
599 tile_block_ur, v);
600 }
601 }
602 }
603
604 trans_O_4x4_3x3(Ow, O);
605
606 for (int j = 0; j < tile_size; j++) {
607 int ydim = tj * tile_size + j;
608 if (ydim < outh) {
609 float *pout_j = pout_b + ydim * outw * simd_w;
610 for (int i = 0; i < tile_size; i++) {
611 int xdim = ti * tile_size + i;
612 if (xdim < outw) {
613 float *pout_i = pout_j + xdim * simd_w;
614 if (is_fwd) {
615 PRAGMA_OMP_SIMD()
616 for (int v = 0; v < simd_w; v++) {
617 O[j][i][v] += with_bias ? bias[v] : 0.f;
618 O[j][i][v] = true
619 && with_relu_presum && O[j][i][v] < 0.f
620 ? O[j][i][v]
621 * jcp.eltwise.alpha
622 : O[j][i][v];
623 }
624 }
625 if (with_sum)
626 accum_output(pout_i, O[j][i], streamout,
627 with_relu_postsum);
628 else
629 store_output(pout_i, O[j][i], streamout);
630 }
631 }
632 }
633 }
634 tile_block_ur++;
635 if (tile_block_ur >= jcp.tile_block_ur) {
636 tile_block_ur = 0;
637 nb_tile_block_ur++;
638 }
639 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
640 nb_tile_block_ur = 0;
641 tile_block++;
642 }
643 }
644 }
645}
646
647template <bool ver_4fma>
648void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
649 float *inp, float *tinp, float *Iw_temp,
650 void (*transpose_4fma_ker)(float *, float *))
651{
652
653 const int ifwp = conv.iw + conv.l_pad;
654 const int ifhp = conv.ih + conv.t_pad;
655 float I[alpha][alpha][simd_w];
656 float Iw[alpha][alpha][simd_w];
657
658 array_offset_calculator<float, 4> Iw_trans_temp(Iw_temp,
659 alpha, alpha, conv.tile_4fma, simd_w);
660 array_offset_calculator<float, 5> input(inp,
661 conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w);
662 array_offset_calculator<float, 8> output(tinp,
663 conv.nb_ic, alpha, alpha,
664 conv.tile_block, conv.ic_block,
665 conv.nb_tile_block_ur, conv.tile_block_ur,
666 conv.ic_simd_block * conv.tile_4fma);
667
668 int tile_base_index =
669 image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding);
670 int tile_4fma = 0;
671 int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur;
672 int nb_tile_block_ur =
673 (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
674 % conv.nb_tile_block_ur;
675 int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
676 / conv.nb_tile_block_ur;
677
678 for (int tj = 0; tj < conv.jtiles; tj++) {
679 for (int ti = 0; ti < conv.itiles; ti++) {
680 for (int j = 0; j < alpha; j++) {
681 int ydim = tj * tile_size + j;
682 if ((conv.t_pad <= ydim) && ydim < ifhp) {
683 for (int i = 0; i < alpha; i++) {
684 int xdim = ti * tile_size + i;
685 if ((conv.l_pad <= xdim) && xdim < ifwp) {
686 PRAGMA_OMP_SIMD()
687 for (int v = 0; v < simd_w; v++) {
688 I[j][i][v] = input(0, 0,
689 ydim - conv.t_pad,
690 xdim - conv.l_pad, v);
691 }
692 } else {
693 PRAGMA_OMP_SIMD()
694 for (int v = 0; v < simd_w; v++) {
695 I[j][i][v] = 0.0f;
696 }
697 }
698 }
699 } else {
700 for (int i = 0; i < alpha; i++) {
701 PRAGMA_OMP_SIMD()
702 for (int v = 0; v < simd_w; v++) {
703 I[j][i][v] = 0.0f;
704 }
705 }
706 }
707 }
708 trans_I_4x4_3x3(Iw, I);
709
710 if (ver_4fma) {
711 for (int j = 0; j < alpha; j++) {
712 for (int i = 0; i < alpha; i++) {
713 float *Iw_temp_base = &(Iw_trans_temp(j, i,
714 tile_4fma, 0));
715 PRAGMA_OMP_SIMD()
716 for (int v = 0; v < simd_w; v++) {
717 Iw_temp_base[v] = Iw[j][i][v];
718 }
719 }
720 }
721 tile_4fma++;
722 if (tile_4fma == conv.tile_4fma) {
723 float *outp = &(output(0, 0, 0,
724 tile_block, 0,
725 nb_tile_block_ur, tile_block_ur, 0));
726 transpose_4fma_ker(outp, (float *)Iw_temp);
727 tile_4fma = 0;
728 tile_block_ur++;
729 }
730 } else {
731 for (int j = 0; j < alpha; j++) {
732 for (int i = 0; i < alpha; i++) {
733 store_output(&(output(0, j, i,
734 tile_block, 0,
735 nb_tile_block_ur, tile_block_ur, 0)),
736 Iw[j][i], true);
737 }
738 }
739 tile_block_ur++;
740 }
741
742 if (tile_block_ur == conv.tile_block_ur) {
743 tile_block_ur = 0;
744 ++nb_tile_block_ur;
745 }
746 if (nb_tile_block_ur == conv.nb_tile_block_ur) {
747 nb_tile_block_ur = 0;
748 tile_block++;
749 }
750 }
751 }
752
753 if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) {
754
755 for (int j = 0; j < alpha; j++) {
756 for (int i = 0; i < alpha; i++) {
757 for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) {
758 float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0));
759 PRAGMA_OMP_SIMD()
760 for (int v = 0; v < simd_w; v++) {
761 Iw_temp_base[v] = 0;
762 }
763 }
764 }
765 }
766 float *outp = &(output(0, 0, 0,
767 tile_block, 0,
768 nb_tile_block_ur, tile_block_ur, 0));
769 transpose_4fma_ker(outp, (float *)Iw_temp);
770 }
771}
772
773template <bool with_bias>
774void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
775 float *inp, float *tinp, float *dbias)
776{
777
778 const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding;
779 float I[alpha][alpha][simd_w];
780 float Iw[alpha][alpha][simd_w];
781
782 array_offset_calculator<float, 5> input(inp,
783 conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block);
784 array_offset_calculator<float, 8> output(tinp,
785 conv.nb_oc, alpha, alpha,
786 conv.tile_block, conv.oc_block,
787 conv.nb_tile_block_ur,
788 conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block);
789
790 int tile_base_index = image * total_tiles;
791 int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma);
792 int nb_tile_block_ur =
793 (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
794 % conv.nb_tile_block_ur;
795 int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
796 / conv.nb_tile_block_ur;
797
798 for (int tj = 0; tj < conv.jtiles; tj++) {
799 for (int ti = 0; ti < conv.itiles; ti++) {
800 for (int j = 0; j < alpha; j++) {
801 int ydim = tj * tile_size + j;
802 if (ydim < conv.oh) {
803 for (int i = 0; i < alpha; i++) {
804 int xdim = ti * tile_size + i;
805 if (xdim < conv.ow) {
806 float *input_base = &(input(0, 0, ydim, xdim, 0));
807
808 PRAGMA_OMP_SIMD()
809 for (int v = 0; v < simd_w; v++) {
810 I[j][i][v] = input_base[v];
811 }
812 if (with_bias && j < tile_size && i < tile_size) {
813 PRAGMA_OMP_SIMD()
814 for (int v = 0; v < simd_w; v++) {
815 dbias[v] += input_base[v];
816 }
817 }
818 } else {
819 PRAGMA_OMP_SIMD()
820 for (int v = 0; v < simd_w; v++) {
821 I[j][i][v] = 0.0f;
822 }
823 }
824 }
825 } else {
826 for (int i = 0; i < alpha; i++) {
827 PRAGMA_OMP_SIMD()
828 for (int v = 0; v < simd_w; v++) {
829 I[j][i][v] = 0.0f;
830 }
831 }
832 }
833 }
834
835 trans_W_3x3_4x4_wu(Iw, I);
836
837 for (int j = 0; j < alpha; j++) {
838 for (int i = 0; i < alpha; i++) {
839 store_output(&(output(0, j, i,
840 tile_block, 0,
841 nb_tile_block_ur,
842 tile_block_ur, 0)),
843 Iw[j][i], true);
844 }
845 }
846 tile_block_ur++;
847 if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) {
848 tile_block_ur = 0;
849 nb_tile_block_ur++;
850 }
851 if (nb_tile_block_ur >= conv.nb_tile_block_ur) {
852 nb_tile_block_ur = 0;
853 tile_block++;
854 }
855 }
856 }
857}
858
859void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv,
860 float *wp, float *twp)
861{
862 const int kh = 3;
863 const int kw = 3;
864 float Fw[alpha][alpha][simd_w][simd_w];
865 float F[kh][kw][simd_w][simd_w];
866
867 array_offset_calculator<float, 8> input(twp,
868 conv.nb_ic, conv.nb_oc,
869 alpha, alpha,
870 conv.oc_block, conv.ic_block,
871 conv.ic_simd_block, conv.oc_simd_block);
872 array_offset_calculator<float, 6> output(wp,
873 conv.oc/simd_w, conv.ic/simd_w,
874 conv.kh, conv.kw,
875 conv.ic_simd_block, conv.oc_simd_block);
876
877 for (int j = 0; j < alpha; j++) {
878 for (int i = 0; i < alpha; i++) {
879 for (int v = 0; v < conv.ic_simd_block; v++) {
880 PRAGMA_OMP_SIMD()
881 for (int k = 0; k < conv.oc_simd_block; k++) {
882 Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k);
883 }
884 }
885 }
886 }
887
888 trans_O_3x3_4x4_wu(Fw, F);
889
890 for (int j = 0; j < kh; j++) {
891 for (int i = 0; i < kw; i++) {
892 for (int v = 0; v < conv.ic_simd_block; v++) {
893 store_output(&(output(0, 0, j, i, v, 0)),
894 F[j][i][v], true);
895 }
896 }
897 }
898}
899
900template <bool is_fwd>
901void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
902 float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
903 const memory_tracking::grantor_t &scratchpad) const {
904 const auto &jcp = kernel_->jcp;
905 const auto &p_ops = attr_->post_ops_;
906
907 const int inph = is_fwd ? jcp.ih : jcp.oh;
908 const int inpw = is_fwd ? jcp.iw : jcp.ow;
909 const int outh = is_fwd ? jcp.oh : jcp.ih;
910 const int outw = is_fwd ? jcp.ow : jcp.iw;
911
912 /* Note that jcp.with_eltwise is true for both fused conv+relu primitive
913 * and conv primitive with PostOps with relu before sum
914 * (PostOps relu after sum is handled later) */
915 auto output_transform = jcp.with_bias
916 ? (jcp.with_eltwise
917 ? (jcp.with_sum
918 ? output_transform_data<is_fwd, true, true, true>
919 : output_transform_data<is_fwd, true, true, false>)
920 : (jcp.with_sum
921 ? output_transform_data<is_fwd, true, false, true>
922 : output_transform_data<is_fwd, true, false, false>))
923 : (jcp.with_eltwise
924 ? (jcp.with_sum
925 ? output_transform_data<is_fwd, false, true, true>
926 : output_transform_data<is_fwd, false, true, false>)
927 : (jcp.with_sum
928 ? output_transform_data<is_fwd, false, false, true>
929 : output_transform_data<is_fwd, false, false, false>));
930
931 /* Notation:
932 FWD: dimM:oc, dimN:ntiles, dimK:ic,
933 BWD: dimM:ic, dimN:ntiles, dimK:oc,
934 FWD/BWD: V: src/diff_dst transform, U:weight transform,
935 M:dst/diff_src transform */
936 array_offset_calculator<float, 5> input(inp_ptr,
937 jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
938 jcp.dimK_reg_block);
939 array_offset_calculator<float, 5> output(out_ptr,
940 jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw,
941 jcp.dimM_simd_block);
942 array_offset_calculator<float, 6> weights(wei_ptr,
943 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
944 jcp.ic_simd_block, jcp.oc_simd_block);
945 array_offset_calculator<float, 2> bias(bias_ptr,
946 jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
947
948 array_offset_calculator<float, 8> M(is_fwd
949 ? scratchpad.template get<float>(key_wino_M)
950 : scratchpad.template get<float>(key_wino_V),
951 jcp.dimN_nb_block, jcp.dimM_nb_block,
952 alpha, alpha,
953 jcp.dimN_block, jcp.dimM_block,
954 jcp.dimN_reg_block, jcp.dimM_simd_block);
955 array_offset_calculator<float, 8> U(
956 scratchpad.template get<float>(key_wino_U),
957 jcp.dimM_nb_block,
958 alpha, alpha,
959 jcp.dimK_nb_block,
960 jcp.dimM_block, jcp.dimK_block,
961 jcp.dimK_reg_block, jcp.dimM_simd_block);
962 array_offset_calculator<float, 8> V(is_fwd
963 ? scratchpad.template get<float>(key_wino_V)
964 : scratchpad.template get<float>(key_wino_M),
965 jcp.dimN_nb_block, alpha, alpha,
966 jcp.dimN_block, jcp.dimK_nb_block,
967 jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
968
969 bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
970 > 2 * LLC_cache_size ? true : false;
971
972 const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
973
974 const bool wants_padded_bias = jcp.with_bias
975 && jcp.oc_without_padding != jcp.oc;
976 float last_slice_bias[simd_w] = {0};
977 if (wants_padded_bias) {
978 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
979 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
980 }
981
982 {
983 parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block,
984 [&](int img, int K_blk1, int K_blk2) {
985 input_transform_data<is_fwd>(img, jcp,
986 &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
987 &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout);
988 });
989
990 parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
991 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
992 float *U_base_ptr = is_fwd
993 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
994 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
995 weight_transform_data<is_fwd>(jcp,
996 &(weights(ofm1 * jcp.oc_block + ofm2,
997 ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr);
998 });
999
1000 parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block,
1001 [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) {
1002
1003 kernel_->gemm_loop_ker_first_iter(
1004 (float *)&(M(N_blk1, M_blk1, oj, oi,
1005 N_blk2, 0, 0, 0)),
1006 (const float *)&(U(M_blk1, oj, oi,
1007 0, 0, 0, 0, 0)),
1008 (const float *)&(V(N_blk1, oj, oi,
1009 N_blk2, 0, 0, 0, 0)));
1010 for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
1011 kernel_->gemm_loop_ker(
1012 (float *)&(M(N_blk1, M_blk1, oj, oi,
1013 N_blk2, 0, 0, 0)),
1014 (const float *)&(U(M_blk1, oj, oi,
1015 K_blk1, 0, 0, 0, 0)),
1016 (const float *)&(V(N_blk1, oj, oi,
1017 N_blk2, K_blk1,
1018 0, 0, 0)));
1019 }
1020
1021 });
1022
1023 parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block,
1024 [&](int img, int M_blk1, int M_blk2) {
1025
1026 const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
1027
1028 float *bias_ptr = wants_padded_bias
1029 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
1030 ? last_slice_bias : &bias(M_blk, 0);
1031
1032 output_transform(img, jcp, p_ops,
1033 &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
1034 &(output(img, M_blk, 0, 0, 0)),
1035 bias_ptr, output_is_aligned);
1036
1037 });
1038
1039 }
1040}
1041
1042template struct _jit_avx512_common_convolution_winograd_t<true>;
1043template struct _jit_avx512_common_convolution_winograd_t<false>;
1044
1045void jit_avx512_common_convolution_winograd_bwd_weights_t::
1046_maybe_execute_diff_bias_copy(float *diff_bias,
1047 const memory_tracking::grantor_t &scratchpad) const {
1048 if (pd()->wants_padded_bias()) {
1049 auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
1050 for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
1051 diff_bias[oc] = padded_bias[oc];
1052 }
1053}
1054
1055void jit_avx512_common_convolution_winograd_bwd_weights_t::
1056_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx,
1057 const memory_tracking::grantor_t &scratchpad) const {
1058 auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
1059 auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
1060 auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS);
1061 auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
1062
1063 const auto &jcp = kernel_->jcp;
1064 const int nthreads = jcp.nthr;
1065
1066 auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
1067 diff_src_transform_bwd_weights<true> :
1068 diff_src_transform_bwd_weights<false>;
1069 auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
1070 ? diff_dst_transform_bwd_weights<true>
1071 : diff_dst_transform_bwd_weights<false>;
1072
1073 array_offset_calculator<float, 5> src((float *)ptr_src,
1074 jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w);
1075 array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst,
1076 jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w);
1077 array_offset_calculator<float, 6> diff_weights(ptr_diff_weights,
1078 jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
1079 array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias()
1080 ? scratchpad.get<float>(key_conv_padded_bias) : ptr_diff_bias,
1081 jcp.oc/simd_w, simd_w);
1082
1083 array_offset_calculator<float, 8> U(
1084 scratchpad.get<float>(key_wino_U),
1085 jcp.nb_ic, jcp.nb_oc,
1086 alpha, alpha,
1087 jcp.oc_block, jcp.ic_block,
1088 jcp.ic_simd_block, jcp.oc_simd_block);
1089
1090 array_offset_calculator<float, 8> M(
1091 scratchpad.get<float>(key_wino_M),
1092 jcp.nb_oc, alpha, alpha,
1093 jcp.tile_block, jcp.oc_block,
1094 jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
1095 jcp.oc_simd_block);
1096 array_offset_calculator<float, 8> V(
1097 scratchpad.get<float>(key_wino_V),
1098 jcp.nb_ic, alpha, alpha,
1099 jcp.tile_block, jcp.ic_block,
1100 jcp.nb_tile_block_ur, jcp.tile_block_ur,
1101 jcp.ic_simd_block * jcp.tile_4fma);
1102
1103 const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
1104 * jcp.ic_simd_block;
1105 array_offset_calculator<float, 2> trans_buffer(
1106 scratchpad.get<float>(key_conv_tr_src),
1107 nthreads,
1108 trans_buffer_size);
1109
1110 array_offset_calculator<float, 2> diff_bias_prv(
1111 scratchpad.get<float>(key_conv_bia_reduction),
1112 nthreads,
1113 jcp.oc);
1114
1115PRAGMA_OMP(parallel num_threads(nthreads))
1116 {
1117 if (jcp.with_bias) {
1118 parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
1119 diff_bias_prv(ithr, ofm) = 0.0f;
1120 });
1121
1122PRAGMA_OMP(for nowait)
1123 for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
1124 PRAGMA_OMP_SIMD()
1125 for (int v = 0; v < simd_w; v++)
1126 diff_bias(bofm, v) = 0.0f;
1127 }
1128 }
1129
1130 const int ithread = mkldnn_get_thread_num();
1131
1132 parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block,
1133 [&](int img, int ifm1, int ifm2) {
1134 float *transb = jcp.ver == ver_4fma
1135 ? &(trans_buffer(ithread, 0))
1136 : NULL;
1137 diff_src_transform_bwd_weights_ver(img, jcp,
1138 &(src(img, ifm1 * jcp.ic_block + ifm2,
1139 0, 0, 0)),
1140 &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)),
1141 transb,
1142 kernel_->transpose_4fma_ker);
1143 });
1144
1145 parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block,
1146 [&](int img, int ofm1, int ofm2) {
1147 float *dbias = jcp.with_bias
1148 ? &(diff_bias_prv(ithread,
1149 simd_w * (ofm1 * jcp.oc_block + ofm2)))
1150 : NULL;
1151 diff_dst_transform_bwd_weights_ver(img, jcp,
1152 &(diff_dst(img, ofm1 * jcp.oc_block + ofm2,
1153 0, 0, 0)),
1154 &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)),
1155 dbias);
1156 });
1157
1158PRAGMA_OMP(barrier)
1159
1160 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
1161 parallel_nd_in_omp(alpha, alpha, jcp.nb_oc,
1162 [&](int oj, int oi, int ofm1) {
1163 kernel_->gemm_loop_ker_first_iter(
1164 (float *)&(U(ifm1, ofm1, oj, oi,
1165 0, 0, 0, 0)),
1166 (const float *)&(M(ofm1, oj, oi,
1167 0, 0, 0, 0, 0)),
1168 (const float *)&(V(ifm1, oj, oi,
1169 0, 0, 0, 0, 0)));
1170 for (int tile_block = 1; tile_block < jcp.tile_block;
1171 tile_block++) {
1172 kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1,
1173 oj, oi,
1174 0, 0, 0, 0)),
1175 (const float *)&(M(ofm1, oj, oi, tile_block,
1176 0, 0, 0, 0)),
1177 (const float *)&(V(ifm1, oj, oi, tile_block,
1178 0, 0, 0, 0)));
1179 }
1180 });
1181 }
1182
1183PRAGMA_OMP(barrier)
1184
1185 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
1186 [&](int ifm1, int ofm1, int ofm2, int ifm2) {
1187 diff_weights_transform_bwd_weights(jcp,
1188 &(diff_weights(ofm1 * jcp.oc_block + ofm2,
1189 ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)),
1190 &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0)));
1191 });
1192
1193 if (jcp.with_bias) {
1194PRAGMA_OMP(for)
1195 for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
1196 for (int ithr = 0; ithr < nthreads; ithr++) {
1197 float* base_bias_ptr = &(diff_bias(ofm1, 0));
1198 float* base_bias_prv_ptr = &(diff_bias_prv(
1199 ithr * jcp.oc + ofm1 * simd_w));
1200 PRAGMA_OMP_SIMD()
1201 for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
1202 base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
1203 }
1204 }
1205 }
1206 }
1207 }
1208
1209 _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad);
1210}
1211
1212}
1213}
1214}
1215// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
1216