1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef _JIT_UNI_REORDER_HPP
18#define _JIT_UNI_REORDER_HPP
19
20#include <assert.h>
21
22#include "c_types_map.hpp"
23#include "type_helpers.hpp"
24
25#include "cpu_primitive.hpp"
26#include "cpu_reorder_pd.hpp"
27
28namespace mkldnn {
29namespace impl {
30namespace cpu {
31
32namespace tr {
33
34constexpr int max_ndims = MKLDNN_MAX_NDIMS;
35
36struct node_t {
37 size_t n;
38 ptrdiff_t is; // input stride
39 ptrdiff_t os; // output stride
40 ptrdiff_t ss; // scale stride
41};
42
43enum class scale_type_t { NONE, COMMON, MANY };
44
45struct prb_t {
46 data_type_t itype;
47 data_type_t otype;
48 int ndims;
49 node_t nodes[max_ndims];
50 ptrdiff_t ioff;
51 ptrdiff_t ooff;
52 scale_type_t scale_type;
53 float beta;
54};
55
56status_t prb_init(prb_t &prb, const memory_desc_t &imd,
57 const memory_desc_t &omd, const primitive_attr_t *attr);
58
59/** sorts the problem nodes so that output strides come in ascending order */
60void prb_normalize(prb_t &p);
61
62/** folds nodes together if possible */
63void prb_simplify(prb_t &p);
64
65/** splits the node dim into two of sizes n1 and n / n1
66 * @warning n must be multiple of n1 */
67void prb_node_split(prb_t &p, int dim, size_t n1);
68
69/** swaps d0 and d1 nodes */
70void prb_node_swap(prb_t &p, int d0, int d1);
71
72/** moves node d0 to the d1 position.
73 * nodes (d0, d1] are shifted to the left if d0 < d1 or
74 * to the right if d0 > d1 */
75void prb_node_move(prb_t &p, int d0, int d1);
76
77/** dumps the problem to stdout */
78void prb_dump(const prb_t &p);
79
80struct call_param_t {
81 const void *in;
82 void *out;
83 const float *scale;
84};
85
86struct kernel_t {
87 struct desc_t {
88 int id;
89 prb_t prb;
90 };
91
92 kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {}
93 void operator()(const call_param_t *c) const { assert(ker_); ker_(c); }
94 virtual ~kernel_t() {}
95
96 /** inits kernel descriptor:
97 * desc -- kernel descriptor (output)
98 * prb -- transposition problem (input)
99 * ndims_ker_max -- limit the maximum number of dimensions kernel
100 * will process (optional, 0 -- no limitation) */
101 static status_t desc_init(desc_t &desc, const prb_t &prb,
102 int ndims_ker_max = 0);
103
104 /** creates kernel for the problem described in desc */
105 static kernel_t *create(const desc_t &desc);
106
107protected:
108 const desc_t desc_;
109 const prb_t &prb_ = desc_.prb;
110 void (*ker_)(const call_param_t *);
111};
112
113/* TODO: add trans_t class */
114
115}
116
117/* for cpu reorder list */
118status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
119 engine_t *engine, const primitive_attr_t *attr,
120 engine_t *src_engine, const memory_desc_t *src_md,
121 engine_t *dst_engine, const memory_desc_t *dst_md);
122
123}
124}
125}
126
127#endif
128