1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #ifndef ARROW_SPARSE_TENSOR_H |
19 | #define ARROW_SPARSE_TENSOR_H |
20 | |
21 | #include <memory> |
22 | #include <string> |
23 | #include <vector> |
24 | |
25 | #include "arrow/tensor.h" |
26 | |
27 | namespace arrow { |
28 | |
29 | // ---------------------------------------------------------------------- |
30 | // SparseIndex class |
31 | |
32 | /// \brief EXPERIMENTAL: Sparse tensor format enumeration |
33 | struct SparseTensorFormat { |
34 | enum type { COO, CSR }; |
35 | }; |
36 | |
37 | /// \brief EXPERIMENTAL: The base class for representing index of non-zero |
38 | /// values in sparse tensor |
39 | class ARROW_EXPORT SparseIndex { |
40 | public: |
41 | explicit SparseIndex(SparseTensorFormat::type format_id, int64_t non_zero_length) |
42 | : format_id_(format_id), non_zero_length_(non_zero_length) {} |
43 | |
44 | virtual ~SparseIndex() = default; |
45 | |
46 | SparseTensorFormat::type format_id() const { return format_id_; } |
47 | int64_t non_zero_length() const { return non_zero_length_; } |
48 | |
49 | virtual std::string ToString() const = 0; |
50 | |
51 | protected: |
52 | SparseTensorFormat::type format_id_; |
53 | int64_t non_zero_length_; |
54 | }; |
55 | |
56 | template <typename SparseIndexType> |
57 | class SparseIndexBase : public SparseIndex { |
58 | public: |
59 | explicit SparseIndexBase(int64_t non_zero_length) |
60 | : SparseIndex(SparseIndexType::format_id, non_zero_length) {} |
61 | }; |
62 | |
63 | // ---------------------------------------------------------------------- |
64 | // SparseCOOIndex class |
65 | |
66 | /// \brief EXPERIMENTAL: The index data for COO sparse tensor |
67 | class ARROW_EXPORT SparseCOOIndex : public SparseIndexBase<SparseCOOIndex> { |
68 | public: |
69 | using CoordsTensor = NumericTensor<Int64Type>; |
70 | |
71 | static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::COO; |
72 | |
73 | // Constructor with a column-major NumericTensor |
74 | explicit SparseCOOIndex(const std::shared_ptr<CoordsTensor>& coords); |
75 | |
76 | const std::shared_ptr<CoordsTensor>& indices() const { return coords_; } |
77 | |
78 | std::string ToString() const override; |
79 | |
80 | bool Equals(const SparseCOOIndex& other) const { |
81 | return indices()->Equals(*other.indices()); |
82 | } |
83 | |
84 | protected: |
85 | std::shared_ptr<CoordsTensor> coords_; |
86 | }; |
87 | |
88 | // ---------------------------------------------------------------------- |
89 | // SparseCSRIndex class |
90 | |
91 | /// \brief EXPERIMENTAL: The index data for CSR sparse matrix |
92 | class ARROW_EXPORT SparseCSRIndex : public SparseIndexBase<SparseCSRIndex> { |
93 | public: |
94 | using IndexTensor = NumericTensor<Int64Type>; |
95 | |
96 | static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSR; |
97 | |
98 | // Constructor with two index vectors |
99 | explicit SparseCSRIndex(const std::shared_ptr<IndexTensor>& indptr, |
100 | const std::shared_ptr<IndexTensor>& indices); |
101 | |
102 | const std::shared_ptr<IndexTensor>& indptr() const { return indptr_; } |
103 | const std::shared_ptr<IndexTensor>& indices() const { return indices_; } |
104 | |
105 | std::string ToString() const override; |
106 | |
107 | bool Equals(const SparseCSRIndex& other) const { |
108 | return indptr()->Equals(*other.indptr()) && indices()->Equals(*other.indices()); |
109 | } |
110 | |
111 | protected: |
112 | std::shared_ptr<IndexTensor> indptr_; |
113 | std::shared_ptr<IndexTensor> indices_; |
114 | }; |
115 | |
116 | // ---------------------------------------------------------------------- |
117 | // SparseTensor class |
118 | |
119 | /// \brief EXPERIMENTAL: The base class of sparse tensor container |
120 | class ARROW_EXPORT SparseTensor { |
121 | public: |
122 | virtual ~SparseTensor() = default; |
123 | |
124 | SparseTensorFormat::type format_id() const { return sparse_index_->format_id(); } |
125 | |
126 | std::shared_ptr<DataType> type() const { return type_; } |
127 | std::shared_ptr<Buffer> data() const { return data_; } |
128 | |
129 | const uint8_t* raw_data() const { return data_->data(); } |
130 | uint8_t* raw_mutable_data() const { return data_->mutable_data(); } |
131 | |
132 | const std::vector<int64_t>& shape() const { return shape_; } |
133 | |
134 | const std::shared_ptr<SparseIndex>& sparse_index() const { return sparse_index_; } |
135 | |
136 | int ndim() const { return static_cast<int>(shape_.size()); } |
137 | |
138 | const std::string& dim_name(int i) const; |
139 | |
140 | /// Total number of value cells in the sparse tensor |
141 | int64_t size() const; |
142 | |
143 | /// Return true if the underlying data buffer is mutable |
144 | bool is_mutable() const { return data_->is_mutable(); } |
145 | |
146 | /// Total number of non-zero cells in the sparse tensor |
147 | int64_t non_zero_length() const { |
148 | return sparse_index_ ? sparse_index_->non_zero_length() : 0; |
149 | } |
150 | |
151 | bool Equals(const SparseTensor& other) const; |
152 | |
153 | protected: |
154 | // Constructor with all attributes |
155 | SparseTensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data, |
156 | const std::vector<int64_t>& shape, |
157 | const std::shared_ptr<SparseIndex>& sparse_index, |
158 | const std::vector<std::string>& dim_names); |
159 | |
160 | std::shared_ptr<DataType> type_; |
161 | std::shared_ptr<Buffer> data_; |
162 | std::vector<int64_t> shape_; |
163 | std::shared_ptr<SparseIndex> sparse_index_; |
164 | |
165 | /// These names are optional |
166 | std::vector<std::string> dim_names_; |
167 | }; |
168 | |
169 | // ---------------------------------------------------------------------- |
170 | // SparseTensorImpl class |
171 | |
172 | /// \brief EXPERIMENTAL: Concrete sparse tensor implementation classes with sparse index |
173 | /// type |
174 | template <typename SparseIndexType> |
175 | class ARROW_EXPORT SparseTensorImpl : public SparseTensor { |
176 | public: |
177 | virtual ~SparseTensorImpl() = default; |
178 | |
179 | // Constructor with all attributes |
180 | SparseTensorImpl(const std::shared_ptr<SparseIndexType>& sparse_index, |
181 | const std::shared_ptr<DataType>& type, |
182 | const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape, |
183 | const std::vector<std::string>& dim_names) |
184 | : SparseTensor(type, data, shape, sparse_index, dim_names) {} |
185 | |
186 | // Constructor for empty sparse tensor |
187 | SparseTensorImpl(const std::shared_ptr<DataType>& type, |
188 | const std::vector<int64_t>& shape, |
189 | const std::vector<std::string>& dim_names = {}); |
190 | |
191 | // Constructor with a dense numeric tensor |
192 | template <typename TYPE> |
193 | explicit SparseTensorImpl(const NumericTensor<TYPE>& tensor); |
194 | |
195 | // Constructor with a dense tensor |
196 | explicit SparseTensorImpl(const Tensor& tensor); |
197 | |
198 | private: |
199 | ARROW_DISALLOW_COPY_AND_ASSIGN(SparseTensorImpl); |
200 | }; |
201 | |
202 | /// \brief EXPERIMENTAL: Type alias for COO sparse tensor |
203 | using SparseTensorCOO = SparseTensorImpl<SparseCOOIndex>; |
204 | |
205 | /// \brief EXPERIMENTAL: Type alias for CSR sparse matrix |
206 | using SparseTensorCSR = SparseTensorImpl<SparseCSRIndex>; |
207 | using SparseMatrixCSR = SparseTensorImpl<SparseCSRIndex>; |
208 | |
209 | } // namespace arrow |
210 | |
211 | #endif // ARROW_SPARSE_TENSOR_H |
212 | |