| 1 | /******************************************************************************* | 
| 2 | * Copyright 2017-2018 Intel Corporation | 
| 3 | * | 
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); | 
| 5 | * you may not use this file except in compliance with the License. | 
| 6 | * You may obtain a copy of the License at | 
| 7 | * | 
| 8 | *     http://www.apache.org/licenses/LICENSE-2.0 | 
| 9 | * | 
| 10 | * Unless required by applicable law or agreed to in writing, software | 
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, | 
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
| 13 | * See the License for the specific language governing permissions and | 
| 14 | * limitations under the License. | 
| 15 | *******************************************************************************/ | 
| 16 |  | 
| 17 | #ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP | 
| 18 | #define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP | 
| 19 |  | 
| 20 | #include "c_types_map.hpp" | 
| 21 | #include "memory_tracking.hpp" | 
| 22 | #include "mkldnn_thread.hpp" | 
| 23 |  | 
| 24 | #include "cpu_convolution_pd.hpp" | 
| 25 | #include "cpu_primitive.hpp" | 
| 26 |  | 
| 27 | #include "jit_avx512_common_conv_winograd_kernel_f32.hpp" | 
| 28 |  | 
| 29 | namespace mkldnn { | 
| 30 | namespace impl { | 
| 31 | namespace cpu { | 
| 32 |  | 
| 33 | namespace winograd_avx512_common { | 
| 34 | inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, | 
| 35 |         const jit_conv_winograd_conf_t &jcp) { | 
| 36 |     using namespace memory_tracking::names; | 
| 37 |  | 
| 38 |     size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; | 
| 39 |     size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic | 
| 40 |         * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); | 
| 41 |     size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc | 
| 42 |         * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); | 
| 43 |  | 
| 44 |     scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); | 
| 45 |     scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); | 
| 46 |     scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); | 
| 47 |  | 
| 48 |     if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) { | 
| 49 |         const int nthr = mkldnn_get_max_threads(); | 
| 50 |  | 
| 51 |         size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr | 
| 52 |             * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block; | 
| 53 |         scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M); | 
| 54 |  | 
| 55 |         size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0; | 
| 56 |         scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); | 
| 57 |  | 
| 58 |         size_t padded_bias_sz = | 
| 59 |             jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0; | 
| 60 |         scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz); | 
| 61 |     } | 
| 62 | } | 
| 63 | } | 
| 64 |  | 
| 65 | template <bool is_fwd> | 
| 66 | struct _jit_avx512_common_convolution_winograd_t { | 
| 67 |     _jit_avx512_common_convolution_winograd_t( | 
| 68 |             const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) | 
| 69 |         : kernel_(nullptr), attr_(attr) { | 
| 70 |         kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp); | 
| 71 |     } | 
| 72 |  | 
| 73 |     ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; } | 
| 74 |  | 
| 75 |     protected: | 
| 76 |         void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, | 
| 77 |                 float *wei_ptr, float *bias_ptr, | 
| 78 |                 const memory_tracking::grantor_t &scratchpad) const; | 
| 79 |         _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_; | 
| 80 |         const primitive_attr_t *attr_; | 
| 81 | }; | 
| 82 |  | 
| 83 | struct jit_avx512_common_convolution_winograd_fwd_t | 
| 84 |      : _jit_avx512_common_convolution_winograd_t<true> | 
| 85 |      , public cpu_primitive_t | 
| 86 |     { | 
| 87 |     struct pd_t : public cpu_convolution_fwd_pd_t { | 
| 88 |         pd_t(engine_t *engine, const convolution_desc_t *adesc, | 
| 89 |                 const primitive_attr_t *attr, | 
| 90 |                 const typename pd_t::base_class *hint_fwd_pd) | 
| 91 |             : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) | 
| 92 |             , jcp_() {} | 
| 93 |  | 
| 94 |         DECLARE_COMMON_PD_T( | 
| 95 |                 JIT_IMPL_NAME_HELPER("jit_wino:" , avx512_common, "" ), | 
| 96 |                 jit_avx512_common_convolution_winograd_fwd_t); | 
| 97 |  | 
| 98 |         status_t init() { | 
| 99 |             bool ok = true | 
| 100 |                 && is_fwd() | 
| 101 |                 && utils::one_of(desc()->alg_kind, | 
| 102 |                         alg_kind::convolution_auto, | 
| 103 |                         alg_kind::convolution_winograd) | 
| 104 |                 && expect_data_types(data_type::f32, data_type::f32, | 
| 105 |                         data_type::f32, data_type::f32, data_type::f32) | 
| 106 |                 && !has_zero_dim_memory() | 
| 107 |                 && set_default_formats(); | 
| 108 |             if (!ok) return status::unimplemented; | 
| 109 |  | 
| 110 |             status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32:: | 
| 111 |                 init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), | 
| 112 |                         *attr()); | 
| 113 |             if (status != status::success) return status; | 
| 114 |             set_default_alg_kind(alg_kind::convolution_winograd); | 
| 115 |  | 
| 116 |             auto scratchpad = scratchpad_registry().registrar(); | 
| 117 |             winograd_avx512_common::init_scratchpad(scratchpad, jcp_); | 
| 118 |  | 
| 119 |             return status; | 
| 120 |         } | 
| 121 |  | 
| 122 |         jit_conv_winograd_conf_t jcp_; | 
| 123 |  | 
| 124 |     protected: | 
| 125 |         bool set_default_formats() { | 
| 126 |             using namespace format_tag; | 
| 127 |             auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; | 
| 128 |             return set_default_formats_common(nChw16c, wei_tag, nChw16c); | 
| 129 |         } | 
| 130 |     }; | 
| 131 |  | 
| 132 |     jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd) | 
| 133 |         : _jit_avx512_common_convolution_winograd_t<true>(apd->jcp_, apd->attr()) | 
| 134 |         , cpu_primitive_t(apd, true) {} | 
| 135 |  | 
| 136 |     ~jit_avx512_common_convolution_winograd_fwd_t(){}; | 
| 137 |  | 
| 138 |     typedef typename prec_traits<data_type::f32>::type data_t; | 
| 139 |  | 
| 140 |     virtual status_t execute(const exec_ctx_t &ctx) const override | 
| 141 |     { | 
| 142 |         auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); | 
| 143 |         auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); | 
| 144 |         auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); | 
| 145 |         auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); | 
| 146 |         this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, | 
| 147 |                 (float *)bias, this->scratchpad(ctx)); | 
| 148 |         return status::success; | 
| 149 |     } | 
| 150 |  | 
| 151 | private: | 
| 152 |     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } | 
| 153 | }; | 
| 154 |  | 
| 155 | struct jit_avx512_common_convolution_winograd_bwd_data_t | 
| 156 |         : _jit_avx512_common_convolution_winograd_t<false>, | 
| 157 |         public cpu_primitive_t { | 
| 158 |     struct pd_t : public cpu_convolution_bwd_data_pd_t { | 
| 159 |         pd_t(engine_t *engine, const convolution_desc_t *adesc, | 
| 160 |                 const primitive_attr_t *attr, | 
| 161 |                 const convolution_fwd_pd_t *hint_fwd_pd) | 
| 162 |             : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) | 
| 163 |             , jcp_() {} | 
| 164 |  | 
| 165 |         DECLARE_COMMON_PD_T( | 
| 166 |                 JIT_IMPL_NAME_HELPER("jit_wino:" , avx512_common, "" ), | 
| 167 |                 jit_avx512_common_convolution_winograd_bwd_data_t); | 
| 168 |  | 
| 169 |         status_t init() { | 
| 170 |             bool ok = true | 
| 171 |                 && desc()->prop_kind == prop_kind::backward_data | 
| 172 |                 && expect_data_types(data_type::f32, data_type::f32, | 
| 173 |                         data_type::undef, data_type::f32, data_type::f32) | 
| 174 |                 && utils::one_of(desc()->alg_kind, | 
| 175 |                         alg_kind::convolution_auto, | 
| 176 |                         alg_kind::convolution_winograd) | 
| 177 |                 && !has_zero_dim_memory() | 
| 178 |                 && set_default_formats() | 
| 179 |                 && mkldnn_thr_syncable(); | 
| 180 |             if (!ok) return status::unimplemented; | 
| 181 |  | 
| 182 |             status_t status = | 
| 183 |                 jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( | 
| 184 |                         jcp_, *desc(), *diff_src_md(), *weights_md(), | 
| 185 |                         *diff_dst_md()); | 
| 186 |             if (status != status::success) return status; | 
| 187 |             set_default_alg_kind(alg_kind::convolution_winograd); | 
| 188 |  | 
| 189 |             auto scratchpad = scratchpad_registry().registrar(); | 
| 190 |             winograd_avx512_common::init_scratchpad(scratchpad, jcp_); | 
| 191 |  | 
| 192 |             return status; | 
| 193 |         } | 
| 194 |  | 
| 195 |         jit_conv_winograd_conf_t jcp_; | 
| 196 |  | 
| 197 |     protected: | 
| 198 |         bool set_default_formats() { | 
| 199 |             using namespace format_tag; | 
| 200 |             auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; | 
| 201 |             return set_default_formats_common(nChw16c, wei_tag, nChw16c); | 
| 202 |         } | 
| 203 |     }; | 
| 204 |  | 
| 205 |     jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd) | 
| 206 |         : _jit_avx512_common_convolution_winograd_t<false>(apd->jcp_, apd->attr()) | 
| 207 |         , cpu_primitive_t(apd, true) {} | 
| 208 |  | 
| 209 |     ~jit_avx512_common_convolution_winograd_bwd_data_t(){}; | 
| 210 |  | 
| 211 |     typedef typename prec_traits<data_type::f32>::type data_t; | 
| 212 |  | 
| 213 |     virtual status_t execute(const exec_ctx_t &ctx) const override { | 
| 214 |         auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); | 
| 215 |         auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); | 
| 216 |         auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); | 
| 217 |         this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, | 
| 218 |                 (float *)weights, nullptr, this->scratchpad(ctx)); | 
| 219 |         return status::success; | 
| 220 |     } | 
| 221 |  | 
| 222 | private: | 
| 223 |     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } | 
| 224 | }; | 
| 225 |  | 
| 226 | struct jit_avx512_common_convolution_winograd_bwd_weights_t | 
| 227 |         : public cpu_primitive_t { | 
| 228 |     struct pd_t : public cpu_convolution_bwd_weights_pd_t { | 
| 229 |         pd_t(engine_t *engine, const convolution_desc_t *adesc, | 
| 230 |                 const primitive_attr_t *attr, | 
| 231 |                 const convolution_fwd_pd_t *hint_fwd_pd) | 
| 232 |             : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, | 
| 233 |                     hint_fwd_pd) | 
| 234 |             , jcp_() {} | 
| 235 |  | 
| 236 |         DECLARE_COMMON_PD_T( | 
| 237 |                 JIT_IMPL_NAME_HELPER("jit_wino:" , avx512_common, "" ), | 
| 238 |                 jit_avx512_common_convolution_winograd_bwd_weights_t); | 
| 239 |  | 
| 240 |         status_t init() { | 
| 241 |             bool ok = true | 
| 242 |                 && desc()->prop_kind == prop_kind::backward_weights | 
| 243 |                 && utils::one_of(desc()->alg_kind, | 
| 244 |                         alg_kind::convolution_auto, | 
| 245 |                         alg_kind::convolution_winograd) | 
| 246 |                 && expect_data_types(data_type::f32, data_type::f32, | 
| 247 |                         data_type::f32, data_type::f32, data_type::f32) | 
| 248 |                 && !has_zero_dim_memory() | 
| 249 |                 && set_default_formats() | 
| 250 |                 && mkldnn_thr_syncable(); | 
| 251 |             if (!ok) return status::unimplemented; | 
| 252 |  | 
| 253 |             status_t status = | 
| 254 |                 jit_avx512_common_conv_winograd_bwd_weights_kernel_f32:: | 
| 255 |                 init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), | 
| 256 |                         *diff_weights_md()); | 
| 257 |             if (status != status::success) return status; | 
| 258 |             set_default_alg_kind(alg_kind::convolution_winograd); | 
| 259 |  | 
| 260 |             auto scratchpad = scratchpad_registry().registrar(); | 
| 261 |             winograd_avx512_common::init_scratchpad(scratchpad, jcp_); | 
| 262 |  | 
| 263 |             return status; | 
| 264 |         } | 
| 265 |  | 
| 266 |         jit_conv_winograd_conf_t jcp_; | 
| 267 |  | 
| 268 |     protected: | 
| 269 |         bool set_default_formats() { | 
| 270 |             using namespace format_tag; | 
| 271 |             auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; | 
| 272 |             return set_default_formats_common(nChw16c, wei_tag, nChw16c); | 
| 273 |         } | 
| 274 |     }; | 
| 275 |  | 
| 276 |     jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd) | 
| 277 |         : cpu_primitive_t(apd, true), kernel_(nullptr) | 
| 278 |     { | 
| 279 |         kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( | 
| 280 |                 pd()->jcp_); | 
| 281 |     } | 
| 282 |  | 
| 283 |     ~jit_avx512_common_convolution_winograd_bwd_weights_t() | 
| 284 |     { delete kernel_; } | 
| 285 |  | 
| 286 |     typedef typename prec_traits<data_type::f32>::type data_t; | 
| 287 |  | 
| 288 |     virtual status_t execute(const exec_ctx_t &ctx) const override | 
| 289 |     { | 
| 290 |         _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx)); | 
| 291 |         return status::success; | 
| 292 |     } | 
| 293 |  | 
| 294 | private: | 
| 295 |     void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, | 
| 296 |             const memory_tracking::grantor_t &scratchpad) const; | 
| 297 |     void _maybe_execute_diff_bias_copy(float *diff_bias, | 
| 298 |             const memory_tracking::grantor_t &scratchpad) const; | 
| 299 |  | 
| 300 |     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } | 
| 301 |     jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_; | 
| 302 | }; | 
| 303 |  | 
| 304 | void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]); | 
| 305 | void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]); | 
| 306 | void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]); | 
| 307 | void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]); | 
| 308 | void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]); | 
| 309 | void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]); | 
| 310 | void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]); | 
| 311 |  | 
| 312 | } | 
| 313 | } | 
| 314 | } | 
| 315 |  | 
| 316 | #endif | 
| 317 |  | 
| 318 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s | 
| 319 |  |