1 | // This file is part of Eigen, a lightweight C++ template library |
2 | // for linear algebra. |
3 | // |
4 | // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr> |
5 | // |
6 | // This Source Code Form is subject to the terms of the Mozilla |
7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
9 | |
10 | #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H |
11 | #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H |
12 | |
13 | namespace Eigen { |
14 | |
15 | // The product of a diagonal matrix with a sparse matrix can be easily |
16 | // implemented using expression template. |
17 | // We have two consider very different cases: |
18 | // 1 - diag * row-major sparse |
19 | // => each inner vector <=> scalar * sparse vector product |
20 | // => so we can reuse CwiseUnaryOp::InnerIterator |
21 | // 2 - diag * col-major sparse |
22 | // => each inner vector <=> densevector * sparse vector cwise product |
23 | // => again, we can reuse specialization of CwiseBinaryOp::InnerIterator |
24 | // for that particular case |
25 | // The two other cases are symmetric. |
26 | |
27 | namespace internal { |
28 | |
29 | enum { |
30 | SDP_AsScalarProduct, |
31 | SDP_AsCwiseProduct |
32 | }; |
33 | |
34 | template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag> |
35 | struct sparse_diagonal_product_evaluator; |
36 | |
37 | template<typename Lhs, typename Rhs, int ProductTag> |
38 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DiagonalShape, SparseShape> |
39 | : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> |
40 | { |
41 | typedef Product<Lhs, Rhs, DefaultProduct> XprType; |
42 | enum { CoeffReadCost = HugeCost, Flags = Rhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags |
43 | |
44 | typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base; |
45 | explicit product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {} |
46 | }; |
47 | |
48 | template<typename Lhs, typename Rhs, int ProductTag> |
49 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DiagonalShape> |
50 | : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> |
51 | { |
52 | typedef Product<Lhs, Rhs, DefaultProduct> XprType; |
53 | enum { CoeffReadCost = HugeCost, Flags = Lhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags |
54 | |
55 | typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base; |
56 | explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal().transpose()) {} |
57 | }; |
58 | |
59 | template<typename SparseXprType, typename DiagonalCoeffType> |
60 | struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct> |
61 | { |
62 | protected: |
63 | typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator; |
64 | typedef typename SparseXprType::Scalar Scalar; |
65 | |
66 | public: |
67 | class InnerIterator : public SparseXprInnerIterator |
68 | { |
69 | public: |
70 | InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) |
71 | : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer), |
72 | m_coeff(xprEval.m_diagCoeffImpl.coeff(outer)) |
73 | {} |
74 | |
75 | EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); } |
76 | protected: |
77 | typename DiagonalCoeffType::Scalar m_coeff; |
78 | }; |
79 | |
80 | sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff) |
81 | : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff) |
82 | {} |
83 | |
84 | Index nonZerosEstimate() const { return m_sparseXprImpl.nonZerosEstimate(); } |
85 | |
86 | protected: |
87 | evaluator<SparseXprType> m_sparseXprImpl; |
88 | evaluator<DiagonalCoeffType> m_diagCoeffImpl; |
89 | }; |
90 | |
91 | |
92 | template<typename SparseXprType, typename DiagCoeffType> |
93 | struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct> |
94 | { |
95 | typedef typename SparseXprType::Scalar Scalar; |
96 | typedef typename SparseXprType::StorageIndex StorageIndex; |
97 | |
98 | typedef typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime |
99 | : SparseXprType::ColsAtCompileTime>::type DiagCoeffNested; |
100 | |
101 | class InnerIterator |
102 | { |
103 | typedef typename evaluator<SparseXprType>::InnerIterator SparseXprIter; |
104 | public: |
105 | InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) |
106 | : m_sparseIter(xprEval.m_sparseXprEval, outer), m_diagCoeffNested(xprEval.m_diagCoeffNested) |
107 | {} |
108 | |
109 | inline Scalar value() const { return m_sparseIter.value() * m_diagCoeffNested.coeff(index()); } |
110 | inline StorageIndex index() const { return m_sparseIter.index(); } |
111 | inline Index outer() const { return m_sparseIter.outer(); } |
112 | inline Index col() const { return SparseXprType::IsRowMajor ? m_sparseIter.index() : m_sparseIter.outer(); } |
113 | inline Index row() const { return SparseXprType::IsRowMajor ? m_sparseIter.outer() : m_sparseIter.index(); } |
114 | |
115 | EIGEN_STRONG_INLINE InnerIterator& operator++() { ++m_sparseIter; return *this; } |
116 | inline operator bool() const { return m_sparseIter; } |
117 | |
118 | protected: |
119 | SparseXprIter m_sparseIter; |
120 | DiagCoeffNested m_diagCoeffNested; |
121 | }; |
122 | |
123 | sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff) |
124 | : m_sparseXprEval(sparseXpr), m_diagCoeffNested(diagCoeff) |
125 | {} |
126 | |
127 | Index nonZerosEstimate() const { return m_sparseXprEval.nonZerosEstimate(); } |
128 | |
129 | protected: |
130 | evaluator<SparseXprType> m_sparseXprEval; |
131 | DiagCoeffNested m_diagCoeffNested; |
132 | }; |
133 | |
134 | } // end namespace internal |
135 | |
136 | } // end namespace Eigen |
137 | |
138 | #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H |
139 | |