1 | /******************************************************************************* |
2 | * Copyright 2017-2018 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_JIT_TRANSPOSE_SRC_HPP |
18 | #define CPU_JIT_TRANSPOSE_SRC_HPP |
19 | |
20 | #include "cpu_barrier.hpp" |
21 | #include "jit_primitive_conf.hpp" |
22 | |
23 | namespace mkldnn { |
24 | namespace impl { |
25 | namespace cpu { |
26 | |
27 | struct jit_trans_src_t { |
28 | struct ctx_t { |
29 | const void *src; |
30 | const void *tr_src; |
31 | const void *src_prf; |
32 | const void *tr_src_prf; |
33 | |
34 | /* 1st conv 4fma: backward by weights */ |
35 | int nthr_oc_b; /* number of threads process given src image */ |
36 | int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ |
37 | simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ |
38 | }; |
39 | |
40 | jit_trans_src_t(const jit_conv_conf_t *conf) |
41 | : conf_(conf), ker_(nullptr) {} |
42 | virtual ~jit_trans_src_t() {} |
43 | |
44 | void operator()(const ctx_t *ctx) |
45 | { assert(ker_); ker_(ctx); } |
46 | |
47 | const jit_conv_conf_t *conf_; |
48 | void (*ker_)(const ctx_t *); |
49 | }; |
50 | |
51 | struct jit_src_transpose_s { |
52 | size_t size; |
53 | const void *src; |
54 | const void *tr_src; |
55 | const void *src_prf; |
56 | const void *tr_src_prf; |
57 | }; |
58 | |
59 | struct jit_trans_dst_t { |
60 | struct ctx_t { |
61 | const void *src; |
62 | const void *tr_src; |
63 | const void *src_prf; |
64 | const void *tr_src_prf; |
65 | |
66 | /* 1st conv 4fma: backward by weights */ |
67 | int nthr_oc_b; /* number of threads process given src image */ |
68 | int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ |
69 | simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ |
70 | }; |
71 | |
72 | jit_trans_dst_t(const jit_conv_conf_t *conf) |
73 | : conf_(conf), ker_(nullptr) {} |
74 | virtual ~jit_trans_dst_t() {} |
75 | |
76 | void operator()(const ctx_t *ctx) |
77 | { assert(ker_); ker_(ctx); } |
78 | |
79 | const jit_conv_conf_t *conf_; |
80 | void (*ker_)(const ctx_t *); |
81 | }; |
82 | |
83 | struct jit_transpose4x16_src_t { |
84 | int src_pf0_distance; |
85 | int tr_src_pf0_distance; |
86 | bool src_pf1; |
87 | bool tr_src_pf1; |
88 | }; |
89 | |
90 | struct jit_transpose4x16_src : public jit_generator { |
91 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) |
92 | |
93 | jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, |
94 | jit_transpose4x16_src_t *tparams_) |
95 | : params(aparams), tparams(tparams_) |
96 | { |
97 | this->generate(); |
98 | jit_ker = (decltype(jit_ker))this->getCode(); |
99 | } |
100 | |
101 | const jit_1x1_conv_conf_t *params; |
102 | const jit_transpose4x16_src_t *tparams; |
103 | void (*jit_ker)(jit_src_transpose_s *); |
104 | |
105 | void operator()(jit_src_transpose_s *arg) { jit_ker(arg); } |
106 | |
107 | static const int transpose_size = 4; |
108 | private: |
109 | static const int typesize = sizeof(float); |
110 | |
111 | int src_stride, tr_src_stride; |
112 | |
113 | Xbyak::Reg64 imm_addr64 = rbx; |
114 | |
115 | Xbyak::Opmask kF0 = k1; |
116 | Xbyak::Opmask kCC = k2; |
117 | Xbyak::Opmask k33 = k3; |
118 | Xbyak::Opmask kFFFF = k4; |
119 | |
120 | Xbyak::Zmm vidx01 = zmm31; |
121 | Xbyak::Zmm vidx10 = zmm30; |
122 | Xbyak::Zmm vidx1 = zmm29; |
123 | Xbyak::Zmm vidxP = zmm28; |
124 | |
125 | Xbyak::Reg64 reg_src = r8; |
126 | Xbyak::Reg64 reg_tr_src = r9; |
127 | Xbyak::Reg64 reg_src_prf = r10; |
128 | Xbyak::Reg64 reg_tr_src_prf = r11; |
129 | Xbyak::Reg64 reg_loop = r12; |
130 | Xbyak::Reg64 reg_tr_src_tmp = r13; |
131 | Xbyak::Reg32 regw_tmp = r14d; |
132 | |
133 | void transpose_block(int ur, int nrows); |
134 | void transpose(int nrows); |
135 | void generate(); |
136 | }; |
137 | |
138 | jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); |
139 | jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); |
140 | |
141 | } |
142 | } |
143 | } |
144 | |
145 | #endif |
146 | |