1 | /******************************************************************************* |
2 | * Copyright 2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_RNN_REORDERS_HPP |
18 | #define CPU_RNN_REORDERS_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "type_helpers.hpp" |
23 | #include "mkldnn_thread.hpp" |
24 | #include "utils.hpp" |
25 | #include "simple_q10n.hpp" |
26 | #include "cpu_reorder_pd.hpp" |
27 | #include "../gemm/os_blas.hpp" |
28 | |
29 | namespace mkldnn { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | template <data_type_t type_i, data_type_t type_o> |
34 | struct rnn_data_reorder_t : public cpu_primitive_t { |
35 | struct pd_t : public cpu_reorder_pd_t { |
36 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
37 | |
38 | DECLARE_COMMON_PD_T("rnn_data_reorder" , rnn_data_reorder_t); |
39 | |
40 | static status_t create(reorder_pd_t **reorder_pd, |
41 | engine_t *engine, const primitive_attr_t *attr, |
42 | engine_t *src_engine, const memory_desc_t *src_md, |
43 | engine_t *dst_engine, const memory_desc_t *dst_md) { |
44 | const memory_desc_wrapper id(src_md), od(dst_md); |
45 | bool args_ok = true |
46 | && id.data_type() == type_i |
47 | && od.data_type() == type_o |
48 | && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc) |
49 | && od == id; |
50 | if (!args_ok) return status::invalid_arguments; |
51 | |
52 | auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, |
53 | dst_md); |
54 | if (_pd == nullptr) return out_of_memory; |
55 | if (_pd->init() != success) { delete _pd; return unimplemented; } |
56 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
57 | } |
58 | }; |
59 | |
60 | private: |
61 | typedef typename prec_traits<type_i>::type in_data_t; |
62 | typedef typename prec_traits<type_o>::type out_data_t; |
63 | |
64 | rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} |
65 | |
66 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
67 | auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); |
68 | auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); |
69 | const memory_desc_wrapper &input_d = pd()->src_md(); |
70 | const memory_desc_wrapper &output_d = pd()->dst_md(); |
71 | const size_t nelems = input_d.nelems(); |
72 | const float scale = pd()->attr()->rnn_data_qparams_.scale_; |
73 | const float shift = pd()->attr()->rnn_data_qparams_.shift_; |
74 | |
75 | parallel_nd(nelems, [&](size_t i) { |
76 | float in = (float)input[input_d.off_l(i)] * scale + shift; |
77 | output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in); |
78 | }); |
79 | |
80 | return status::success; |
81 | } |
82 | |
83 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
84 | }; |
85 | |
86 | template <data_type_t type_i, data_type_t type_o> |
87 | struct rnn_weights_reorder_t : public cpu_primitive_t { |
88 | struct pd_t : public cpu_reorder_pd_t { |
89 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
90 | |
91 | DECLARE_COMMON_PD_T("rnn_weights_reorder" , rnn_weights_reorder_t); |
92 | |
93 | static status_t create(reorder_pd_t **reorder_pd, |
94 | engine_t *engine, const primitive_attr_t *attr, |
95 | engine_t *src_engine, const memory_desc_t *src_md, |
96 | engine_t *dst_engine, const memory_desc_t *dst_md) { |
97 | #if !USE_MKL_PACKED_GEMM |
98 | return status::unimplemented; |
99 | #endif |
100 | const memory_desc_wrapper id(src_md), od(dst_md); |
101 | bool args_ok = true |
102 | && id.data_type() == type_i |
103 | && od.data_type() == type_o |
104 | && od.format_kind() == format_kind::rnn_packed |
105 | && od.rnn_packed_desc().format == mkldnn_ldigo_p |
106 | && od.rnn_packed_desc().n_parts == 1 |
107 | && attr != nullptr; |
108 | if (!args_ok) return status::invalid_arguments; |
109 | |
110 | format_tag_t itag = id.matches_one_of_tag( |
111 | format_tag::ldigo, format_tag::ldgoi); |
112 | if (itag == format_tag::undef) return status::invalid_arguments; |
113 | |
114 | const int mask = attr->rnn_weights_qparams_.mask_; |
115 | if (!utils::one_of(mask, 0, 3)) return status::unimplemented; |
116 | |
117 | auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, |
118 | dst_md); |
119 | if (_pd == nullptr) return out_of_memory; |
120 | _pd->itag_ = itag; |
121 | if (_pd->init() != success) { delete _pd; return unimplemented; } |
122 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
123 | } |
124 | |
125 | status_t init() { |
126 | status_t status = cpu_reorder_pd_t::init(); |
127 | if (status != status::success) return status; |
128 | |
129 | init_scratchpad(); |
130 | |
131 | return status::success; |
132 | } |
133 | |
134 | format_tag_t itag_ = mkldnn_format_tag_undef; |
135 | |
136 | private: |
137 | void init_scratchpad() { |
138 | const memory_desc_wrapper id(src_md()); |
139 | const size_t nelems = id.nelems(); |
140 | const auto &dims = id.dims(); |
141 | |
142 | using namespace memory_tracking::names; |
143 | auto scratchpad = scratchpad_registry().registrar(); |
144 | size_t quantization_size = sizeof(int8_t) * nelems; |
145 | size_t reduction_size = itag_ == ldigo |
146 | ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0] |
147 | * dims[1] * dims[3] * dims[4] |
148 | : 0; |
149 | scratchpad.book( |
150 | key_reorder_rnn_weights_quantization, quantization_size); |
151 | scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size); |
152 | } |
153 | }; |
154 | |
155 | private: |
156 | typedef typename prec_traits<type_i>::type in_data_t; |
157 | typedef typename prec_traits<type_o>::type out_data_t; |
158 | |
159 | rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} |
160 | |
161 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
162 | #if USE_MKL_PACKED_GEMM |
163 | auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); |
164 | auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); |
165 | const memory_desc_wrapper &input_d = pd()->src_md(); |
166 | const memory_desc_wrapper &output_d = pd()->dst_md(); |
167 | const auto &dims = input_d.dims(); |
168 | |
169 | const int L = dims[0]; |
170 | const int D = dims[1]; |
171 | const int I = dims[2]; |
172 | const int G = dims[3]; |
173 | const int O = dims[4]; |
174 | |
175 | const bool is_igo = pd()->itag_ == format_tag::ldigo; |
176 | |
177 | /* Quantize input & compute compensation */ |
178 | auto quantized = (int8_t * __restrict)scratchpad(ctx).template get<void>( |
179 | memory_tracking::names::key_reorder_rnn_weights_quantization); |
180 | auto reduction = (int32_t * __restrict)scratchpad(ctx).template get<void>( |
181 | memory_tracking::names::key_reorder_rnn_weights_reduction); |
182 | float *comp = reinterpret_cast<float *>( |
183 | output + output_d.rnn_packed_desc().offset_compensation); |
184 | const float *scales = pd()->attr()->rnn_weights_qparams_.scales_; |
185 | const int mask = pd()->attr()->rnn_weights_qparams_.mask_; |
186 | |
187 | if (is_igo) { |
188 | int nthr = mkldnn_get_max_threads(); |
189 | int LD_nthr = nstl::min(L * D, nthr); |
190 | int I_nthr = nstl::min(I, nthr / LD_nthr); |
191 | parallel(nthr, [&](const int ithr, const int nthr) { |
192 | int LD_ithr = -1, LD_s = -1, LD_e = -1; |
193 | int I_ithr = -1, I_s = -1, I_e = -1; |
194 | if (ithr < LD_nthr * I_nthr) { |
195 | LD_ithr = ithr % LD_nthr; |
196 | I_ithr = ithr / LD_nthr; |
197 | balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); |
198 | balance211(I, I_nthr, I_ithr, I_s, I_e); |
199 | } |
200 | int32_t *comp_ithr = reduction + I_ithr * L * D * G * O; |
201 | for (int ld = LD_s; ld < LD_e; ld++) { |
202 | for (int go = 0; go < G * O; go++) |
203 | comp_ithr[ld * G * O + go] = 0; |
204 | for (int i = I_s; i < I_e; i++) { |
205 | PRAGMA_OMP_SIMD() |
206 | for (int go = 0; go < G * O; go++) { |
207 | const float s = scales[(mask == 0) ? 0 : go]; |
208 | int8_t q = qz_b0<in_data_t, out_data_t>()( |
209 | input[ld * I * G * O + i * G * O + go], s); |
210 | quantized[ld * I * G * O + i * G * O + go] |
211 | = (int32_t)q; |
212 | comp_ithr[ld * G * O + go] += (int32_t)q; |
213 | } |
214 | } |
215 | } |
216 | }); |
217 | parallel_nd(L * D * G * O, |
218 | [&](int s) { comp[s] = saturate<float>(reduction[s]); }); |
219 | for (int i = 1; i < I_nthr; i++) { |
220 | parallel_nd(L * D * G * O, [&](int s) { |
221 | comp[s] += saturate<float>( |
222 | reduction[i * L * D * G * O + s]); |
223 | }); |
224 | } |
225 | } else { |
226 | parallel_nd(L * D, G * O, [&](int ld, int go) { |
227 | int32_t compensation = 0; |
228 | const float s = scales[(mask == 0) ? 0 : go]; |
229 | PRAGMA_OMP_SIMD() |
230 | for (int i = 0; i < I; i++) { |
231 | int8_t q = qz_b0<in_data_t, out_data_t>()( |
232 | input[ld * G * O * I + go * I + i], s); |
233 | compensation += (int32_t)q; |
234 | quantized[ld * G * O * I + go * I + i] = q; |
235 | } |
236 | comp[ld * G * O + go] = saturate<float>(compensation); |
237 | }); |
238 | } |
239 | |
240 | /* Pack */ |
241 | auto off_igo = [&](int l, int d, int i, int g, int o) { |
242 | return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; |
243 | }; |
244 | auto off_goi = [&](int l, int d, int i, int g, int o) { |
245 | return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; |
246 | }; |
247 | int n_parts = output_d.rnn_packed_desc().n_parts; |
248 | const size_t *size_packed_cell |
249 | = output_d.rnn_packed_desc().part_pack_size; |
250 | const int *parts = output_d.rnn_packed_desc().parts; |
251 | const int n = output_d.rnn_packed_desc().n; |
252 | char *to_pack = output; |
253 | for (int l = 0; l < L; l++) { |
254 | for (int d = 0; d < D; d++) { |
255 | for (int p = 0; p < n_parts; p++) { |
256 | int g = (p > 0) ? parts[p - 1] : 0; |
257 | int m_p = parts[p] * O; |
258 | int k_p = I; |
259 | cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix, |
260 | is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p, |
261 | &quantized[is_igo ? off_igo(l, d, 0, g, 0) : |
262 | off_goi(l, d, g, 0, 0)], |
263 | is_igo ? G * O : I, to_pack); |
264 | to_pack += size_packed_cell[p]; |
265 | } |
266 | } |
267 | } |
268 | #endif |
269 | return status::success; |
270 | } |
271 | |
272 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
273 | }; |
274 | |
275 | template <> |
276 | struct rnn_weights_reorder_t<data_type::f32, data_type::f32> |
277 | : public cpu_primitive_t { |
278 | struct pd_t : public cpu_reorder_pd_t { |
279 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
280 | |
281 | DECLARE_COMMON_PD_T("rnn_weights_reorder" , rnn_weights_reorder_t); |
282 | |
283 | static status_t create(reorder_pd_t **reorder_pd, |
284 | engine_t *engine, const primitive_attr_t *attr, |
285 | engine_t *src_engine, const memory_desc_t *src_md, |
286 | engine_t *dst_engine, const memory_desc_t *dst_md) { |
287 | #if !USE_MKL_PACKED_GEMM |
288 | return status::unimplemented; |
289 | #endif |
290 | const memory_desc_wrapper id(src_md), od(dst_md); |
291 | bool args_ok = true |
292 | && id.data_type() == data_type::f32 |
293 | && od.data_type() == data_type::f32 |
294 | && od.format_kind() == format_kind::rnn_packed |
295 | && utils::one_of(od.rnn_packed_desc().format, |
296 | mkldnn_ldigo_p, mkldnn_ldgoi_p) |
297 | && attr->has_default_values(); |
298 | if (!args_ok) return status::invalid_arguments; |
299 | |
300 | format_tag_t itag = id.matches_one_of_tag( |
301 | format_tag::ldigo, format_tag::ldgoi); |
302 | if (itag == format_tag::undef) return status::invalid_arguments; |
303 | |
304 | const int mask = attr->rnn_weights_qparams_.mask_; |
305 | if (!utils::one_of(mask, 0, 3)) return status::unimplemented; |
306 | |
307 | auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, |
308 | dst_md); |
309 | if (_pd == nullptr) return out_of_memory; |
310 | if (_pd->init() != success) { delete _pd; return unimplemented; } |
311 | _pd->itag_ = itag; |
312 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
313 | } |
314 | |
315 | format_tag_t itag_; |
316 | }; |
317 | |
318 | private: |
319 | rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} |
320 | |
321 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
322 | #if USE_MKL_PACKED_GEMM |
323 | auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM); |
324 | auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO); |
325 | const memory_desc_wrapper &input_d = pd()->src_md(); |
326 | const memory_desc_wrapper &output_d = pd()->dst_md(); |
327 | const auto &dims = input_d.dims(); |
328 | const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); |
329 | const int L = dims[0]; |
330 | const int D = dims[1]; |
331 | const int I = dims[2]; |
332 | const int G = dims[3]; |
333 | const int O = dims[4]; |
334 | |
335 | /* Pack */ |
336 | bool cross_case = false |
337 | || (pd()->itag_ == format_tag::ldigo |
338 | && rnn_pdata.format == mkldnn_ldgoi_p) |
339 | || (pd()->itag_ == format_tag::ldgoi |
340 | && rnn_pdata.format == mkldnn_ldigo_p); |
341 | auto trans = cross_case ? CblasTrans : CblasNoTrans; |
342 | int n_parts = rnn_pdata.n_parts; |
343 | const size_t *size_packed_cell = rnn_pdata.part_pack_size; |
344 | const int *parts = rnn_pdata.parts; |
345 | const int n = rnn_pdata.n; |
346 | |
347 | const bool is_igo = pd()->itag_ == format_tag::ldigo; |
348 | auto off_igo = [&](int l, int d, int i, int g, int o) { |
349 | return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; |
350 | }; |
351 | auto off_goi = [&](int l, int d, int i, int g, int o) { |
352 | return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; |
353 | }; |
354 | for (int l = 0; l < L; l++) { |
355 | for (int d = 0; d < D; d++) { |
356 | for (int p = 0; p < n_parts; p++) { |
357 | int g = (p > 0) ? parts[p - 1] : 0; |
358 | int m_p = is_igo ? parts[p] * O : I; |
359 | int k_p = is_igo ? I : parts[p] * O; |
360 | int ld = is_igo ? G * O : I; |
361 | cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n, |
362 | k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) : |
363 | off_goi(l, d, 0, g, 0)], |
364 | ld, output); |
365 | output += size_packed_cell[p] / sizeof(float); |
366 | } |
367 | } |
368 | } |
369 | #endif |
370 | return status::success; |
371 | } |
372 | |
373 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
374 | }; |
375 | |
376 | } // namespace cpu |
377 | } // namespace impl |
378 | } // namespace mkldnn |
379 | |
380 | #endif |
381 | |