1// ======================================================================== //
2// Copyright 2009-2019 Intel Corporation //
3// //
4// Licensed under the Apache License, Version 2.0 (the "License"); //
5// you may not use this file except in compliance with the License. //
6// You may obtain a copy of the License at //
7// //
8// http://www.apache.org/licenses/LICENSE-2.0 //
9// //
10// Unless required by applicable law or agreed to in writing, software //
11// distributed under the License is distributed on an "AS IS" BASIS, //
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
13// See the License for the specific language governing permissions and //
14// limitations under the License. //
15// ======================================================================== //
16
17#pragma once
18
19#include "node.h"
20#include "image.h"
21
22namespace oidn {
23
24 // Input reorder node
25 template<int K, class TransferFunction>
26 class InputReorderNode : public Node
27 {
28 private:
29 // Source
30 Image color;
31 Image albedo;
32 Image normal;
33
34 // Destination
35 std::shared_ptr<memory> dst;
36 float* dstPtr;
37 int C2;
38 int H2;
39 int W2;
40
41 // Tile
42 int h1Begin;
43 int w1Begin;
44 int h2Begin;
45 int w2Begin;
46 int H;
47 int W;
48
49 std::shared_ptr<TransferFunction> transferFunc;
50
51 public:
52 InputReorderNode(const Image& color,
53 const Image& albedo,
54 const Image& normal,
55 const std::shared_ptr<memory>& dst,
56 const std::shared_ptr<TransferFunction>& transferFunc)
57 : color(color), albedo(albedo), normal(normal),
58 dst(dst),
59 h1Begin(0), w1Begin(0),
60 H(color.height), W(color.width),
61 transferFunc(transferFunc)
62 {
63 const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
64 assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
65 assert(dstDesc.ndims == 4);
66 assert(dstDesc.data_type == memory::data_type::f32);
67 assert(dstDesc.dims[0] == 1);
68 //assert(dstDesc.dims[1] >= getPadded<K>(C1));
69
70 dstPtr = (float*)dst->get_data_handle();
71 C2 = dstDesc.dims[1];
72 H2 = dstDesc.dims[2];
73 W2 = dstDesc.dims[3];
74 }
75
76 void setTile(int h1, int w1, int h2, int w2, int H, int W) override
77 {
78 h1Begin = h1;
79 w1Begin = w1;
80 h2Begin = h2;
81 w2Begin = w2;
82 this->H = H;
83 this->W = W;
84 }
85
86 void execute(stream& sm) override
87 {
88 assert(H + h1Begin <= color.height);
89 assert(W + w1Begin <= color.width);
90 assert(H + h2Begin <= H2);
91 assert(W + w2Begin <= W2);
92
93 parallel_nd(H2, [&](int h2)
94 {
95 const int h = h2 - h2Begin;
96
97 if (h >= 0 && h < H)
98 {
99 const int h1 = h + h1Begin;
100
101 // Zero pad
102 for (int w2 = 0; w2 < w2Begin; ++w2)
103 {
104 int c = 0;
105 while (c < C2)
106 store(h2, w2, c, 0.f);
107 }
108
109 // Reorder
110 for (int w = 0; w < W; ++w)
111 {
112 const int w1 = w + w1Begin;
113 const int w2 = w + w2Begin;
114
115 int c = 0;
116 storeColor(h2, w2, c, (float*)color.get(h1, w1));
117 if (albedo)
118 storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
119 if (normal)
120 storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
121 while (c < C2)
122 store(h2, w2, c, 0.f);
123 }
124
125 // Zero pad
126 for (int w2 = W + w2Begin; w2 < W2; ++w2)
127 {
128 int c = 0;
129 while (c < C2)
130 store(h2, w2, c, 0.f);
131 }
132 }
133 else
134 {
135 // Zero pad
136 for (int w2 = 0; w2 < W2; ++w2)
137 {
138 int c = 0;
139 while (c < C2)
140 store(h2, w2, c, 0.f);
141 }
142 }
143 });
144 }
145
146 std::shared_ptr<memory> getDst() const override { return dst; }
147
148 private:
149 // Stores a single value
150 __forceinline void store(int h, int w, int& c, float value)
151 {
152 // Destination is in nChwKc format
153 float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
154 *dst_c = value;
155 c++;
156 }
157
158 // Stores a color
159 __forceinline void storeColor(int h, int w, int& c, const float* values)
160 {
161 #pragma unroll
162 for (int i = 0; i < 3; ++i)
163 {
164 // Load the value
165 float x = values[i];
166
167 // Sanitize the value
168 x = maxSafe(x, 0.f);
169
170 // Apply the transfer function
171 x = transferFunc->forward(x);
172
173 // Store the value
174 store(h, w, c, x);
175 }
176 }
177
178 // Stores an albedo
179 __forceinline void storeAlbedo(int h, int w, int& c, const float* values)
180 {
181 #pragma unroll
182 for (int i = 0; i < 3; ++i)
183 {
184 // Load the value
185 float x = values[i];
186
187 // Sanitize the value
188 x = clampSafe(x, 0.f, 1.f);
189
190 // Store the value
191 store(h, w, c, x);
192 }
193 }
194
195 // Stores a normal
196 __forceinline void storeNormal(int h, int w, int& c, const float* values)
197 {
198 // Load the normal
199 float x = values[0];
200 float y = values[1];
201 float z = values[2];
202
203 // Compute the length of the normal
204 const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
205
206 // Normalize the normal and transform it to [0..1]
207 if (isfinite(lengthSqr))
208 {
209 const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
210
211 const float scale = invLength * 0.5f;
212 const float offset = 0.5f;
213
214 x = x * scale + offset;
215 y = y * scale + offset;
216 z = z * scale + offset;
217 }
218 else
219 {
220 x = 0.f;
221 y = 0.f;
222 z = 0.f;
223 }
224
225 // Store the normal
226 store(h, w, c, x);
227 store(h, w, c, y);
228 store(h, w, c, z);
229 }
230 };
231
232} // namespace oidn
233