1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
11#define EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename Lhs, typename Rhs, typename ResultType>
18static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, bool sortedInsertion = false)
19{
20 typedef typename remove_all<Lhs>::type::Scalar LhsScalar;
21 typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
22 typedef typename remove_all<ResultType>::type::Scalar ResScalar;
23
24 // make sure to call innerSize/outerSize since we fake the storage order.
25 Index rows = lhs.innerSize();
26 Index cols = rhs.outerSize();
27 eigen_assert(lhs.outerSize() == rhs.innerSize());
28
29 ei_declare_aligned_stack_constructed_variable(bool, mask, rows, 0);
30 ei_declare_aligned_stack_constructed_variable(ResScalar, values, rows, 0);
31 ei_declare_aligned_stack_constructed_variable(Index, indices, rows, 0);
32
33 std::memset(mask,0,sizeof(bool)*rows);
34
35 evaluator<Lhs> lhsEval(lhs);
36 evaluator<Rhs> rhsEval(rhs);
37
38 // estimate the number of non zero entries
39 // given a rhs column containing Y non zeros, we assume that the respective Y columns
40 // of the lhs differs in average of one non zeros, thus the number of non zeros for
41 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
42 // per column of the lhs.
43 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
44 Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
45
46 res.setZero();
47 res.reserve(Index(estimated_nnz_prod));
48 // we compute each column of the result, one after the other
49 for (Index j=0; j<cols; ++j)
50 {
51
52 res.startVec(j);
53 Index nnz = 0;
54 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
55 {
56 RhsScalar y = rhsIt.value();
57 Index k = rhsIt.index();
58 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
59 {
60 Index i = lhsIt.index();
61 LhsScalar x = lhsIt.value();
62 if(!mask[i])
63 {
64 mask[i] = true;
65 values[i] = x * y;
66 indices[nnz] = i;
67 ++nnz;
68 }
69 else
70 values[i] += x * y;
71 }
72 }
73 if(!sortedInsertion)
74 {
75 // unordered insertion
76 for(Index k=0; k<nnz; ++k)
77 {
78 Index i = indices[k];
79 res.insertBackByOuterInnerUnordered(j,i) = values[i];
80 mask[i] = false;
81 }
82 }
83 else
84 {
85 // alternative ordered insertion code:
86 const Index t200 = rows/11; // 11 == (log2(200)*1.39)
87 const Index t = (rows*100)/139;
88
89 // FIXME reserve nnz non zeros
90 // FIXME implement faster sorting algorithms for very small nnz
91 // if the result is sparse enough => use a quick sort
92 // otherwise => loop through the entire vector
93 // In order to avoid to perform an expensive log2 when the
94 // result is clearly very sparse we use a linear bound up to 200.
95 if((nnz<200 && nnz<t200) || nnz * numext::log2(int(nnz)) < t)
96 {
97 if(nnz>1) std::sort(indices,indices+nnz);
98 for(Index k=0; k<nnz; ++k)
99 {
100 Index i = indices[k];
101 res.insertBackByOuterInner(j,i) = values[i];
102 mask[i] = false;
103 }
104 }
105 else
106 {
107 // dense path
108 for(Index i=0; i<rows; ++i)
109 {
110 if(mask[i])
111 {
112 mask[i] = false;
113 res.insertBackByOuterInner(j,i) = values[i];
114 }
115 }
116 }
117 }
118 }
119 res.finalize();
120}
121
122
123} // end namespace internal
124
125namespace internal {
126
127template<typename Lhs, typename Rhs, typename ResultType,
128 int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
129 int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
130 int ResStorageOrder = (traits<ResultType>::Flags&RowMajorBit) ? RowMajor : ColMajor>
131struct conservative_sparse_sparse_product_selector;
132
133template<typename Lhs, typename Rhs, typename ResultType>
134struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
135{
136 typedef typename remove_all<Lhs>::type LhsCleaned;
137 typedef typename LhsCleaned::Scalar Scalar;
138
139 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
140 {
141 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
142 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrixAux;
143 typedef typename sparse_eval<ColMajorMatrixAux,ResultType::RowsAtCompileTime,ResultType::ColsAtCompileTime,ColMajorMatrixAux::Flags>::type ColMajorMatrix;
144
145 // If the result is tall and thin (in the extreme case a column vector)
146 // then it is faster to sort the coefficients inplace instead of transposing twice.
147 // FIXME, the following heuristic is probably not very good.
148 if(lhs.rows()>rhs.cols())
149 {
150 ColMajorMatrix resCol(lhs.rows(),rhs.cols());
151 // perform sorted insertion
152 internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol, true);
153 res = resCol.markAsRValue();
154 }
155 else
156 {
157 ColMajorMatrixAux resCol(lhs.rows(),rhs.cols());
158 // ressort to transpose to sort the entries
159 internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrixAux>(lhs, rhs, resCol, false);
160 RowMajorMatrix resRow(resCol);
161 res = resRow.markAsRValue();
162 }
163 }
164};
165
166template<typename Lhs, typename Rhs, typename ResultType>
167struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
168{
169 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
170 {
171 typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRhs;
172 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
173 RowMajorRhs rhsRow = rhs;
174 RowMajorRes resRow(lhs.rows(), rhs.cols());
175 internal::conservative_sparse_sparse_product_impl<RowMajorRhs,Lhs,RowMajorRes>(rhsRow, lhs, resRow);
176 res = resRow;
177 }
178};
179
180template<typename Lhs, typename Rhs, typename ResultType>
181struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
182{
183 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
184 {
185 typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorLhs;
186 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorRes;
187 RowMajorLhs lhsRow = lhs;
188 RowMajorRes resRow(lhs.rows(), rhs.cols());
189 internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorLhs,RowMajorRes>(rhs, lhsRow, resRow);
190 res = resRow;
191 }
192};
193
194template<typename Lhs, typename Rhs, typename ResultType>
195struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
196{
197 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
198 {
199 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
200 RowMajorMatrix resRow(lhs.rows(), rhs.cols());
201 internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
202 res = resRow;
203 }
204};
205
206
207template<typename Lhs, typename Rhs, typename ResultType>
208struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
209{
210 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
211
212 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
213 {
214 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
215 ColMajorMatrix resCol(lhs.rows(), rhs.cols());
216 internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
217 res = resCol;
218 }
219};
220
221template<typename Lhs, typename Rhs, typename ResultType>
222struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
223{
224 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
225 {
226 typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
227 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
228 ColMajorLhs lhsCol = lhs;
229 ColMajorRes resCol(lhs.rows(), rhs.cols());
230 internal::conservative_sparse_sparse_product_impl<ColMajorLhs,Rhs,ColMajorRes>(lhsCol, rhs, resCol);
231 res = resCol;
232 }
233};
234
235template<typename Lhs, typename Rhs, typename ResultType>
236struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
237{
238 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
239 {
240 typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
241 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRes;
242 ColMajorRhs rhsCol = rhs;
243 ColMajorRes resCol(lhs.rows(), rhs.cols());
244 internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorRhs,ColMajorRes>(lhs, rhsCol, resCol);
245 res = resCol;
246 }
247};
248
249template<typename Lhs, typename Rhs, typename ResultType>
250struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
251{
252 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
253 {
254 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename ResultType::StorageIndex> RowMajorMatrix;
255 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
256 RowMajorMatrix resRow(lhs.rows(),rhs.cols());
257 internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
258 // sort the non zeros:
259 ColMajorMatrix resCol(resRow);
260 res = resCol;
261 }
262};
263
264} // end namespace internal
265
266
267namespace internal {
268
269template<typename Lhs, typename Rhs, typename ResultType>
270static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
271{
272 typedef typename remove_all<Lhs>::type::Scalar LhsScalar;
273 typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
274 Index cols = rhs.outerSize();
275 eigen_assert(lhs.outerSize() == rhs.innerSize());
276
277 evaluator<Lhs> lhsEval(lhs);
278 evaluator<Rhs> rhsEval(rhs);
279
280 for (Index j=0; j<cols; ++j)
281 {
282 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
283 {
284 RhsScalar y = rhsIt.value();
285 Index k = rhsIt.index();
286 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
287 {
288 Index i = lhsIt.index();
289 LhsScalar x = lhsIt.value();
290 res.coeffRef(i,j) += x * y;
291 }
292 }
293 }
294}
295
296
297} // end namespace internal
298
299namespace internal {
300
301template<typename Lhs, typename Rhs, typename ResultType,
302 int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
303 int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor>
304struct sparse_sparse_to_dense_product_selector;
305
306template<typename Lhs, typename Rhs, typename ResultType>
307struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor>
308{
309 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
310 {
311 internal::sparse_sparse_to_dense_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, res);
312 }
313};
314
315template<typename Lhs, typename Rhs, typename ResultType>
316struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor>
317{
318 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
319 {
320 typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorLhs;
321 ColMajorLhs lhsCol(lhs);
322 internal::sparse_sparse_to_dense_product_impl<ColMajorLhs,Rhs,ResultType>(lhsCol, rhs, res);
323 }
324};
325
326template<typename Lhs, typename Rhs, typename ResultType>
327struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor>
328{
329 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
330 {
331 typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorRhs;
332 ColMajorRhs rhsCol(rhs);
333 internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorRhs,ResultType>(lhs, rhsCol, res);
334 }
335};
336
337template<typename Lhs, typename Rhs, typename ResultType>
338struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor>
339{
340 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
341 {
342 Transpose<ResultType> trRes(res);
343 internal::sparse_sparse_to_dense_product_impl<Rhs,Lhs,Transpose<ResultType> >(rhs, lhs, trRes);
344 }
345};
346
347
348} // end namespace internal
349
350} // end namespace Eigen
351
352#endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
353