1/*******************************************************************************
2* Copyright 2016-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include <initializer_list>
20
21#include "c_types_map.hpp"
22#include "memory_desc_wrapper.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26namespace mkldnn {
27namespace impl {
28
29status_t fill_blocked(memory_desc_t &md,
30 std::initializer_list<int> perm,
31 std::initializer_list<int> inner_blks,
32 std::initializer_list<int> inner_idxs) {
33 const bool ok = true
34 && perm.size() == (size_t)md.ndims
35 && inner_blks.size() == inner_idxs.size();
36 if (!ok) return status::invalid_arguments;
37
38 md.offset0 = 0;
39
40 blocking_desc_t &blk = md.format_desc.blocking;
41
42 dim_t block_size = 1;
43 dims_t blocks = {0};
44 utils::array_set(blocks, 1, md.ndims);
45
46 blk.inner_nblks = (int)inner_blks.size();
47
48 int iblk = 0;
49 for (const auto &b: inner_idxs)
50 blk.inner_idxs[iblk++] = b;
51
52 iblk = 0;
53 for (const auto &b: inner_blks) {
54 int dim = blk.inner_idxs[iblk];
55 block_size *= b;
56 blocks[dim] *= b;
57 blk.inner_blks[iblk++] = b;
58 }
59
60 utils::array_set(md.padded_offsets, 0, md.ndims);
61 for (int d = 0; d < md.ndims; ++d)
62 md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
63
64 dim_t stride = block_size;
65 // if only we use C++14, the initializer_list would have rbegin()/rend()...
66 for (int d = 0; d < md.ndims; ++d)
67 stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
68
69 for (const auto &d: perm) {
70 if (md.padded_dims[d] == 0) {
71 blk.strides[d] = 1;
72 continue;
73 }
74 stride /= md.padded_dims[d] / blocks[d];
75 blk.strides[d] = stride;
76 }
77
78 assert(stride == block_size);
79
80 return status::success;
81}
82
83status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
84 format_tag_t tag)
85{
86 using namespace format_tag;
87
88 if (memory_desc.ndims == 0) return status::invalid_arguments;
89
90# define C(tag, ... /* perm, inner_blks, inner_idxs */) \
91 case tag: return fill_blocked(memory_desc, __VA_ARGS__)
92
93 switch (tag) {
94 C(a, {0}, {}, {});
95 C(ab, {0, 1}, {}, {});
96 C(abc, {0, 1, 2}, {}, {});
97 C(abcd, {0, 1, 2, 3}, {}, {});
98 C(abcde, {0, 1, 2, 3, 4}, {}, {});
99 C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
100 C(abdec, {0, 1, 3, 4, 2}, {}, {});
101 C(acb, {0, 2, 1}, {}, {});
102 C(acbde, {0, 2, 1, 3, 4}, {}, {});
103 C(acdb, {0, 2, 3, 1}, {}, {});
104 C(acdeb, {0, 2, 3, 4, 1}, {}, {});
105 C(ba, {1, 0}, {}, {});
106 C(bac, {1, 0, 2}, {}, {});
107 C(bacd, {1, 0, 2, 3}, {}, {});
108 C(bcda, {1, 2, 3, 0}, {}, {});
109 C(cba, {2, 1, 0}, {}, {});
110 C(cdba, {2, 3, 1, 0}, {}, {});
111 C(cdeba, {2, 3, 4, 1, 0}, {}, {});
112 C(decab, {3, 4, 2, 0, 1}, {}, {});
113
114 C(Abc4a, {0, 1, 2}, {4}, {0});
115 C(aBc4b, {0, 1, 2}, {4}, {1});
116 C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
117 C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
118 C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
119 C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
120 C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
121 C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
122 C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
123 C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
124 C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
125 C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
126 C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
127 C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
128 C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
129 C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
130 C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
131 C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
132 C(Acb4a, {0, 2, 1}, {4}, {0});
133 C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
134 C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
135
136 C(Abc16a, {0, 1, 2}, {16}, {0});
137 C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
138 C(aBc16b, {0, 1, 2}, {16}, {1});
139 C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
140 C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
141 C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
142 C(aBc8b, {0, 1, 2}, {8}, {1});
143 C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
144 C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
145 C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
146 C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
147 C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
148 C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
149 C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
150 C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
151 C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
152 C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
153 C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
154 C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
155 C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
156 C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
157 C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
158 C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
159 C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
160 C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
161 C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
162 C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
163 C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
164 C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
165 C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
166 C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
167 C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
168 C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
169 C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
170 C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
171 C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
172 C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
173 C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
174 C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
175 C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
176 C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
177 C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
178 C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
179 C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
180 C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
181 C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
182 C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
183 C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
184 C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
185 C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
186 C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
187 C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
188 C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
189 C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
190 C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
191 C(Acb16a, {0, 2, 1}, {16}, {0});
192 C(Acb8a, {0, 2, 1}, {8}, {0});
193 C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
194 C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
195 C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
196 C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
197 C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
198 C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
199 C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
200 C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
201 default: break;
202 }
203
204#undef C
205
206 return status::invalid_arguments;
207}
208
209}
210}
211
212// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
213