1// ======================================================================== //
2// Copyright 2009-2019 Intel Corporation //
3// //
4// Licensed under the Apache License, Version 2.0 (the "License"); //
5// you may not use this file except in compliance with the License. //
6// You may obtain a copy of the License at //
7// //
8// http://www.apache.org/licenses/LICENSE-2.0 //
9// //
10// Unless required by applicable law or agreed to in writing, software //
11// distributed under the License is distributed on an "AS IS" BASIS, //
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
13// See the License for the specific language governing permissions and //
14// limitations under the License. //
15// ======================================================================== //
16
17#include "common/tensor.h"
18#include "image.h"
19#include "node.h"
20#include "input_reorder.h"
21#include "output_reorder.h"
22#include "transfer_function.h"
23
24#pragma once
25
26namespace oidn {
27
28 // Progress state
29 struct Progress
30 {
31 ProgressMonitorFunction func;
32 void* userPtr;
33 int taskCount;
34 };
35
36 class Executable
37 {
38 public:
39 virtual ~Executable() {}
40 virtual void execute(const Progress& progress, int taskIndex) = 0;
41 };
42
43 template<int K>
44 class Network : public Executable
45 {
46 public:
47 Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
48
49 void execute(const Progress& progress, int taskIndex) override;
50
51 std::shared_ptr<memory> allocTensor(const memory::dims& dims,
52 memory::format_tag format = memory::format_tag::any,
53 void* data = nullptr);
54
55 std::shared_ptr<memory> castTensor(const memory::dims& dims,
56 const std::shared_ptr<memory>& src,
57 size_t srcOffset = 0,
58 memory::format_tag format = memory::format_tag::any);
59
60 std::shared_ptr<memory> castTensor(const memory::dims& dims,
61 const std::shared_ptr<memory>& src,
62 const memory::dims& srcOffset);
63
64 void zeroTensor(const std::shared_ptr<memory>& dst);
65
66 memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
67
68 std::shared_ptr<Node> addInputReorder(const Image& color,
69 const Image& albedo,
70 const Image& normal,
71 const std::shared_ptr<TransferFunction>& transferFunc,
72 int alignment,
73 const std::shared_ptr<memory>& userDst = nullptr);
74
75 std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
76 const std::shared_ptr<TransferFunction>& transferFunc,
77 const Image& output);
78
79 memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
80 std::shared_ptr<Node> addConv(const std::string& name,
81 const std::shared_ptr<memory>& src,
82 const std::shared_ptr<memory>& userDst = nullptr,
83 bool relu = true);
84
85 memory::dims getPoolDims(const memory::dims& srcDims);
86 std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
87 const std::shared_ptr<memory>& userDst = nullptr);
88
89 memory::dims getUpsampleDims(const memory::dims& srcDims);
90 std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
91 const std::shared_ptr<memory>& userDst = nullptr);
92
93 memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
94
95 std::shared_ptr<Node> addAutoexposure(const Image& color,
96 const std::shared_ptr<HDRTransferFunction>& transferFunc);
97
98 void finalize();
99
100 private:
101 Ref<Device> device;
102 engine eng;
103 stream sm;
104 std::vector<std::shared_ptr<Node>> nodes;
105 std::map<std::string, Tensor> weightMap;
106
107 // Memory allocation statistics
108 size_t activationAllocBytes = 0; // number of allocated activation bytes
109 size_t totalAllocBytes = 0; // total number of allocated bytes
110 };
111
112} // namespace oidn
113