| 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr> |
| 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
| 11 | #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
| 12 | |
| 13 | namespace Eigen { |
| 14 | |
| 15 | namespace internal { |
| 16 | |
| 17 | |
| 18 | // perform a pseudo in-place sparse * sparse product assuming all matrices are col major |
| 19 | template<typename Lhs, typename Rhs, typename ResultType> |
| 20 | static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance) |
| 21 | { |
| 22 | // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); |
| 23 | |
| 24 | typedef typename remove_all<Rhs>::type::Scalar RhsScalar; |
| 25 | typedef typename remove_all<ResultType>::type::Scalar ResScalar; |
| 26 | typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex; |
| 27 | |
| 28 | // make sure to call innerSize/outerSize since we fake the storage order. |
| 29 | Index rows = lhs.innerSize(); |
| 30 | Index cols = rhs.outerSize(); |
| 31 | //Index size = lhs.outerSize(); |
| 32 | eigen_assert(lhs.outerSize() == rhs.innerSize()); |
| 33 | |
| 34 | // allocate a temporary buffer |
| 35 | AmbiVector<ResScalar,StorageIndex> tempVector(rows); |
| 36 | |
| 37 | // mimics a resizeByInnerOuter: |
| 38 | if(ResultType::IsRowMajor) |
| 39 | res.resize(cols, rows); |
| 40 | else |
| 41 | res.resize(rows, cols); |
| 42 | |
| 43 | evaluator<Lhs> lhsEval(lhs); |
| 44 | evaluator<Rhs> rhsEval(rhs); |
| 45 | |
| 46 | // estimate the number of non zero entries |
| 47 | // given a rhs column containing Y non zeros, we assume that the respective Y columns |
| 48 | // of the lhs differs in average of one non zeros, thus the number of non zeros for |
| 49 | // the product of a rhs column with the lhs is X+Y where X is the average number of non zero |
| 50 | // per column of the lhs. |
| 51 | // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) |
| 52 | Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate(); |
| 53 | |
| 54 | res.reserve(estimated_nnz_prod); |
| 55 | double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols())); |
| 56 | for (Index j=0; j<cols; ++j) |
| 57 | { |
| 58 | // FIXME: |
| 59 | //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); |
| 60 | // let's do a more accurate determination of the nnz ratio for the current column j of res |
| 61 | tempVector.init(ratioColRes); |
| 62 | tempVector.setZero(); |
| 63 | for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) |
| 64 | { |
| 65 | // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) |
| 66 | tempVector.restart(); |
| 67 | RhsScalar x = rhsIt.value(); |
| 68 | for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) |
| 69 | { |
| 70 | tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; |
| 71 | } |
| 72 | } |
| 73 | res.startVec(j); |
| 74 | for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it) |
| 75 | res.insertBackByOuterInner(j,it.index()) = it.value(); |
| 76 | } |
| 77 | res.finalize(); |
| 78 | } |
| 79 | |
| 80 | template<typename Lhs, typename Rhs, typename ResultType, |
| 81 | int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, |
| 82 | int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, |
| 83 | int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> |
| 84 | struct sparse_sparse_product_with_pruning_selector; |
| 85 | |
| 86 | template<typename Lhs, typename Rhs, typename ResultType> |
| 87 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> |
| 88 | { |
| 89 | typedef typename ResultType::RealScalar RealScalar; |
| 90 | |
| 91 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 92 | { |
| 93 | typename remove_all<ResultType>::type _res(res.rows(), res.cols()); |
| 94 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); |
| 95 | res.swap(_res); |
| 96 | } |
| 97 | }; |
| 98 | |
| 99 | template<typename Lhs, typename Rhs, typename ResultType> |
| 100 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> |
| 101 | { |
| 102 | typedef typename ResultType::RealScalar RealScalar; |
| 103 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 104 | { |
| 105 | // we need a col-major matrix to hold the result |
| 106 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType; |
| 107 | SparseTemporaryType _res(res.rows(), res.cols()); |
| 108 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); |
| 109 | res = _res; |
| 110 | } |
| 111 | }; |
| 112 | |
| 113 | template<typename Lhs, typename Rhs, typename ResultType> |
| 114 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> |
| 115 | { |
| 116 | typedef typename ResultType::RealScalar RealScalar; |
| 117 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 118 | { |
| 119 | // let's transpose the product to get a column x column product |
| 120 | typename remove_all<ResultType>::type _res(res.rows(), res.cols()); |
| 121 | internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); |
| 122 | res.swap(_res); |
| 123 | } |
| 124 | }; |
| 125 | |
| 126 | template<typename Lhs, typename Rhs, typename ResultType> |
| 127 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> |
| 128 | { |
| 129 | typedef typename ResultType::RealScalar RealScalar; |
| 130 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 131 | { |
| 132 | typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; |
| 133 | typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; |
| 134 | ColMajorMatrixLhs colLhs(lhs); |
| 135 | ColMajorMatrixRhs colRhs(rhs); |
| 136 | internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance); |
| 137 | |
| 138 | // let's transpose the product to get a column x column product |
| 139 | // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; |
| 140 | // SparseTemporaryType _res(res.cols(), res.rows()); |
| 141 | // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); |
| 142 | // res = _res.transpose(); |
| 143 | } |
| 144 | }; |
| 145 | |
| 146 | template<typename Lhs, typename Rhs, typename ResultType> |
| 147 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> |
| 148 | { |
| 149 | typedef typename ResultType::RealScalar RealScalar; |
| 150 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 151 | { |
| 152 | typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs; |
| 153 | RowMajorMatrixLhs rowLhs(lhs); |
| 154 | sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); |
| 155 | } |
| 156 | }; |
| 157 | |
| 158 | template<typename Lhs, typename Rhs, typename ResultType> |
| 159 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> |
| 160 | { |
| 161 | typedef typename ResultType::RealScalar RealScalar; |
| 162 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 163 | { |
| 164 | typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs; |
| 165 | RowMajorMatrixRhs rowRhs(rhs); |
| 166 | sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); |
| 167 | } |
| 168 | }; |
| 169 | |
| 170 | template<typename Lhs, typename Rhs, typename ResultType> |
| 171 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> |
| 172 | { |
| 173 | typedef typename ResultType::RealScalar RealScalar; |
| 174 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 175 | { |
| 176 | typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; |
| 177 | ColMajorMatrixRhs colRhs(rhs); |
| 178 | internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); |
| 179 | } |
| 180 | }; |
| 181 | |
| 182 | template<typename Lhs, typename Rhs, typename ResultType> |
| 183 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> |
| 184 | { |
| 185 | typedef typename ResultType::RealScalar RealScalar; |
| 186 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) |
| 187 | { |
| 188 | typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; |
| 189 | ColMajorMatrixLhs colLhs(lhs); |
| 190 | internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); |
| 191 | } |
| 192 | }; |
| 193 | |
| 194 | } // end namespace internal |
| 195 | |
| 196 | } // end namespace Eigen |
| 197 | |
| 198 | #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |
| 199 | |