1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP
18#define JIT_SSE42_CONV_KERNEL_F32_HPP
19
20#include "c_types_map.hpp"
21#include "cpu_memory.hpp"
22#include "jit_generator.hpp"
23#include "jit_primitive_conf.hpp"
24#include "jit_uni_eltwise.hpp"
25
26namespace mkldnn {
27namespace impl {
28namespace cpu {
29
30struct jit_sse42_conv_fwd_kernel_f32: public jit_generator {
31 jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
32 const primitive_attr_t &attr)
33 : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
34 {
35 if (jcp.with_eltwise)
36 eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this,
37 jcp.eltwise);
38
39 this->generate();
40 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
41 }
42
43 ~jit_sse42_conv_fwd_kernel_f32() {
44 delete eltwise_injector_;
45 }
46
47 static bool post_ops_ok(jit_conv_conf_t &jcp,
48 const primitive_attr_t &attr);
49
50 static status_t init_conf(jit_conv_conf_t &jcp,
51 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
52 const memory_desc_wrapper &weights_d,
53 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
54
55 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32)
56 jit_conv_conf_t jcp;
57 const primitive_attr_t &attr_;
58 void (*jit_ker)(jit_conv_call_s *);
59
60private:
61 using reg64_t = const Xbyak::Reg64;
62 reg64_t reg_input = rax;
63 reg64_t aux_reg_input = r8;
64 reg64_t reg_kernel = rdx;
65 reg64_t aux_reg_kernel = r9;
66 reg64_t reg_output = rsi;
67 reg64_t reg_bias = rbx;
68
69 reg64_t kj = r10;
70 reg64_t oi_iter = r11;
71 reg64_t ki_iter = r12;
72 reg64_t reg_kh = abi_not_param1;
73 reg64_t simd_iter = r15;
74 reg64_t reg_oc_blocks = r14;
75 reg64_t imm_addr64 = reg_oc_blocks;
76 Xbyak::Reg32 reg_ci_flag = r13d;
77
78 jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_;
79
80 inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
81 int oc_blocks);
82 inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks);
83 inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks);
84 inline void solve_common(int oc_blocks);
85
86 void generate();
87};
88
89}
90}
91}
92
93#endif
94