1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "mkldnn_types.h"
18
19#include "c_types_map.hpp"
20#include "type_helpers.hpp"
21#include "nstl.hpp"
22
23#include "jit_uni_pooling.hpp"
24
25namespace mkldnn {
26namespace impl {
27namespace cpu {
28
29template <cpu_isa_t isa>
30void jit_uni_pooling_fwd_t<isa>::execute_forward(const data_t *src,
31 data_t *dst, char *indices) const {
32 const memory_desc_wrapper src_d(pd()->src_md());
33 const memory_desc_wrapper dst_d(pd()->dst_md());
34 const memory_desc_wrapper indices_d(pd()->workspace_md());
35 const size_t ind_dt_size = indices
36 ? types::data_type_size(indices_d.data_type()) : 0;
37
38 const auto &jpp = pd()->jpp_;
39
40 auto ker = [&](int n, int b_c, int oh) {
41 auto arg = jit_pool_call_s();
42
43 const int ij = oh * jpp.stride_h;
44 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
45 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
46 const int ih = nstl::max(ij - jpp.t_pad, 0);
47
48 arg.src = &src[src_d.blk_off(n, b_c, ih)];
49 arg.dst = &dst[dst_d.blk_off(n, b_c, oh)];
50 if (indices) {
51 const size_t ind_off = indices_d.blk_off(n, b_c, oh);
52 arg.indices = &indices[ind_off * ind_dt_size];
53 }
54 arg.oh = oh == 0;
55 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
56 arg.kh_padding_shift = i_t_overflow*jpp.kw;
57 arg.kw_padding = 0;
58 arg.ker_area_h = (float)(jpp.kh -
59 nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
60 nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
61 (*kernel_)(&arg);
62 };
63
64 parallel_nd(jpp.mb, jpp.nb_c, jpp.oh,
65 [&](int n, int b_c, int oh) {
66 ker(n, b_c, oh);
67 });
68}
69
70template <cpu_isa_t isa>
71void jit_uni_pooling_fwd_t<isa>::execute_forward_3d(const data_t *src,
72 data_t *dst, char *indices) const {
73 const memory_desc_wrapper src_d(pd()->src_md());
74 const memory_desc_wrapper dst_d(pd()->dst_md());
75 const memory_desc_wrapper indices_d(pd()->workspace_md());
76 const size_t ind_dt_size = indices
77 ? types::data_type_size(indices_d.data_type()) : 0;
78
79 const auto &jpp = pd()->jpp_;
80
81 auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
82 int d_b_overflow) {
83 auto arg = jit_pool_call_s();
84
85 const int ij = oh * jpp.stride_h;
86 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
87 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
88 const int ih = nstl::max(ij - jpp.t_pad, 0);
89
90 arg.src = &src[src_d.blk_off(n, b_c, id, ih)];
91 arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)];
92 if (indices) {
93 const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
94 arg.indices = &indices[ind_off * ind_dt_size];
95 }
96 arg.oh = (oh + od == 0);
97 arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
98 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
99 arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh;
100 arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
101 arg.kw_padding = 0;
102 arg.ker_area_h = (float)(jpp.kh -
103 nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
104 nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
105 nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
106 nstl::max(0, jpp.f_pad - od*jpp.stride_d));
107
108
109 (*kernel_)(&arg);
110 };
111
112 parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
113 [&](int n, int b_c, int od) {
114 const int ik = od * jpp.stride_d;
115 const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
116 const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad)
117 -jpp.id;
118 const int id = nstl::max(ik - jpp.f_pad, 0);
119 for (int oh = 0; oh < jpp.oh; ++oh) {
120 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow);
121 }
122 });
123}
124
125template <cpu_isa_t isa>
126void jit_uni_pooling_bwd_t<isa>::execute_backward(const data_t *diff_dst,
127 const char *indices, data_t *diff_src) const {
128 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
129 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
130 const memory_desc_wrapper indices_d(pd()->workspace_md());
131 const size_t ind_dt_size = indices
132 ? types::data_type_size(indices_d.data_type()) : 0;
133
134 const auto &jpp = pd()->jpp_;
135
136 auto ker = [&](int n, int b_c, int oh) {
137 auto arg = jit_pool_call_s();
138
139 const int ij = oh * jpp.stride_h;
140 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
141 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
142 const int ih = nstl::max(ij - jpp.t_pad, 0);
143
144 arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)];
145 arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)];
146 if (indices) {
147 const size_t ind_off = indices_d.blk_off(n, b_c, oh);
148 arg.indices = &indices[ind_off * ind_dt_size];
149 }
150 arg.oh = (oh == 0);
151 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
152 arg.kh_padding_shift = i_t_overflow*jpp.kw;
153 arg.kw_padding = 0;
154 arg.ker_area_h = (float)(jpp.kh -
155 nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
156 nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
157
158 (*kernel_)(&arg);
159 };
160
161 parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) {
162 for (int oh = 0; oh < jpp.oh; ++oh) {
163 ker(n, b_c, oh);
164 }
165 });
166}
167
168template <cpu_isa_t isa>
169void jit_uni_pooling_bwd_t<isa>::execute_backward_3d(const data_t *diff_dst,
170 const char *indices, data_t *diff_src) const {
171 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
172 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
173 const memory_desc_wrapper indices_d(pd()->workspace_md());
174 const size_t ind_dt_size = indices
175 ? types::data_type_size(indices_d.data_type()) : 0;
176
177 const auto &jpp = pd()->jpp_;
178
179 auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
180 int d_b_overflow, int zero_size, int kd) {
181 auto arg = jit_pool_call_s();
182
183 const int ij = oh * jpp.stride_h;
184 const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
185 const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
186 const int ih = nstl::max(ij - jpp.t_pad, 0);
187
188 arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)];
189 arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)];
190 if (indices) {
191 const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
192 arg.indices = &indices[ind_off * ind_dt_size];
193 }
194 arg.oh = zero_size;
195 arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
196 arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
197 arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh
198 + kd * jpp.kw * jpp.kh;
199 arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
200 arg.kw_padding = 0;
201 arg.ker_area_h = (float)(jpp.kh -
202 nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
203 nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
204 nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
205 nstl::max(0, jpp.f_pad - od*jpp.stride_d));
206
207 (*kernel_)(&arg);
208 };
209
210 if (jpp.simple_alg) {
211
212 parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
213 [&](int n, int b_c, int od) {
214 const int ik = od * jpp.stride_d;
215 const int d_t_overflow = nstl::max(0, jpp.f_pad - ik);
216 const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
217 - jpp.f_pad) - jpp.id;
218 const int id = nstl::max(ik - jpp.f_pad, 0);
219 int zero_s = jpp.stride_d - d_t_overflow - (nstl::max(
220 jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id);
221 for (int oh = 0; oh < jpp.oh; ++oh) {
222 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
223 (oh == 0) ? zero_s : 0, 0);
224 }
225 });
226 } else {
227 ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c
228 * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw;
229
230 parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; });
231
232 for (int kd = 0; kd < jpp.kd; ++kd) {
233 parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) {
234 for (int od = 0; od < jpp.od; ++od) {
235 const int ik = od * jpp.stride_d;
236 const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
237 const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
238 - jpp.f_pad) - jpp.id;
239 if (kd >= jpp.kd - d_t_overflow - d_b_overflow)
240 continue;
241 const int id = nstl::max(ik - jpp.f_pad, 0);
242 for (int oh = 0; oh < jpp.oh; ++oh) {
243 ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
244 0, kd);
245 }
246 }
247 });
248 }
249 }
250}
251
252
253template struct jit_uni_pooling_fwd_t<sse42>;
254template struct jit_uni_pooling_bwd_t<sse42>;
255template struct jit_uni_pooling_fwd_t<avx>;
256template struct jit_uni_pooling_bwd_t<avx>;
257template struct jit_uni_pooling_fwd_t<avx512_common>;
258template struct jit_uni_pooling_bwd_t<avx512_common>;
259
260}
261}
262}
263
264// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
265