1/*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#ifndef CPU_RNN_REORDERS_HPP
18#define CPU_RNN_REORDERS_HPP
19
20#include <assert.h>
21
22#include "type_helpers.hpp"
23#include "mkldnn_thread.hpp"
24#include "utils.hpp"
25#include "simple_q10n.hpp"
26#include "cpu_reorder_pd.hpp"
27#include "../gemm/os_blas.hpp"
28
29namespace mkldnn {
30namespace impl {
31namespace cpu {
32
33template <data_type_t type_i, data_type_t type_o>
34struct rnn_data_reorder_t : public cpu_primitive_t {
35 struct pd_t : public cpu_reorder_pd_t {
36 using cpu_reorder_pd_t::cpu_reorder_pd_t;
37
38 DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t);
39
40 static status_t create(reorder_pd_t **reorder_pd,
41 engine_t *engine, const primitive_attr_t *attr,
42 engine_t *src_engine, const memory_desc_t *src_md,
43 engine_t *dst_engine, const memory_desc_t *dst_md) {
44 const memory_desc_wrapper id(src_md), od(dst_md);
45 bool args_ok = true
46 && id.data_type() == type_i
47 && od.data_type() == type_o
48 && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc)
49 && od == id;
50 if (!args_ok) return status::invalid_arguments;
51
52 auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
53 dst_md);
54 if (_pd == nullptr) return out_of_memory;
55 if (_pd->init() != success) { delete _pd; return unimplemented; }
56 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
57 }
58 };
59
60private:
61 typedef typename prec_traits<type_i>::type in_data_t;
62 typedef typename prec_traits<type_o>::type out_data_t;
63
64 rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
65
66 virtual status_t execute(const exec_ctx_t &ctx) const override {
67 auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
68 auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO);
69 const memory_desc_wrapper &input_d = pd()->src_md();
70 const memory_desc_wrapper &output_d = pd()->dst_md();
71 const size_t nelems = input_d.nelems();
72 const float scale = pd()->attr()->rnn_data_qparams_.scale_;
73 const float shift = pd()->attr()->rnn_data_qparams_.shift_;
74
75 parallel_nd(nelems, [&](size_t i) {
76 float in = (float)input[input_d.off_l(i)] * scale + shift;
77 output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in);
78 });
79
80 return status::success;
81 }
82
83 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
84};
85
86template <data_type_t type_i, data_type_t type_o>
87struct rnn_weights_reorder_t : public cpu_primitive_t {
88 struct pd_t : public cpu_reorder_pd_t {
89 using cpu_reorder_pd_t::cpu_reorder_pd_t;
90
91 DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
92
93 static status_t create(reorder_pd_t **reorder_pd,
94 engine_t *engine, const primitive_attr_t *attr,
95 engine_t *src_engine, const memory_desc_t *src_md,
96 engine_t *dst_engine, const memory_desc_t *dst_md) {
97#if !USE_MKL_PACKED_GEMM
98 return status::unimplemented;
99#endif
100 const memory_desc_wrapper id(src_md), od(dst_md);
101 bool args_ok = true
102 && id.data_type() == type_i
103 && od.data_type() == type_o
104 && od.format_kind() == format_kind::rnn_packed
105 && od.rnn_packed_desc().format == mkldnn_ldigo_p
106 && od.rnn_packed_desc().n_parts == 1
107 && attr != nullptr;
108 if (!args_ok) return status::invalid_arguments;
109
110 format_tag_t itag = id.matches_one_of_tag(
111 format_tag::ldigo, format_tag::ldgoi);
112 if (itag == format_tag::undef) return status::invalid_arguments;
113
114 const int mask = attr->rnn_weights_qparams_.mask_;
115 if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
116
117 auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
118 dst_md);
119 if (_pd == nullptr) return out_of_memory;
120 _pd->itag_ = itag;
121 if (_pd->init() != success) { delete _pd; return unimplemented; }
122 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
123 }
124
125 status_t init() {
126 status_t status = cpu_reorder_pd_t::init();
127 if (status != status::success) return status;
128
129 init_scratchpad();
130
131 return status::success;
132 }
133
134 format_tag_t itag_ = mkldnn_format_tag_undef;
135
136 private:
137 void init_scratchpad() {
138 const memory_desc_wrapper id(src_md());
139 const size_t nelems = id.nelems();
140 const auto &dims = id.dims();
141
142 using namespace memory_tracking::names;
143 auto scratchpad = scratchpad_registry().registrar();
144 size_t quantization_size = sizeof(int8_t) * nelems;
145 size_t reduction_size = itag_ == ldigo
146 ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0]
147 * dims[1] * dims[3] * dims[4]
148 : 0;
149 scratchpad.book(
150 key_reorder_rnn_weights_quantization, quantization_size);
151 scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size);
152 }
153 };
154
155private:
156 typedef typename prec_traits<type_i>::type in_data_t;
157 typedef typename prec_traits<type_o>::type out_data_t;
158
159 rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
160
161 virtual status_t execute(const exec_ctx_t &ctx) const override {
162#if USE_MKL_PACKED_GEMM
163 auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
164 auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO);
165 const memory_desc_wrapper &input_d = pd()->src_md();
166 const memory_desc_wrapper &output_d = pd()->dst_md();
167 const auto &dims = input_d.dims();
168
169 const int L = dims[0];
170 const int D = dims[1];
171 const int I = dims[2];
172 const int G = dims[3];
173 const int O = dims[4];
174
175 const bool is_igo = pd()->itag_ == format_tag::ldigo;
176
177 /* Quantize input & compute compensation */
178 auto quantized = (int8_t * __restrict)scratchpad(ctx).template get<void>(
179 memory_tracking::names::key_reorder_rnn_weights_quantization);
180 auto reduction = (int32_t * __restrict)scratchpad(ctx).template get<void>(
181 memory_tracking::names::key_reorder_rnn_weights_reduction);
182 float *comp = reinterpret_cast<float *>(
183 output + output_d.rnn_packed_desc().offset_compensation);
184 const float *scales = pd()->attr()->rnn_weights_qparams_.scales_;
185 const int mask = pd()->attr()->rnn_weights_qparams_.mask_;
186
187 if (is_igo) {
188 int nthr = mkldnn_get_max_threads();
189 int LD_nthr = nstl::min(L * D, nthr);
190 int I_nthr = nstl::min(I, nthr / LD_nthr);
191 parallel(nthr, [&](const int ithr, const int nthr) {
192 int LD_ithr = -1, LD_s = -1, LD_e = -1;
193 int I_ithr = -1, I_s = -1, I_e = -1;
194 if (ithr < LD_nthr * I_nthr) {
195 LD_ithr = ithr % LD_nthr;
196 I_ithr = ithr / LD_nthr;
197 balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e);
198 balance211(I, I_nthr, I_ithr, I_s, I_e);
199 }
200 int32_t *comp_ithr = reduction + I_ithr * L * D * G * O;
201 for (int ld = LD_s; ld < LD_e; ld++) {
202 for (int go = 0; go < G * O; go++)
203 comp_ithr[ld * G * O + go] = 0;
204 for (int i = I_s; i < I_e; i++) {
205 PRAGMA_OMP_SIMD()
206 for (int go = 0; go < G * O; go++) {
207 const float s = scales[(mask == 0) ? 0 : go];
208 int8_t q = qz_b0<in_data_t, out_data_t>()(
209 input[ld * I * G * O + i * G * O + go], s);
210 quantized[ld * I * G * O + i * G * O + go]
211 = (int32_t)q;
212 comp_ithr[ld * G * O + go] += (int32_t)q;
213 }
214 }
215 }
216 });
217 parallel_nd(L * D * G * O,
218 [&](int s) { comp[s] = saturate<float>(reduction[s]); });
219 for (int i = 1; i < I_nthr; i++) {
220 parallel_nd(L * D * G * O, [&](int s) {
221 comp[s] += saturate<float>(
222 reduction[i * L * D * G * O + s]);
223 });
224 }
225 } else {
226 parallel_nd(L * D, G * O, [&](int ld, int go) {
227 int32_t compensation = 0;
228 const float s = scales[(mask == 0) ? 0 : go];
229 PRAGMA_OMP_SIMD()
230 for (int i = 0; i < I; i++) {
231 int8_t q = qz_b0<in_data_t, out_data_t>()(
232 input[ld * G * O * I + go * I + i], s);
233 compensation += (int32_t)q;
234 quantized[ld * G * O * I + go * I + i] = q;
235 }
236 comp[ld * G * O + go] = saturate<float>(compensation);
237 });
238 }
239
240 /* Pack */
241 auto off_igo = [&](int l, int d, int i, int g, int o) {
242 return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
243 };
244 auto off_goi = [&](int l, int d, int i, int g, int o) {
245 return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
246 };
247 int n_parts = output_d.rnn_packed_desc().n_parts;
248 const size_t *size_packed_cell
249 = output_d.rnn_packed_desc().part_pack_size;
250 const int *parts = output_d.rnn_packed_desc().parts;
251 const int n = output_d.rnn_packed_desc().n;
252 char *to_pack = output;
253 for (int l = 0; l < L; l++) {
254 for (int d = 0; d < D; d++) {
255 for (int p = 0; p < n_parts; p++) {
256 int g = (p > 0) ? parts[p - 1] : 0;
257 int m_p = parts[p] * O;
258 int k_p = I;
259 cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix,
260 is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p,
261 &quantized[is_igo ? off_igo(l, d, 0, g, 0) :
262 off_goi(l, d, g, 0, 0)],
263 is_igo ? G * O : I, to_pack);
264 to_pack += size_packed_cell[p];
265 }
266 }
267 }
268#endif
269 return status::success;
270 }
271
272 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
273};
274
275template <>
276struct rnn_weights_reorder_t<data_type::f32, data_type::f32>
277 : public cpu_primitive_t {
278 struct pd_t : public cpu_reorder_pd_t {
279 using cpu_reorder_pd_t::cpu_reorder_pd_t;
280
281 DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
282
283 static status_t create(reorder_pd_t **reorder_pd,
284 engine_t *engine, const primitive_attr_t *attr,
285 engine_t *src_engine, const memory_desc_t *src_md,
286 engine_t *dst_engine, const memory_desc_t *dst_md) {
287#if !USE_MKL_PACKED_GEMM
288 return status::unimplemented;
289#endif
290 const memory_desc_wrapper id(src_md), od(dst_md);
291 bool args_ok = true
292 && id.data_type() == data_type::f32
293 && od.data_type() == data_type::f32
294 && od.format_kind() == format_kind::rnn_packed
295 && utils::one_of(od.rnn_packed_desc().format,
296 mkldnn_ldigo_p, mkldnn_ldgoi_p)
297 && attr->has_default_values();
298 if (!args_ok) return status::invalid_arguments;
299
300 format_tag_t itag = id.matches_one_of_tag(
301 format_tag::ldigo, format_tag::ldgoi);
302 if (itag == format_tag::undef) return status::invalid_arguments;
303
304 const int mask = attr->rnn_weights_qparams_.mask_;
305 if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
306
307 auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
308 dst_md);
309 if (_pd == nullptr) return out_of_memory;
310 if (_pd->init() != success) { delete _pd; return unimplemented; }
311 _pd->itag_ = itag;
312 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
313 }
314
315 format_tag_t itag_;
316 };
317
318private:
319 rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
320
321 virtual status_t execute(const exec_ctx_t &ctx) const override {
322#if USE_MKL_PACKED_GEMM
323 auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM);
324 auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO);
325 const memory_desc_wrapper &input_d = pd()->src_md();
326 const memory_desc_wrapper &output_d = pd()->dst_md();
327 const auto &dims = input_d.dims();
328 const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc();
329 const int L = dims[0];
330 const int D = dims[1];
331 const int I = dims[2];
332 const int G = dims[3];
333 const int O = dims[4];
334
335 /* Pack */
336 bool cross_case = false
337 || (pd()->itag_ == format_tag::ldigo
338 && rnn_pdata.format == mkldnn_ldgoi_p)
339 || (pd()->itag_ == format_tag::ldgoi
340 && rnn_pdata.format == mkldnn_ldigo_p);
341 auto trans = cross_case ? CblasTrans : CblasNoTrans;
342 int n_parts = rnn_pdata.n_parts;
343 const size_t *size_packed_cell = rnn_pdata.part_pack_size;
344 const int *parts = rnn_pdata.parts;
345 const int n = rnn_pdata.n;
346
347 const bool is_igo = pd()->itag_ == format_tag::ldigo;
348 auto off_igo = [&](int l, int d, int i, int g, int o) {
349 return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
350 };
351 auto off_goi = [&](int l, int d, int i, int g, int o) {
352 return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
353 };
354 for (int l = 0; l < L; l++) {
355 for (int d = 0; d < D; d++) {
356 for (int p = 0; p < n_parts; p++) {
357 int g = (p > 0) ? parts[p - 1] : 0;
358 int m_p = is_igo ? parts[p] * O : I;
359 int k_p = is_igo ? I : parts[p] * O;
360 int ld = is_igo ? G * O : I;
361 cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n,
362 k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) :
363 off_goi(l, d, 0, g, 0)],
364 ld, output);
365 output += size_packed_cell[p] / sizeof(float);
366 }
367 }
368 }
369#endif
370 return status::success;
371 }
372
373 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
374};
375
376} // namespace cpu
377} // namespace impl
378} // namespace mkldnn
379
380#endif
381