1/*******************************************************************************
2* Copyright 2017-2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_JIT_TRANSPOSE_SRC_HPP
18#define CPU_JIT_TRANSPOSE_SRC_HPP
19
20#include "cpu_barrier.hpp"
21#include "jit_primitive_conf.hpp"
22
23namespace mkldnn {
24namespace impl {
25namespace cpu {
26
27struct jit_trans_src_t {
28 struct ctx_t {
29 const void *src;
30 const void *tr_src;
31 const void *src_prf;
32 const void *tr_src_prf;
33
34 /* 1st conv 4fma: backward by weights */
35 int nthr_oc_b; /* number of threads process given src image */
36 int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
37 simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
38 };
39
40 jit_trans_src_t(const jit_conv_conf_t *conf)
41 : conf_(conf), ker_(nullptr) {}
42 virtual ~jit_trans_src_t() {}
43
44 void operator()(const ctx_t *ctx)
45 { assert(ker_); ker_(ctx); }
46
47 const jit_conv_conf_t *conf_;
48 void (*ker_)(const ctx_t *);
49};
50
51struct jit_src_transpose_s {
52 size_t size;
53 const void *src;
54 const void *tr_src;
55 const void *src_prf;
56 const void *tr_src_prf;
57};
58
59struct jit_trans_dst_t {
60 struct ctx_t {
61 const void *src;
62 const void *tr_src;
63 const void *src_prf;
64 const void *tr_src_prf;
65
66 /* 1st conv 4fma: backward by weights */
67 int nthr_oc_b; /* number of threads process given src image */
68 int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
69 simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
70 };
71
72 jit_trans_dst_t(const jit_conv_conf_t *conf)
73 : conf_(conf), ker_(nullptr) {}
74 virtual ~jit_trans_dst_t() {}
75
76 void operator()(const ctx_t *ctx)
77 { assert(ker_); ker_(ctx); }
78
79 const jit_conv_conf_t *conf_;
80 void (*ker_)(const ctx_t *);
81};
82
83struct jit_transpose4x16_src_t {
84 int src_pf0_distance;
85 int tr_src_pf0_distance;
86 bool src_pf1;
87 bool tr_src_pf1;
88};
89
90struct jit_transpose4x16_src : public jit_generator {
91 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src)
92
93 jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams,
94 jit_transpose4x16_src_t *tparams_)
95 : params(aparams), tparams(tparams_)
96 {
97 this->generate();
98 jit_ker = (decltype(jit_ker))this->getCode();
99 }
100
101 const jit_1x1_conv_conf_t *params;
102 const jit_transpose4x16_src_t *tparams;
103 void (*jit_ker)(jit_src_transpose_s *);
104
105 void operator()(jit_src_transpose_s *arg) { jit_ker(arg); }
106
107 static const int transpose_size = 4;
108private:
109 static const int typesize = sizeof(float);
110
111 int src_stride, tr_src_stride;
112
113 Xbyak::Reg64 imm_addr64 = rbx;
114
115 Xbyak::Opmask kF0 = k1;
116 Xbyak::Opmask kCC = k2;
117 Xbyak::Opmask k33 = k3;
118 Xbyak::Opmask kFFFF = k4;
119
120 Xbyak::Zmm vidx01 = zmm31;
121 Xbyak::Zmm vidx10 = zmm30;
122 Xbyak::Zmm vidx1 = zmm29;
123 Xbyak::Zmm vidxP = zmm28;
124
125 Xbyak::Reg64 reg_src = r8;
126 Xbyak::Reg64 reg_tr_src = r9;
127 Xbyak::Reg64 reg_src_prf = r10;
128 Xbyak::Reg64 reg_tr_src_prf = r11;
129 Xbyak::Reg64 reg_loop = r12;
130 Xbyak::Reg64 reg_tr_src_tmp = r13;
131 Xbyak::Reg32 regw_tmp = r14d;
132
133 void transpose_block(int ur, int nrows);
134 void transpose(int nrows);
135 void generate();
136};
137
138jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf);
139jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf);
140
141}
142}
143}
144
145#endif
146