1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#ifndef ARROW_TENSOR_H
19#define ARROW_TENSOR_H
20
21#include <cstdint>
22#include <memory>
23#include <string>
24#include <vector>
25
26#include "arrow/buffer.h"
27#include "arrow/type.h"
28#include "arrow/type_traits.h"
29#include "arrow/util/macros.h"
30#include "arrow/util/visibility.h"
31
32namespace arrow {
33
34static inline bool is_tensor_supported(Type::type type_id) {
35 switch (type_id) {
36 case Type::UINT8:
37 case Type::INT8:
38 case Type::UINT16:
39 case Type::INT16:
40 case Type::UINT32:
41 case Type::INT32:
42 case Type::UINT64:
43 case Type::INT64:
44 case Type::HALF_FLOAT:
45 case Type::FLOAT:
46 case Type::DOUBLE:
47 return true;
48 default:
49 break;
50 }
51 return false;
52}
53
54template <typename SparseIndexType>
55class SparseTensorImpl;
56
57class ARROW_EXPORT Tensor {
58 public:
59 virtual ~Tensor() = default;
60
61 /// Constructor with no dimension names or strides, data assumed to be row-major
62 Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
63 const std::vector<int64_t>& shape);
64
65 /// Constructor with non-negative strides
66 Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
67 const std::vector<int64_t>& shape, const std::vector<int64_t>& strides);
68
69 /// Constructor with non-negative strides and dimension names
70 Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
71 const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
72 const std::vector<std::string>& dim_names);
73
74 std::shared_ptr<DataType> type() const { return type_; }
75 std::shared_ptr<Buffer> data() const { return data_; }
76
77 const uint8_t* raw_data() const { return data_->data(); }
78 uint8_t* raw_mutable_data() { return data_->mutable_data(); }
79
80 const std::vector<int64_t>& shape() const { return shape_; }
81 const std::vector<int64_t>& strides() const { return strides_; }
82
83 int ndim() const { return static_cast<int>(shape_.size()); }
84
85 const std::string& dim_name(int i) const;
86
87 /// Total number of value cells in the tensor
88 int64_t size() const;
89
90 /// Return true if the underlying data buffer is mutable
91 bool is_mutable() const { return data_->is_mutable(); }
92
93 /// Either row major or column major
94 bool is_contiguous() const;
95
96 /// AKA "C order"
97 bool is_row_major() const;
98
99 /// AKA "Fortran order"
100 bool is_column_major() const;
101
102 Type::type type_id() const;
103
104 bool Equals(const Tensor& other) const;
105
106 protected:
107 Tensor() {}
108
109 std::shared_ptr<DataType> type_;
110 std::shared_ptr<Buffer> data_;
111 std::vector<int64_t> shape_;
112 std::vector<int64_t> strides_;
113
114 /// These names are optional
115 std::vector<std::string> dim_names_;
116
117 template <typename SparseIndexType>
118 friend class SparseTensorImpl;
119
120 private:
121 ARROW_DISALLOW_COPY_AND_ASSIGN(Tensor);
122};
123
124template <typename TYPE>
125class NumericTensor : public Tensor {
126 public:
127 using TypeClass = TYPE;
128 using value_type = typename TypeClass::c_type;
129
130 /// Constructor with non-negative strides and dimension names
131 NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
132 const std::vector<int64_t>& strides,
133 const std::vector<std::string>& dim_names)
134 : Tensor(TypeTraits<TYPE>::type_singleton(), data, shape, strides, dim_names) {}
135
136 /// Constructor with no dimension names or strides, data assumed to be row-major
137 NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape)
138 : NumericTensor(data, shape, {}, {}) {}
139
140 /// Constructor with non-negative strides
141 NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
142 const std::vector<int64_t>& strides)
143 : NumericTensor(data, shape, strides, {}) {}
144
145 const value_type& Value(const std::vector<int64_t>& index) const {
146 int64_t offset = CalculateValueOffset(index);
147 const value_type* ptr = reinterpret_cast<const value_type*>(raw_data() + offset);
148 return *ptr;
149 }
150
151 protected:
152 int64_t CalculateValueOffset(const std::vector<int64_t>& index) const {
153 int64_t offset = 0;
154 for (size_t i = 0; i < index.size(); ++i) {
155 offset += index[i] * strides_[i];
156 }
157 return offset;
158 }
159};
160
161} // namespace arrow
162
163#endif // ARROW_TENSOR_H
164