1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
18#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
19
20#include "c_types_map.hpp"
21#include "cpu_memory.hpp"
22
23#include "jit_generator.hpp"
24#include "jit_primitive_conf.hpp"
25
26namespace mkldnn {
27namespace impl {
28namespace cpu {
29
30//alpha determines the output tile_size
31constexpr int alpha = 6;
32constexpr int tile_size = 4;
33//simd length used for vectorization
34constexpr int simd_w = 16;
35
36struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator {
37 _jit_avx512_common_conv_winograd_data_kernel_f32(
38 jit_conv_winograd_conf_t ajcp)
39 : jcp(ajcp)
40 {
41 //******************* First iter kernel ********************//
42 this->gemm_loop_generate(true);
43 gemm_loop_ker_first_iter
44 = (decltype(gemm_loop_ker_first_iter)) this->getCode();
45
46 //************** Subsequent iterations kernel **************//
47 if (jcp.dimK_nb_block > 1) {
48 align();
49 const Xbyak::uint8 *addr = getCurr();
50 this->gemm_loop_generate(false);
51 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
52 }
53 }
54
55 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32)
56
57 static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
58 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
59 const memory_desc_wrapper &weights_d,
60 const memory_desc_wrapper &dst_d);
61
62 static status_t init_conf_kernel(
63 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
64
65 jit_conv_winograd_conf_t jcp;
66 void (*gemm_loop_ker)(float *, const float *, const float *);
67 void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
68
69protected:
70 using reg64_t = const Xbyak::Reg64;
71 enum { typesize = sizeof(float) };
72
73 void gemm_loop_generate(bool is_beta_zero);
74
75 /* registers used for GEMM */
76 reg64_t reg_dstC = abi_param1;
77 reg64_t reg_srcA = abi_param2;
78 reg64_t reg_srcB = abi_param3;
79
80 reg64_t reg_dimM_block_loop_cnt = r10;
81 reg64_t reg_dimK_block_loop_cnt = r11;
82};
83
84struct jit_avx512_common_conv_winograd_fwd_kernel_f32
85 : _jit_avx512_common_conv_winograd_data_kernel_f32 {
86 using _jit_avx512_common_conv_winograd_data_kernel_f32::
87 _jit_avx512_common_conv_winograd_data_kernel_f32;
88
89 static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
90
91 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
92 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
93 const memory_desc_wrapper &weights_d,
94 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
95};
96
97struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32
98 : public _jit_avx512_common_conv_winograd_data_kernel_f32 {
99 using _jit_avx512_common_conv_winograd_data_kernel_f32::
100 _jit_avx512_common_conv_winograd_data_kernel_f32;
101
102 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
103 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
104 const memory_desc_wrapper &weights_d,
105 const memory_desc_wrapper &diff_dst_d);
106};
107
108struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32
109 : public jit_generator {
110 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32)
111
112 jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
113 jit_conv_winograd_conf_t ajcp)
114 : jcp(ajcp)
115 {
116
117 //******************* First iter kernel ********************//
118 {
119 align();
120 const Xbyak::uint8 *addr = getCurr();
121 this->gemm_loop_generate(true);
122 gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr;
123 }
124
125 if (jcp.tile_block > 1) {
126 align();
127 const Xbyak::uint8 *addr = getCurr();
128 this->gemm_loop_generate(false);
129 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
130 }
131
132 if (jcp.ver == ver_4fma) {
133 align();
134 const Xbyak::uint8 *addr = getCurr();
135 this->transpose_ker_generate();
136 transpose_4fma_ker = (decltype(transpose_4fma_ker))addr;
137 }
138 }
139
140 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
141 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
142 const memory_desc_wrapper &diff_dst_d,
143 const memory_desc_wrapper &diff_weights_d);
144
145 jit_conv_winograd_conf_t jcp;
146 void (*gemm_loop_ker)(float *, const float *, const float *);
147 void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
148 void (*transpose_4fma_ker)(float *, float *);
149
150private:
151 using reg64_t = const Xbyak::Reg64;
152 enum { typesize = sizeof(float) };
153
154 void gemm_loop_generate(bool is_first_tile);
155 void transpose_ker_generate();
156
157 reg64_t reg_origB = abi_param2;
158 reg64_t reg_transB = abi_param1;
159
160 reg64_t reg_dstC = abi_param1;
161 reg64_t reg_srcA_const = abi_param2;
162 reg64_t reg_srcB = abi_param3;
163
164 reg64_t reg_sp = rsp;
165 reg64_t reg_srcA = r9;
166 reg64_t reg_nb_ic = r10;
167 reg64_t reg_loop_cpt = r11;
168 reg64_t reg_transB_idx = r13;
169
170 /* Registers used by new kernel */
171 reg64_t reg_dimM_block_loop_cnt = r10;
172 reg64_t reg_dimK_block_loop_cnt = r12;
173 reg64_t reg_dimN_block_loop_cnt = r11;
174};
175}
176}
177}
178
179#endif
180