1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef JIT_SSE42_CONV_KERNEL_F32_HPP |
18 | #define JIT_SSE42_CONV_KERNEL_F32_HPP |
19 | |
20 | #include "c_types_map.hpp" |
21 | #include "cpu_memory.hpp" |
22 | #include "jit_generator.hpp" |
23 | #include "jit_primitive_conf.hpp" |
24 | #include "jit_uni_eltwise.hpp" |
25 | |
26 | namespace mkldnn { |
27 | namespace impl { |
28 | namespace cpu { |
29 | |
30 | struct jit_sse42_conv_fwd_kernel_f32: public jit_generator { |
31 | jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, |
32 | const primitive_attr_t &attr) |
33 | : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) |
34 | { |
35 | if (jcp.with_eltwise) |
36 | eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this, |
37 | jcp.eltwise); |
38 | |
39 | this->generate(); |
40 | jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); |
41 | } |
42 | |
43 | ~jit_sse42_conv_fwd_kernel_f32() { |
44 | delete eltwise_injector_; |
45 | } |
46 | |
47 | static bool post_ops_ok(jit_conv_conf_t &jcp, |
48 | const primitive_attr_t &attr); |
49 | |
50 | static status_t init_conf(jit_conv_conf_t &jcp, |
51 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
52 | const memory_desc_wrapper &weights_d, |
53 | const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); |
54 | |
55 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32) |
56 | jit_conv_conf_t jcp; |
57 | const primitive_attr_t &attr_; |
58 | void (*jit_ker)(jit_conv_call_s *); |
59 | |
60 | private: |
61 | using reg64_t = const Xbyak::Reg64; |
62 | reg64_t reg_input = rax; |
63 | reg64_t aux_reg_input = r8; |
64 | reg64_t reg_kernel = rdx; |
65 | reg64_t aux_reg_kernel = r9; |
66 | reg64_t reg_output = rsi; |
67 | reg64_t reg_bias = rbx; |
68 | |
69 | reg64_t kj = r10; |
70 | reg64_t oi_iter = r11; |
71 | reg64_t ki_iter = r12; |
72 | reg64_t reg_kh = abi_not_param1; |
73 | reg64_t simd_iter = r15; |
74 | reg64_t reg_oc_blocks = r14; |
75 | reg64_t imm_addr64 = reg_oc_blocks; |
76 | Xbyak::Reg32 reg_ci_flag = r13d; |
77 | |
78 | jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_; |
79 | |
80 | inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, |
81 | int oc_blocks); |
82 | inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); |
83 | inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); |
84 | inline void solve_common(int oc_blocks); |
85 | |
86 | void generate(); |
87 | }; |
88 | |
89 | } |
90 | } |
91 | } |
92 | |
93 | #endif |
94 | |