| 1 | // ======================================================================== // | 
| 2 | // Copyright 2009-2019 Intel Corporation                                    // | 
| 3 | //                                                                          // | 
| 4 | // Licensed under the Apache License, Version 2.0 (the "License");          // | 
| 5 | // you may not use this file except in compliance with the License.         // | 
| 6 | // You may obtain a copy of the License at                                  // | 
| 7 | //                                                                          // | 
| 8 | //     http://www.apache.org/licenses/LICENSE-2.0                           // | 
| 9 | //                                                                          // | 
| 10 | // Unless required by applicable law or agreed to in writing, software      // | 
| 11 | // distributed under the License is distributed on an "AS IS" BASIS,        // | 
| 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // | 
| 13 | // See the License for the specific language governing permissions and      // | 
| 14 | // limitations under the License.                                           // | 
| 15 | // ======================================================================== // | 
| 16 |  | 
| 17 | #include "exception.h" | 
| 18 | #include "tensor.h" | 
| 19 |  | 
| 20 | namespace oidn { | 
| 21 |  | 
| 22 |   std::map<std::string, Tensor> parseTensors(void* buffer) | 
| 23 |   { | 
| 24 |     char* input = (char*)buffer; | 
| 25 |  | 
| 26 |     // Parse the magic value | 
| 27 |     const int magic = *(unsigned short*)input; | 
| 28 |     if (magic != 0x41D7) | 
| 29 |       throw Exception(Error::InvalidOperation, "invalid tensor archive" ); | 
| 30 |     input += sizeof(unsigned short); | 
| 31 |  | 
| 32 |     // Parse the version | 
| 33 |     const int majorVersion = *(unsigned char*)input++; | 
| 34 |     const int minorVersion = *(unsigned char*)input++; | 
| 35 |     UNUSED(minorVersion); | 
| 36 |     if (majorVersion > 1) | 
| 37 |       throw Exception(Error::InvalidOperation, "unsupported tensor archive version" ); | 
| 38 |  | 
| 39 |     // Parse the number of tensors | 
| 40 |     const int numTensors = *(int*)input; | 
| 41 |     input += sizeof(int); | 
| 42 |  | 
| 43 |     // Parse the tensors | 
| 44 |     std::map<std::string, Tensor> tensorMap; | 
| 45 |     for (int i = 0; i < numTensors; ++i) | 
| 46 |     { | 
| 47 |       Tensor tensor; | 
| 48 |  | 
| 49 |       // Parse the name | 
| 50 |       const int nameLen = *(unsigned char*)input++; | 
| 51 |       std::string name(input, nameLen); | 
| 52 |       input += nameLen; | 
| 53 |  | 
| 54 |       // Parse the number of dimensions | 
| 55 |       const int ndims = *(unsigned char*)input++; | 
| 56 |  | 
| 57 |       // Parse the shape of the tensor | 
| 58 |       tensor.dims.resize(ndims); | 
| 59 |       for (int i = 0; i < ndims; ++i) | 
| 60 |         tensor.dims[i] = ((int*)input)[i]; | 
| 61 |       input += ndims * sizeof(int); | 
| 62 |  | 
| 63 |       // Parse the format of the tensor | 
| 64 |       tensor.format = std::string(input, input + ndims); | 
| 65 |       input += ndims; | 
| 66 |  | 
| 67 |       // Parse the data type of the tensor | 
| 68 |       const char type = *(unsigned char*)input++; | 
| 69 |       if (type != 'f') // only float32 is supported | 
| 70 |         throw Exception(Error::InvalidOperation, "unsupported tensor data type" ); | 
| 71 |  | 
| 72 |       // Skip the data | 
| 73 |       tensor.data = (float*)input; | 
| 74 |       input += tensor.size() * sizeof(float); | 
| 75 |  | 
| 76 |       // Add the tensor to the map | 
| 77 |       tensorMap.emplace(name, std::move(tensor)); | 
| 78 |     } | 
| 79 |  | 
| 80 |     return tensorMap; | 
| 81 |   } | 
| 82 |  | 
| 83 | } // namespace oidn | 
| 84 |  |