1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #ifndef ARROW_TENSOR_H |
19 | #define ARROW_TENSOR_H |
20 | |
21 | #include <cstdint> |
22 | #include <memory> |
23 | #include <string> |
24 | #include <vector> |
25 | |
26 | #include "arrow/buffer.h" |
27 | #include "arrow/type.h" |
28 | #include "arrow/type_traits.h" |
29 | #include "arrow/util/macros.h" |
30 | #include "arrow/util/visibility.h" |
31 | |
32 | namespace arrow { |
33 | |
34 | static inline bool is_tensor_supported(Type::type type_id) { |
35 | switch (type_id) { |
36 | case Type::UINT8: |
37 | case Type::INT8: |
38 | case Type::UINT16: |
39 | case Type::INT16: |
40 | case Type::UINT32: |
41 | case Type::INT32: |
42 | case Type::UINT64: |
43 | case Type::INT64: |
44 | case Type::HALF_FLOAT: |
45 | case Type::FLOAT: |
46 | case Type::DOUBLE: |
47 | return true; |
48 | default: |
49 | break; |
50 | } |
51 | return false; |
52 | } |
53 | |
54 | template <typename SparseIndexType> |
55 | class SparseTensorImpl; |
56 | |
57 | class ARROW_EXPORT Tensor { |
58 | public: |
59 | virtual ~Tensor() = default; |
60 | |
61 | /// Constructor with no dimension names or strides, data assumed to be row-major |
62 | Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data, |
63 | const std::vector<int64_t>& shape); |
64 | |
65 | /// Constructor with non-negative strides |
66 | Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data, |
67 | const std::vector<int64_t>& shape, const std::vector<int64_t>& strides); |
68 | |
69 | /// Constructor with non-negative strides and dimension names |
70 | Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data, |
71 | const std::vector<int64_t>& shape, const std::vector<int64_t>& strides, |
72 | const std::vector<std::string>& dim_names); |
73 | |
74 | std::shared_ptr<DataType> type() const { return type_; } |
75 | std::shared_ptr<Buffer> data() const { return data_; } |
76 | |
77 | const uint8_t* raw_data() const { return data_->data(); } |
78 | uint8_t* raw_mutable_data() { return data_->mutable_data(); } |
79 | |
80 | const std::vector<int64_t>& shape() const { return shape_; } |
81 | const std::vector<int64_t>& strides() const { return strides_; } |
82 | |
83 | int ndim() const { return static_cast<int>(shape_.size()); } |
84 | |
85 | const std::string& dim_name(int i) const; |
86 | |
87 | /// Total number of value cells in the tensor |
88 | int64_t size() const; |
89 | |
90 | /// Return true if the underlying data buffer is mutable |
91 | bool is_mutable() const { return data_->is_mutable(); } |
92 | |
93 | /// Either row major or column major |
94 | bool is_contiguous() const; |
95 | |
96 | /// AKA "C order" |
97 | bool is_row_major() const; |
98 | |
99 | /// AKA "Fortran order" |
100 | bool is_column_major() const; |
101 | |
102 | Type::type type_id() const; |
103 | |
104 | bool Equals(const Tensor& other) const; |
105 | |
106 | protected: |
107 | Tensor() {} |
108 | |
109 | std::shared_ptr<DataType> type_; |
110 | std::shared_ptr<Buffer> data_; |
111 | std::vector<int64_t> shape_; |
112 | std::vector<int64_t> strides_; |
113 | |
114 | /// These names are optional |
115 | std::vector<std::string> dim_names_; |
116 | |
117 | template <typename SparseIndexType> |
118 | friend class SparseTensorImpl; |
119 | |
120 | private: |
121 | ARROW_DISALLOW_COPY_AND_ASSIGN(Tensor); |
122 | }; |
123 | |
124 | template <typename TYPE> |
125 | class NumericTensor : public Tensor { |
126 | public: |
127 | using TypeClass = TYPE; |
128 | using value_type = typename TypeClass::c_type; |
129 | |
130 | /// Constructor with non-negative strides and dimension names |
131 | NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape, |
132 | const std::vector<int64_t>& strides, |
133 | const std::vector<std::string>& dim_names) |
134 | : Tensor(TypeTraits<TYPE>::type_singleton(), data, shape, strides, dim_names) {} |
135 | |
136 | /// Constructor with no dimension names or strides, data assumed to be row-major |
137 | NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape) |
138 | : NumericTensor(data, shape, {}, {}) {} |
139 | |
140 | /// Constructor with non-negative strides |
141 | NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape, |
142 | const std::vector<int64_t>& strides) |
143 | : NumericTensor(data, shape, strides, {}) {} |
144 | |
145 | const value_type& Value(const std::vector<int64_t>& index) const { |
146 | int64_t offset = CalculateValueOffset(index); |
147 | const value_type* ptr = reinterpret_cast<const value_type*>(raw_data() + offset); |
148 | return *ptr; |
149 | } |
150 | |
151 | protected: |
152 | int64_t CalculateValueOffset(const std::vector<int64_t>& index) const { |
153 | int64_t offset = 0; |
154 | for (size_t i = 0; i < index.size(); ++i) { |
155 | offset += index[i] * strides_[i]; |
156 | } |
157 | return offset; |
158 | } |
159 | }; |
160 | |
161 | } // namespace arrow |
162 | |
163 | #endif // ARROW_TENSOR_H |
164 | |