1 | // ======================================================================== // |
2 | // Copyright 2009-2019 Intel Corporation // |
3 | // // |
4 | // Licensed under the Apache License, Version 2.0 (the "License"); // |
5 | // you may not use this file except in compliance with the License. // |
6 | // You may obtain a copy of the License at // |
7 | // // |
8 | // http://www.apache.org/licenses/LICENSE-2.0 // |
9 | // // |
10 | // Unless required by applicable law or agreed to in writing, software // |
11 | // distributed under the License is distributed on an "AS IS" BASIS, // |
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // |
13 | // See the License for the specific language governing permissions and // |
14 | // limitations under the License. // |
15 | // ======================================================================== // |
16 | |
17 | #ifdef _WIN32 |
18 | # define OIDN_API extern "C" __declspec(dllexport) |
19 | #else |
20 | # define OIDN_API extern "C" __attribute__ ((visibility ("default"))) |
21 | #endif |
22 | |
23 | // Locks the device that owns the specified object |
24 | // Use *only* inside OIDN_TRY/CATCH! |
25 | #define OIDN_LOCK(obj) \ |
26 | std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex()); |
27 | |
28 | // Try/catch for converting exceptions to errors |
29 | #define OIDN_TRY \ |
30 | try { |
31 | |
32 | #define OIDN_CATCH(obj) \ |
33 | } catch (Exception& e) { \ |
34 | Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \ |
35 | } catch (std::bad_alloc&) { \ |
36 | Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ |
37 | } catch (mkldnn::error& e) { \ |
38 | if (e.status == mkldnn_out_of_memory) \ |
39 | Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ |
40 | else \ |
41 | Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \ |
42 | } catch (std::exception& e) { \ |
43 | Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \ |
44 | } catch (...) { \ |
45 | Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \ |
46 | } |
47 | |
48 | #include "device.h" |
49 | #include "filter.h" |
50 | #include <mutex> |
51 | |
52 | namespace oidn { |
53 | |
54 | namespace |
55 | { |
56 | __forceinline void checkHandle(void* handle) |
57 | { |
58 | if (handle == nullptr) |
59 | throw Exception(Error::InvalidArgument, "invalid handle" ); |
60 | } |
61 | |
62 | template<typename T> |
63 | __forceinline void retainObject(T* obj) |
64 | { |
65 | if (obj) |
66 | { |
67 | obj->incRef(); |
68 | } |
69 | else |
70 | { |
71 | OIDN_TRY |
72 | checkHandle(obj); |
73 | OIDN_CATCH(obj) |
74 | } |
75 | } |
76 | |
77 | template<typename T> |
78 | __forceinline void releaseObject(T* obj) |
79 | { |
80 | if (obj == nullptr || obj->decRefKeep() == 0) |
81 | { |
82 | OIDN_TRY |
83 | checkHandle(obj); |
84 | OIDN_LOCK(obj); |
85 | obj->destroy(); |
86 | OIDN_CATCH(obj) |
87 | } |
88 | } |
89 | |
90 | template<> |
91 | __forceinline void releaseObject(Device* obj) |
92 | { |
93 | if (obj == nullptr || obj->decRefKeep() == 0) |
94 | { |
95 | OIDN_TRY |
96 | checkHandle(obj); |
97 | // Do NOT lock the device because it owns the mutex |
98 | obj->destroy(); |
99 | OIDN_CATCH(obj) |
100 | } |
101 | } |
102 | } |
103 | |
104 | OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type) |
105 | { |
106 | Ref<Device> device = nullptr; |
107 | OIDN_TRY |
108 | if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT) |
109 | device = makeRef<Device>(); |
110 | else |
111 | throw Exception(Error::InvalidArgument, "invalid device type" ); |
112 | OIDN_CATCH(device) |
113 | return (OIDNDevice)device.detach(); |
114 | } |
115 | |
116 | OIDN_API void oidnRetainDevice(OIDNDevice hDevice) |
117 | { |
118 | Device* device = (Device*)hDevice; |
119 | retainObject(device); |
120 | } |
121 | |
122 | OIDN_API void oidnReleaseDevice(OIDNDevice hDevice) |
123 | { |
124 | Device* device = (Device*)hDevice; |
125 | releaseObject(device); |
126 | } |
127 | |
128 | OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value) |
129 | { |
130 | Device* device = (Device*)hDevice; |
131 | OIDN_TRY |
132 | checkHandle(hDevice); |
133 | OIDN_LOCK(device); |
134 | device->set1i(name, value); |
135 | OIDN_CATCH(device) |
136 | } |
137 | |
138 | OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value) |
139 | { |
140 | Device* device = (Device*)hDevice; |
141 | OIDN_TRY |
142 | checkHandle(hDevice); |
143 | OIDN_LOCK(device); |
144 | device->set1i(name, value); |
145 | OIDN_CATCH(device) |
146 | } |
147 | |
148 | OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name) |
149 | { |
150 | Device* device = (Device*)hDevice; |
151 | OIDN_TRY |
152 | checkHandle(hDevice); |
153 | OIDN_LOCK(device); |
154 | return device->get1i(name); |
155 | OIDN_CATCH(device) |
156 | return false; |
157 | } |
158 | |
159 | OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name) |
160 | { |
161 | Device* device = (Device*)hDevice; |
162 | OIDN_TRY |
163 | checkHandle(hDevice); |
164 | OIDN_LOCK(device); |
165 | return device->get1i(name); |
166 | OIDN_CATCH(device) |
167 | return 0; |
168 | } |
169 | |
170 | OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr) |
171 | { |
172 | Device* device = (Device*)hDevice; |
173 | OIDN_TRY |
174 | checkHandle(hDevice); |
175 | OIDN_LOCK(device); |
176 | device->setErrorFunction((ErrorFunction)func, userPtr); |
177 | OIDN_CATCH(device) |
178 | } |
179 | |
180 | OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage) |
181 | { |
182 | Device* device = (Device*)hDevice; |
183 | OIDN_TRY |
184 | return (OIDNError)Device::getError(device, outMessage); |
185 | OIDN_CATCH(device) |
186 | if (outMessage) *outMessage = "" ; |
187 | return OIDN_ERROR_UNKNOWN; |
188 | } |
189 | |
190 | OIDN_API void oidnCommitDevice(OIDNDevice hDevice) |
191 | { |
192 | Device* device = (Device*)hDevice; |
193 | OIDN_TRY |
194 | checkHandle(hDevice); |
195 | OIDN_LOCK(device); |
196 | device->commit(); |
197 | OIDN_CATCH(device) |
198 | } |
199 | |
200 | OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize) |
201 | { |
202 | Device* device = (Device*)hDevice; |
203 | OIDN_TRY |
204 | checkHandle(hDevice); |
205 | OIDN_LOCK(device); |
206 | Ref<Buffer> buffer = device->newBuffer(byteSize); |
207 | return (OIDNBuffer)buffer.detach(); |
208 | OIDN_CATCH(device) |
209 | return nullptr; |
210 | } |
211 | |
212 | OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize) |
213 | { |
214 | Device* device = (Device*)hDevice; |
215 | OIDN_TRY |
216 | checkHandle(hDevice); |
217 | OIDN_LOCK(device); |
218 | Ref<Buffer> buffer = device->newBuffer(ptr, byteSize); |
219 | return (OIDNBuffer)buffer.detach(); |
220 | OIDN_CATCH(device) |
221 | return nullptr; |
222 | } |
223 | |
224 | OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer) |
225 | { |
226 | Buffer* buffer = (Buffer*)hBuffer; |
227 | retainObject(buffer); |
228 | } |
229 | |
230 | OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer) |
231 | { |
232 | Buffer* buffer = (Buffer*)hBuffer; |
233 | releaseObject(buffer); |
234 | } |
235 | |
236 | OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize) |
237 | { |
238 | Buffer* buffer = (Buffer*)hBuffer; |
239 | OIDN_TRY |
240 | checkHandle(hBuffer); |
241 | OIDN_LOCK(buffer); |
242 | return buffer->map(byteOffset, byteSize); |
243 | OIDN_CATCH(buffer) |
244 | return nullptr; |
245 | } |
246 | |
247 | OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr) |
248 | { |
249 | Buffer* buffer = (Buffer*)hBuffer; |
250 | OIDN_TRY |
251 | checkHandle(hBuffer); |
252 | OIDN_LOCK(buffer); |
253 | return buffer->unmap(mappedPtr); |
254 | OIDN_CATCH(buffer) |
255 | } |
256 | |
257 | OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type) |
258 | { |
259 | Device* device = (Device*)hDevice; |
260 | OIDN_TRY |
261 | checkHandle(hDevice); |
262 | OIDN_LOCK(device); |
263 | Ref<Filter> filter = device->newFilter(type); |
264 | return (OIDNFilter)filter.detach(); |
265 | OIDN_CATCH(device) |
266 | return nullptr; |
267 | } |
268 | |
269 | OIDN_API void oidnRetainFilter(OIDNFilter hFilter) |
270 | { |
271 | Filter* filter = (Filter*)hFilter; |
272 | retainObject(filter); |
273 | } |
274 | |
275 | OIDN_API void oidnReleaseFilter(OIDNFilter hFilter) |
276 | { |
277 | Filter* filter = (Filter*)hFilter; |
278 | releaseObject(filter); |
279 | } |
280 | |
281 | OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name, |
282 | OIDNBuffer hBuffer, OIDNFormat format, |
283 | size_t width, size_t height, |
284 | size_t byteOffset, |
285 | size_t bytePixelStride, size_t byteRowStride) |
286 | { |
287 | Filter* filter = (Filter*)hFilter; |
288 | OIDN_TRY |
289 | checkHandle(hFilter); |
290 | checkHandle(hBuffer); |
291 | OIDN_LOCK(filter); |
292 | Ref<Buffer> buffer = (Buffer*)hBuffer; |
293 | if (buffer->getDevice() != filter->getDevice()) |
294 | throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices" ); |
295 | Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); |
296 | filter->setImage(name, data); |
297 | OIDN_CATCH(filter) |
298 | } |
299 | |
300 | OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name, |
301 | void* ptr, OIDNFormat format, |
302 | size_t width, size_t height, |
303 | size_t byteOffset, |
304 | size_t bytePixelStride, size_t byteRowStride) |
305 | { |
306 | Filter* filter = (Filter*)hFilter; |
307 | OIDN_TRY |
308 | checkHandle(hFilter); |
309 | OIDN_LOCK(filter); |
310 | Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); |
311 | filter->setImage(name, data); |
312 | OIDN_CATCH(filter) |
313 | } |
314 | |
315 | OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value) |
316 | { |
317 | Filter* filter = (Filter*)hFilter; |
318 | OIDN_TRY |
319 | checkHandle(hFilter); |
320 | OIDN_LOCK(filter); |
321 | filter->set1i(name, int(value)); |
322 | OIDN_CATCH(filter) |
323 | } |
324 | |
325 | OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name) |
326 | { |
327 | Filter* filter = (Filter*)hFilter; |
328 | OIDN_TRY |
329 | checkHandle(hFilter); |
330 | OIDN_LOCK(filter); |
331 | return filter->get1i(name); |
332 | OIDN_CATCH(filter) |
333 | return false; |
334 | } |
335 | |
336 | OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value) |
337 | { |
338 | Filter* filter = (Filter*)hFilter; |
339 | OIDN_TRY |
340 | checkHandle(hFilter); |
341 | OIDN_LOCK(filter); |
342 | filter->set1i(name, value); |
343 | OIDN_CATCH(filter) |
344 | } |
345 | |
346 | OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name) |
347 | { |
348 | Filter* filter = (Filter*)hFilter; |
349 | OIDN_TRY |
350 | checkHandle(hFilter); |
351 | OIDN_LOCK(filter); |
352 | return filter->get1i(name); |
353 | OIDN_CATCH(filter) |
354 | return 0; |
355 | } |
356 | |
357 | OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value) |
358 | { |
359 | Filter* filter = (Filter*)hFilter; |
360 | OIDN_TRY |
361 | checkHandle(hFilter); |
362 | OIDN_LOCK(filter); |
363 | filter->set1f(name, value); |
364 | OIDN_CATCH(filter) |
365 | } |
366 | |
367 | OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name) |
368 | { |
369 | Filter* filter = (Filter*)hFilter; |
370 | OIDN_TRY |
371 | checkHandle(hFilter); |
372 | OIDN_LOCK(filter); |
373 | return filter->get1f(name); |
374 | OIDN_CATCH(filter) |
375 | return 0; |
376 | } |
377 | |
378 | OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr) |
379 | { |
380 | Filter* filter = (Filter*)hFilter; |
381 | OIDN_TRY |
382 | checkHandle(hFilter); |
383 | OIDN_LOCK(filter); |
384 | filter->setProgressMonitorFunction(func, userPtr); |
385 | OIDN_CATCH(filter) |
386 | } |
387 | |
388 | OIDN_API void oidnCommitFilter(OIDNFilter hFilter) |
389 | { |
390 | Filter* filter = (Filter*)hFilter; |
391 | OIDN_TRY |
392 | checkHandle(hFilter); |
393 | OIDN_LOCK(filter); |
394 | filter->commit(); |
395 | OIDN_CATCH(filter) |
396 | } |
397 | |
398 | OIDN_API void oidnExecuteFilter(OIDNFilter hFilter) |
399 | { |
400 | Filter* filter = (Filter*)hFilter; |
401 | OIDN_TRY |
402 | checkHandle(hFilter); |
403 | OIDN_LOCK(filter); |
404 | filter->execute(); |
405 | OIDN_CATCH(filter) |
406 | } |
407 | |
408 | } // namespace oidn |
409 | |