1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef PRIMITIVE_ATTR_HPP
18#define PRIMITIVE_ATTR_HPP
19
20#include "mkldnn.h"
21
22#include "c_types_map.hpp"
23#include "nstl.hpp"
24#include "utils.hpp"
25
26namespace mkldnn {
27namespace impl {
28
29struct rnn_data_qparams_t : public c_compatible {
30 rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
31 bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
32
33 status_t set(float scale, float shift) {
34 scale_ = scale;
35 shift_ = shift;
36 return status::success;
37 }
38
39 float scale_;
40 float shift_;
41};
42
43struct scales_t: public c_compatible {
44 scales_t(): count_(1), mask_(0), scales_(scales_buf_)
45 { set(1.); }
46
47 scales_t(const scales_t &rhs): scales_t()
48 { set(rhs.count_, rhs.mask_, rhs.scales_); }
49
50 ~scales_t() { cleanup(); }
51
52 scales_t &operator=(const scales_t &rhs) {
53 if (&rhs == this)
54 return *this;
55 status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
56 assert(status == status::success);
57 (void)status;
58 return *this;
59 }
60
61 bool has_default_values() const {
62 for (dim_t c = 0; c < count_; ++c) {
63 if(scales_[c] != 1.) return false;
64 }
65 return true;
66 }
67
68 status_t set(dim_t count, int mask, const float *scales);
69 status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
70
71 dim_t count_;
72 int mask_;
73 float *scales_;
74
75private:
76 enum { scales_buf_size = 16 };
77 float scales_buf_[scales_buf_size];
78
79 void cleanup() {
80 if (scales_ != scales_buf_ && scales_ != nullptr)
81 impl::free(scales_);
82
83 count_ = 1;
84 mask_ = 0;
85 scales_ = scales_buf_;
86 }
87};
88
89}
90}
91
92struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
93 struct entry_t {
94 struct eltwise_t {
95 mkldnn::impl::alg_kind_t alg;
96 float scale, alpha, beta;
97 };
98
99 mkldnn::impl::primitive_kind_t kind;
100 union {
101 struct { float scale; } sum;
102 eltwise_t eltwise;
103 };
104
105 bool is_eltwise(bool require_scale_one = true) const {
106 using namespace mkldnn::impl;
107 return kind == primitive_kind::eltwise
108 && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
109 }
110
111 bool is_relu(bool require_scale_one = true,
112 bool require_nslope_zero = true) const {
113 using namespace mkldnn::impl;
114 return is_eltwise(require_scale_one)
115 && eltwise.alg == alg_kind::eltwise_relu
116 && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
117 }
118
119 bool is_sum(bool require_scale_one = true) const {
120 using namespace mkldnn::impl;
121 return kind == primitive_kind::sum
122 && IMPLICATION(require_scale_one, sum.scale == 1.f);
123 }
124 };
125
126 mkldnn_post_ops(): len_(0) {}
127
128 mkldnn::impl::status_t append_sum(float scale);
129 mkldnn::impl::status_t append_eltwise(float scale,
130 mkldnn::impl::alg_kind_t alg, float alpha, float beta);
131
132 int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
133 int stop = -1) const {
134 if (stop == -1) stop = len_;
135 stop = mkldnn::impl::nstl::min(stop, len_);
136 for (int idx = start; idx < stop; ++idx)
137 if (entry_[idx].kind == kind) return idx;
138 return -1;
139 }
140
141 bool has_default_values() const { return len_ == 0; }
142
143 bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
144 { return find(kind, index, index + 1) == index; }
145
146 enum { capacity = 4 };
147
148 int len_;
149 entry_t entry_[capacity];
150};
151
152struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
153 mkldnn_primitive_attr()
154 : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
155 {}
156
157 mkldnn_primitive_attr *clone() const
158 { return new mkldnn_primitive_attr(*this); }
159
160 /** Returns true if the attributes have default values.
161 *
162 * @note The scratchpad_mode_ is not take into account */
163 bool has_default_values() const {
164 return true
165 && output_scales_.has_default_values()
166 && post_ops_.has_default_values()
167 && rnn_data_qparams_.has_default_values()
168 && rnn_weights_qparams_.has_default_values();
169 }
170
171 mkldnn::impl::status_t set_scratchpad_mode(
172 mkldnn::impl::scratchpad_mode_t scratchpad_mode);
173 mkldnn::impl::status_t set_post_ops(
174 const mkldnn::impl::post_ops_t &post_ops);
175
176 mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
177 mkldnn::impl::scales_t output_scales_;
178 mkldnn::impl::post_ops_t post_ops_;
179 mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
180 mkldnn::impl::scales_t rnn_weights_qparams_;
181};
182
183#endif
184