1// ======================================================================== //
2// Copyright 2009-2019 Intel Corporation //
3// //
4// Licensed under the Apache License, Version 2.0 (the "License"); //
5// you may not use this file except in compliance with the License. //
6// You may obtain a copy of the License at //
7// //
8// http://www.apache.org/licenses/LICENSE-2.0 //
9// //
10// Unless required by applicable law or agreed to in writing, software //
11// distributed under the License is distributed on an "AS IS" BASIS, //
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
13// See the License for the specific language governing permissions and //
14// limitations under the License. //
15// ======================================================================== //
16
17#pragma once
18
19#include "node.h"
20#include "image.h"
21
22namespace oidn {
23
24 // Output reorder node
25 template<int K, class TransferFunction>
26 class OutputReorderNode : public Node
27 {
28 private:
29 // Source
30 std::shared_ptr<memory> src;
31 const float* srcPtr;
32 int H1;
33 int W1;
34
35 // Destination
36 Image output;
37
38 // Tile
39 int h1Begin;
40 int w1Begin;
41 int h2Begin;
42 int w2Begin;
43 int H;
44 int W;
45
46 std::shared_ptr<TransferFunction> transferFunc;
47
48 public:
49 OutputReorderNode(const std::shared_ptr<memory>& src,
50 const Image& output,
51 const std::shared_ptr<TransferFunction>& transferFunc)
52 : src(src),
53 output(output),
54 h1Begin(0), w1Begin(0),
55 h2Begin(0), w2Begin(0),
56 H(output.height), W(output.width),
57 transferFunc(transferFunc)
58 {
59 const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
60 MAYBE_UNUSED(srcDesc);
61 assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
62 assert(srcDesc.ndims == 4);
63 assert(srcDesc.data_type == memory::data_type::f32);
64 assert(srcDesc.dims[0] == 1);
65 // We assume output data is <= K OC
66 assert(srcDesc.dims[1] == K);
67
68 srcPtr = (float*)src->get_data_handle();
69 H1 = srcDesc.dims[2];
70 W1 = srcDesc.dims[3];
71 }
72
73 void setTile(int h1, int w1, int h2, int w2, int H, int W) override
74 {
75 h1Begin = h1;
76 w1Begin = w1;
77 h2Begin = h2;
78 w2Begin = w2;
79 this->H = H;
80 this->W = W;
81 }
82
83 void execute(stream& sm) override
84 {
85 assert(h1Begin + H <= H1);
86 assert(w1Begin + W <= W1);
87 assert(h2Begin + H <= output.height);
88 assert(w2Begin + W <= output.width);
89
90 const int C1 = K;
91
92 parallel_nd(H, [&](int h)
93 {
94 const int h1 = h + h1Begin;
95 const int h2 = h + h2Begin;
96
97 for (int w = 0; w < W; ++w)
98 {
99 const int w1 = w + w1Begin;
100 const int w2 = w + w2Begin;
101 float* dstPtr_C = (float*)output.get(h2, w2);
102
103 // Source is in nChwKc format. In this case C is 1 so this is really nhwc
104 const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1;
105
106 #pragma unroll
107 for (int i = 0; i < 3; ++i)
108 {
109 // Load the value
110 float x = srcPtr_C[i];
111
112 // The CNN output may contain negative values or even NaNs, so it must be sanitized
113 x = maxSafe(x, 0.f);
114
115 // Apply the inverse transfer function
116 x = transferFunc->inverse(x);
117
118 // Sanitize and store the final value
119 dstPtr_C[i] = max(x, 0.f);
120 }
121 }
122 });
123 }
124 };
125
126} // namespace oidn
127