1// ======================================================================== //
2// Copyright 2009-2019 Intel Corporation //
3// //
4// Licensed under the Apache License, Version 2.0 (the "License"); //
5// you may not use this file except in compliance with the License. //
6// You may obtain a copy of the License at //
7// //
8// http://www.apache.org/licenses/LICENSE-2.0 //
9// //
10// Unless required by applicable law or agreed to in writing, software //
11// distributed under the License is distributed on an "AS IS" BASIS, //
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
13// See the License for the specific language governing permissions and //
14// limitations under the License. //
15// ======================================================================== //
16
17#include "autoencoder.h"
18
19namespace oidn {
20
21 // --------------------------------------------------------------------------
22 // AutoencoderFilter
23 // --------------------------------------------------------------------------
24
25 AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device)
26 : Filter(device)
27 {
28 }
29
30 void AutoencoderFilter::setImage(const std::string& name, const Image& data)
31 {
32 if (name == "color")
33 color = data;
34 else if (name == "albedo")
35 albedo = data;
36 else if (name == "normal")
37 normal = data;
38 else if (name == "output")
39 output = data;
40
41 dirty = true;
42 }
43
44 void AutoencoderFilter::set1i(const std::string& name, int value)
45 {
46 if (name == "hdr")
47 hdr = value;
48 else if (name == "srgb")
49 srgb = value;
50 else if (name == "maxMemoryMB")
51 maxMemoryMB = value;
52
53 dirty = true;
54 }
55
56 int AutoencoderFilter::get1i(const std::string& name)
57 {
58 if (name == "hdr")
59 return hdr;
60 else if (name == "srgb")
61 return srgb;
62 else if (name == "maxMemoryMB")
63 return maxMemoryMB;
64 else if (name == "alignment")
65 return alignment;
66 else if (name == "overlap")
67 return overlap;
68 else
69 throw Exception(Error::InvalidArgument, "invalid parameter");
70 }
71
72 void AutoencoderFilter::set1f(const std::string& name, float value)
73 {
74 if (name == "hdrScale")
75 hdrScale = value;
76
77 dirty = true;
78 }
79
80 float AutoencoderFilter::get1f(const std::string& name)
81 {
82 if (name == "hdrScale")
83 return hdrScale;
84 else
85 throw Exception(Error::InvalidArgument, "invalid parameter");
86 }
87
88 void AutoencoderFilter::commit()
89 {
90 if (!dirty)
91 return;
92
93 // -- GODOT start --
94 //device->executeTask([&]()
95 //{
96 // GODOT end --
97
98 if (mayiuse(avx512_common))
99 net = buildNet<16>();
100 else
101 net = buildNet<8>();
102
103 // GODOT start --
104 //});
105 // GODOT end --
106
107 dirty = false;
108 }
109
110 void AutoencoderFilter::execute()
111 {
112 if (dirty)
113 throw Exception(Error::InvalidOperation, "changes to the filter are not committed");
114
115 if (!net)
116 return;
117 // -- GODOT start --
118 //device->executeTask([&]()
119 //{
120 // -- GODOT end --
121 Progress progress;
122 progress.func = progressFunc;
123 progress.userPtr = progressUserPtr;
124 progress.taskCount = tileCountH * tileCountW;
125
126 // Iterate over the tiles
127 int tileIndex = 0;
128
129 for (int i = 0; i < tileCountH; ++i)
130 {
131 const int h = i * (tileH - 2*overlap); // input tile position (including overlap)
132 const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top
133 const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom
134 const int tileH1 = min(H - h, tileH); // input tile size (including overlap)
135 const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size
136 const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer
137
138 for (int j = 0; j < tileCountW; ++j)
139 {
140 const int w = j * (tileW - 2*overlap); // input tile position (including overlap)
141 const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left
142 const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right
143 const int tileW1 = min(W - w, tileW); // input tile size (including overlap)
144 const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size
145 const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer
146
147 // Set the input tile
148 inputReorder->setTile(h, w,
149 alignOffsetH, alignOffsetW,
150 tileH1, tileW1);
151
152 // Set the output tile
153 outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW,
154 h + overlapBeginH, w + overlapBeginW,
155 tileH2, tileW2);
156
157 //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2);
158
159 // Denoise the tile
160 net->execute(progress, tileIndex);
161
162 // Next tile
163 tileIndex++;
164 }
165 }
166 // -- GODOT start --
167 //});
168 // -- GODOT end --
169 }
170
171 void AutoencoderFilter::computeTileSize()
172 {
173 const int minTileSize = 3*overlap;
174 const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8;
175 const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel;
176
177 tileCountH = 1;
178 tileCountW = 1;
179 tileH = roundUp(H, alignment);
180 tileW = roundUp(W, alignment);
181
182 // Divide the image into tiles until the tile size gets below the threshold
183 while (int64_t(tileH) * tileW > maxTilePixels)
184 {
185 if (tileH > minTileSize && tileH > tileW)
186 {
187 tileCountH++;
188 tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize);
189 }
190 else if (tileW > minTileSize)
191 {
192 tileCountW++;
193 tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize);
194 }
195 else
196 break;
197 }
198
199 // Compute the final number of tiles
200 tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1;
201 tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1;
202
203 if (device->isVerbose(2))
204 {
205 std::cout << "Tile size : " << tileW << "x" << tileH << std::endl;
206 std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl;
207 }
208 }
209
210 template<int K>
211 std::shared_ptr<Executable> AutoencoderFilter::buildNet()
212 {
213 H = color.height;
214 W = color.width;
215
216 // Configure the network
217 int inputC;
218 void* weightPtr;
219
220 if (srgb && hdr)
221 throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time");
222
223 if (color && !albedo && !normal && weightData.hdr)
224 {
225 inputC = 3;
226 weightPtr = hdr ? weightData.hdr : weightData.ldr;
227 }
228 else if (color && albedo && !normal && weightData.hdr_alb)
229 {
230 inputC = 6;
231 weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb;
232 }
233 else if (color && albedo && normal && weightData.hdr_alb_nrm)
234 {
235 inputC = 9;
236 weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm;
237 }
238 else
239 {
240 throw Exception(Error::InvalidOperation, "unsupported combination of input features");
241 }
242
243 if (!output)
244 throw Exception(Error::InvalidOperation, "output image not specified");
245
246 if ((color.format != Format::Float3)
247 || (albedo && albedo.format != Format::Float3)
248 || (normal && normal.format != Format::Float3)
249 || (output.format != Format::Float3))
250 throw Exception(Error::InvalidOperation, "unsupported image format");
251
252 if ((albedo && (albedo.width != W || albedo.height != H))
253 || (normal && (normal.width != W || normal.height != H))
254 || (output.width != W || output.height != H))
255 throw Exception(Error::InvalidOperation, "image size mismatch");
256
257 // Compute the tile size
258 computeTileSize();
259
260 // If the image size is zero, there is nothing else to do
261 if (H <= 0 || W <= 0)
262 return nullptr;
263
264 // Parse the weights
265 const auto weightMap = parseTensors(weightPtr);
266
267 // Create the network
268 std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap);
269
270 // Compute the tensor sizes
271 const auto inputDims = memory::dims({1, inputC, tileH, tileW});
272 const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0
273
274 const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0
275 const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1
276 const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1
277 const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0
278 const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2
279 const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0
280 const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3
281 const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0
282 const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4
283 const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0
284 const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1
285 const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4
286 const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims);
287 const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0
288 const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1
289 const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3
290 const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims);
291 const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0
292 const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1
293 const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2
294 const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims);
295 const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0
296 const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1
297 const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1
298 const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims);
299 const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0
300 const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1
301 const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0
302 const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims);
303 const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0
304 const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1
305 const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0
306
307 const auto outputDims = memory::dims({1, 3, tileH, tileW});
308
309 // Allocate two temporary ping-pong buffers to decrease memory usage
310 const auto temp0Dims = getMaxTensorDims({
311 conv1Dims,
312 conv2Dims,
313 conv3Dims,
314 conv4Dims,
315 conv5Dims,
316 conv6Dims,
317 conv7Dims,
318 conv8Dims,
319 conv9Dims,
320 conv10Dims,
321 conv11Dims
322 });
323
324 const auto temp1Dims = getMaxTensorDims({
325 conv1bDims,
326 pool5Dims,
327 conv6bDims,
328 conv7bDims,
329 conv8bDims,
330 conv9bDims,
331 conv10bDims,
332 });
333
334 auto temp0 = net->allocTensor(temp0Dims);
335 auto temp1 = net->allocTensor(temp1Dims);
336
337 // Allocate enough memory to hold the concat outputs. Then use the first
338 // half to hold the previous conv output and the second half to hold the
339 // pool/orig image output. This works because everything is C dimension
340 // outermost, padded to K floats, and all the concats are on the C dimension.
341 auto concat0Dst = net->allocTensor(concat0Dims);
342 auto concat1Dst = net->allocTensor(concat1Dims);
343 auto concat2Dst = net->allocTensor(concat2Dims);
344 auto concat3Dst = net->allocTensor(concat3Dims);
345 auto concat4Dst = net->allocTensor(concat4Dims);
346
347 // Transfer function
348 std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc();
349
350 // Autoexposure
351 if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc))
352 {
353 if (isnan(hdrScale))
354 net->addAutoexposure(color, tf);
355 else
356 tf->setExposure(hdrScale);
357 }
358
359 // Input reorder
360 auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims);
361 inputReorder = net->addInputReorder(color, albedo, normal,
362 transferFunc,
363 alignment, inputReorderDst);
364
365 // conv1
366 auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0);
367
368 // conv1b
369 auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1);
370
371 // pool1
372 // Adjust pointer for pool1 to eliminate concat1
373 auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims);
374 auto pool1 = net->addPool(conv1b->getDst(), pool1Dst);
375
376 // conv2
377 auto conv2 = net->addConv("conv2", pool1->getDst(), temp0);
378
379 // pool2
380 // Adjust pointer for pool2 to eliminate concat2
381 auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims);
382 auto pool2 = net->addPool(conv2->getDst(), pool2Dst);
383
384 // conv3
385 auto conv3 = net->addConv("conv3", pool2->getDst(), temp0);
386
387 // pool3
388 // Adjust pointer for pool3 to eliminate concat3
389 auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims);
390 auto pool3 = net->addPool(conv3->getDst(), pool3Dst);
391
392 // conv4
393 auto conv4 = net->addConv("conv4", pool3->getDst(), temp0);
394
395 // pool4
396 // Adjust pointer for pool4 to eliminate concat4
397 auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims);
398 auto pool4 = net->addPool(conv4->getDst(), pool4Dst);
399
400 // conv5
401 auto conv5 = net->addConv("conv5", pool4->getDst(), temp0);
402
403 // pool5
404 auto pool5 = net->addPool(conv5->getDst(), temp1);
405
406 // upsample4
407 auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst);
408 auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst);
409
410 // conv6
411 auto conv6 = net->addConv("conv6", concat4Dst, temp0);
412
413 // conv6b
414 auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1);
415
416 // upsample3
417 auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst);
418 auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst);
419
420 // conv7
421 auto conv7 = net->addConv("conv7", concat3Dst, temp0);
422
423 // conv7b
424 auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1);
425
426 // upsample2
427 auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst);
428 auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst);
429
430 // conv8
431 auto conv8 = net->addConv("conv8", concat2Dst, temp0);
432
433 // conv8b
434 auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1);
435
436 // upsample1
437 auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst);
438 auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst);
439
440 // conv9
441 auto conv9 = net->addConv("conv9", concat1Dst, temp0);
442
443 // conv9b
444 auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1);
445
446 // upsample0
447 auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst);
448 auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst);
449
450 // conv10
451 auto conv10 = net->addConv("conv10", concat0Dst, temp0);
452
453 // conv10b
454 auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1);
455
456 // conv11
457 auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */);
458
459 // Output reorder
460 outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output);
461
462 net->finalize();
463 return net;
464 }
465
466 std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc()
467 {
468 if (hdr)
469 return std::make_shared<PQXTransferFunction>();
470 else if (srgb)
471 return std::make_shared<LinearTransferFunction>();
472 else
473 return std::make_shared<GammaTransferFunction>();
474 }
475
476// -- GODOT start --
477// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
478#if 0
479// -- GODOT end --
480
481 // --------------------------------------------------------------------------
482 // RTFilter
483 // --------------------------------------------------------------------------
484
485 namespace weights
486 {
487 // LDR
488 extern unsigned char rt_ldr[]; // color
489 extern unsigned char rt_ldr_alb[]; // color, albedo
490 extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal
491
492 // HDR
493 extern unsigned char rt_hdr[]; // color
494 extern unsigned char rt_hdr_alb[]; // color, albedo
495 extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal
496 }
497
498 RTFilter::RTFilter(const Ref<Device>& device)
499 : AutoencoderFilter(device)
500 {
501 weightData.ldr = weights::rt_ldr;
502 weightData.ldr_alb = weights::rt_ldr_alb;
503 weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm;
504 weightData.hdr = weights::rt_hdr;
505 weightData.hdr_alb = weights::rt_hdr_alb;
506 weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
507 }
508// -- GODOT start --
509#endif
510// -- GODOT end --
511
512 // --------------------------------------------------------------------------
513 // RTLightmapFilter
514 // --------------------------------------------------------------------------
515
516 namespace weights
517 {
518 // HDR
519 extern unsigned char rtlightmap_hdr[]; // color
520 }
521
522 RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device)
523 : AutoencoderFilter(device)
524 {
525 weightData.hdr = weights::rtlightmap_hdr;
526
527 hdr = true;
528 }
529
530 std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc()
531 {
532 return std::make_shared<LogTransferFunction>();
533 }
534
535} // namespace oidn
536