1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "mkldnn_traits.hpp"
20#include "mkldnn_thread.hpp"
21#include "type_helpers.hpp"
22#include "utils.hpp"
23
24#include "cpu_memory.hpp"
25
26namespace mkldnn {
27namespace impl {
28namespace cpu {
29
30using namespace mkldnn::impl;
31using namespace mkldnn::impl::data_type;
32using namespace mkldnn::impl::status;
33using namespace mkldnn::impl::format_tag;
34
35enum blk_kind_t { a, b, c, ab, ba, bc, cb };
36
37template <data_type_t dt, blk_kind_t blk_kind, int blksize>
38void typed_zero_pad_blk(
39 const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
40 using data_t = typename prec_traits<dt>::type;
41 const auto &dims = m_d.dims();
42 const auto &pdims = m_d.padded_dims();
43 const auto &blk = m_d.blocking_desc();
44 auto dim_is_blocked = [&](int dim) {
45 for (int i = 0; i < blk.inner_nblks; i++)
46 if (blk.inner_idxs[i] == dim)
47 return true;
48 return false;
49 };
50 bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1),
51 C_blocked = dim_is_blocked(2);
52
53 assert(blk.inner_nblks < 4);
54 assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked)
55 || (C_blocked && B_blocked));
56
57 const int a_tail_s = A_blocked ? dims[0] % blksize : 0;
58 const int b_tail_s = B_blocked ? dims[1] % blksize : 0;
59 const int c_tail_s = C_blocked ? dims[2] % blksize : 0;
60 assert(a_tail_s || b_tail_s || c_tail_s);
61
62 const int A = A_blocked ? pdims[0] / blksize : dims[0];
63 const int B = B_blocked ? pdims[1] / blksize : dims[1];
64 const int C = C_blocked ? pdims[2] / blksize : dims[2];
65 const int D = m_d.ndims() > 3 ? dims[3] : 1;
66 const int E = m_d.ndims() > 4 ? dims[4] : 1;
67 const int F = m_d.ndims() > 5 ? dims[5] : 1;
68 const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1;
69
70 auto zeroize_tail = [&](data_t *d, const int tail_s) {
71 for (int b = tail_s; b < blksize; ++b)
72 d[b] = 0;
73 };
74 auto zeroize_tail_inner = [&](data_t *d, const int tail_s) {
75 for (int b1 = 0; b1 < blksize; ++b1)
76 for (int b2 = tail_s; b2 < blksize; ++b2)
77 d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
78 + b1 % inner_blk]
79 = 0;
80 };
81 auto zeroize_tail_outer = [&](data_t *d, const int tail_s) {
82 for (int b1 = tail_s; b1 < blksize; ++b1)
83 for (int b2 = 0; b2 < blksize; ++b2)
84 d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
85 + b1 % inner_blk]
86 = 0;
87 };
88
89 if (c_tail_s) {
90 parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) {
91 auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)];
92 if (blk_kind == c)
93 zeroize_tail(x, c_tail_s);
94 else if (blk_kind == bc)
95 zeroize_tail_inner(x, c_tail_s);
96 else if (blk_kind == cb)
97 zeroize_tail_outer(x, c_tail_s);
98 });
99 }
100
101 if (b_tail_s) {
102 parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) {
103 auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)];
104 if (blk_kind == b)
105 zeroize_tail(x, b_tail_s);
106 else if (blk_kind == ab || blk_kind == cb)
107 zeroize_tail_inner(x, b_tail_s);
108 else if (blk_kind == ba || blk_kind == bc)
109 zeroize_tail_outer(x, b_tail_s);
110 });
111 }
112
113 if (a_tail_s) {
114 parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) {
115 auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)];
116 if (blk_kind == a)
117 zeroize_tail(x, a_tail_s);
118 else if (blk_kind == ba)
119 zeroize_tail_inner(x, a_tail_s);
120 else if (blk_kind == ab)
121 zeroize_tail_outer(x, a_tail_s);
122 });
123 }
124}
125
126/*
127 * all
128 */
129template <data_type_t dt>
130void typed_zero_pad_generic_blocked(
131 const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
132 const int ndims = m_d.ndims();
133 const auto &dims = m_d.dims();
134 const auto &pdims = m_d.padded_dims();
135
136 const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true);
137
138 /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1]
139 * | \ /
140 * | ---------------------
141 * has contiguous
142 * padding
143 *
144 * step <-- D_k+1 * ... * D_ndims-1
145 * step_dim <-- k
146 */
147
148 ptrdiff_t step = 1;
149 int step_dim = ndims - 1;
150 for (; step_dim >= 0; --step_dim) {
151 if (dims[step_dim] != pdims[step_dim])
152 break;
153 step *= dims[step_dim];
154 }
155
156 assert(step_dim >= 0 && "no zero padding is required");
157 if (step_dim < 0)
158 return;
159
160 parallel_nd(nelems / step, [&](ptrdiff_t e1) {
161 bool need_zero = false;
162
163 ptrdiff_t idx = e1;
164 for (int d = step_dim; d >= 0; --d) {
165 if (idx % pdims[d] >= dims[d]) {
166 need_zero = true;
167 break;
168 }
169 idx /= pdims[d];
170 }
171
172 if (need_zero) {
173 for (ptrdiff_t e0 = 0; e0 < step; ++e0)
174 data[m_d.off_l(e1 * step + e0, true)] = 0;
175 }
176 });
177}
178
179template <data_type_t dt>
180status_t cpu_memory_t::typed_zero_pad() const {
181 const memory_desc_wrapper mdw(md());
182
183 if (mdw.format_kind() != format_kind::blocked)
184 return unimplemented;
185
186 if (mdw.nelems(false) == mdw.nelems(true))
187 return success;
188
189 auto *data = (typename prec_traits<dt>::type *)data_;
190 auto blk = mdw.blocking_desc();
191
192 auto get_blksize = [&](int ind) {
193 int blksize = 1;
194 for (int i = 0; i < blk.inner_nblks; i++) {
195 if (blk.inner_idxs[i] == ind)
196 blksize *= blk.inner_blks[i];
197 }
198 return blksize;
199 };
200 const int blksize = get_blksize(blk.inner_idxs[0]);
201
202# define CASE(blksize_, blk_kind) \
203 do { \
204 if (blksize == blksize_) { \
205 typed_zero_pad_blk<dt, blk_kind, blksize_>(mdw, data); \
206 return success; \
207 } \
208 } while(0)
209
210 switch (blk.inner_nblks) {
211 case 1:
212 if (blk.inner_idxs[0] == 0) {
213 CASE(4, a);
214 CASE(8, a);
215 CASE(16, a);
216 } else if (blk.inner_idxs[0] == 1) {
217 CASE(4, b);
218 CASE(8, b);
219 CASE(16, b);
220 }
221 break;
222 case 2:
223 case 3:
224 if (!IMPLICATION(blk.inner_nblks == 3,
225 blk.inner_idxs[0] == blk.inner_idxs[2]))
226 break;
227
228 if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
229 CASE(4, ab);
230 CASE(8, ab);
231 CASE(16, ab);
232 } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) {
233 CASE(4, ba);
234 CASE(8, ba);
235 CASE(16, ba);
236 }
237 if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
238 CASE(4, bc);
239 CASE(8, bc);
240 CASE(16, bc);
241 } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) {
242 CASE(4, cb);
243 CASE(8, cb);
244 CASE(16, cb);
245 }
246 break;
247 default: break;
248 }
249
250# undef CASE
251
252 // the last line of defence
253 typed_zero_pad_generic_blocked<dt>(mdw, data);
254 return success;
255}
256
257status_t cpu_memory_t::zero_pad() const {
258 memory_desc_wrapper mdw(md());
259 const bool skip_zeroing = false
260 || data_ == nullptr
261 || mdw.is_zero()
262 || !mdw.is_blocking_desc();
263 if (skip_zeroing) return success;
264
265 switch (mdw.data_type()) {
266 case f32: return typed_zero_pad<f32>();
267 case s32: return typed_zero_pad<s32>();
268 case s8: return typed_zero_pad<s8>();
269 case u8: return typed_zero_pad<u8>();
270 default: assert(!"memory is undefined"); return unimplemented;
271 }
272 return unimplemented;
273}
274
275}
276}
277}
278