1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#ifndef ARROW_SPARSE_TENSOR_H
19#define ARROW_SPARSE_TENSOR_H
20
21#include <memory>
22#include <string>
23#include <vector>
24
25#include "arrow/tensor.h"
26
27namespace arrow {
28
29// ----------------------------------------------------------------------
30// SparseIndex class
31
32/// \brief EXPERIMENTAL: Sparse tensor format enumeration
33struct SparseTensorFormat {
34 enum type { COO, CSR };
35};
36
37/// \brief EXPERIMENTAL: The base class for representing index of non-zero
38/// values in sparse tensor
39class ARROW_EXPORT SparseIndex {
40 public:
41 explicit SparseIndex(SparseTensorFormat::type format_id, int64_t non_zero_length)
42 : format_id_(format_id), non_zero_length_(non_zero_length) {}
43
44 virtual ~SparseIndex() = default;
45
46 SparseTensorFormat::type format_id() const { return format_id_; }
47 int64_t non_zero_length() const { return non_zero_length_; }
48
49 virtual std::string ToString() const = 0;
50
51 protected:
52 SparseTensorFormat::type format_id_;
53 int64_t non_zero_length_;
54};
55
56template <typename SparseIndexType>
57class SparseIndexBase : public SparseIndex {
58 public:
59 explicit SparseIndexBase(int64_t non_zero_length)
60 : SparseIndex(SparseIndexType::format_id, non_zero_length) {}
61};
62
63// ----------------------------------------------------------------------
64// SparseCOOIndex class
65
66/// \brief EXPERIMENTAL: The index data for COO sparse tensor
67class ARROW_EXPORT SparseCOOIndex : public SparseIndexBase<SparseCOOIndex> {
68 public:
69 using CoordsTensor = NumericTensor<Int64Type>;
70
71 static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::COO;
72
73 // Constructor with a column-major NumericTensor
74 explicit SparseCOOIndex(const std::shared_ptr<CoordsTensor>& coords);
75
76 const std::shared_ptr<CoordsTensor>& indices() const { return coords_; }
77
78 std::string ToString() const override;
79
80 bool Equals(const SparseCOOIndex& other) const {
81 return indices()->Equals(*other.indices());
82 }
83
84 protected:
85 std::shared_ptr<CoordsTensor> coords_;
86};
87
88// ----------------------------------------------------------------------
89// SparseCSRIndex class
90
91/// \brief EXPERIMENTAL: The index data for CSR sparse matrix
92class ARROW_EXPORT SparseCSRIndex : public SparseIndexBase<SparseCSRIndex> {
93 public:
94 using IndexTensor = NumericTensor<Int64Type>;
95
96 static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSR;
97
98 // Constructor with two index vectors
99 explicit SparseCSRIndex(const std::shared_ptr<IndexTensor>& indptr,
100 const std::shared_ptr<IndexTensor>& indices);
101
102 const std::shared_ptr<IndexTensor>& indptr() const { return indptr_; }
103 const std::shared_ptr<IndexTensor>& indices() const { return indices_; }
104
105 std::string ToString() const override;
106
107 bool Equals(const SparseCSRIndex& other) const {
108 return indptr()->Equals(*other.indptr()) && indices()->Equals(*other.indices());
109 }
110
111 protected:
112 std::shared_ptr<IndexTensor> indptr_;
113 std::shared_ptr<IndexTensor> indices_;
114};
115
116// ----------------------------------------------------------------------
117// SparseTensor class
118
119/// \brief EXPERIMENTAL: The base class of sparse tensor container
120class ARROW_EXPORT SparseTensor {
121 public:
122 virtual ~SparseTensor() = default;
123
124 SparseTensorFormat::type format_id() const { return sparse_index_->format_id(); }
125
126 std::shared_ptr<DataType> type() const { return type_; }
127 std::shared_ptr<Buffer> data() const { return data_; }
128
129 const uint8_t* raw_data() const { return data_->data(); }
130 uint8_t* raw_mutable_data() const { return data_->mutable_data(); }
131
132 const std::vector<int64_t>& shape() const { return shape_; }
133
134 const std::shared_ptr<SparseIndex>& sparse_index() const { return sparse_index_; }
135
136 int ndim() const { return static_cast<int>(shape_.size()); }
137
138 const std::string& dim_name(int i) const;
139
140 /// Total number of value cells in the sparse tensor
141 int64_t size() const;
142
143 /// Return true if the underlying data buffer is mutable
144 bool is_mutable() const { return data_->is_mutable(); }
145
146 /// Total number of non-zero cells in the sparse tensor
147 int64_t non_zero_length() const {
148 return sparse_index_ ? sparse_index_->non_zero_length() : 0;
149 }
150
151 bool Equals(const SparseTensor& other) const;
152
153 protected:
154 // Constructor with all attributes
155 SparseTensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
156 const std::vector<int64_t>& shape,
157 const std::shared_ptr<SparseIndex>& sparse_index,
158 const std::vector<std::string>& dim_names);
159
160 std::shared_ptr<DataType> type_;
161 std::shared_ptr<Buffer> data_;
162 std::vector<int64_t> shape_;
163 std::shared_ptr<SparseIndex> sparse_index_;
164
165 /// These names are optional
166 std::vector<std::string> dim_names_;
167};
168
169// ----------------------------------------------------------------------
170// SparseTensorImpl class
171
172/// \brief EXPERIMENTAL: Concrete sparse tensor implementation classes with sparse index
173/// type
174template <typename SparseIndexType>
175class ARROW_EXPORT SparseTensorImpl : public SparseTensor {
176 public:
177 virtual ~SparseTensorImpl() = default;
178
179 // Constructor with all attributes
180 SparseTensorImpl(const std::shared_ptr<SparseIndexType>& sparse_index,
181 const std::shared_ptr<DataType>& type,
182 const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
183 const std::vector<std::string>& dim_names)
184 : SparseTensor(type, data, shape, sparse_index, dim_names) {}
185
186 // Constructor for empty sparse tensor
187 SparseTensorImpl(const std::shared_ptr<DataType>& type,
188 const std::vector<int64_t>& shape,
189 const std::vector<std::string>& dim_names = {});
190
191 // Constructor with a dense numeric tensor
192 template <typename TYPE>
193 explicit SparseTensorImpl(const NumericTensor<TYPE>& tensor);
194
195 // Constructor with a dense tensor
196 explicit SparseTensorImpl(const Tensor& tensor);
197
198 private:
199 ARROW_DISALLOW_COPY_AND_ASSIGN(SparseTensorImpl);
200};
201
202/// \brief EXPERIMENTAL: Type alias for COO sparse tensor
203using SparseTensorCOO = SparseTensorImpl<SparseCOOIndex>;
204
205/// \brief EXPERIMENTAL: Type alias for CSR sparse matrix
206using SparseTensorCSR = SparseTensorImpl<SparseCSRIndex>;
207using SparseMatrixCSR = SparseTensorImpl<SparseCSRIndex>;
208
209} // namespace arrow
210
211#endif // ARROW_SPARSE_TENSOR_H
212