1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_JIT_UNI_ELTWISE_HPP |
18 | #define CPU_JIT_UNI_ELTWISE_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "type_helpers.hpp" |
24 | #include "utils.hpp" |
25 | |
26 | #include "cpu_eltwise_pd.hpp" |
27 | #include "cpu_primitive.hpp" |
28 | |
29 | #include "jit_generator.hpp" |
30 | |
31 | namespace mkldnn { |
32 | namespace impl { |
33 | namespace cpu { |
34 | |
35 | template <cpu_isa_t isa> |
36 | struct jit_uni_eltwise_injector_f32 { |
37 | using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm, |
38 | isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; |
39 | |
40 | jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, |
41 | float alpha, float beta, bool save_state = true, |
42 | Xbyak::Reg64 p_table = Xbyak::util::rax, |
43 | Xbyak::Opmask k_mask = Xbyak::Opmask(1)) |
44 | : alg_(alg), alpha_(alpha), beta_(beta), h(host) |
45 | , save_state_(save_state), p_table(p_table), k_mask(k_mask) |
46 | { |
47 | using namespace alg_kind; |
48 | assert(utils::one_of(isa, sse42, avx2, avx512_common)); |
49 | assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, |
50 | eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, |
51 | eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); |
52 | } |
53 | |
54 | // note that eltwise.scale is ignored |
55 | jit_uni_eltwise_injector_f32(jit_generator *host, |
56 | const post_ops_t::entry_t::eltwise_t &eltwise, |
57 | bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax, |
58 | Xbyak::Opmask k_mask = Xbyak::Opmask(1)) |
59 | : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, |
60 | eltwise.beta, save_state, p_table, k_mask) {} |
61 | |
62 | void compute_vector_range(size_t start_idx, size_t end_idx); |
63 | void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); } |
64 | void prepare_table(bool gen_table=true); |
65 | void load_table_addr() { h->mov(p_table, l_table); } |
66 | |
67 | const alg_kind_t alg_; |
68 | const float alpha_; |
69 | const float beta_; |
70 | |
71 | jit_generator * const h; |
72 | |
73 | const bool save_state_; |
74 | const Xbyak::Reg64 p_table; |
75 | const Xbyak::Opmask k_mask; |
76 | Xbyak::Label l_table; |
77 | |
78 | private: |
79 | // if only the injector was inherited from jit_generator... |
80 | enum { |
81 | _cmp_le_os = jit_generator::_cmp_le_os, |
82 | _cmp_nle_us = jit_generator::_cmp_nle_us, |
83 | _op_floor = jit_generator::_op_floor, |
84 | }; |
85 | |
86 | size_t vlen = cpu_isa_traits<isa>::vlen; |
87 | |
88 | const static size_t preserved_vecs_max = 5; |
89 | |
90 | size_t vecs_to_preserve = 0; |
91 | size_t vecs_count = isa == avx512_common ? 32 : 16; |
92 | size_t preserved_vecs_count = 0; |
93 | size_t preserved_vec_idxs[preserved_vecs_max] = {0}; |
94 | size_t start_idx_tail = 0; |
95 | |
96 | Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4; |
97 | |
98 | Xbyak::Address table_val(int index) |
99 | { return h->ptr[p_table + index * vlen]; } |
100 | |
101 | int aux_vecs_count(alg_kind_t alg); |
102 | |
103 | void compute_body(size_t start_idx, size_t end_idx); |
104 | void injector_preamble(size_t start_idx, size_t end_idx); |
105 | void injector_preamble_tail(size_t start_idx); |
106 | void injector_postamble(); |
107 | void assign_regs(); |
108 | |
109 | void exp_compute_vector(const Vmm &vmm_src); |
110 | void relu_compute_vector(const Vmm &vmm_src); |
111 | void relu_zero_ns_compute_vector(const Vmm &vmm_src); |
112 | void elu_compute_vector(const Vmm &vmm_src); |
113 | void tanh_compute_vector(const Vmm &vmm_src); |
114 | void square_compute_vector(const Vmm &vmm_src); |
115 | void abs_compute_vector(const Vmm &vmm_src); |
116 | void sqrt_compute_vector(const Vmm &vmm_src); |
117 | void linear_compute_vector(const Vmm &vmm_src); |
118 | void bounded_relu_compute_vector(const Vmm &vmm_src); |
119 | void soft_relu_compute_vector(const Vmm &vmm_src); |
120 | void logistic_compute_vector(const Vmm &vmm_src); |
121 | |
122 | void relu_prepare_table(); |
123 | void elu_prepare_table(); |
124 | void soft_relu_prepare_table(); |
125 | void abs_prepare_table(); |
126 | void sqrt_prepare_table(); |
127 | void linear_prepare_table(); |
128 | void bounded_relu_prepare_table(); |
129 | }; |
130 | |
131 | struct jit_uni_eltwise_kernel_f32; |
132 | |
133 | template <cpu_isa_t isa> |
134 | struct jit_uni_eltwise_fwd_t : public cpu_primitive_t { |
135 | struct pd_t : public cpu_eltwise_fwd_pd_t { |
136 | using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; |
137 | |
138 | DECLARE_COMMON_PD_T( |
139 | JIT_IMPL_NAME_HELPER("jit:" , isa, "" ), |
140 | jit_uni_eltwise_fwd_t<isa>); |
141 | |
142 | status_t init(); |
143 | }; |
144 | |
145 | jit_uni_eltwise_fwd_t(const pd_t *apd); |
146 | ~jit_uni_eltwise_fwd_t(); |
147 | |
148 | typedef typename prec_traits<data_type::f32>::type data_t; |
149 | |
150 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
151 | execute_forward(ctx); |
152 | return status::success; |
153 | } |
154 | |
155 | private: |
156 | void execute_forward(const exec_ctx_t &ctx) const; |
157 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
158 | jit_uni_eltwise_kernel_f32 *kernel_; |
159 | }; |
160 | |
161 | template <cpu_isa_t isa> |
162 | struct jit_uni_eltwise_bwd_t : public cpu_primitive_t { |
163 | struct pd_t : public cpu_eltwise_bwd_pd_t { |
164 | using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; |
165 | |
166 | DECLARE_COMMON_PD_T( |
167 | JIT_IMPL_NAME_HELPER("jit:" , isa, "" ), |
168 | jit_uni_eltwise_bwd_t<isa>); |
169 | |
170 | status_t init(); |
171 | }; |
172 | |
173 | jit_uni_eltwise_bwd_t(const pd_t *apd); |
174 | ~jit_uni_eltwise_bwd_t(); |
175 | |
176 | typedef typename prec_traits<data_type::f32>::type data_t; |
177 | |
178 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
179 | execute_backward(ctx); |
180 | return status::success; |
181 | } |
182 | |
183 | private: |
184 | void execute_backward(const exec_ctx_t &ctx) const; |
185 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
186 | jit_uni_eltwise_kernel_f32 *kernel_; |
187 | }; |
188 | |
189 | } |
190 | } |
191 | } |
192 | |
193 | #endif |
194 | |