1 | /******************************************************************************* |
---|---|
2 | * Copyright 2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils.hpp" |
18 | |
19 | #include "convolution_pd.hpp" |
20 | |
21 | namespace mkldnn { |
22 | namespace impl { |
23 | |
24 | using namespace prop_kind; |
25 | |
26 | memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) { |
27 | return desc->prop_kind == backward_data |
28 | ? &desc->diff_src_desc : &desc->src_desc; |
29 | } |
30 | |
31 | memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) { |
32 | return desc->prop_kind == backward_weights |
33 | ? &desc->diff_weights_desc : &desc->weights_desc; |
34 | } |
35 | |
36 | memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) { |
37 | return desc->prop_kind == backward_weights |
38 | ? &desc->diff_bias_desc : &desc->bias_desc; |
39 | } |
40 | |
41 | memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) { |
42 | return utils::one_of(desc->prop_kind, forward_inference, forward_training) |
43 | ? &desc->dst_desc : &desc->diff_dst_desc; |
44 | } |
45 | |
46 | const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) |
47 | { return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); } |
48 | const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) |
49 | { return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); } |
50 | const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) |
51 | { return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); } |
52 | const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) |
53 | { return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); } |
54 | |
55 | } |
56 | } |
57 |