1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_JIT_UNI_ELTWISE_HPP
18#define CPU_JIT_UNI_ELTWISE_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26#include "cpu_eltwise_pd.hpp"
27#include "cpu_primitive.hpp"
28
29#include "jit_generator.hpp"
30
31namespace mkldnn {
32namespace impl {
33namespace cpu {
34
35template <cpu_isa_t isa>
36struct jit_uni_eltwise_injector_f32 {
37 using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
38 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
39
40 jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
41 float alpha, float beta, bool save_state = true,
42 Xbyak::Reg64 p_table = Xbyak::util::rax,
43 Xbyak::Opmask k_mask = Xbyak::Opmask(1))
44 : alg_(alg), alpha_(alpha), beta_(beta), h(host)
45 , save_state_(save_state), p_table(p_table), k_mask(k_mask)
46 {
47 using namespace alg_kind;
48 assert(utils::one_of(isa, sse42, avx2, avx512_common));
49 assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
50 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
51 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
52 }
53
54 // note that eltwise.scale is ignored
55 jit_uni_eltwise_injector_f32(jit_generator *host,
56 const post_ops_t::entry_t::eltwise_t &eltwise,
57 bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
58 Xbyak::Opmask k_mask = Xbyak::Opmask(1))
59 : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
60 eltwise.beta, save_state, p_table, k_mask) {}
61
62 void compute_vector_range(size_t start_idx, size_t end_idx);
63 void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
64 void prepare_table(bool gen_table=true);
65 void load_table_addr() { h->mov(p_table, l_table); }
66
67 const alg_kind_t alg_;
68 const float alpha_;
69 const float beta_;
70
71 jit_generator * const h;
72
73 const bool save_state_;
74 const Xbyak::Reg64 p_table;
75 const Xbyak::Opmask k_mask;
76 Xbyak::Label l_table;
77
78private:
79 // if only the injector was inherited from jit_generator...
80 enum {
81 _cmp_le_os = jit_generator::_cmp_le_os,
82 _cmp_nle_us = jit_generator::_cmp_nle_us,
83 _op_floor = jit_generator::_op_floor,
84 };
85
86 size_t vlen = cpu_isa_traits<isa>::vlen;
87
88 const static size_t preserved_vecs_max = 5;
89
90 size_t vecs_to_preserve = 0;
91 size_t vecs_count = isa == avx512_common ? 32 : 16;
92 size_t preserved_vecs_count = 0;
93 size_t preserved_vec_idxs[preserved_vecs_max] = {0};
94 size_t start_idx_tail = 0;
95
96 Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4;
97
98 Xbyak::Address table_val(int index)
99 { return h->ptr[p_table + index * vlen]; }
100
101 int aux_vecs_count(alg_kind_t alg);
102
103 void compute_body(size_t start_idx, size_t end_idx);
104 void injector_preamble(size_t start_idx, size_t end_idx);
105 void injector_preamble_tail(size_t start_idx);
106 void injector_postamble();
107 void assign_regs();
108
109 void exp_compute_vector(const Vmm &vmm_src);
110 void relu_compute_vector(const Vmm &vmm_src);
111 void relu_zero_ns_compute_vector(const Vmm &vmm_src);
112 void elu_compute_vector(const Vmm &vmm_src);
113 void tanh_compute_vector(const Vmm &vmm_src);
114 void square_compute_vector(const Vmm &vmm_src);
115 void abs_compute_vector(const Vmm &vmm_src);
116 void sqrt_compute_vector(const Vmm &vmm_src);
117 void linear_compute_vector(const Vmm &vmm_src);
118 void bounded_relu_compute_vector(const Vmm &vmm_src);
119 void soft_relu_compute_vector(const Vmm &vmm_src);
120 void logistic_compute_vector(const Vmm &vmm_src);
121
122 void relu_prepare_table();
123 void elu_prepare_table();
124 void soft_relu_prepare_table();
125 void abs_prepare_table();
126 void sqrt_prepare_table();
127 void linear_prepare_table();
128 void bounded_relu_prepare_table();
129};
130
131struct jit_uni_eltwise_kernel_f32;
132
133template <cpu_isa_t isa>
134struct jit_uni_eltwise_fwd_t : public cpu_primitive_t {
135 struct pd_t : public cpu_eltwise_fwd_pd_t {
136 using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
137
138 DECLARE_COMMON_PD_T(
139 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
140 jit_uni_eltwise_fwd_t<isa>);
141
142 status_t init();
143 };
144
145 jit_uni_eltwise_fwd_t(const pd_t *apd);
146 ~jit_uni_eltwise_fwd_t();
147
148 typedef typename prec_traits<data_type::f32>::type data_t;
149
150 virtual status_t execute(const exec_ctx_t &ctx) const override {
151 execute_forward(ctx);
152 return status::success;
153 }
154
155private:
156 void execute_forward(const exec_ctx_t &ctx) const;
157 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
158 jit_uni_eltwise_kernel_f32 *kernel_;
159};
160
161template <cpu_isa_t isa>
162struct jit_uni_eltwise_bwd_t : public cpu_primitive_t {
163 struct pd_t : public cpu_eltwise_bwd_pd_t {
164 using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
165
166 DECLARE_COMMON_PD_T(
167 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
168 jit_uni_eltwise_bwd_t<isa>);
169
170 status_t init();
171 };
172
173 jit_uni_eltwise_bwd_t(const pd_t *apd);
174 ~jit_uni_eltwise_bwd_t();
175
176 typedef typename prec_traits<data_type::f32>::type data_t;
177
178 virtual status_t execute(const exec_ctx_t &ctx) const override {
179 execute_backward(ctx);
180 return status::success;
181 }
182
183private:
184 void execute_backward(const exec_ctx_t &ctx) const;
185 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
186 jit_uni_eltwise_kernel_f32 *kernel_;
187};
188
189}
190}
191}
192
193#endif
194