1 | // ======================================================================== // |
2 | // Copyright 2009-2019 Intel Corporation // |
3 | // // |
4 | // Licensed under the Apache License, Version 2.0 (the "License"); // |
5 | // you may not use this file except in compliance with the License. // |
6 | // You may obtain a copy of the License at // |
7 | // // |
8 | // http://www.apache.org/licenses/LICENSE-2.0 // |
9 | // // |
10 | // Unless required by applicable law or agreed to in writing, software // |
11 | // distributed under the License is distributed on an "AS IS" BASIS, // |
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // |
13 | // See the License for the specific language governing permissions and // |
14 | // limitations under the License. // |
15 | // ======================================================================== // |
16 | |
17 | #include "common/tensor.h" |
18 | #include "image.h" |
19 | #include "node.h" |
20 | #include "input_reorder.h" |
21 | #include "output_reorder.h" |
22 | #include "transfer_function.h" |
23 | |
24 | #pragma once |
25 | |
26 | namespace oidn { |
27 | |
28 | // Progress state |
29 | struct Progress |
30 | { |
31 | ProgressMonitorFunction func; |
32 | void* userPtr; |
33 | int taskCount; |
34 | }; |
35 | |
36 | class Executable |
37 | { |
38 | public: |
39 | virtual ~Executable() {} |
40 | virtual void execute(const Progress& progress, int taskIndex) = 0; |
41 | }; |
42 | |
43 | template<int K> |
44 | class Network : public Executable |
45 | { |
46 | public: |
47 | Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap); |
48 | |
49 | void execute(const Progress& progress, int taskIndex) override; |
50 | |
51 | std::shared_ptr<memory> allocTensor(const memory::dims& dims, |
52 | memory::format_tag format = memory::format_tag::any, |
53 | void* data = nullptr); |
54 | |
55 | std::shared_ptr<memory> castTensor(const memory::dims& dims, |
56 | const std::shared_ptr<memory>& src, |
57 | size_t srcOffset = 0, |
58 | memory::format_tag format = memory::format_tag::any); |
59 | |
60 | std::shared_ptr<memory> castTensor(const memory::dims& dims, |
61 | const std::shared_ptr<memory>& src, |
62 | const memory::dims& srcOffset); |
63 | |
64 | void zeroTensor(const std::shared_ptr<memory>& dst); |
65 | |
66 | memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); |
67 | |
68 | std::shared_ptr<Node> addInputReorder(const Image& color, |
69 | const Image& albedo, |
70 | const Image& normal, |
71 | const std::shared_ptr<TransferFunction>& transferFunc, |
72 | int alignment, |
73 | const std::shared_ptr<memory>& userDst = nullptr); |
74 | |
75 | std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src, |
76 | const std::shared_ptr<TransferFunction>& transferFunc, |
77 | const Image& output); |
78 | |
79 | memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); |
80 | std::shared_ptr<Node> addConv(const std::string& name, |
81 | const std::shared_ptr<memory>& src, |
82 | const std::shared_ptr<memory>& userDst = nullptr, |
83 | bool relu = true); |
84 | |
85 | memory::dims getPoolDims(const memory::dims& srcDims); |
86 | std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src, |
87 | const std::shared_ptr<memory>& userDst = nullptr); |
88 | |
89 | memory::dims getUpsampleDims(const memory::dims& srcDims); |
90 | std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src, |
91 | const std::shared_ptr<memory>& userDst = nullptr); |
92 | |
93 | memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); |
94 | |
95 | std::shared_ptr<Node> addAutoexposure(const Image& color, |
96 | const std::shared_ptr<HDRTransferFunction>& transferFunc); |
97 | |
98 | void finalize(); |
99 | |
100 | private: |
101 | Ref<Device> device; |
102 | engine eng; |
103 | stream sm; |
104 | std::vector<std::shared_ptr<Node>> nodes; |
105 | std::map<std::string, Tensor> weightMap; |
106 | |
107 | // Memory allocation statistics |
108 | size_t activationAllocBytes = 0; // number of allocated activation bytes |
109 | size_t totalAllocBytes = 0; // total number of allocated bytes |
110 | }; |
111 | |
112 | } // namespace oidn |
113 | |