1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_WINO_REORDER_HPP |
18 | #define CPU_WINO_REORDER_HPP |
19 | |
20 | #include "mkldnn_thread.hpp" |
21 | |
22 | #include "simple_q10n.hpp" |
23 | |
24 | namespace mkldnn { |
25 | namespace impl { |
26 | namespace cpu { |
27 | |
28 | template <data_type_t type_i, data_type_t type_o> |
29 | struct wino_reorder_t : public cpu_primitive_t { |
30 | struct pd_t : public cpu_reorder_pd_t { |
31 | using cpu_reorder_pd_t::cpu_reorder_pd_t; |
32 | |
33 | DECLARE_COMMON_PD_T("wino_reorder" , wino_reorder_t); |
34 | |
35 | static status_t create(reorder_pd_t **reorder_pd, |
36 | engine_t *engine, const primitive_attr_t *attr, |
37 | engine_t *src_engine, const memory_desc_t *src_md, |
38 | engine_t *dst_engine, const memory_desc_t *dst_md) { |
39 | const memory_desc_wrapper id(src_md), od(dst_md); |
40 | bool args_ok = true |
41 | && id.data_type() == type_i |
42 | && od.data_type() == type_o |
43 | && id.matches_tag(utils::pick(id.ndims() - 4, |
44 | format_tag::oihw, format_tag::goihw)) |
45 | && od.format_kind() == format_kind::wino |
46 | && utils::one_of(od.wino_desc().wino_format, |
47 | mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, |
48 | mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); |
49 | if (!args_ok) return status::invalid_arguments; |
50 | |
51 | auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, |
52 | dst_md); |
53 | if (_pd == nullptr) return status::out_of_memory; |
54 | if (_pd->init() != status::success) { |
55 | delete _pd; |
56 | return status::unimplemented; |
57 | } |
58 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
59 | } |
60 | |
61 | status_t init() { |
62 | status_t status = cpu_reorder_pd_t::init(); |
63 | if (status != status::success) return status; |
64 | |
65 | init_scratchpad(); |
66 | |
67 | return status::success; |
68 | } |
69 | |
70 | private: |
71 | void init_scratchpad() { |
72 | auto &o = memory_desc_wrapper(dst_md()).wino_desc(); |
73 | size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; |
74 | size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; |
75 | |
76 | using namespace memory_tracking::names; |
77 | auto scratchpad = scratchpad_registry().registrar(); |
78 | scratchpad.book(key_reorder_wino_transform_space, |
79 | sizeof(in_data_t) * transform_space_size); |
80 | scratchpad.book(key_reorder_wino_plain, |
81 | sizeof(out_data_t) * plain_size); |
82 | } |
83 | }; |
84 | |
85 | private: |
86 | typedef typename prec_traits<type_i>::type in_data_t; |
87 | typedef typename prec_traits<type_o>::type out_data_t; |
88 | const int unsign_val_in_wino_domain_ = 5; |
89 | |
90 | wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { |
91 | const memory_desc_wrapper src_d(pd()->src_md()); |
92 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
93 | |
94 | r_ = dst_d.wino_desc().r; |
95 | w_alpha_ = dst_d.wino_desc().alpha; |
96 | wino_format_ = dst_d.wino_desc().wino_format; |
97 | |
98 | const auto &in_dims = src_d.dims(); |
99 | int groups; |
100 | int groups_offset; |
101 | if (src_d.ndims() == 5) { |
102 | groups = in_dims[0]; |
103 | groups_offset = 1; |
104 | } else { |
105 | groups = 1; |
106 | groups_offset = 0; |
107 | } |
108 | assert(groups == 1); // groups are not supported now |
109 | MAYBE_UNUSED(groups); |
110 | |
111 | or_oc_ = in_dims[0 + groups_offset]; |
112 | or_ic_ = in_dims[1 + groups_offset]; |
113 | kh_ = in_dims[2 + groups_offset]; |
114 | kw_ = in_dims[3 + groups_offset]; |
115 | |
116 | oc_ = dst_d.wino_desc().oc; |
117 | ic_ = dst_d.wino_desc().ic; |
118 | oc_block_ = dst_d.wino_desc().oc_block; |
119 | ic_block_ = dst_d.wino_desc().ic_block; |
120 | assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); |
121 | nb_oc_ = oc_ / oc_block_; |
122 | nb_ic_ = ic_ / ic_block_; |
123 | ic2_block_ = 1; |
124 | if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) |
125 | ic2_block_ = dst_d.wino_desc().ic2_block; |
126 | oc2_block_ = dst_d.wino_desc().oc2_block; |
127 | assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); |
128 | |
129 | adj_scale_ = dst_d.wino_desc().adj_scale; |
130 | |
131 | size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; |
132 | size_wspace_ = r_ * w_alpha_ * oc_block_; |
133 | } |
134 | |
135 | void transform(out_data_t *__restrict tmp_wei, |
136 | const in_data_t *__restrict input, |
137 | in_data_t *__restrict wspace) const { |
138 | const memory_desc_wrapper src_d(pd()->src_md()); |
139 | |
140 | const int smask = pd()->attr()->output_scales_.mask_; |
141 | const int ndims_mask = math::ilog2q(smask + 1); |
142 | const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); |
143 | const float *__restrict scales = pd()->attr()->output_scales_.scales_; |
144 | assert(D_mask == 1 || D_mask == (size_t)oc_); |
145 | |
146 | /* transform weights to winograd domain */ |
147 | const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, |
148 | { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; |
149 | |
150 | const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, |
151 | { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, |
152 | { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, |
153 | { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, |
154 | { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, |
155 | { 0.f, 0.f, 1.f } }; |
156 | |
157 | float *__restrict g; |
158 | if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, |
159 | mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) |
160 | g = (float *)G_2x2_3x3; |
161 | else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) |
162 | g = (float *)G_4x4_3x3; |
163 | else { |
164 | assert("Unknown winograd weights target layout" ); |
165 | return; |
166 | } |
167 | |
168 | int Z = oc_ * ic_; |
169 | assert(r_ == kh_ && r_ == kw_); |
170 | |
171 | for (int iic = 0; iic < ic_; iic++) { |
172 | for (int ob = 0; ob < nb_oc_; ob++) { |
173 | const in_data_t *__restrict _inp |
174 | = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; |
175 | out_data_t *__restrict _out |
176 | = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; |
177 | |
178 | for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); |
179 | |
180 | for_nd(0, 1, r_, w_alpha_, oc_block_, |
181 | [&](int ih, int j, int ioc) { |
182 | for (int iw = 0; iw < r_; ++iw) { |
183 | int inp_oc = ob * oc_block_ + ioc; |
184 | int inp_ic = iic; |
185 | in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) |
186 | ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] |
187 | : 0.f; |
188 | wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] |
189 | += inp_v * g[j * r_ + iw]; |
190 | } |
191 | }); |
192 | |
193 | for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, |
194 | [&](int i, int j, int ioc) { |
195 | float t = 0; |
196 | for (int k = 0; k < r_; ++k) |
197 | t += g[i * r_ + k] |
198 | * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; |
199 | if (type_o == data_type::s8) { |
200 | const float scale = (D_mask == 1) |
201 | ? scales[0] |
202 | : scales[ob * oc_block_ + ioc]; |
203 | _out[(i * w_alpha_ + j) * Z + ioc] |
204 | = qz_b0<in_data_t, out_data_t>()( |
205 | (in_data_t)t, scale * adj_scale_); |
206 | } else { |
207 | _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; |
208 | } |
209 | }); |
210 | }} |
211 | } |
212 | |
213 | void reorder_to_aaOIoi(out_data_t *__restrict output, |
214 | const out_data_t *__restrict tmp_wei) const { |
215 | int32_t *__restrict dst_bias = nullptr; |
216 | if (type_o == data_type::s8) { |
217 | const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; |
218 | const size_t bias_size = w_alpha_ * w_alpha_ * oc_; |
219 | |
220 | dst_bias = (int32_t *)(output + bias_shift); |
221 | utils::array_set((int32_t *)dst_bias, 0, bias_size); |
222 | } |
223 | int index = 0; |
224 | for (int u_h = 0; u_h < w_alpha_; u_h++) { |
225 | for (int u_w = 0; u_w < w_alpha_; u_w++) { |
226 | for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { |
227 | int u_h_shift = u_h * w_alpha_ * ic_ * oc_; |
228 | int u_w_shift = u_w * ic_ * oc_; |
229 | int u_h_shift_b = u_h * w_alpha_ * oc_; |
230 | int u_w_shift_b = u_w * oc_; |
231 | int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; |
232 | for (int ib = 0; ib < nb_ic_; ib++) { |
233 | for (int i = 0; i < ic_block_; i++) { |
234 | int _i = ib * ic_block_; |
235 | int _o = ob * oc_block_; |
236 | int ic_shift = (_i + i) * oc_; |
237 | int oc_shift = (_o + o); |
238 | int ic_block_shift = ib * oc_block_ * ic_block_ + i; |
239 | int src_offset = |
240 | u_h_shift + u_w_shift + ic_shift + oc_shift; |
241 | int dst_offset = u_h_shift + u_w_shift + oc_block_shift |
242 | + ic_block_shift; |
243 | |
244 | output[dst_offset] = tmp_wei[src_offset]; |
245 | if (type_o == data_type::s8) { |
246 | int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; |
247 | if (index != unsign_val_in_wino_domain_) |
248 | dst_bias[bias_offset] |
249 | -= (128 * (int32_t)output[dst_offset]); |
250 | else |
251 | dst_bias[bias_offset] = 0; |
252 | } |
253 | }} |
254 | }); |
255 | index++; |
256 | }} |
257 | } |
258 | |
259 | void reorder_to_aaOio(out_data_t *__restrict output, |
260 | const out_data_t *__restrict tmp_wei) const { |
261 | for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, |
262 | [&](int u_h, int u_w, int ob) { |
263 | for (int ib = 0; ib < nb_ic_; ib++) { |
264 | for (int i = 0; i < ic_block_; i++) { |
265 | for (int o = 0; o < oc_block_; o++) { |
266 | int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ |
267 | + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); |
268 | |
269 | int dst_offset |
270 | = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ |
271 | + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ |
272 | + ob * nb_ic_ * ic_block_ * oc_block_ |
273 | + ib * ic_block_ * oc_block_ + i * oc_block_ + o; |
274 | output[dst_offset] = tmp_wei[src_offset]; |
275 | }}} |
276 | }); |
277 | } |
278 | |
279 | void reorder_to_aaOBiOo(out_data_t *__restrict output, |
280 | const out_data_t *__restrict tmp_wei) const { |
281 | int oc_chunks = nb_oc_ / oc2_block_; |
282 | |
283 | for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, |
284 | [&](int u_h, int u_w, int occ) { |
285 | for (int ib = 0; ib < nb_ic_; ib++) { |
286 | out_data_t *__restrict wei_ptr = output |
287 | + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) |
288 | * oc2_block_ * ic_block_ * oc_block_; |
289 | int wei_offset = 0; |
290 | for (int i = 0; i < ic_block_; i++) { |
291 | for (int ob2 = 0; ob2 < oc2_block_; ob2++) { |
292 | for (int o = 0; o < oc_block_; o++) { |
293 | int icp = ib * ic_block_ + i; |
294 | int ocp = |
295 | occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; |
296 | |
297 | int src_offset = u_h * w_alpha_ * ic_ * oc_ |
298 | + u_w * ic_ * oc_ + icp * oc_ + ocp; |
299 | wei_ptr[wei_offset + o] = tmp_wei[src_offset]; |
300 | } |
301 | wei_offset += oc_block_; |
302 | }} |
303 | } |
304 | }); |
305 | } |
306 | |
307 | void reorder_to_OBaaIBOIio(out_data_t *__restrict output, |
308 | const out_data_t *__restrict tmp_wei) const { |
309 | int ic_chunks = nb_ic_ / ic2_block_; |
310 | int oc_chunks = nb_oc_ / oc2_block_; |
311 | |
312 | for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, |
313 | [&](int occ, int u_h, int u_w) { |
314 | for (int icc = 0; icc < ic_chunks; icc++) { |
315 | for (int ob = 0; ob < oc2_block_; ob++) { |
316 | int ocp = (occ * oc2_block_ + ob) * oc_block_; |
317 | for (int ib = 0; ib < ic2_block_; ib++) { |
318 | for (int i = 0; i < ic_block_; i++) { |
319 | int icp = (icc * ic2_block_ + ib) * ic_block_ + i; |
320 | |
321 | int src_offset = u_h * w_alpha_ * ic_ * oc_ |
322 | + u_w * ic_ * oc_ + icp * oc_ + ocp; |
323 | int wei_offset |
324 | = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) |
325 | * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ |
326 | + ib) * ic_block_ + i) * oc_block_; |
327 | for (int o = 0; o < oc_block_; o++) |
328 | output[wei_offset + o] = tmp_wei[src_offset + o]; |
329 | }} |
330 | }} |
331 | }); |
332 | } |
333 | |
334 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
335 | auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); |
336 | auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); |
337 | |
338 | auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get<void>( |
339 | memory_tracking::names::key_reorder_wino_transform_space); |
340 | auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get<void>( |
341 | memory_tracking::names::key_reorder_wino_plain); |
342 | |
343 | transform(tmp_wei, input, wspace); |
344 | |
345 | /* reorder to winograd domain */ |
346 | switch (wino_format_) { |
347 | case mkldnn_wino_wei_aaOIoi: |
348 | reorder_to_aaOIoi(output, tmp_wei); break; |
349 | case mkldnn_wino_wei_aaOio: |
350 | reorder_to_aaOio(output, tmp_wei); break; |
351 | case mkldnn_wino_wei_aaOBiOo: |
352 | reorder_to_aaOBiOo(output, tmp_wei); break; |
353 | case mkldnn_wino_wei_OBaaIBOIio: |
354 | reorder_to_OBaaIBOIio(output, tmp_wei); break; |
355 | default: assert("Unknown wino format" ); break; |
356 | } |
357 | |
358 | return status::success; |
359 | } |
360 | |
361 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } |
362 | int r_, w_alpha_; |
363 | int ic_, oc_, or_ic_, or_oc_, kh_, kw_; |
364 | int oc_block_, ic_block_, oc2_block_, ic2_block_; |
365 | float adj_scale_; |
366 | int nb_oc_, nb_ic_; |
367 | mkldnn_wino_memory_format_t wino_format_; |
368 | int size_wino_wei_; |
369 | int size_wspace_; |
370 | }; |
371 | |
372 | } // namespace cpu |
373 | } // namespace impl |
374 | } // namespace mkldnn |
375 | |
376 | #endif |
377 | |