1 | /******************************************************************************* |
2 | * Copyright 2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "c_types_map.hpp" |
20 | #include "memory_desc_wrapper.hpp" |
21 | #include "mkldnn_debug.h" |
22 | #include "nstl.hpp" |
23 | #include "type_helpers.hpp" |
24 | #include "utils.hpp" |
25 | |
26 | #include "jit_uni_reorder.hpp" |
27 | |
28 | using namespace mkldnn::impl::types; |
29 | using namespace mkldnn::impl::status; |
30 | |
31 | namespace mkldnn { |
32 | namespace impl { |
33 | namespace cpu { |
34 | |
35 | namespace tr { |
36 | |
37 | /** ad-hoc structure to describe blocked memory layout */ |
38 | struct layout_desc_t { |
39 | data_type_t dt; |
40 | int ndims; |
41 | dims_t id; |
42 | dims_t dims; |
43 | strides_t strides; |
44 | }; |
45 | |
46 | status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, |
47 | layout_desc_t &ld) { |
48 | const auto md = memory_desc_wrapper(md_); |
49 | |
50 | bool ok = true |
51 | && md.is_blocking_desc() |
52 | && md.extra().flags == 0; |
53 | if (!ok) return invalid_arguments; |
54 | |
55 | const auto &bd = md.blocking_desc(); |
56 | |
57 | ld.ndims = 0; |
58 | ld.dt = md.data_type(); |
59 | |
60 | auto P = [&ld](int id, int dim, ptrdiff_t stride) { |
61 | assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); |
62 | ld.id[ld.ndims] = id; |
63 | ld.dims[ld.ndims] = dim; |
64 | ld.strides[ld.ndims] = stride; |
65 | ++ld.ndims; |
66 | }; |
67 | |
68 | dims_t blocks; |
69 | md.compute_blocks(blocks); |
70 | |
71 | for (int d = 0; d < md.ndims(); ++d) { |
72 | const int ld_ndims_start = ld.ndims; |
73 | if (blocks[d] != 1) { |
74 | stride_t stride = 1; |
75 | for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { |
76 | if (bd.inner_idxs[iblk] == d) |
77 | P(d, bd.inner_blks[iblk], stride); |
78 | stride *= bd.inner_blks[iblk]; |
79 | } |
80 | } |
81 | P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); |
82 | |
83 | // TODO: NOW: revisit, do we need a reverse? |
84 | // TODO: NOW: consider using strides instead of block sizes in md |
85 | // reverse the order of dims |
86 | for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { |
87 | const int idx0 = ld_ndims_start + ld_d; |
88 | const int idx1 = ld.ndims - 1 - ld_d; |
89 | nstl::swap(ld.dims[idx0], ld.dims[idx1]); |
90 | nstl::swap(ld.strides[idx0], ld.strides[idx1]); |
91 | } |
92 | } |
93 | |
94 | return success; |
95 | } |
96 | |
97 | status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, |
98 | const primitive_attr_t *attr) { |
99 | auto im_d = memory_desc_wrapper(imd); |
100 | auto om_d = memory_desc_wrapper(omd); |
101 | |
102 | bool ok = true |
103 | && im_d.is_blocking_desc() |
104 | && om_d.is_blocking_desc() |
105 | && !im_d.has_zero_dim() |
106 | && !om_d.has_zero_dim(); |
107 | if (!ok) |
108 | return unimplemented; |
109 | |
110 | dims_t iblocks, oblocks; |
111 | im_d.compute_blocks(iblocks); |
112 | om_d.compute_blocks(oblocks); |
113 | |
114 | /* padding_dim consistency check */ |
115 | for (int d = 0; d < im_d.ndims(); ++d) { |
116 | const auto pdim = im_d.padded_dims()[d]; |
117 | bool ok = true |
118 | && pdim == om_d.padded_dims()[d] |
119 | && pdim % iblocks[d] == 0 |
120 | && pdim % oblocks[d] == 0; |
121 | if (!ok) return unimplemented; |
122 | } |
123 | |
124 | layout_desc_t ild, old; |
125 | status_t status = cvt_mem_desc_to_layout_desc(imd, ild); |
126 | if (status != success) return status; |
127 | status = cvt_mem_desc_to_layout_desc(omd, old); |
128 | if (status != success) return status; |
129 | |
130 | p.itype = ild.dt; |
131 | p.otype = old.dt; |
132 | |
133 | p.scale_type = attr->output_scales_.has_default_values() |
134 | ? scale_type_t::NONE |
135 | : (attr->output_scales_.mask_ == 0 |
136 | ? scale_type_t::COMMON |
137 | : scale_type_t::MANY); |
138 | |
139 | ptrdiff_t ss[max_ndims] = {0}; |
140 | if (p.scale_type == scale_type_t::MANY) { |
141 | ptrdiff_t last_ss = 1; |
142 | for (int d = old.ndims - 1; d >=0; --d) { |
143 | assert((d == 0 || old.id[d - 1] <= old.id[d]) |
144 | && "logical dimensions should be in ascending order" ); |
145 | if (attr->output_scales_.mask_ & (1 << old.id[d])) { |
146 | ss[d] = last_ss; |
147 | last_ss *= old.dims[d]; |
148 | } |
149 | } |
150 | } |
151 | |
152 | int ndims = 0; |
153 | |
154 | int i_pos = 0; /* state for input -- current dimension */ |
155 | int o_pos = 0; /* state for output -- current dimension */ |
156 | |
157 | while (i_pos < ild.ndims && o_pos < old.ndims) { |
158 | assert(ild.id[i_pos] == old.id[o_pos]); |
159 | if (ild.id[i_pos] != old.id[o_pos]) |
160 | return runtime_error; |
161 | |
162 | assert(ndims < max_ndims); |
163 | if (ndims == max_ndims) |
164 | return runtime_error; |
165 | |
166 | if (ild.dims[i_pos] == old.dims[o_pos]) { |
167 | p.nodes[ndims].n = ild.dims[i_pos]; |
168 | p.nodes[ndims].is = ild.strides[i_pos]; |
169 | p.nodes[ndims].os = old.strides[o_pos]; |
170 | p.nodes[ndims].ss = ss[o_pos]; |
171 | ++ndims; |
172 | ++i_pos; |
173 | ++o_pos; |
174 | } else if (ild.dims[i_pos] < old.dims[o_pos]) { |
175 | assert(old.dims[o_pos] % ild.dims[i_pos] == 0); |
176 | int factor = old.dims[o_pos] / ild.dims[i_pos]; |
177 | p.nodes[ndims].n = ild.dims[i_pos]; |
178 | p.nodes[ndims].is = ild.strides[i_pos]; |
179 | p.nodes[ndims].os = old.strides[o_pos] * factor; |
180 | p.nodes[ndims].ss = ss[o_pos] * factor; |
181 | ++ndims; |
182 | ++i_pos; |
183 | old.dims[o_pos] = factor; |
184 | } else if (ild.dims[i_pos] > old.dims[o_pos]) { |
185 | assert(ild.dims[i_pos] % old.dims[o_pos] == 0); |
186 | int factor = ild.dims[i_pos] / old.dims[o_pos]; |
187 | p.nodes[ndims].n = old.dims[o_pos]; |
188 | p.nodes[ndims].is = ild.strides[i_pos] * factor; |
189 | p.nodes[ndims].os = old.strides[o_pos]; |
190 | p.nodes[ndims].ss = ss[o_pos]; |
191 | ++ndims; |
192 | ++o_pos; |
193 | ild.dims[i_pos] = factor; |
194 | } |
195 | } |
196 | p.ndims = ndims; |
197 | |
198 | dims_t zero_pos = {0}; |
199 | p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); |
200 | p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); |
201 | |
202 | const int sum_idx = attr->post_ops_.find(primitive_kind::sum); |
203 | p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; |
204 | |
205 | return success; |
206 | } |
207 | |
208 | void prb_normalize(prb_t &p) { |
209 | for (int d = 0; d < p.ndims; ++d) { |
210 | int min_pos = d; |
211 | for (int j = d + 1; j < p.ndims; ++j) { |
212 | bool new_min = false |
213 | || p.nodes[j].os < p.nodes[min_pos].os |
214 | || (true |
215 | && p.nodes[j].os == p.nodes[min_pos].os |
216 | && p.nodes[j].n < p.nodes[min_pos].n); |
217 | if (new_min) min_pos = j; |
218 | } |
219 | if (min_pos != d) |
220 | nstl::swap(p.nodes[d], p.nodes[min_pos]); |
221 | } |
222 | } |
223 | |
224 | void prb_simplify(prb_t &p) { |
225 | #if defined(__GNUC__) && __GNUC__ >= 4 |
226 | /* GCC produces bogus array subscript is above array bounds warning for |
227 | * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ |
228 | #pragma GCC diagnostic push |
229 | #pragma GCC diagnostic ignored "-Warray-bounds" |
230 | #endif |
231 | for (int d = 0; d < p.ndims - 1; ++d) { |
232 | auto &this_node = p.nodes[d + 0]; |
233 | auto &next_node = p.nodes[d + 1]; |
234 | const bool fold = false |
235 | || next_node.n == (size_t)1 // trivial case, just drop next node |
236 | || (true // or real folding if possible |
237 | && next_node.is == (ptrdiff_t)this_node.n * this_node.is |
238 | && next_node.os == (ptrdiff_t)this_node.n * this_node.os |
239 | && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); |
240 | if (fold) { |
241 | this_node.n *= next_node.n; |
242 | for (int j = d + 2; j < p.ndims; ++j) |
243 | p.nodes[j - 1] = p.nodes[j]; |
244 | --p.ndims; |
245 | --d; // make another try |
246 | } |
247 | } |
248 | #if defined(__GNUC__) && __GNUC__ >= 4 |
249 | #pragma GCC diagnostic pop |
250 | #endif |
251 | } |
252 | |
253 | void prb_node_split(prb_t &p, int dim, size_t n1) { |
254 | assert(dim < p.ndims); |
255 | assert(p.ndims < max_ndims); |
256 | assert(p.nodes[dim].n % n1 == 0); |
257 | |
258 | p.ndims += 1; |
259 | |
260 | for (int d = p.ndims; d > dim + 1; --d) |
261 | p.nodes[d] = p.nodes[d - 1]; |
262 | |
263 | p.nodes[dim + 1].n = p.nodes[dim].n / n1; |
264 | p.nodes[dim + 1].is = p.nodes[dim].is * n1; |
265 | p.nodes[dim + 1].os = p.nodes[dim].os * n1; |
266 | p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; |
267 | |
268 | p.nodes[dim].n = n1; |
269 | } |
270 | |
271 | void prb_node_swap(prb_t &p, int d0, int d1) { |
272 | assert(d0 < p.ndims); |
273 | assert(d1 < p.ndims); |
274 | assert(p.ndims < max_ndims); |
275 | |
276 | if (d0 == d1) return; |
277 | |
278 | nstl::swap(p.nodes[d0], p.nodes[d1]); |
279 | } |
280 | |
281 | void prb_node_move(prb_t &p, int d0, int d1) { |
282 | assert(d0 < p.ndims); |
283 | assert(d1 < p.ndims); |
284 | assert(p.ndims < max_ndims); |
285 | |
286 | if (d0 == d1) return; |
287 | |
288 | node_t node = p.nodes[d0]; |
289 | |
290 | if (d0 < d1) |
291 | for (int d = d0; d < d1; ++d) |
292 | p.nodes[d] = p.nodes[d + 1]; |
293 | else |
294 | for (int d = d0; d > d1; --d) |
295 | p.nodes[d] = p.nodes[d - 1]; |
296 | |
297 | p.nodes[d1] = node; |
298 | } |
299 | |
300 | void prb_dump(const prb_t &p) { |
301 | printf("@@@ type:%s:%s ndims:%d " , mkldnn_dt2str(p.itype), |
302 | mkldnn_dt2str(p.otype), p.ndims); |
303 | for (int d = 0; d < p.ndims; ++d) |
304 | printf("[%zu:%td:%td:%td]" , |
305 | p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); |
306 | printf(" off:%zu:%zu\n" , p.ioff, p.ooff); |
307 | } |
308 | |
309 | } |
310 | |
311 | } |
312 | } |
313 | } |
314 | |