1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils.hpp"
18
19#include "inner_product_pd.hpp"
20
21namespace mkldnn {
22namespace impl {
23
24using namespace prop_kind;
25
26memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
27 return desc->prop_kind == backward_data
28 ? &desc->diff_src_desc : &desc->src_desc;
29}
30
31memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
32 return desc->prop_kind == backward_weights
33 ? &desc->diff_weights_desc : &desc->weights_desc;
34}
35
36memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
37 return desc->prop_kind == backward_weights
38 ? &desc->diff_bias_desc : &desc->bias_desc;
39}
40
41memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
42 return utils::one_of(desc->prop_kind, forward_inference, forward_training)
43 ? &desc->dst_desc : &desc->diff_dst_desc;
44}
45
46const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
47{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
48const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
49{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
50const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
51{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
52const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
53{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
54
55}
56}
57