| 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr> |
| 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H |
| 11 | #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H |
| 12 | |
| 13 | namespace Eigen { |
| 14 | |
| 15 | // The product of a diagonal matrix with a sparse matrix can be easily |
| 16 | // implemented using expression template. |
| 17 | // We have two consider very different cases: |
| 18 | // 1 - diag * row-major sparse |
| 19 | // => each inner vector <=> scalar * sparse vector product |
| 20 | // => so we can reuse CwiseUnaryOp::InnerIterator |
| 21 | // 2 - diag * col-major sparse |
| 22 | // => each inner vector <=> densevector * sparse vector cwise product |
| 23 | // => again, we can reuse specialization of CwiseBinaryOp::InnerIterator |
| 24 | // for that particular case |
| 25 | // The two other cases are symmetric. |
| 26 | |
| 27 | namespace internal { |
| 28 | |
| 29 | enum { |
| 30 | SDP_AsScalarProduct, |
| 31 | SDP_AsCwiseProduct |
| 32 | }; |
| 33 | |
| 34 | template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag> |
| 35 | struct sparse_diagonal_product_evaluator; |
| 36 | |
| 37 | template<typename Lhs, typename Rhs, int ProductTag> |
| 38 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DiagonalShape, SparseShape> |
| 39 | : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> |
| 40 | { |
| 41 | typedef Product<Lhs, Rhs, DefaultProduct> XprType; |
| 42 | enum { CoeffReadCost = HugeCost, Flags = Rhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags |
| 43 | |
| 44 | typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base; |
| 45 | explicit product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {} |
| 46 | }; |
| 47 | |
| 48 | template<typename Lhs, typename Rhs, int ProductTag> |
| 49 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DiagonalShape> |
| 50 | : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> |
| 51 | { |
| 52 | typedef Product<Lhs, Rhs, DefaultProduct> XprType; |
| 53 | enum { CoeffReadCost = HugeCost, Flags = Lhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags |
| 54 | |
| 55 | typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base; |
| 56 | explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal().transpose()) {} |
| 57 | }; |
| 58 | |
| 59 | template<typename SparseXprType, typename DiagonalCoeffType> |
| 60 | struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct> |
| 61 | { |
| 62 | protected: |
| 63 | typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator; |
| 64 | typedef typename SparseXprType::Scalar Scalar; |
| 65 | |
| 66 | public: |
| 67 | class InnerIterator : public SparseXprInnerIterator |
| 68 | { |
| 69 | public: |
| 70 | InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) |
| 71 | : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer), |
| 72 | m_coeff(xprEval.m_diagCoeffImpl.coeff(outer)) |
| 73 | {} |
| 74 | |
| 75 | EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); } |
| 76 | protected: |
| 77 | typename DiagonalCoeffType::Scalar m_coeff; |
| 78 | }; |
| 79 | |
| 80 | sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff) |
| 81 | : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff) |
| 82 | {} |
| 83 | |
| 84 | Index nonZerosEstimate() const { return m_sparseXprImpl.nonZerosEstimate(); } |
| 85 | |
| 86 | protected: |
| 87 | evaluator<SparseXprType> m_sparseXprImpl; |
| 88 | evaluator<DiagonalCoeffType> m_diagCoeffImpl; |
| 89 | }; |
| 90 | |
| 91 | |
| 92 | template<typename SparseXprType, typename DiagCoeffType> |
| 93 | struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct> |
| 94 | { |
| 95 | typedef typename SparseXprType::Scalar Scalar; |
| 96 | typedef typename SparseXprType::StorageIndex StorageIndex; |
| 97 | |
| 98 | typedef typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime |
| 99 | : SparseXprType::ColsAtCompileTime>::type DiagCoeffNested; |
| 100 | |
| 101 | class InnerIterator |
| 102 | { |
| 103 | typedef typename evaluator<SparseXprType>::InnerIterator SparseXprIter; |
| 104 | public: |
| 105 | InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) |
| 106 | : m_sparseIter(xprEval.m_sparseXprEval, outer), m_diagCoeffNested(xprEval.m_diagCoeffNested) |
| 107 | {} |
| 108 | |
| 109 | inline Scalar value() const { return m_sparseIter.value() * m_diagCoeffNested.coeff(index()); } |
| 110 | inline StorageIndex index() const { return m_sparseIter.index(); } |
| 111 | inline Index outer() const { return m_sparseIter.outer(); } |
| 112 | inline Index col() const { return SparseXprType::IsRowMajor ? m_sparseIter.index() : m_sparseIter.outer(); } |
| 113 | inline Index row() const { return SparseXprType::IsRowMajor ? m_sparseIter.outer() : m_sparseIter.index(); } |
| 114 | |
| 115 | EIGEN_STRONG_INLINE InnerIterator& operator++() { ++m_sparseIter; return *this; } |
| 116 | inline operator bool() const { return m_sparseIter; } |
| 117 | |
| 118 | protected: |
| 119 | SparseXprIter m_sparseIter; |
| 120 | DiagCoeffNested m_diagCoeffNested; |
| 121 | }; |
| 122 | |
| 123 | sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff) |
| 124 | : m_sparseXprEval(sparseXpr), m_diagCoeffNested(diagCoeff) |
| 125 | {} |
| 126 | |
| 127 | Index nonZerosEstimate() const { return m_sparseXprEval.nonZerosEstimate(); } |
| 128 | |
| 129 | protected: |
| 130 | evaluator<SparseXprType> m_sparseXprEval; |
| 131 | DiagCoeffNested m_diagCoeffNested; |
| 132 | }; |
| 133 | |
| 134 | } // end namespace internal |
| 135 | |
| 136 | } // end namespace Eigen |
| 137 | |
| 138 | #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H |
| 139 | |