1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "c_types_map.hpp"
20#include "memory_desc_wrapper.hpp"
21#include "mkldnn_debug.h"
22#include "nstl.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26#include "jit_uni_reorder.hpp"
27
28using namespace mkldnn::impl::types;
29using namespace mkldnn::impl::status;
30
31namespace mkldnn {
32namespace impl {
33namespace cpu {
34
35namespace tr {
36
37/** ad-hoc structure to describe blocked memory layout */
38struct layout_desc_t {
39 data_type_t dt;
40 int ndims;
41 dims_t id;
42 dims_t dims;
43 strides_t strides;
44};
45
46status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
47 layout_desc_t &ld) {
48 const auto md = memory_desc_wrapper(md_);
49
50 bool ok = true
51 && md.is_blocking_desc()
52 && md.extra().flags == 0;
53 if (!ok) return invalid_arguments;
54
55 const auto &bd = md.blocking_desc();
56
57 ld.ndims = 0;
58 ld.dt = md.data_type();
59
60 auto P = [&ld](int id, int dim, ptrdiff_t stride) {
61 assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
62 ld.id[ld.ndims] = id;
63 ld.dims[ld.ndims] = dim;
64 ld.strides[ld.ndims] = stride;
65 ++ld.ndims;
66 };
67
68 dims_t blocks;
69 md.compute_blocks(blocks);
70
71 for (int d = 0; d < md.ndims(); ++d) {
72 const int ld_ndims_start = ld.ndims;
73 if (blocks[d] != 1) {
74 stride_t stride = 1;
75 for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
76 if (bd.inner_idxs[iblk] == d)
77 P(d, bd.inner_blks[iblk], stride);
78 stride *= bd.inner_blks[iblk];
79 }
80 }
81 P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]);
82
83 // TODO: NOW: revisit, do we need a reverse?
84 // TODO: NOW: consider using strides instead of block sizes in md
85 // reverse the order of dims
86 for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
87 const int idx0 = ld_ndims_start + ld_d;
88 const int idx1 = ld.ndims - 1 - ld_d;
89 nstl::swap(ld.dims[idx0], ld.dims[idx1]);
90 nstl::swap(ld.strides[idx0], ld.strides[idx1]);
91 }
92 }
93
94 return success;
95}
96
97status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
98 const primitive_attr_t *attr) {
99 auto im_d = memory_desc_wrapper(imd);
100 auto om_d = memory_desc_wrapper(omd);
101
102 bool ok = true
103 && im_d.is_blocking_desc()
104 && om_d.is_blocking_desc()
105 && !im_d.has_zero_dim()
106 && !om_d.has_zero_dim();
107 if (!ok)
108 return unimplemented;
109
110 dims_t iblocks, oblocks;
111 im_d.compute_blocks(iblocks);
112 om_d.compute_blocks(oblocks);
113
114 /* padding_dim consistency check */
115 for (int d = 0; d < im_d.ndims(); ++d) {
116 const auto pdim = im_d.padded_dims()[d];
117 bool ok = true
118 && pdim == om_d.padded_dims()[d]
119 && pdim % iblocks[d] == 0
120 && pdim % oblocks[d] == 0;
121 if (!ok) return unimplemented;
122 }
123
124 layout_desc_t ild, old;
125 status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
126 if (status != success) return status;
127 status = cvt_mem_desc_to_layout_desc(omd, old);
128 if (status != success) return status;
129
130 p.itype = ild.dt;
131 p.otype = old.dt;
132
133 p.scale_type = attr->output_scales_.has_default_values()
134 ? scale_type_t::NONE
135 : (attr->output_scales_.mask_ == 0
136 ? scale_type_t::COMMON
137 : scale_type_t::MANY);
138
139 ptrdiff_t ss[max_ndims] = {0};
140 if (p.scale_type == scale_type_t::MANY) {
141 ptrdiff_t last_ss = 1;
142 for (int d = old.ndims - 1; d >=0; --d) {
143 assert((d == 0 || old.id[d - 1] <= old.id[d])
144 && "logical dimensions should be in ascending order");
145 if (attr->output_scales_.mask_ & (1 << old.id[d])) {
146 ss[d] = last_ss;
147 last_ss *= old.dims[d];
148 }
149 }
150 }
151
152 int ndims = 0;
153
154 int i_pos = 0; /* state for input -- current dimension */
155 int o_pos = 0; /* state for output -- current dimension */
156
157 while (i_pos < ild.ndims && o_pos < old.ndims) {
158 assert(ild.id[i_pos] == old.id[o_pos]);
159 if (ild.id[i_pos] != old.id[o_pos])
160 return runtime_error;
161
162 assert(ndims < max_ndims);
163 if (ndims == max_ndims)
164 return runtime_error;
165
166 if (ild.dims[i_pos] == old.dims[o_pos]) {
167 p.nodes[ndims].n = ild.dims[i_pos];
168 p.nodes[ndims].is = ild.strides[i_pos];
169 p.nodes[ndims].os = old.strides[o_pos];
170 p.nodes[ndims].ss = ss[o_pos];
171 ++ndims;
172 ++i_pos;
173 ++o_pos;
174 } else if (ild.dims[i_pos] < old.dims[o_pos]) {
175 assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
176 int factor = old.dims[o_pos] / ild.dims[i_pos];
177 p.nodes[ndims].n = ild.dims[i_pos];
178 p.nodes[ndims].is = ild.strides[i_pos];
179 p.nodes[ndims].os = old.strides[o_pos] * factor;
180 p.nodes[ndims].ss = ss[o_pos] * factor;
181 ++ndims;
182 ++i_pos;
183 old.dims[o_pos] = factor;
184 } else if (ild.dims[i_pos] > old.dims[o_pos]) {
185 assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
186 int factor = ild.dims[i_pos] / old.dims[o_pos];
187 p.nodes[ndims].n = old.dims[o_pos];
188 p.nodes[ndims].is = ild.strides[i_pos] * factor;
189 p.nodes[ndims].os = old.strides[o_pos];
190 p.nodes[ndims].ss = ss[o_pos];
191 ++ndims;
192 ++o_pos;
193 ild.dims[i_pos] = factor;
194 }
195 }
196 p.ndims = ndims;
197
198 dims_t zero_pos = {0};
199 p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
200 p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
201
202 const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
203 p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
204
205 return success;
206}
207
208void prb_normalize(prb_t &p) {
209 for (int d = 0; d < p.ndims; ++d) {
210 int min_pos = d;
211 for (int j = d + 1; j < p.ndims; ++j) {
212 bool new_min = false
213 || p.nodes[j].os < p.nodes[min_pos].os
214 || (true
215 && p.nodes[j].os == p.nodes[min_pos].os
216 && p.nodes[j].n < p.nodes[min_pos].n);
217 if (new_min) min_pos = j;
218 }
219 if (min_pos != d)
220 nstl::swap(p.nodes[d], p.nodes[min_pos]);
221 }
222}
223
224void prb_simplify(prb_t &p) {
225#if defined(__GNUC__) && __GNUC__ >= 4
226/* GCC produces bogus array subscript is above array bounds warning for
227 * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
228#pragma GCC diagnostic push
229#pragma GCC diagnostic ignored "-Warray-bounds"
230#endif
231 for (int d = 0; d < p.ndims - 1; ++d) {
232 auto &this_node = p.nodes[d + 0];
233 auto &next_node = p.nodes[d + 1];
234 const bool fold = false
235 || next_node.n == (size_t)1 // trivial case, just drop next node
236 || (true // or real folding if possible
237 && next_node.is == (ptrdiff_t)this_node.n * this_node.is
238 && next_node.os == (ptrdiff_t)this_node.n * this_node.os
239 && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
240 if (fold) {
241 this_node.n *= next_node.n;
242 for (int j = d + 2; j < p.ndims; ++j)
243 p.nodes[j - 1] = p.nodes[j];
244 --p.ndims;
245 --d; // make another try
246 }
247 }
248#if defined(__GNUC__) && __GNUC__ >= 4
249#pragma GCC diagnostic pop
250#endif
251}
252
253void prb_node_split(prb_t &p, int dim, size_t n1) {
254 assert(dim < p.ndims);
255 assert(p.ndims < max_ndims);
256 assert(p.nodes[dim].n % n1 == 0);
257
258 p.ndims += 1;
259
260 for (int d = p.ndims; d > dim + 1; --d)
261 p.nodes[d] = p.nodes[d - 1];
262
263 p.nodes[dim + 1].n = p.nodes[dim].n / n1;
264 p.nodes[dim + 1].is = p.nodes[dim].is * n1;
265 p.nodes[dim + 1].os = p.nodes[dim].os * n1;
266 p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
267
268 p.nodes[dim].n = n1;
269}
270
271void prb_node_swap(prb_t &p, int d0, int d1) {
272 assert(d0 < p.ndims);
273 assert(d1 < p.ndims);
274 assert(p.ndims < max_ndims);
275
276 if (d0 == d1) return;
277
278 nstl::swap(p.nodes[d0], p.nodes[d1]);
279}
280
281void prb_node_move(prb_t &p, int d0, int d1) {
282 assert(d0 < p.ndims);
283 assert(d1 < p.ndims);
284 assert(p.ndims < max_ndims);
285
286 if (d0 == d1) return;
287
288 node_t node = p.nodes[d0];
289
290 if (d0 < d1)
291 for (int d = d0; d < d1; ++d)
292 p.nodes[d] = p.nodes[d + 1];
293 else
294 for (int d = d0; d > d1; --d)
295 p.nodes[d] = p.nodes[d - 1];
296
297 p.nodes[d1] = node;
298}
299
300void prb_dump(const prb_t &p) {
301 printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
302 mkldnn_dt2str(p.otype), p.ndims);
303 for (int d = 0; d < p.ndims; ++d)
304 printf("[%zu:%td:%td:%td]",
305 p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
306 printf(" off:%zu:%zu\n", p.ioff, p.ooff);
307}
308
309}
310
311}
312}
313}
314