1// ======================================================================== //
2// Copyright 2009-2019 Intel Corporation //
3// //
4// Licensed under the Apache License, Version 2.0 (the "License"); //
5// you may not use this file except in compliance with the License. //
6// You may obtain a copy of the License at //
7// //
8// http://www.apache.org/licenses/LICENSE-2.0 //
9// //
10// Unless required by applicable law or agreed to in writing, software //
11// distributed under the License is distributed on an "AS IS" BASIS, //
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
13// See the License for the specific language governing permissions and //
14// limitations under the License. //
15// ======================================================================== //
16
17#include "exception.h"
18#include "tensor.h"
19
20namespace oidn {
21
22 std::map<std::string, Tensor> parseTensors(void* buffer)
23 {
24 char* input = (char*)buffer;
25
26 // Parse the magic value
27 const int magic = *(unsigned short*)input;
28 if (magic != 0x41D7)
29 throw Exception(Error::InvalidOperation, "invalid tensor archive");
30 input += sizeof(unsigned short);
31
32 // Parse the version
33 const int majorVersion = *(unsigned char*)input++;
34 const int minorVersion = *(unsigned char*)input++;
35 UNUSED(minorVersion);
36 if (majorVersion > 1)
37 throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
38
39 // Parse the number of tensors
40 const int numTensors = *(int*)input;
41 input += sizeof(int);
42
43 // Parse the tensors
44 std::map<std::string, Tensor> tensorMap;
45 for (int i = 0; i < numTensors; ++i)
46 {
47 Tensor tensor;
48
49 // Parse the name
50 const int nameLen = *(unsigned char*)input++;
51 std::string name(input, nameLen);
52 input += nameLen;
53
54 // Parse the number of dimensions
55 const int ndims = *(unsigned char*)input++;
56
57 // Parse the shape of the tensor
58 tensor.dims.resize(ndims);
59 for (int i = 0; i < ndims; ++i)
60 tensor.dims[i] = ((int*)input)[i];
61 input += ndims * sizeof(int);
62
63 // Parse the format of the tensor
64 tensor.format = std::string(input, input + ndims);
65 input += ndims;
66
67 // Parse the data type of the tensor
68 const char type = *(unsigned char*)input++;
69 if (type != 'f') // only float32 is supported
70 throw Exception(Error::InvalidOperation, "unsupported tensor data type");
71
72 // Skip the data
73 tensor.data = (float*)input;
74 input += tensor.size() * sizeof(float);
75
76 // Add the tensor to the map
77 tensorMap.emplace(name, std::move(tensor));
78 }
79
80 return tensorMap;
81 }
82
83} // namespace oidn
84