1 | // ======================================================================== // |
2 | // Copyright 2009-2019 Intel Corporation // |
3 | // // |
4 | // Licensed under the Apache License, Version 2.0 (the "License"); // |
5 | // you may not use this file except in compliance with the License. // |
6 | // You may obtain a copy of the License at // |
7 | // // |
8 | // http://www.apache.org/licenses/LICENSE-2.0 // |
9 | // // |
10 | // Unless required by applicable law or agreed to in writing, software // |
11 | // distributed under the License is distributed on an "AS IS" BASIS, // |
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // |
13 | // See the License for the specific language governing permissions and // |
14 | // limitations under the License. // |
15 | // ======================================================================== // |
16 | |
17 | #include "autoencoder.h" |
18 | |
19 | namespace oidn { |
20 | |
21 | // -------------------------------------------------------------------------- |
22 | // AutoencoderFilter |
23 | // -------------------------------------------------------------------------- |
24 | |
25 | AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device) |
26 | : Filter(device) |
27 | { |
28 | } |
29 | |
30 | void AutoencoderFilter::setImage(const std::string& name, const Image& data) |
31 | { |
32 | if (name == "color" ) |
33 | color = data; |
34 | else if (name == "albedo" ) |
35 | albedo = data; |
36 | else if (name == "normal" ) |
37 | normal = data; |
38 | else if (name == "output" ) |
39 | output = data; |
40 | |
41 | dirty = true; |
42 | } |
43 | |
44 | void AutoencoderFilter::set1i(const std::string& name, int value) |
45 | { |
46 | if (name == "hdr" ) |
47 | hdr = value; |
48 | else if (name == "srgb" ) |
49 | srgb = value; |
50 | else if (name == "maxMemoryMB" ) |
51 | maxMemoryMB = value; |
52 | |
53 | dirty = true; |
54 | } |
55 | |
56 | int AutoencoderFilter::get1i(const std::string& name) |
57 | { |
58 | if (name == "hdr" ) |
59 | return hdr; |
60 | else if (name == "srgb" ) |
61 | return srgb; |
62 | else if (name == "maxMemoryMB" ) |
63 | return maxMemoryMB; |
64 | else if (name == "alignment" ) |
65 | return alignment; |
66 | else if (name == "overlap" ) |
67 | return overlap; |
68 | else |
69 | throw Exception(Error::InvalidArgument, "invalid parameter" ); |
70 | } |
71 | |
72 | void AutoencoderFilter::set1f(const std::string& name, float value) |
73 | { |
74 | if (name == "hdrScale" ) |
75 | hdrScale = value; |
76 | |
77 | dirty = true; |
78 | } |
79 | |
80 | float AutoencoderFilter::get1f(const std::string& name) |
81 | { |
82 | if (name == "hdrScale" ) |
83 | return hdrScale; |
84 | else |
85 | throw Exception(Error::InvalidArgument, "invalid parameter" ); |
86 | } |
87 | |
88 | void AutoencoderFilter::commit() |
89 | { |
90 | if (!dirty) |
91 | return; |
92 | |
93 | // -- GODOT start -- |
94 | //device->executeTask([&]() |
95 | //{ |
96 | // GODOT end -- |
97 | |
98 | if (mayiuse(avx512_common)) |
99 | net = buildNet<16>(); |
100 | else |
101 | net = buildNet<8>(); |
102 | |
103 | // GODOT start -- |
104 | //}); |
105 | // GODOT end -- |
106 | |
107 | dirty = false; |
108 | } |
109 | |
110 | void AutoencoderFilter::execute() |
111 | { |
112 | if (dirty) |
113 | throw Exception(Error::InvalidOperation, "changes to the filter are not committed" ); |
114 | |
115 | if (!net) |
116 | return; |
117 | // -- GODOT start -- |
118 | //device->executeTask([&]() |
119 | //{ |
120 | // -- GODOT end -- |
121 | Progress progress; |
122 | progress.func = progressFunc; |
123 | progress.userPtr = progressUserPtr; |
124 | progress.taskCount = tileCountH * tileCountW; |
125 | |
126 | // Iterate over the tiles |
127 | int tileIndex = 0; |
128 | |
129 | for (int i = 0; i < tileCountH; ++i) |
130 | { |
131 | const int h = i * (tileH - 2*overlap); // input tile position (including overlap) |
132 | const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top |
133 | const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom |
134 | const int tileH1 = min(H - h, tileH); // input tile size (including overlap) |
135 | const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size |
136 | const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer |
137 | |
138 | for (int j = 0; j < tileCountW; ++j) |
139 | { |
140 | const int w = j * (tileW - 2*overlap); // input tile position (including overlap) |
141 | const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left |
142 | const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right |
143 | const int tileW1 = min(W - w, tileW); // input tile size (including overlap) |
144 | const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size |
145 | const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer |
146 | |
147 | // Set the input tile |
148 | inputReorder->setTile(h, w, |
149 | alignOffsetH, alignOffsetW, |
150 | tileH1, tileW1); |
151 | |
152 | // Set the output tile |
153 | outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW, |
154 | h + overlapBeginH, w + overlapBeginW, |
155 | tileH2, tileW2); |
156 | |
157 | //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2); |
158 | |
159 | // Denoise the tile |
160 | net->execute(progress, tileIndex); |
161 | |
162 | // Next tile |
163 | tileIndex++; |
164 | } |
165 | } |
166 | // -- GODOT start -- |
167 | //}); |
168 | // -- GODOT end -- |
169 | } |
170 | |
171 | void AutoencoderFilter::computeTileSize() |
172 | { |
173 | const int minTileSize = 3*overlap; |
174 | const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8; |
175 | const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel; |
176 | |
177 | tileCountH = 1; |
178 | tileCountW = 1; |
179 | tileH = roundUp(H, alignment); |
180 | tileW = roundUp(W, alignment); |
181 | |
182 | // Divide the image into tiles until the tile size gets below the threshold |
183 | while (int64_t(tileH) * tileW > maxTilePixels) |
184 | { |
185 | if (tileH > minTileSize && tileH > tileW) |
186 | { |
187 | tileCountH++; |
188 | tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize); |
189 | } |
190 | else if (tileW > minTileSize) |
191 | { |
192 | tileCountW++; |
193 | tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize); |
194 | } |
195 | else |
196 | break; |
197 | } |
198 | |
199 | // Compute the final number of tiles |
200 | tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1; |
201 | tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1; |
202 | |
203 | if (device->isVerbose(2)) |
204 | { |
205 | std::cout << "Tile size : " << tileW << "x" << tileH << std::endl; |
206 | std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl; |
207 | } |
208 | } |
209 | |
210 | template<int K> |
211 | std::shared_ptr<Executable> AutoencoderFilter::buildNet() |
212 | { |
213 | H = color.height; |
214 | W = color.width; |
215 | |
216 | // Configure the network |
217 | int inputC; |
218 | void* weightPtr; |
219 | |
220 | if (srgb && hdr) |
221 | throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time" ); |
222 | |
223 | if (color && !albedo && !normal && weightData.hdr) |
224 | { |
225 | inputC = 3; |
226 | weightPtr = hdr ? weightData.hdr : weightData.ldr; |
227 | } |
228 | else if (color && albedo && !normal && weightData.hdr_alb) |
229 | { |
230 | inputC = 6; |
231 | weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb; |
232 | } |
233 | else if (color && albedo && normal && weightData.hdr_alb_nrm) |
234 | { |
235 | inputC = 9; |
236 | weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm; |
237 | } |
238 | else |
239 | { |
240 | throw Exception(Error::InvalidOperation, "unsupported combination of input features" ); |
241 | } |
242 | |
243 | if (!output) |
244 | throw Exception(Error::InvalidOperation, "output image not specified" ); |
245 | |
246 | if ((color.format != Format::Float3) |
247 | || (albedo && albedo.format != Format::Float3) |
248 | || (normal && normal.format != Format::Float3) |
249 | || (output.format != Format::Float3)) |
250 | throw Exception(Error::InvalidOperation, "unsupported image format" ); |
251 | |
252 | if ((albedo && (albedo.width != W || albedo.height != H)) |
253 | || (normal && (normal.width != W || normal.height != H)) |
254 | || (output.width != W || output.height != H)) |
255 | throw Exception(Error::InvalidOperation, "image size mismatch" ); |
256 | |
257 | // Compute the tile size |
258 | computeTileSize(); |
259 | |
260 | // If the image size is zero, there is nothing else to do |
261 | if (H <= 0 || W <= 0) |
262 | return nullptr; |
263 | |
264 | // Parse the weights |
265 | const auto weightMap = parseTensors(weightPtr); |
266 | |
267 | // Create the network |
268 | std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap); |
269 | |
270 | // Compute the tensor sizes |
271 | const auto inputDims = memory::dims({1, inputC, tileH, tileW}); |
272 | const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0 |
273 | |
274 | const auto conv1Dims = net->getConvDims("conv1" , inputReorderDims); //-> temp0 |
275 | const auto conv1bDims = net->getConvDims("conv1b" , conv1Dims); //-> temp1 |
276 | const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1 |
277 | const auto conv2Dims = net->getConvDims("conv2" , pool1Dims); //-> temp0 |
278 | const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2 |
279 | const auto conv3Dims = net->getConvDims("conv3" , pool2Dims); //-> temp0 |
280 | const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3 |
281 | const auto conv4Dims = net->getConvDims("conv4" , pool3Dims); //-> temp0 |
282 | const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4 |
283 | const auto conv5Dims = net->getConvDims("conv5" , pool4Dims); //-> temp0 |
284 | const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1 |
285 | const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4 |
286 | const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims); |
287 | const auto conv6Dims = net->getConvDims("conv6" , concat4Dims); //-> temp0 |
288 | const auto conv6bDims = net->getConvDims("conv6b" , conv6Dims); //-> temp1 |
289 | const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3 |
290 | const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims); |
291 | const auto conv7Dims = net->getConvDims("conv7" , concat3Dims); //-> temp0 |
292 | const auto conv7bDims = net->getConvDims("conv7b" , conv7Dims); //-> temp1 |
293 | const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2 |
294 | const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims); |
295 | const auto conv8Dims = net->getConvDims("conv8" , concat2Dims); //-> temp0 |
296 | const auto conv8bDims = net->getConvDims("conv8b" , conv8Dims); //-> temp1 |
297 | const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1 |
298 | const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims); |
299 | const auto conv9Dims = net->getConvDims("conv9" , concat1Dims); //-> temp0 |
300 | const auto conv9bDims = net->getConvDims("conv9b" , conv9Dims); //-> temp1 |
301 | const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0 |
302 | const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims); |
303 | const auto conv10Dims = net->getConvDims("conv10" , concat0Dims); //-> temp0 |
304 | const auto conv10bDims = net->getConvDims("conv10b" , conv10Dims); //-> temp1 |
305 | const auto conv11Dims = net->getConvDims("conv11" , conv10bDims); //-> temp0 |
306 | |
307 | const auto outputDims = memory::dims({1, 3, tileH, tileW}); |
308 | |
309 | // Allocate two temporary ping-pong buffers to decrease memory usage |
310 | const auto temp0Dims = getMaxTensorDims({ |
311 | conv1Dims, |
312 | conv2Dims, |
313 | conv3Dims, |
314 | conv4Dims, |
315 | conv5Dims, |
316 | conv6Dims, |
317 | conv7Dims, |
318 | conv8Dims, |
319 | conv9Dims, |
320 | conv10Dims, |
321 | conv11Dims |
322 | }); |
323 | |
324 | const auto temp1Dims = getMaxTensorDims({ |
325 | conv1bDims, |
326 | pool5Dims, |
327 | conv6bDims, |
328 | conv7bDims, |
329 | conv8bDims, |
330 | conv9bDims, |
331 | conv10bDims, |
332 | }); |
333 | |
334 | auto temp0 = net->allocTensor(temp0Dims); |
335 | auto temp1 = net->allocTensor(temp1Dims); |
336 | |
337 | // Allocate enough memory to hold the concat outputs. Then use the first |
338 | // half to hold the previous conv output and the second half to hold the |
339 | // pool/orig image output. This works because everything is C dimension |
340 | // outermost, padded to K floats, and all the concats are on the C dimension. |
341 | auto concat0Dst = net->allocTensor(concat0Dims); |
342 | auto concat1Dst = net->allocTensor(concat1Dims); |
343 | auto concat2Dst = net->allocTensor(concat2Dims); |
344 | auto concat3Dst = net->allocTensor(concat3Dims); |
345 | auto concat4Dst = net->allocTensor(concat4Dims); |
346 | |
347 | // Transfer function |
348 | std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc(); |
349 | |
350 | // Autoexposure |
351 | if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc)) |
352 | { |
353 | if (isnan(hdrScale)) |
354 | net->addAutoexposure(color, tf); |
355 | else |
356 | tf->setExposure(hdrScale); |
357 | } |
358 | |
359 | // Input reorder |
360 | auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims); |
361 | inputReorder = net->addInputReorder(color, albedo, normal, |
362 | transferFunc, |
363 | alignment, inputReorderDst); |
364 | |
365 | // conv1 |
366 | auto conv1 = net->addConv("conv1" , inputReorder->getDst(), temp0); |
367 | |
368 | // conv1b |
369 | auto conv1b = net->addConv("conv1b" , conv1->getDst(), temp1); |
370 | |
371 | // pool1 |
372 | // Adjust pointer for pool1 to eliminate concat1 |
373 | auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims); |
374 | auto pool1 = net->addPool(conv1b->getDst(), pool1Dst); |
375 | |
376 | // conv2 |
377 | auto conv2 = net->addConv("conv2" , pool1->getDst(), temp0); |
378 | |
379 | // pool2 |
380 | // Adjust pointer for pool2 to eliminate concat2 |
381 | auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims); |
382 | auto pool2 = net->addPool(conv2->getDst(), pool2Dst); |
383 | |
384 | // conv3 |
385 | auto conv3 = net->addConv("conv3" , pool2->getDst(), temp0); |
386 | |
387 | // pool3 |
388 | // Adjust pointer for pool3 to eliminate concat3 |
389 | auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims); |
390 | auto pool3 = net->addPool(conv3->getDst(), pool3Dst); |
391 | |
392 | // conv4 |
393 | auto conv4 = net->addConv("conv4" , pool3->getDst(), temp0); |
394 | |
395 | // pool4 |
396 | // Adjust pointer for pool4 to eliminate concat4 |
397 | auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims); |
398 | auto pool4 = net->addPool(conv4->getDst(), pool4Dst); |
399 | |
400 | // conv5 |
401 | auto conv5 = net->addConv("conv5" , pool4->getDst(), temp0); |
402 | |
403 | // pool5 |
404 | auto pool5 = net->addPool(conv5->getDst(), temp1); |
405 | |
406 | // upsample4 |
407 | auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst); |
408 | auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst); |
409 | |
410 | // conv6 |
411 | auto conv6 = net->addConv("conv6" , concat4Dst, temp0); |
412 | |
413 | // conv6b |
414 | auto conv6b = net->addConv("conv6b" , conv6->getDst(), temp1); |
415 | |
416 | // upsample3 |
417 | auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst); |
418 | auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst); |
419 | |
420 | // conv7 |
421 | auto conv7 = net->addConv("conv7" , concat3Dst, temp0); |
422 | |
423 | // conv7b |
424 | auto conv7b = net->addConv("conv7b" , conv7->getDst(), temp1); |
425 | |
426 | // upsample2 |
427 | auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst); |
428 | auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst); |
429 | |
430 | // conv8 |
431 | auto conv8 = net->addConv("conv8" , concat2Dst, temp0); |
432 | |
433 | // conv8b |
434 | auto conv8b = net->addConv("conv8b" , conv8->getDst(), temp1); |
435 | |
436 | // upsample1 |
437 | auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst); |
438 | auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst); |
439 | |
440 | // conv9 |
441 | auto conv9 = net->addConv("conv9" , concat1Dst, temp0); |
442 | |
443 | // conv9b |
444 | auto conv9b = net->addConv("conv9b" , conv9->getDst(), temp1); |
445 | |
446 | // upsample0 |
447 | auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst); |
448 | auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst); |
449 | |
450 | // conv10 |
451 | auto conv10 = net->addConv("conv10" , concat0Dst, temp0); |
452 | |
453 | // conv10b |
454 | auto conv10b = net->addConv("conv10b" , conv10->getDst(), temp1); |
455 | |
456 | // conv11 |
457 | auto conv11 = net->addConv("conv11" , conv10b->getDst(), temp0, false /* no relu */); |
458 | |
459 | // Output reorder |
460 | outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output); |
461 | |
462 | net->finalize(); |
463 | return net; |
464 | } |
465 | |
466 | std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc() |
467 | { |
468 | if (hdr) |
469 | return std::make_shared<PQXTransferFunction>(); |
470 | else if (srgb) |
471 | return std::make_shared<LinearTransferFunction>(); |
472 | else |
473 | return std::make_shared<GammaTransferFunction>(); |
474 | } |
475 | |
476 | // -- GODOT start -- |
477 | // Godot doesn't need Raytracing filters. Removing them saves space in the weights files. |
478 | #if 0 |
479 | // -- GODOT end -- |
480 | |
481 | // -------------------------------------------------------------------------- |
482 | // RTFilter |
483 | // -------------------------------------------------------------------------- |
484 | |
485 | namespace weights |
486 | { |
487 | // LDR |
488 | extern unsigned char rt_ldr[]; // color |
489 | extern unsigned char rt_ldr_alb[]; // color, albedo |
490 | extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal |
491 | |
492 | // HDR |
493 | extern unsigned char rt_hdr[]; // color |
494 | extern unsigned char rt_hdr_alb[]; // color, albedo |
495 | extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal |
496 | } |
497 | |
498 | RTFilter::RTFilter(const Ref<Device>& device) |
499 | : AutoencoderFilter(device) |
500 | { |
501 | weightData.ldr = weights::rt_ldr; |
502 | weightData.ldr_alb = weights::rt_ldr_alb; |
503 | weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm; |
504 | weightData.hdr = weights::rt_hdr; |
505 | weightData.hdr_alb = weights::rt_hdr_alb; |
506 | weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; |
507 | } |
508 | // -- GODOT start -- |
509 | #endif |
510 | // -- GODOT end -- |
511 | |
512 | // -------------------------------------------------------------------------- |
513 | // RTLightmapFilter |
514 | // -------------------------------------------------------------------------- |
515 | |
516 | namespace weights |
517 | { |
518 | // HDR |
519 | extern unsigned char rtlightmap_hdr[]; // color |
520 | } |
521 | |
522 | RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device) |
523 | : AutoencoderFilter(device) |
524 | { |
525 | weightData.hdr = weights::rtlightmap_hdr; |
526 | |
527 | hdr = true; |
528 | } |
529 | |
530 | std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc() |
531 | { |
532 | return std::make_shared<LogTransferFunction>(); |
533 | } |
534 | |
535 | } // namespace oidn |
536 | |