1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_JIT_UNI_POOLING_HPP
18#define CPU_JIT_UNI_POOLING_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26#include "cpu_pooling_pd.hpp"
27#include "cpu_primitive.hpp"
28
29#include "jit_uni_pool_kernel_f32.hpp"
30
31namespace mkldnn {
32namespace impl {
33namespace cpu {
34
35template <cpu_isa_t isa>
36struct jit_uni_pooling_fwd_t: public cpu_primitive_t {
37 struct pd_t: public cpu_pooling_fwd_pd_t {
38 using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
39
40 DECLARE_COMMON_PD_T(
41 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
42 jit_uni_pooling_fwd_t<isa>);
43
44 status_t init() {
45 using namespace utils;
46
47 bool ok = true
48 && set_default_params() == status::success
49 && is_fwd()
50 && !has_zero_dim_memory()
51 && everyone_is(data_type::f32,
52 src_md()->data_type,
53 dst_md()->data_type)
54 && attr()->has_default_values()
55 && memory_desc_matches_tag(*src_md(), desired_fmt_tag())
56 && memory_desc_matches_tag(*dst_md(), desired_fmt_tag());
57 if (!ok) return status::unimplemented;
58
59 bool is_training = desc_.prop_kind == prop_kind::forward_training;
60 if (desc()->alg_kind == alg_kind::pooling_max && is_training)
61 init_default_ws();
62
63 return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this);
64 }
65
66 format_tag_t desired_fmt_tag() {
67 using namespace format_tag;
68 return ndims() == 4
69 ? isa == avx512_common ? nChw16c : nChw8c
70 : isa == avx512_common ? nCdhw16c : nCdhw8c;
71 }
72
73 jit_pool_conf_t jpp_;
74 };
75
76 jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
77 { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); }
78
79 ~jit_uni_pooling_fwd_t() { delete kernel_; }
80
81 typedef typename prec_traits<data_type::f32>::type data_t;
82
83 virtual status_t execute(const exec_ctx_t &ctx) const override {
84 auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
85 auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
86 auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE);
87
88 if (pd()->ndims() == 5)
89 execute_forward_3d(src, dst, ws);
90 else
91 execute_forward(src, dst, ws);
92
93 return status::success;
94 }
95
96private:
97 void execute_forward(const data_t *src, data_t *dst, char *indices) const;
98 void execute_forward_3d(const data_t *src, data_t *dst,
99 char *indices) const;
100 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
101 jit_uni_pool_kernel_f32<isa> *kernel_;
102};
103
104template <cpu_isa_t isa>
105struct jit_uni_pooling_bwd_t: public cpu_primitive_t {
106 struct pd_t: public cpu_pooling_bwd_pd_t {
107 using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
108
109 DECLARE_COMMON_PD_T(
110 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
111 jit_uni_pooling_bwd_t<isa>);
112
113 status_t init() {
114 using namespace utils;
115
116 bool ok = true
117 && set_default_params() == status::success
118 && !is_fwd()
119 && !has_zero_dim_memory()
120 && everyone_is(data_type::f32,
121 diff_src_md()->data_type,
122 diff_dst_md()->data_type)
123 && attr()->has_default_values()
124 && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag())
125 && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag());
126 if (!ok) return status::unimplemented;
127
128 if (desc()->alg_kind == alg_kind::pooling_max) {
129 init_default_ws();
130 if (!compare_ws(hint_fwd_pd_))
131 return status::unimplemented;
132 }
133
134 return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this);
135 }
136
137 format_tag_t desired_fmt_tag() {
138 using namespace format_tag;
139 return ndims()
140 ? isa == avx512_common ? nChw16c : nChw8c
141 : isa == avx512_common ? nCdhw16c : nCdhw8c;
142 }
143
144 jit_pool_conf_t jpp_;
145 };
146
147 jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd)
148 { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); }
149
150 ~jit_uni_pooling_bwd_t() { delete kernel_; }
151
152 typedef typename prec_traits<data_type::f32>::type data_t;
153
154 virtual status_t execute(const exec_ctx_t &ctx) const override {
155 auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
156 auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE);
157 auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
158
159 if (pd()->ndims() == 5)
160 execute_backward_3d(diff_dst, ws, diff_src);
161 else
162 execute_backward(diff_dst, ws, diff_src);
163
164 return status::success;
165 }
166
167private:
168 void execute_backward(const data_t *diff_dst, const char *indices,
169 data_t *diff_src) const;
170 void execute_backward_3d(const data_t *diff_dst, const char *indices,
171 data_t *diff_src) const;
172 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
173 jit_uni_pool_kernel_f32<isa> *kernel_;
174};
175
176}
177}
178}
179
180#endif
181
182// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
183