| 1 | // ======================================================================== // |
| 2 | // Copyright 2009-2019 Intel Corporation // |
| 3 | // // |
| 4 | // Licensed under the Apache License, Version 2.0 (the "License"); // |
| 5 | // you may not use this file except in compliance with the License. // |
| 6 | // You may obtain a copy of the License at // |
| 7 | // // |
| 8 | // http://www.apache.org/licenses/LICENSE-2.0 // |
| 9 | // // |
| 10 | // Unless required by applicable law or agreed to in writing, software // |
| 11 | // distributed under the License is distributed on an "AS IS" BASIS, // |
| 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // |
| 13 | // See the License for the specific language governing permissions and // |
| 14 | // limitations under the License. // |
| 15 | // ======================================================================== // |
| 16 | |
| 17 | #include "common/tensor.h" |
| 18 | #include "image.h" |
| 19 | #include "node.h" |
| 20 | #include "input_reorder.h" |
| 21 | #include "output_reorder.h" |
| 22 | #include "transfer_function.h" |
| 23 | |
| 24 | #pragma once |
| 25 | |
| 26 | namespace oidn { |
| 27 | |
| 28 | // Progress state |
| 29 | struct Progress |
| 30 | { |
| 31 | ProgressMonitorFunction func; |
| 32 | void* userPtr; |
| 33 | int taskCount; |
| 34 | }; |
| 35 | |
| 36 | class Executable |
| 37 | { |
| 38 | public: |
| 39 | virtual ~Executable() {} |
| 40 | virtual void execute(const Progress& progress, int taskIndex) = 0; |
| 41 | }; |
| 42 | |
| 43 | template<int K> |
| 44 | class Network : public Executable |
| 45 | { |
| 46 | public: |
| 47 | Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap); |
| 48 | |
| 49 | void execute(const Progress& progress, int taskIndex) override; |
| 50 | |
| 51 | std::shared_ptr<memory> allocTensor(const memory::dims& dims, |
| 52 | memory::format_tag format = memory::format_tag::any, |
| 53 | void* data = nullptr); |
| 54 | |
| 55 | std::shared_ptr<memory> castTensor(const memory::dims& dims, |
| 56 | const std::shared_ptr<memory>& src, |
| 57 | size_t srcOffset = 0, |
| 58 | memory::format_tag format = memory::format_tag::any); |
| 59 | |
| 60 | std::shared_ptr<memory> castTensor(const memory::dims& dims, |
| 61 | const std::shared_ptr<memory>& src, |
| 62 | const memory::dims& srcOffset); |
| 63 | |
| 64 | void zeroTensor(const std::shared_ptr<memory>& dst); |
| 65 | |
| 66 | memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); |
| 67 | |
| 68 | std::shared_ptr<Node> addInputReorder(const Image& color, |
| 69 | const Image& albedo, |
| 70 | const Image& normal, |
| 71 | const std::shared_ptr<TransferFunction>& transferFunc, |
| 72 | int alignment, |
| 73 | const std::shared_ptr<memory>& userDst = nullptr); |
| 74 | |
| 75 | std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src, |
| 76 | const std::shared_ptr<TransferFunction>& transferFunc, |
| 77 | const Image& output); |
| 78 | |
| 79 | memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); |
| 80 | std::shared_ptr<Node> addConv(const std::string& name, |
| 81 | const std::shared_ptr<memory>& src, |
| 82 | const std::shared_ptr<memory>& userDst = nullptr, |
| 83 | bool relu = true); |
| 84 | |
| 85 | memory::dims getPoolDims(const memory::dims& srcDims); |
| 86 | std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src, |
| 87 | const std::shared_ptr<memory>& userDst = nullptr); |
| 88 | |
| 89 | memory::dims getUpsampleDims(const memory::dims& srcDims); |
| 90 | std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src, |
| 91 | const std::shared_ptr<memory>& userDst = nullptr); |
| 92 | |
| 93 | memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); |
| 94 | |
| 95 | std::shared_ptr<Node> addAutoexposure(const Image& color, |
| 96 | const std::shared_ptr<HDRTransferFunction>& transferFunc); |
| 97 | |
| 98 | void finalize(); |
| 99 | |
| 100 | private: |
| 101 | Ref<Device> device; |
| 102 | engine eng; |
| 103 | stream sm; |
| 104 | std::vector<std::shared_ptr<Node>> nodes; |
| 105 | std::map<std::string, Tensor> weightMap; |
| 106 | |
| 107 | // Memory allocation statistics |
| 108 | size_t activationAllocBytes = 0; // number of allocated activation bytes |
| 109 | size_t totalAllocBytes = 0; // total number of allocated bytes |
| 110 | }; |
| 111 | |
| 112 | } // namespace oidn |
| 113 | |