1 | // This file is part of Eigen, a lightweight C++ template library |
2 | // for linear algebra. |
3 | // |
4 | // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr> |
5 | // Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr> |
6 | // |
7 | // This Source Code Form is subject to the terms of the Mozilla |
8 | // Public License v. 2.0. If a copy of the MPL was not distributed |
9 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
10 | |
11 | #ifndef EIGEN_SPARSE_TRIANGULARVIEW_H |
12 | #define EIGEN_SPARSE_TRIANGULARVIEW_H |
13 | |
14 | namespace Eigen { |
15 | |
16 | /** \ingroup SparseCore_Module |
17 | * |
18 | * \brief Base class for a triangular part in a \b sparse matrix |
19 | * |
20 | * This class is an abstract base class of class TriangularView, and objects of type TriangularViewImpl cannot be instantiated. |
21 | * It extends class TriangularView with additional methods which are available for sparse expressions only. |
22 | * |
23 | * \sa class TriangularView, SparseMatrixBase::triangularView() |
24 | */ |
25 | template<typename MatrixType, unsigned int Mode> class TriangularViewImpl<MatrixType,Mode,Sparse> |
26 | : public SparseMatrixBase<TriangularView<MatrixType,Mode> > |
27 | { |
28 | enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit)) |
29 | || ((Mode&Upper) && (MatrixType::Flags&RowMajorBit)), |
30 | SkipLast = !SkipFirst, |
31 | SkipDiag = (Mode&ZeroDiag) ? 1 : 0, |
32 | HasUnitDiag = (Mode&UnitDiag) ? 1 : 0 |
33 | }; |
34 | |
35 | typedef TriangularView<MatrixType,Mode> TriangularViewType; |
36 | |
37 | protected: |
38 | // dummy solve function to make TriangularView happy. |
39 | void solve() const; |
40 | |
41 | typedef SparseMatrixBase<TriangularViewType> Base; |
42 | public: |
43 | |
44 | EIGEN_SPARSE_PUBLIC_INTERFACE(TriangularViewType) |
45 | |
46 | typedef typename MatrixType::Nested MatrixTypeNested; |
47 | typedef typename internal::remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef; |
48 | typedef typename internal::remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned; |
49 | |
50 | template<typename RhsType, typename DstType> |
51 | EIGEN_DEVICE_FUNC |
52 | EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const { |
53 | if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs))) |
54 | dst = rhs; |
55 | this->solveInPlace(dst); |
56 | } |
57 | |
58 | /** Applies the inverse of \c *this to the dense vector or matrix \a other, "in-place" */ |
59 | template<typename OtherDerived> void solveInPlace(MatrixBase<OtherDerived>& other) const; |
60 | |
61 | /** Applies the inverse of \c *this to the sparse vector or matrix \a other, "in-place" */ |
62 | template<typename OtherDerived> void solveInPlace(SparseMatrixBase<OtherDerived>& other) const; |
63 | |
64 | }; |
65 | |
66 | namespace internal { |
67 | |
68 | template<typename ArgType, unsigned int Mode> |
69 | struct unary_evaluator<TriangularView<ArgType,Mode>, IteratorBased> |
70 | : evaluator_base<TriangularView<ArgType,Mode> > |
71 | { |
72 | typedef TriangularView<ArgType,Mode> XprType; |
73 | |
74 | protected: |
75 | |
76 | typedef typename XprType::Scalar Scalar; |
77 | typedef typename XprType::StorageIndex StorageIndex; |
78 | typedef typename evaluator<ArgType>::InnerIterator EvalIterator; |
79 | |
80 | enum { SkipFirst = ((Mode&Lower) && !(ArgType::Flags&RowMajorBit)) |
81 | || ((Mode&Upper) && (ArgType::Flags&RowMajorBit)), |
82 | SkipLast = !SkipFirst, |
83 | SkipDiag = (Mode&ZeroDiag) ? 1 : 0, |
84 | HasUnitDiag = (Mode&UnitDiag) ? 1 : 0 |
85 | }; |
86 | |
87 | public: |
88 | |
89 | enum { |
90 | CoeffReadCost = evaluator<ArgType>::CoeffReadCost, |
91 | Flags = XprType::Flags |
92 | }; |
93 | |
94 | explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()), m_arg(xpr.nestedExpression()) {} |
95 | |
96 | inline Index nonZerosEstimate() const { |
97 | return m_argImpl.nonZerosEstimate(); |
98 | } |
99 | |
100 | class InnerIterator : public EvalIterator |
101 | { |
102 | typedef EvalIterator Base; |
103 | public: |
104 | |
105 | EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer) |
106 | : Base(xprEval.m_argImpl,outer), m_returnOne(false), m_containsDiag(Base::outer()<xprEval.m_arg.innerSize()) |
107 | { |
108 | if(SkipFirst) |
109 | { |
110 | while((*this) && ((HasUnitDiag||SkipDiag) ? this->index()<=outer : this->index()<outer)) |
111 | Base::operator++(); |
112 | if(HasUnitDiag) |
113 | m_returnOne = m_containsDiag; |
114 | } |
115 | else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer())) |
116 | { |
117 | if((!SkipFirst) && Base::operator bool()) |
118 | Base::operator++(); |
119 | m_returnOne = m_containsDiag; |
120 | } |
121 | } |
122 | |
123 | EIGEN_STRONG_INLINE InnerIterator& operator++() |
124 | { |
125 | if(HasUnitDiag && m_returnOne) |
126 | m_returnOne = false; |
127 | else |
128 | { |
129 | Base::operator++(); |
130 | if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer())) |
131 | { |
132 | if((!SkipFirst) && Base::operator bool()) |
133 | Base::operator++(); |
134 | m_returnOne = m_containsDiag; |
135 | } |
136 | } |
137 | return *this; |
138 | } |
139 | |
140 | EIGEN_STRONG_INLINE operator bool() const |
141 | { |
142 | if(HasUnitDiag && m_returnOne) |
143 | return true; |
144 | if(SkipFirst) return Base::operator bool(); |
145 | else |
146 | { |
147 | if (SkipDiag) return (Base::operator bool() && this->index() < this->outer()); |
148 | else return (Base::operator bool() && this->index() <= this->outer()); |
149 | } |
150 | } |
151 | |
152 | // inline Index row() const { return (ArgType::Flags&RowMajorBit ? Base::outer() : this->index()); } |
153 | // inline Index col() const { return (ArgType::Flags&RowMajorBit ? this->index() : Base::outer()); } |
154 | inline StorageIndex index() const |
155 | { |
156 | if(HasUnitDiag && m_returnOne) return internal::convert_index<StorageIndex>(Base::outer()); |
157 | else return Base::index(); |
158 | } |
159 | inline Scalar value() const |
160 | { |
161 | if(HasUnitDiag && m_returnOne) return Scalar(1); |
162 | else return Base::value(); |
163 | } |
164 | |
165 | protected: |
166 | bool m_returnOne; |
167 | bool m_containsDiag; |
168 | private: |
169 | Scalar& valueRef(); |
170 | }; |
171 | |
172 | protected: |
173 | evaluator<ArgType> m_argImpl; |
174 | const ArgType& m_arg; |
175 | }; |
176 | |
177 | } // end namespace internal |
178 | |
179 | template<typename Derived> |
180 | template<int Mode> |
181 | inline const TriangularView<const Derived, Mode> |
182 | SparseMatrixBase<Derived>::triangularView() const |
183 | { |
184 | return TriangularView<const Derived, Mode>(derived()); |
185 | } |
186 | |
187 | } // end namespace Eigen |
188 | |
189 | #endif // EIGEN_SPARSE_TRIANGULARVIEW_H |
190 | |