1// ======================================================================== //
2// Copyright 2009-2019 Intel Corporation //
3// //
4// Licensed under the Apache License, Version 2.0 (the "License"); //
5// you may not use this file except in compliance with the License. //
6// You may obtain a copy of the License at //
7// //
8// http://www.apache.org/licenses/LICENSE-2.0 //
9// //
10// Unless required by applicable law or agreed to in writing, software //
11// distributed under the License is distributed on an "AS IS" BASIS, //
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
13// See the License for the specific language governing permissions and //
14// limitations under the License. //
15// ======================================================================== //
16
17#pragma once
18
19#include "common.h"
20#include <vector>
21
22namespace oidn {
23
24 class Node
25 {
26 public:
27 virtual ~Node() = default;
28
29 virtual void execute(stream& sm) = 0;
30
31 virtual std::shared_ptr<memory> getDst() const { return nullptr; }
32
33 virtual size_t getScratchpadSize() const { return 0; }
34 virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
35
36 virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
37 {
38 assert(0); // not supported
39 }
40 };
41
42 // Node wrapping an MKL-DNN primitive
43 class MklNode : public Node
44 {
45 private:
46 primitive prim;
47 std::unordered_map<int, memory> args;
48 std::shared_ptr<memory> scratchpad;
49
50 public:
51 MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
52 : prim(prim),
53 args(args)
54 {}
55
56 size_t getScratchpadSize() const override
57 {
58 const auto primDesc = prim.get_primitive_desc();
59 const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
60 if (scratchpadDesc == nullptr)
61 return 0;
62 return mkldnn_memory_desc_get_size(scratchpadDesc);
63 }
64
65 void setScratchpad(const std::shared_ptr<memory>& mem) override
66 {
67 scratchpad = mem;
68 args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
69 }
70
71 void execute(stream& sm) override
72 {
73 prim.execute(sm, args);
74 }
75 };
76
77 // Convolution node
78 class ConvNode : public MklNode
79 {
80 private:
81 std::shared_ptr<memory> src;
82 std::shared_ptr<memory> weights;
83 std::shared_ptr<memory> bias;
84 std::shared_ptr<memory> dst;
85
86 public:
87 ConvNode(const convolution_forward::primitive_desc& desc,
88 const std::shared_ptr<memory>& src,
89 const std::shared_ptr<memory>& weights,
90 const std::shared_ptr<memory>& bias,
91 const std::shared_ptr<memory>& dst)
92 : MklNode(convolution_forward(desc),
93 { { MKLDNN_ARG_SRC, *src },
94 { MKLDNN_ARG_WEIGHTS, *weights },
95 { MKLDNN_ARG_BIAS, *bias },
96 { MKLDNN_ARG_DST, *dst } }),
97 src(src), weights(weights), bias(bias), dst(dst)
98 {}
99
100 std::shared_ptr<memory> getDst() const override { return dst; }
101 };
102
103 // Pooling node
104 class PoolNode : public MklNode
105 {
106 private:
107 std::shared_ptr<memory> src;
108 std::shared_ptr<memory> dst;
109
110 public:
111 PoolNode(const pooling_forward::primitive_desc& desc,
112 const std::shared_ptr<memory>& src,
113 const std::shared_ptr<memory>& dst)
114 : MklNode(pooling_forward(desc),
115 { { MKLDNN_ARG_SRC, *src },
116 { MKLDNN_ARG_DST, *dst } }),
117 src(src), dst(dst)
118 {}
119
120 std::shared_ptr<memory> getDst() const override { return dst; }
121 };
122
123 // Reorder node
124 class ReorderNode : public MklNode
125 {
126 private:
127 std::shared_ptr<memory> src;
128 std::shared_ptr<memory> dst;
129
130 public:
131 ReorderNode(const std::shared_ptr<memory>& src,
132 const std::shared_ptr<memory>& dst)
133 : MklNode(reorder(reorder::primitive_desc(*src, *dst)),
134 { { MKLDNN_ARG_SRC, *src },
135 { MKLDNN_ARG_DST, *dst } }),
136 src(src), dst(dst)
137 {}
138
139 std::shared_ptr<memory> getDst() const override { return dst; }
140 };
141
142} // namespace oidn
143