1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "mkldnn.h"
18
19#include "c_types_map.hpp"
20#include "primitive_attr.hpp"
21#include "type_helpers.hpp"
22#include "utils.hpp"
23
24using namespace mkldnn::impl;
25using namespace mkldnn::impl::status;
26using namespace mkldnn::impl::utils;
27
28namespace mkldnn {
29namespace impl {
30
31status_t scales_t::set(dim_t count, int mask, const float *scales) {
32 cleanup();
33
34 count_ = count;
35 mask_ = mask;
36
37 if (count_ == 1) {
38 scales_ = scales_buf_;
39 utils::array_set(scales_, scales[0], scales_buf_size);
40 } else {
41 scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
42 if (scales_ == nullptr)
43 return status::out_of_memory;
44
45 for (dim_t c = 0; c < count_; ++c)
46 scales_[c] = scales[c];
47 }
48
49 return status::success;
50}
51
52}
53}
54
55status_t post_ops_t::append_sum(float scale) {
56 if (len_ == capacity)
57 return out_of_memory;
58
59 entry_[len_].kind = primitive_kind::sum;
60 entry_[len_].sum.scale = scale;
61
62 len_++;
63
64 return success;
65}
66
67status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
68 float beta) {
69 using namespace mkldnn::impl::alg_kind;
70 bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
71 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
72 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
73 if (!known_alg)
74 return invalid_arguments;
75
76 if (len_ == capacity)
77 return out_of_memory;
78
79 entry_[len_].kind = primitive_kind::eltwise;
80 entry_[len_].eltwise.scale = scale;
81 entry_[len_].eltwise.alg = alg;
82 entry_[len_].eltwise.alpha = alpha;
83 entry_[len_].eltwise.beta = beta;
84
85 len_++;
86
87 return success;
88}
89
90status_t primitive_attr_t::set_scratchpad_mode(
91 scratchpad_mode_t scratchpad_mode) {
92 using namespace mkldnn::impl::scratchpad_mode;
93
94 const bool ok = one_of(scratchpad_mode, library, user);
95 if (!ok)
96 return invalid_arguments;
97
98 scratchpad_mode_ = scratchpad_mode;
99 return success;
100}
101
102status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
103 this->post_ops_ = post_ops;
104 return success;
105}
106
107/* Public C API */
108
109status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
110 if (attr == nullptr)
111 return invalid_arguments;
112
113 return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
114 new mkldnn_primitive_attr);
115}
116
117status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
118 const primitive_attr_t *existing_attr) {
119 if (any_null(attr, existing_attr))
120 return invalid_arguments;
121
122 return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
123 existing_attr->clone());
124}
125
126status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
127 if (attr)
128 delete attr;
129
130 return success;
131}
132
133status_t mkldnn_primitive_attr_get_scratchpad_mode(
134 const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
135 if (any_null(attr, scratchpad_mode))
136 return invalid_arguments;
137
138 *scratchpad_mode = attr->scratchpad_mode_;
139
140 return success;
141}
142
143status_t mkldnn_primitive_attr_set_scratchpad_mode(
144 primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
145 if (any_null(attr))
146 return invalid_arguments;
147
148 return attr->set_scratchpad_mode(scratchpad_mode);
149}
150
151status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
152 dim_t *count, int *mask, const float **scales) {
153 if (any_null(attr, count, mask, scales))
154 return invalid_arguments;
155
156 *count = attr->output_scales_.count_;
157 *mask = attr->output_scales_.mask_;
158 *scales = attr->output_scales_.scales_;
159
160 return success;
161}
162
163status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
164 dim_t count, int mask, const float *scales) {
165 bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
166 if (!ok)
167 return invalid_arguments;
168
169 return attr->output_scales_.set(count, mask, scales);
170}
171
172status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
173 const post_ops_t **post_ops) {
174 if (any_null(attr, post_ops))
175 return invalid_arguments;
176
177 *post_ops = &attr->post_ops_;
178 return success;
179}
180
181status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
182 const post_ops_t *post_ops) {
183 if (any_null(attr, post_ops))
184 return invalid_arguments;
185
186 return attr->set_post_ops(*post_ops);
187}
188
189status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
190 if (post_ops == nullptr)
191 return invalid_arguments;
192
193 return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
194}
195
196status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
197 if (post_ops)
198 delete post_ops;
199
200 return success;
201}
202
203int mkldnn_post_ops_len(const post_ops_t *post_ops) {
204 if (post_ops)
205 return post_ops->len_;
206
207 return 0;
208}
209
210primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
211 int index) {
212 bool ok = post_ops && 0 <= index && index < post_ops->len_;
213 if (!ok)
214 return primitive_kind::undefined;
215
216 return post_ops->entry_[index].kind;
217}
218
219status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
220 if (post_ops == nullptr)
221 return invalid_arguments;
222
223 return post_ops->append_sum(scale);
224}
225
226namespace {
227bool simple_get_params_check(const post_ops_t *post_ops, int index,
228 primitive_kind_t kind) {
229 bool ok = true
230 && post_ops != nullptr
231 && 0 <= index
232 && index < post_ops->len_
233 && post_ops->entry_[index].kind == kind;
234 return ok;
235}
236}
237
238status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
239 float *scale) {
240 bool ok = true
241 && simple_get_params_check(post_ops, index, primitive_kind::sum)
242 && !any_null(scale);
243 if (!ok)
244 return invalid_arguments;
245
246 *scale = post_ops->entry_[index].sum.scale;
247 return success;
248}
249
250status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
251 alg_kind_t kind, float alpha, float beta) {
252 if (post_ops == nullptr)
253 return invalid_arguments;
254
255 return post_ops->append_eltwise(scale, kind, alpha, beta);
256}
257
258status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
259 int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
260 bool ok = true
261 && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
262 && !any_null(scale, alpha, beta);
263 if (!ok)
264 return invalid_arguments;
265
266 const auto &e = post_ops->entry_[index].eltwise;
267 *scale = e.scale;
268 *alg = e.alg;
269 *alpha = e.alpha;
270 *beta = e.beta;
271
272 return success;
273}
274
275status_t mkldnn_primitive_attr_set_rnn_data_qparams(
276 primitive_attr_t *attr, const float scale, const float shift) {
277 if (attr == nullptr)
278 return invalid_arguments;
279
280 return attr->rnn_data_qparams_.set(scale, shift);
281}
282
283status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
284 primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
285 bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
286 if (!ok)
287 return invalid_arguments;
288
289 return attr->rnn_weights_qparams_.set(count, mask, scales);
290}
291