| 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr> |
| 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #ifndef EIGEN_SPARSEPRODUCT_H |
| 11 | #define EIGEN_SPARSEPRODUCT_H |
| 12 | |
| 13 | namespace Eigen { |
| 14 | |
| 15 | /** \returns an expression of the product of two sparse matrices. |
| 16 | * By default a conservative product preserving the symbolic non zeros is performed. |
| 17 | * The automatic pruning of the small values can be achieved by calling the pruned() function |
| 18 | * in which case a totally different product algorithm is employed: |
| 19 | * \code |
| 20 | * C = (A*B).pruned(); // supress numerical zeros (exact) |
| 21 | * C = (A*B).pruned(ref); |
| 22 | * C = (A*B).pruned(ref,epsilon); |
| 23 | * \endcode |
| 24 | * where \c ref is a meaningful non zero reference value. |
| 25 | * */ |
| 26 | template<typename Derived> |
| 27 | template<typename OtherDerived> |
| 28 | inline const Product<Derived,OtherDerived,AliasFreeProduct> |
| 29 | SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const |
| 30 | { |
| 31 | return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived()); |
| 32 | } |
| 33 | |
| 34 | namespace internal { |
| 35 | |
| 36 | // sparse * sparse |
| 37 | template<typename Lhs, typename Rhs, int ProductType> |
| 38 | struct generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType> |
| 39 | { |
| 40 | template<typename Dest> |
| 41 | static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) |
| 42 | { |
| 43 | evalTo(dst, lhs, rhs, typename evaluator_traits<Dest>::Shape()); |
| 44 | } |
| 45 | |
| 46 | // dense += sparse * sparse |
| 47 | template<typename Dest,typename ActualLhs> |
| 48 | static void addTo(Dest& dst, const ActualLhs& lhs, const Rhs& rhs, typename enable_if<is_same<typename evaluator_traits<Dest>::Shape,DenseShape>::value,int*>::type* = 0) |
| 49 | { |
| 50 | typedef typename nested_eval<ActualLhs,Dynamic>::type LhsNested; |
| 51 | typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; |
| 52 | LhsNested lhsNested(lhs); |
| 53 | RhsNested rhsNested(rhs); |
| 54 | internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type, |
| 55 | typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst); |
| 56 | } |
| 57 | |
| 58 | // dense -= sparse * sparse |
| 59 | template<typename Dest> |
| 60 | static void subTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, typename enable_if<is_same<typename evaluator_traits<Dest>::Shape,DenseShape>::value,int*>::type* = 0) |
| 61 | { |
| 62 | addTo(dst, -lhs, rhs); |
| 63 | } |
| 64 | |
| 65 | protected: |
| 66 | |
| 67 | // sparse = sparse * sparse |
| 68 | template<typename Dest> |
| 69 | static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, SparseShape) |
| 70 | { |
| 71 | typedef typename nested_eval<Lhs,Dynamic>::type LhsNested; |
| 72 | typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; |
| 73 | LhsNested lhsNested(lhs); |
| 74 | RhsNested rhsNested(rhs); |
| 75 | internal::conservative_sparse_sparse_product_selector<typename remove_all<LhsNested>::type, |
| 76 | typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst); |
| 77 | } |
| 78 | |
| 79 | // dense = sparse * sparse |
| 80 | template<typename Dest> |
| 81 | static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, DenseShape) |
| 82 | { |
| 83 | dst.setZero(); |
| 84 | addTo(dst, lhs, rhs); |
| 85 | } |
| 86 | }; |
| 87 | |
| 88 | // sparse * sparse-triangular |
| 89 | template<typename Lhs, typename Rhs, int ProductType> |
| 90 | struct generic_product_impl<Lhs, Rhs, SparseShape, SparseTriangularShape, ProductType> |
| 91 | : public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType> |
| 92 | {}; |
| 93 | |
| 94 | // sparse-triangular * sparse |
| 95 | template<typename Lhs, typename Rhs, int ProductType> |
| 96 | struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, ProductType> |
| 97 | : public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType> |
| 98 | {}; |
| 99 | |
| 100 | // dense = sparse-product (can be sparse*sparse, sparse*perm, etc.) |
| 101 | template< typename DstXprType, typename Lhs, typename Rhs> |
| 102 | struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense> |
| 103 | { |
| 104 | typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType; |
| 105 | static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &) |
| 106 | { |
| 107 | Index dstRows = src.rows(); |
| 108 | Index dstCols = src.cols(); |
| 109 | if((dst.rows()!=dstRows) || (dst.cols()!=dstCols)) |
| 110 | dst.resize(dstRows, dstCols); |
| 111 | |
| 112 | generic_product_impl<Lhs, Rhs>::evalTo(dst,src.lhs(),src.rhs()); |
| 113 | } |
| 114 | }; |
| 115 | |
| 116 | // dense += sparse-product (can be sparse*sparse, sparse*perm, etc.) |
| 117 | template< typename DstXprType, typename Lhs, typename Rhs> |
| 118 | struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::add_assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense> |
| 119 | { |
| 120 | typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType; |
| 121 | static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &) |
| 122 | { |
| 123 | generic_product_impl<Lhs, Rhs>::addTo(dst,src.lhs(),src.rhs()); |
| 124 | } |
| 125 | }; |
| 126 | |
| 127 | // dense -= sparse-product (can be sparse*sparse, sparse*perm, etc.) |
| 128 | template< typename DstXprType, typename Lhs, typename Rhs> |
| 129 | struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::sub_assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense> |
| 130 | { |
| 131 | typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType; |
| 132 | static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &) |
| 133 | { |
| 134 | generic_product_impl<Lhs, Rhs>::subTo(dst,src.lhs(),src.rhs()); |
| 135 | } |
| 136 | }; |
| 137 | |
| 138 | template<typename Lhs, typename Rhs, int Options> |
| 139 | struct unary_evaluator<SparseView<Product<Lhs, Rhs, Options> >, IteratorBased> |
| 140 | : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject> |
| 141 | { |
| 142 | typedef SparseView<Product<Lhs, Rhs, Options> > XprType; |
| 143 | typedef typename XprType::PlainObject PlainObject; |
| 144 | typedef evaluator<PlainObject> Base; |
| 145 | |
| 146 | explicit unary_evaluator(const XprType& xpr) |
| 147 | : m_result(xpr.rows(), xpr.cols()) |
| 148 | { |
| 149 | using std::abs; |
| 150 | ::new (static_cast<Base*>(this)) Base(m_result); |
| 151 | typedef typename nested_eval<Lhs,Dynamic>::type LhsNested; |
| 152 | typedef typename nested_eval<Rhs,Dynamic>::type RhsNested; |
| 153 | LhsNested lhsNested(xpr.nestedExpression().lhs()); |
| 154 | RhsNested rhsNested(xpr.nestedExpression().rhs()); |
| 155 | |
| 156 | internal::sparse_sparse_product_with_pruning_selector<typename remove_all<LhsNested>::type, |
| 157 | typename remove_all<RhsNested>::type, PlainObject>::run(lhsNested,rhsNested,m_result, |
| 158 | abs(xpr.reference())*xpr.epsilon()); |
| 159 | } |
| 160 | |
| 161 | protected: |
| 162 | PlainObject m_result; |
| 163 | }; |
| 164 | |
| 165 | } // end namespace internal |
| 166 | |
| 167 | } // end namespace Eigen |
| 168 | |
| 169 | #endif // EIGEN_SPARSEPRODUCT_H |
| 170 | |