1 | /******************************************************************************* |
2 | * Copyright 2016-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <stddef.h> |
19 | #include <stdint.h> |
20 | |
21 | #include "mkldnn.h" |
22 | |
23 | #include "c_types_map.hpp" |
24 | #include "engine.hpp" |
25 | #include "type_helpers.hpp" |
26 | #include "utils.hpp" |
27 | |
28 | using namespace mkldnn::impl; |
29 | using namespace mkldnn::impl::utils; |
30 | using namespace mkldnn::impl::status; |
31 | using namespace mkldnn::impl::data_type; |
32 | |
33 | namespace { |
34 | bool memory_desc_sanity_check(int ndims,const dims_t dims, |
35 | data_type_t data_type, format_kind_t format_kind) { |
36 | if (ndims == 0) return true; |
37 | |
38 | bool ok = true |
39 | && dims != nullptr |
40 | && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS |
41 | && one_of(data_type, f32, s32, s8, u8) |
42 | && format_kind != format_kind::undef; |
43 | if (!ok) return false; |
44 | for (int d = 0; d < ndims; ++d) |
45 | if (dims[d] < 0) return false; |
46 | |
47 | return true; |
48 | } |
49 | |
50 | bool memory_desc_sanity_check(const memory_desc_t *md) { |
51 | if (md == nullptr) return false; |
52 | return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, |
53 | format_kind::any); |
54 | } |
55 | } |
56 | |
57 | status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, |
58 | const dims_t dims, data_type_t data_type, format_tag_t tag) { |
59 | if (any_null(memory_desc)) return invalid_arguments; |
60 | if (ndims == 0 || tag == format_tag::undef) { |
61 | *memory_desc = types::zero_md(); |
62 | return success; |
63 | } |
64 | |
65 | format_kind_t format_kind = types::format_tag_to_kind(tag); |
66 | |
67 | /* memory_desc != 0 */ |
68 | bool args_ok = !any_null(memory_desc) |
69 | && memory_desc_sanity_check(ndims, dims, data_type, format_kind); |
70 | if (!args_ok) return invalid_arguments; |
71 | |
72 | auto md = memory_desc_t(); |
73 | md.ndims = ndims; |
74 | array_copy(md.dims, dims, ndims); |
75 | md.data_type = data_type; |
76 | array_copy(md.padded_dims, dims, ndims); |
77 | md.format_kind = format_kind; |
78 | |
79 | status_t status = success; |
80 | if (tag == format_tag::undef) { |
81 | status = invalid_arguments; |
82 | } else if (tag == format_tag::any) { |
83 | // nop |
84 | } else if (format_kind == format_kind::blocked) { |
85 | status = memory_desc_wrapper::compute_blocking(md, tag); |
86 | } else { |
87 | assert(!"unreachable" ); |
88 | status = invalid_arguments; |
89 | } |
90 | |
91 | if (status == success) |
92 | *memory_desc = md; |
93 | |
94 | return status; |
95 | } |
96 | |
97 | status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, |
98 | int ndims, const dims_t dims, data_type_t data_type, |
99 | const dims_t strides) { |
100 | if (any_null(memory_desc)) return invalid_arguments; |
101 | if (ndims == 0) { |
102 | *memory_desc = types::zero_md(); |
103 | return success; |
104 | } |
105 | |
106 | /* memory_desc != 0 */ |
107 | bool args_ok = !any_null(memory_desc) |
108 | && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); |
109 | if (!args_ok) return invalid_arguments; |
110 | |
111 | auto md = memory_desc_t(); |
112 | md.ndims = ndims; |
113 | array_copy(md.dims, dims, ndims); |
114 | md.data_type = data_type; |
115 | array_copy(md.padded_dims, dims, ndims); |
116 | md.format_kind = format_kind::blocked; |
117 | |
118 | dims_t default_strides = {0}; |
119 | if (strides == nullptr) { |
120 | default_strides[md.ndims - 1] = 1; |
121 | for (int d = md.ndims - 2; d >= 0; --d) |
122 | default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; |
123 | strides = default_strides; |
124 | } else { |
125 | /* TODO: add sanity check for the provided strides */ |
126 | } |
127 | |
128 | array_copy(md.format_desc.blocking.strides, strides, md.ndims); |
129 | |
130 | *memory_desc = md; |
131 | |
132 | return status::success; |
133 | } |
134 | |
135 | status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, |
136 | const memory_desc_t *parent_md, const dims_t dims, |
137 | const dims_t offsets) { |
138 | if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) |
139 | return invalid_arguments; |
140 | |
141 | const memory_desc_wrapper src_d(parent_md); |
142 | |
143 | for (int d = 0; d < src_d.ndims(); ++d) { |
144 | if (dims[d] < 0 || offsets[d] < 0 |
145 | || (offsets[d] + dims[d] > src_d.dims()[d])) |
146 | return invalid_arguments; |
147 | } |
148 | |
149 | if (src_d.format_kind() != format_kind::blocked) |
150 | return unimplemented; |
151 | |
152 | dims_t blocks; |
153 | src_d.compute_blocks(blocks); |
154 | |
155 | memory_desc_t dst_d = *parent_md; |
156 | auto &dst_d_blk = dst_d.format_desc.blocking; |
157 | |
158 | /* TODO: put this into memory_desc_wrapper */ |
159 | for (int d = 0; d < src_d.ndims(); ++d) { |
160 | /* very limited functionality for now */ |
161 | const bool ok = true |
162 | && offsets[d] % blocks[d] == 0 /* [r1] */ |
163 | && src_d.padded_offsets()[d] == 0 |
164 | && (false |
165 | || dims[d] % blocks[d] == 0 |
166 | || dims[d] < blocks[d]); |
167 | if (!ok) |
168 | return unimplemented; |
169 | |
170 | const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; |
171 | |
172 | dst_d.dims[d] = dims[d]; |
173 | dst_d.padded_dims[d] = is_right_border |
174 | ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; |
175 | dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; |
176 | dst_d.offset0 += /* [r1] */ |
177 | offsets[d] / blocks[d] * dst_d_blk.strides[d]; |
178 | } |
179 | |
180 | *md = dst_d; |
181 | |
182 | return success; |
183 | } |
184 | |
185 | int mkldnn_memory_desc_equal(const memory_desc_t *lhs, |
186 | const memory_desc_t *rhs) { |
187 | if (lhs == rhs) return 1; |
188 | if (any_null(lhs, rhs)) return 0; |
189 | return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); |
190 | } |
191 | |
192 | size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { |
193 | if (md == nullptr) return 0; |
194 | return memory_desc_wrapper(*md).size(); |
195 | } |
196 | |
197 | status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, |
198 | engine_t *engine, void *handle) { |
199 | if (any_null(memory, engine)) return invalid_arguments; |
200 | memory_desc_t z_md = types::zero_md(); |
201 | return engine->memory_create(memory, md ? md : &z_md, handle); |
202 | } |
203 | |
204 | status_t mkldnn_memory_get_memory_desc(const memory_t *memory, |
205 | const memory_desc_t **md) { |
206 | if (any_null(memory, md)) return invalid_arguments; |
207 | *md = memory->md(); |
208 | return success; |
209 | } |
210 | |
211 | status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { |
212 | if (any_null(memory, engine)) return invalid_arguments; |
213 | *engine = memory->engine(); |
214 | return success; |
215 | } |
216 | |
217 | status_t mkldnn_memory_get_data_handle(const memory_t *memory, |
218 | void **handle) { |
219 | if (any_null(handle)) |
220 | return invalid_arguments; |
221 | if (memory == nullptr) { |
222 | *handle = nullptr; |
223 | return success; |
224 | } |
225 | return memory->get_data_handle(handle); |
226 | } |
227 | |
228 | status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { |
229 | if (any_null(memory)) return invalid_arguments; |
230 | return memory->set_data_handle(handle); |
231 | } |
232 | |
233 | status_t mkldnn_memory_destroy(memory_t *memory) { |
234 | delete memory; |
235 | return success; |
236 | } |
237 | |
238 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
239 | |