1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef MEMORY_DESC_WRAPPER_HPP
18#define MEMORY_DESC_WRAPPER_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "nstl.hpp"
24#include "utils.hpp"
25
26#include "type_helpers.hpp"
27
28namespace mkldnn {
29namespace impl {
30
31/** thin wrapper class over \struct memory_desc_t which allows easy
32 * manipulations with underlying C structure, which is taken by reference */
33struct memory_desc_wrapper: public c_compatible {
34 const memory_desc_t *md_;
35
36 /** constructor which takes a reference to a constant underlying C memory
37 * descriptor \param md */
38 memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
39 memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
40
41 /* implementing attributes */
42 int ndims() const { return md_->ndims; }
43 const dims_t &dims() const { return md_->dims; }
44 data_type_t data_type() const { return md_->data_type; }
45
46 const dims_t &padded_dims() const { return md_->padded_dims; }
47 const dims_t &padded_offsets() const { return md_->padded_offsets; }
48 dim_t offset0() const { return md_->offset0; }
49
50 format_kind_t format_kind() const { return md_->format_kind; }
51
52 bool is_blocking_desc() const
53 { return format_kind() == format_kind::blocked; }
54 bool is_wino_desc() const
55 { return format_kind() == format_kind::wino; }
56 bool is_rnn_packed_desc() const
57 { return format_kind() == format_kind::rnn_packed; }
58
59 const blocking_desc_t &blocking_desc() const {
60 assert(is_blocking_desc());
61 return md_->format_desc.blocking;
62 }
63 const wino_desc_t &wino_desc() const {
64 assert(is_wino_desc());
65 return md_->format_desc.wino_desc;
66 }
67 const rnn_packed_desc_t &rnn_packed_desc() const {
68 assert(is_rnn_packed_desc());
69 return md_->format_desc.rnn_packed_desc;
70 }
71
72 const memory_extra_desc_t &extra() const { return md_->extra; }
73
74 /* some useful function */
75
76 /** returns the number of elements including padding if \param with_padding
77 * is true, and the number of data elements otherwise */
78 dim_t nelems(bool with_padding = false) const {
79 if (is_zero()) return 0;
80 return utils::array_product(
81 with_padding ? padded_dims() : dims(), ndims());
82 }
83
84 /** returns true if memory descriptor is zero */
85 bool is_zero() const { return ndims() == 0; }
86
87 /** returns true if memory descriptor contains zero as one of its dim */
88 bool has_zero_dim() const { return nelems() == 0; }
89
90 /** return the size of data type (a shortcut) */
91 size_t data_type_size() const
92 { return types::data_type_size(data_type()); }
93
94 /** return the size of data type of additional buffer */
95 size_t additional_buffer_data_size() const {
96 if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
97 return sizeof(int32_t);
98 return 0;
99 }
100
101 /** return true if memory format has additional buffer */
102 bool is_additional_buffer() const {
103 return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
104 }
105
106 /** returns the size of additional buffer */
107 size_t additional_buffer_size() const {
108 if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
109 int cmask = extra().compensation_mask;
110 assert(cmask == 1 || cmask == 3);
111 dim_t prod = 1;
112 for (int d = 0; d < ndims(); ++d)
113 if (cmask & (1<<d)) prod *= padded_dims()[d];
114 return prod * additional_buffer_data_size();
115 }
116
117 return 0;
118 }
119
120 /** returns the size required to store described memory
121 * note: if offset0 != 0 returns 0 (need to specify the behavior) */
122 size_t size() const {
123 if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
124 return 0;
125
126 if (format_kind() == format_kind::wino) {
127 return wino_desc().size;
128 } else if (format_kind() == format_kind::rnn_packed) {
129 return rnn_packed_desc().size;
130 } else {
131 if (offset0() != 0) return 0;
132
133 dims_t blocks = {0};
134 compute_blocks(blocks);
135
136 const auto &bd = blocking_desc();
137
138 size_t max_size = 0;
139 for (int d = 0; d < ndims(); ++d)
140 max_size = nstl::max<size_t>(max_size,
141 padded_dims()[d] / blocks[d] * bd.strides[d]);
142
143 if (max_size == 1 && bd.inner_nblks != 0) {
144 max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
145 }
146
147 return max_size * data_type_size() + additional_buffer_size();
148 }
149 }
150
151 /** returns true if data is dense in memory */
152 bool is_dense(bool with_padding = false) const {
153 if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
154 return false;
155 return nelems(with_padding) * data_type_size() == size();
156 }
157
158 /** returns true if memory desc is fully defined */
159 bool is_defined() const { return format_kind() != format_kind::any; }
160
161 /** returns true if the only (potentially) padded dim is \param dim */
162 bool only_padded_dim(int dim) const {
163 for (int d = 0; d < ndims(); ++d)
164 if (d != dim && dims()[d] != padded_dims()[d])
165 return false;
166 return true;
167 }
168
169 /** returns true if memory desc has blocked layout and block dims are 1s */
170 bool is_plain() const {
171 if (!is_blocking_desc()) return false;
172 return blocking_desc().inner_nblks == 0;
173 }
174
175 /** returns overall block sizes */
176 void compute_blocks(dims_t blocks) const {
177 if (!is_blocking_desc()) {
178 utils::array_set(blocks, 0, ndims());
179 return;
180 }
181
182 utils::array_set(blocks, 1, ndims());
183
184 const auto &bd = blocking_desc();
185 for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
186 blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
187 }
188
189 /* comparison section */
190
191 bool operator==(const memory_desc_wrapper &rhs) const
192 { return *this->md_ == *rhs.md_; }
193 bool operator!=(const memory_desc_wrapper &rhs) const
194 { return !operator==(rhs); }
195 bool operator==(const memory_desc_t &rhs) const
196 { return operator==(memory_desc_wrapper(rhs)); }
197 bool operator!=(const memory_desc_t &rhs) const
198 { return !operator==(rhs); }
199
200 /** returns true if data (w/o padding if with_padding == false and w/
201 * padding otherwise) have the same physical structure, i.e. dimensions,
202 * strides, and blocked structure. Depending on with_data_type flag
203 * data_type is taken or not taken into account. dim_start allows to check
204 * similarity for the logical part of data [dim_start .. ndims()].
205 * CAUTION: format kind any and undef are not similar to whatever, hence the
206 * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
207 /* TODO: revise */
208 bool similar_to(const memory_desc_wrapper &rhs,
209 bool with_padding = true, bool with_data_type = true,
210 int dim_start = 0) const;
211
212 /** returns true if one memory can be reordered to another */
213 bool consistent_with(const memory_desc_wrapper &rhs) const;
214
215 /** returns true if the memory desc corresponds to the given format tag and
216 * strides.
217 * @sa memory_desc_matches_tag */
218 bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
219 return memory_desc_matches_tag(*md_, tag, strides);
220 }
221
222 /** returns matching tag (or undef if match is not found)
223 * XXX: This is a workaround that eventually should go away! */
224 template <typename... Tags>
225 format_tag_t matches_one_of_tag(Tags ...tags) const {
226 for (const auto tag: {tags...}) {
227 if (memory_desc_matches_tag(*md_, tag))
228 return tag;
229 }
230 return format_tag::undef;
231 }
232
233 /* offset section */
234
235 /** returns physical offset by logical one. logical offset is represented by
236 * an array \param pos. if \param is_pos_padded is true \param pos
237 * represents the position in already padded area */
238 dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
239 assert(is_blocking_desc());
240 const blocking_desc_t &blk = blocking_desc();
241
242 dims_t pos_copy = {0};
243 for (int d = 0; d < ndims(); ++d)
244 pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
245
246 dim_t phys_offset = offset0();
247
248 if (blk.inner_nblks > 0) {
249 dim_t blk_stride = 1;
250 for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
251 const int d = blk.inner_idxs[iblk];
252 const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
253
254 phys_offset += p * blk_stride;
255
256 pos_copy[d] /= blk.inner_blks[iblk];
257
258 blk_stride *= blk.inner_blks[iblk];
259 }
260 }
261
262 for (int d = 0; d < ndims(); ++d) {
263 const dim_t p = pos_copy[d];
264 phys_offset += p * blk.strides[d];
265 }
266
267 return phys_offset;
268 }
269
270 /** returns physical offset by logical one. logical offset is represented by
271 * a scalar \param l_offset. if \param is_pos_padded is true, \param
272 * l_offset represents logical offset in already padded area */
273 dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
274 assert(is_blocking_desc());
275 dims_t pos;
276 for (int rd = 0; rd < ndims(); ++rd) {
277 const int d = ndims() - 1 - rd;
278 const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
279 pos[d] = l_offset % cur_dim;
280 l_offset /= cur_dim;
281 }
282 return off_v(pos, is_pos_padded);
283 }
284
285 /** returns physical offset by logical one. logical offset is represented by
286 * a tuple of indices (\param xn, ..., \param x1, \param x0) */
287 template<typename... Args>
288 dim_t off(Args... args) const {
289 assert(sizeof...(args) == ndims());
290 dims_t pos = { args... };
291 return off_v(pos, false);
292 }
293
294 /** returns physical offset by logical one. logical offset is represented by
295 * a tuple of indices (\param xn, ..., \param x1, \param x0) in already
296 * padded area */
297 template<typename... Args>
298 dim_t off_padding(Args... args) const {
299 assert(sizeof...(args) == ndims());
300 dims_t pos = { args... };
301 return off_v(pos, true);
302 }
303
304 /** returns physical offset by logical one. Logical offset is represented by
305 * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
306 * user responsibility to adjust the result to get offset within blocks */
307 template<typename ...Args>
308 dim_t blk_off(Args... args) const {
309 return _blk_off<sizeof...(args), Args...>(args...);
310 }
311
312 template<bool skip_first, typename T, typename ...Args>
313 dim_t blk_off(T xn, Args... args) const {
314 return skip_first
315 ? blk_off<Args...>(args...)
316 : blk_off<T, Args...>(xn, args...);
317 }
318
319 /* static functions section */
320 /* TODO: replace with non-static, once md_ becomes non-const ref */
321
322 static status_t compute_blocking(memory_desc_t &memory_desc,
323 format_tag_t tag);
324
325private:
326 /* TODO: put logical_offset in utils */
327 template<typename T>
328 dim_t logical_offset(T x0) const { return x0; }
329
330 template<typename T, typename... Args>
331 dim_t logical_offset(T xn, Args... args) const {
332 const size_t n_args = sizeof...(args);
333 return xn * utils::array_product<n_args>(
334 &dims()[ndims() - n_args]) + logical_offset(args...);
335 }
336
337 template<int ORIG_LEN, typename ...Void>
338 dim_t _blk_off() const { return offset0(); }
339
340 template<int ORIG_LEN, typename T, typename ...Args>
341 dim_t _blk_off(T xc, Args ...args) const {
342 assert(is_blocking_desc());
343 constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
344 return xc * blocking_desc().strides[dc]
345 + _blk_off<ORIG_LEN, Args...>(args...);
346 }
347};
348
349inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
350 bool with_padding, bool with_data_type, int dim_start) const {
351 using namespace utils;
352
353 if (one_of(format_kind(), format_kind::undef, format_kind::any))
354 return false;
355 if (is_wino_desc() || is_rnn_packed_desc())
356 return false;
357
358 const int ds = dim_start;
359 const auto &blk = blocking_desc();
360 const auto &r_blk = rhs.blocking_desc();
361
362 return ndims() == rhs.ndims()
363 && dim_start <= ndims() /* guard */
364 && format_kind() == rhs.format_kind()
365 && IMPLICATION(with_data_type, data_type() == rhs.data_type())
366 && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
367 && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
368 && blk.inner_nblks == r_blk.inner_nblks
369 && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
370 && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
371 && IMPLICATION(with_padding, true
372 && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
373 ndims() - ds)
374 && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
375 ndims() - ds));
376}
377
378inline bool memory_desc_wrapper::consistent_with(
379 const memory_desc_wrapper &rhs) const {
380 if (ndims() == rhs.ndims()) {
381 for (int d = 0; d < ndims(); ++d) {
382 if (dims()[d] != rhs.dims()[d]) return false;
383 }
384 return true;
385 } else {
386 /* TODO: revise.
387 * is the following possible?
388 * [1, a, b] <--reorder--> [a, b]
389 * [a, 1, b] <--reorder--> [a, b]
390 * not, at least for now */
391 return false;
392 }
393}
394
395}
396}
397
398#endif
399
400// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
401