1 | // This file is part of Eigen, a lightweight C++ template library |
2 | // for linear algebra. |
3 | // |
4 | // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr> |
5 | // |
6 | // This Source Code Form is subject to the terms of the Mozilla |
7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
9 | |
10 | #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
11 | #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
12 | |
13 | namespace Eigen { |
14 | |
15 | namespace internal { |
16 | |
17 | |
18 | // perform a pseudo in-place sparse * sparse product assuming all matrices are col major |
19 | template<typename Lhs, typename Rhs, typename ResultType> |
20 | static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance) |
21 | { |
22 | // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); |
23 | |
24 | typedef typename remove_all<Rhs>::type::Scalar RhsScalar; |
25 | typedef typename remove_all<ResultType>::type::Scalar ResScalar; |
26 | typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex; |
27 | |
28 | // make sure to call innerSize/outerSize since we fake the storage order. |
29 | Index rows = lhs.innerSize(); |
30 | Index cols = rhs.outerSize(); |
31 | //Index size = lhs.outerSize(); |
32 | eigen_assert(lhs.outerSize() == rhs.innerSize()); |
33 | |
34 | // allocate a temporary buffer |
35 | AmbiVector<ResScalar,StorageIndex> tempVector(rows); |
36 | |
37 | // mimics a resizeByInnerOuter: |
38 | if(ResultType::IsRowMajor) |
39 | res.resize(cols, rows); |
40 | else |
41 | res.resize(rows, cols); |
42 | |
43 | evaluator<Lhs> lhsEval(lhs); |
44 | evaluator<Rhs> rhsEval(rhs); |
45 | |
46 | // estimate the number of non zero entries |
47 | // given a rhs column containing Y non zeros, we assume that the respective Y columns |
48 | // of the lhs differs in average of one non zeros, thus the number of non zeros for |
49 | // the product of a rhs column with the lhs is X+Y where X is the average number of non zero |
50 | // per column of the lhs. |
51 | // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) |
52 | Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate(); |
53 | |
54 | res.reserve(estimated_nnz_prod); |
55 | double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols())); |
56 | for (Index j=0; j<cols; ++j) |
57 | { |
58 | // FIXME: |
59 | //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); |
60 | // let's do a more accurate determination of the nnz ratio for the current column j of res |
61 | tempVector.init(ratioColRes); |
62 | tempVector.setZero(); |
63 | for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) |
64 | { |
65 | // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) |
66 | tempVector.restart(); |
67 | RhsScalar x = rhsIt.value(); |
68 | for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) |
69 | { |
70 | tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; |
71 | } |
72 | } |
73 | res.startVec(j); |
74 | for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it) |
75 | res.insertBackByOuterInner(j,it.index()) = it.value(); |
76 | } |
77 | res.finalize(); |
78 | } |
79 | |
80 | template<typename Lhs, typename Rhs, typename ResultType, |
81 | int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, |
82 | int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, |
83 | int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> |
84 | struct sparse_sparse_product_with_pruning_selector; |
85 | |
86 | template<typename Lhs, typename Rhs, typename ResultType> |
87 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> |
88 | { |
89 | typedef typename ResultType::RealScalar RealScalar; |
90 | |
91 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
92 | { |
93 | typename remove_all<ResultType>::type _res(res.rows(), res.cols()); |
94 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); |
95 | res.swap(_res); |
96 | } |
97 | }; |
98 | |
99 | template<typename Lhs, typename Rhs, typename ResultType> |
100 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> |
101 | { |
102 | typedef typename ResultType::RealScalar RealScalar; |
103 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
104 | { |
105 | // we need a col-major matrix to hold the result |
106 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType; |
107 | SparseTemporaryType _res(res.rows(), res.cols()); |
108 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); |
109 | res = _res; |
110 | } |
111 | }; |
112 | |
113 | template<typename Lhs, typename Rhs, typename ResultType> |
114 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> |
115 | { |
116 | typedef typename ResultType::RealScalar RealScalar; |
117 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
118 | { |
119 | // let's transpose the product to get a column x column product |
120 | typename remove_all<ResultType>::type _res(res.rows(), res.cols()); |
121 | internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); |
122 | res.swap(_res); |
123 | } |
124 | }; |
125 | |
126 | template<typename Lhs, typename Rhs, typename ResultType> |
127 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> |
128 | { |
129 | typedef typename ResultType::RealScalar RealScalar; |
130 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
131 | { |
132 | typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; |
133 | typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; |
134 | ColMajorMatrixLhs colLhs(lhs); |
135 | ColMajorMatrixRhs colRhs(rhs); |
136 | internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance); |
137 | |
138 | // let's transpose the product to get a column x column product |
139 | // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; |
140 | // SparseTemporaryType _res(res.cols(), res.rows()); |
141 | // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); |
142 | // res = _res.transpose(); |
143 | } |
144 | }; |
145 | |
146 | template<typename Lhs, typename Rhs, typename ResultType> |
147 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> |
148 | { |
149 | typedef typename ResultType::RealScalar RealScalar; |
150 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
151 | { |
152 | typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs; |
153 | RowMajorMatrixLhs rowLhs(lhs); |
154 | sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); |
155 | } |
156 | }; |
157 | |
158 | template<typename Lhs, typename Rhs, typename ResultType> |
159 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> |
160 | { |
161 | typedef typename ResultType::RealScalar RealScalar; |
162 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
163 | { |
164 | typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs; |
165 | RowMajorMatrixRhs rowRhs(rhs); |
166 | sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); |
167 | } |
168 | }; |
169 | |
170 | template<typename Lhs, typename Rhs, typename ResultType> |
171 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> |
172 | { |
173 | typedef typename ResultType::RealScalar RealScalar; |
174 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
175 | { |
176 | typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; |
177 | ColMajorMatrixRhs colRhs(rhs); |
178 | internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); |
179 | } |
180 | }; |
181 | |
182 | template<typename Lhs, typename Rhs, typename ResultType> |
183 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> |
184 | { |
185 | typedef typename ResultType::RealScalar RealScalar; |
186 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
187 | { |
188 | typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; |
189 | ColMajorMatrixLhs colLhs(lhs); |
190 | internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); |
191 | } |
192 | }; |
193 | |
194 | } // end namespace internal |
195 | |
196 | } // end namespace Eigen |
197 | |
198 | #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
199 | |