1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include <initializer_list> |
20 | |
21 | #include "c_types_map.hpp" |
22 | #include "memory_desc_wrapper.hpp" |
23 | #include "type_helpers.hpp" |
24 | #include "utils.hpp" |
25 | |
26 | namespace mkldnn { |
27 | namespace impl { |
28 | |
29 | status_t fill_blocked(memory_desc_t &md, |
30 | std::initializer_list<int> perm, |
31 | std::initializer_list<int> inner_blks, |
32 | std::initializer_list<int> inner_idxs) { |
33 | const bool ok = true |
34 | && perm.size() == (size_t)md.ndims |
35 | && inner_blks.size() == inner_idxs.size(); |
36 | if (!ok) return status::invalid_arguments; |
37 | |
38 | md.offset0 = 0; |
39 | |
40 | blocking_desc_t &blk = md.format_desc.blocking; |
41 | |
42 | dim_t block_size = 1; |
43 | dims_t blocks = {0}; |
44 | utils::array_set(blocks, 1, md.ndims); |
45 | |
46 | blk.inner_nblks = (int)inner_blks.size(); |
47 | |
48 | int iblk = 0; |
49 | for (const auto &b: inner_idxs) |
50 | blk.inner_idxs[iblk++] = b; |
51 | |
52 | iblk = 0; |
53 | for (const auto &b: inner_blks) { |
54 | int dim = blk.inner_idxs[iblk]; |
55 | block_size *= b; |
56 | blocks[dim] *= b; |
57 | blk.inner_blks[iblk++] = b; |
58 | } |
59 | |
60 | utils::array_set(md.padded_offsets, 0, md.ndims); |
61 | for (int d = 0; d < md.ndims; ++d) |
62 | md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); |
63 | |
64 | dim_t stride = block_size; |
65 | // if only we use C++14, the initializer_list would have rbegin()/rend()... |
66 | for (int d = 0; d < md.ndims; ++d) |
67 | stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; |
68 | |
69 | for (const auto &d: perm) { |
70 | if (md.padded_dims[d] == 0) { |
71 | blk.strides[d] = 1; |
72 | continue; |
73 | } |
74 | stride /= md.padded_dims[d] / blocks[d]; |
75 | blk.strides[d] = stride; |
76 | } |
77 | |
78 | assert(stride == block_size); |
79 | |
80 | return status::success; |
81 | } |
82 | |
83 | status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, |
84 | format_tag_t tag) |
85 | { |
86 | using namespace format_tag; |
87 | |
88 | if (memory_desc.ndims == 0) return status::invalid_arguments; |
89 | |
90 | # define C(tag, ... /* perm, inner_blks, inner_idxs */) \ |
91 | case tag: return fill_blocked(memory_desc, __VA_ARGS__) |
92 | |
93 | switch (tag) { |
94 | C(a, {0}, {}, {}); |
95 | C(ab, {0, 1}, {}, {}); |
96 | C(abc, {0, 1, 2}, {}, {}); |
97 | C(abcd, {0, 1, 2, 3}, {}, {}); |
98 | C(abcde, {0, 1, 2, 3, 4}, {}, {}); |
99 | C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); |
100 | C(abdec, {0, 1, 3, 4, 2}, {}, {}); |
101 | C(acb, {0, 2, 1}, {}, {}); |
102 | C(acbde, {0, 2, 1, 3, 4}, {}, {}); |
103 | C(acdb, {0, 2, 3, 1}, {}, {}); |
104 | C(acdeb, {0, 2, 3, 4, 1}, {}, {}); |
105 | C(ba, {1, 0}, {}, {}); |
106 | C(bac, {1, 0, 2}, {}, {}); |
107 | C(bacd, {1, 0, 2, 3}, {}, {}); |
108 | C(bcda, {1, 2, 3, 0}, {}, {}); |
109 | C(cba, {2, 1, 0}, {}, {}); |
110 | C(cdba, {2, 3, 1, 0}, {}, {}); |
111 | C(cdeba, {2, 3, 4, 1, 0}, {}, {}); |
112 | C(decab, {3, 4, 2, 0, 1}, {}, {}); |
113 | |
114 | C(Abc4a, {0, 1, 2}, {4}, {0}); |
115 | C(aBc4b, {0, 1, 2}, {4}, {1}); |
116 | C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); |
117 | C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); |
118 | C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); |
119 | C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); |
120 | C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); |
121 | C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); |
122 | C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); |
123 | C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); |
124 | C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); |
125 | C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); |
126 | C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); |
127 | C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); |
128 | C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); |
129 | C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); |
130 | C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); |
131 | C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); |
132 | C(Acb4a, {0, 2, 1}, {4}, {0}); |
133 | C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); |
134 | C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); |
135 | |
136 | C(Abc16a, {0, 1, 2}, {16}, {0}); |
137 | C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); |
138 | C(aBc16b, {0, 1, 2}, {16}, {1}); |
139 | C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); |
140 | C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); |
141 | C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); |
142 | C(aBc8b, {0, 1, 2}, {8}, {1}); |
143 | C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); |
144 | C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); |
145 | C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); |
146 | C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); |
147 | C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); |
148 | C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); |
149 | C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); |
150 | C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); |
151 | C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); |
152 | C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); |
153 | C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); |
154 | C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); |
155 | C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); |
156 | C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); |
157 | C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); |
158 | C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); |
159 | C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); |
160 | C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); |
161 | C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); |
162 | C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); |
163 | C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); |
164 | C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); |
165 | C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); |
166 | C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); |
167 | C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); |
168 | C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); |
169 | C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); |
170 | C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); |
171 | C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); |
172 | C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); |
173 | C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); |
174 | C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); |
175 | C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); |
176 | C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); |
177 | C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); |
178 | C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); |
179 | C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); |
180 | C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); |
181 | C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); |
182 | C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); |
183 | C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); |
184 | C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); |
185 | C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); |
186 | C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); |
187 | C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); |
188 | C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); |
189 | C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); |
190 | C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); |
191 | C(Acb16a, {0, 2, 1}, {16}, {0}); |
192 | C(Acb8a, {0, 2, 1}, {8}, {0}); |
193 | C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); |
194 | C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); |
195 | C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); |
196 | C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); |
197 | C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); |
198 | C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); |
199 | C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); |
200 | C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); |
201 | default: break; |
202 | } |
203 | |
204 | #undef C |
205 | |
206 | return status::invalid_arguments; |
207 | } |
208 | |
209 | } |
210 | } |
211 | |
212 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
213 | |