1 | // This file is part of Eigen, a lightweight C++ template library |
2 | // for linear algebra. |
3 | // |
4 | // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr> |
5 | // |
6 | // This Source Code Form is subject to the terms of the Mozilla |
7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
9 | |
10 | #ifndef EIGEN_SPARSETRIANGULARSOLVER_H |
11 | #define EIGEN_SPARSETRIANGULARSOLVER_H |
12 | |
13 | namespace Eigen { |
14 | |
15 | namespace internal { |
16 | |
17 | template<typename Lhs, typename Rhs, int Mode, |
18 | int UpLo = (Mode & Lower) |
19 | ? Lower |
20 | : (Mode & Upper) |
21 | ? Upper |
22 | : -1, |
23 | int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit> |
24 | struct sparse_solve_triangular_selector; |
25 | |
26 | // forward substitution, row-major |
27 | template<typename Lhs, typename Rhs, int Mode> |
28 | struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor> |
29 | { |
30 | typedef typename Rhs::Scalar Scalar; |
31 | typedef evaluator<Lhs> LhsEval; |
32 | typedef typename evaluator<Lhs>::InnerIterator LhsIterator; |
33 | static void run(const Lhs& lhs, Rhs& other) |
34 | { |
35 | LhsEval lhsEval(lhs); |
36 | for(Index col=0 ; col<other.cols() ; ++col) |
37 | { |
38 | for(Index i=0; i<lhs.rows(); ++i) |
39 | { |
40 | Scalar tmp = other.coeff(i,col); |
41 | Scalar lastVal(0); |
42 | Index lastIndex = 0; |
43 | for(LhsIterator it(lhsEval, i); it; ++it) |
44 | { |
45 | lastVal = it.value(); |
46 | lastIndex = it.index(); |
47 | if(lastIndex==i) |
48 | break; |
49 | tmp -= lastVal * other.coeff(lastIndex,col); |
50 | } |
51 | if (Mode & UnitDiag) |
52 | other.coeffRef(i,col) = tmp; |
53 | else |
54 | { |
55 | eigen_assert(lastIndex==i); |
56 | other.coeffRef(i,col) = tmp/lastVal; |
57 | } |
58 | } |
59 | } |
60 | } |
61 | }; |
62 | |
63 | // backward substitution, row-major |
64 | template<typename Lhs, typename Rhs, int Mode> |
65 | struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor> |
66 | { |
67 | typedef typename Rhs::Scalar Scalar; |
68 | typedef evaluator<Lhs> LhsEval; |
69 | typedef typename evaluator<Lhs>::InnerIterator LhsIterator; |
70 | static void run(const Lhs& lhs, Rhs& other) |
71 | { |
72 | LhsEval lhsEval(lhs); |
73 | for(Index col=0 ; col<other.cols() ; ++col) |
74 | { |
75 | for(Index i=lhs.rows()-1 ; i>=0 ; --i) |
76 | { |
77 | Scalar tmp = other.coeff(i,col); |
78 | Scalar l_ii(0); |
79 | LhsIterator it(lhsEval, i); |
80 | while(it && it.index()<i) |
81 | ++it; |
82 | if(!(Mode & UnitDiag)) |
83 | { |
84 | eigen_assert(it && it.index()==i); |
85 | l_ii = it.value(); |
86 | ++it; |
87 | } |
88 | else if (it && it.index() == i) |
89 | ++it; |
90 | for(; it; ++it) |
91 | { |
92 | tmp -= it.value() * other.coeff(it.index(),col); |
93 | } |
94 | |
95 | if (Mode & UnitDiag) other.coeffRef(i,col) = tmp; |
96 | else other.coeffRef(i,col) = tmp/l_ii; |
97 | } |
98 | } |
99 | } |
100 | }; |
101 | |
102 | // forward substitution, col-major |
103 | template<typename Lhs, typename Rhs, int Mode> |
104 | struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor> |
105 | { |
106 | typedef typename Rhs::Scalar Scalar; |
107 | typedef evaluator<Lhs> LhsEval; |
108 | typedef typename evaluator<Lhs>::InnerIterator LhsIterator; |
109 | static void run(const Lhs& lhs, Rhs& other) |
110 | { |
111 | LhsEval lhsEval(lhs); |
112 | for(Index col=0 ; col<other.cols() ; ++col) |
113 | { |
114 | for(Index i=0; i<lhs.cols(); ++i) |
115 | { |
116 | Scalar& tmp = other.coeffRef(i,col); |
117 | if (tmp!=Scalar(0)) // optimization when other is actually sparse |
118 | { |
119 | LhsIterator it(lhsEval, i); |
120 | while(it && it.index()<i) |
121 | ++it; |
122 | if(!(Mode & UnitDiag)) |
123 | { |
124 | eigen_assert(it && it.index()==i); |
125 | tmp /= it.value(); |
126 | } |
127 | if (it && it.index()==i) |
128 | ++it; |
129 | for(; it; ++it) |
130 | other.coeffRef(it.index(), col) -= tmp * it.value(); |
131 | } |
132 | } |
133 | } |
134 | } |
135 | }; |
136 | |
137 | // backward substitution, col-major |
138 | template<typename Lhs, typename Rhs, int Mode> |
139 | struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor> |
140 | { |
141 | typedef typename Rhs::Scalar Scalar; |
142 | typedef evaluator<Lhs> LhsEval; |
143 | typedef typename evaluator<Lhs>::InnerIterator LhsIterator; |
144 | static void run(const Lhs& lhs, Rhs& other) |
145 | { |
146 | LhsEval lhsEval(lhs); |
147 | for(Index col=0 ; col<other.cols() ; ++col) |
148 | { |
149 | for(Index i=lhs.cols()-1; i>=0; --i) |
150 | { |
151 | Scalar& tmp = other.coeffRef(i,col); |
152 | if (tmp!=Scalar(0)) // optimization when other is actually sparse |
153 | { |
154 | if(!(Mode & UnitDiag)) |
155 | { |
156 | // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements |
157 | LhsIterator it(lhsEval, i); |
158 | while(it && it.index()!=i) |
159 | ++it; |
160 | eigen_assert(it && it.index()==i); |
161 | other.coeffRef(i,col) /= it.value(); |
162 | } |
163 | LhsIterator it(lhsEval, i); |
164 | for(; it && it.index()<i; ++it) |
165 | other.coeffRef(it.index(), col) -= tmp * it.value(); |
166 | } |
167 | } |
168 | } |
169 | } |
170 | }; |
171 | |
172 | } // end namespace internal |
173 | |
174 | #ifndef EIGEN_PARSED_BY_DOXYGEN |
175 | |
176 | template<typename ExpressionType,unsigned int Mode> |
177 | template<typename OtherDerived> |
178 | void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const |
179 | { |
180 | eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows()); |
181 | eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); |
182 | |
183 | enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit }; |
184 | |
185 | typedef typename internal::conditional<copy, |
186 | typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; |
187 | OtherCopy otherCopy(other.derived()); |
188 | |
189 | internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(derived().nestedExpression(), otherCopy); |
190 | |
191 | if (copy) |
192 | other = otherCopy; |
193 | } |
194 | #endif |
195 | |
196 | // pure sparse path |
197 | |
198 | namespace internal { |
199 | |
200 | template<typename Lhs, typename Rhs, int Mode, |
201 | int UpLo = (Mode & Lower) |
202 | ? Lower |
203 | : (Mode & Upper) |
204 | ? Upper |
205 | : -1, |
206 | int StorageOrder = int(Lhs::Flags) & (RowMajorBit)> |
207 | struct sparse_solve_triangular_sparse_selector; |
208 | |
209 | // forward substitution, col-major |
210 | template<typename Lhs, typename Rhs, int Mode, int UpLo> |
211 | struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor> |
212 | { |
213 | typedef typename Rhs::Scalar Scalar; |
214 | typedef typename promote_index_type<typename traits<Lhs>::StorageIndex, |
215 | typename traits<Rhs>::StorageIndex>::type StorageIndex; |
216 | static void run(const Lhs& lhs, Rhs& other) |
217 | { |
218 | const bool IsLower = (UpLo==Lower); |
219 | AmbiVector<Scalar,StorageIndex> tempVector(other.rows()*2); |
220 | tempVector.setBounds(0,other.rows()); |
221 | |
222 | Rhs res(other.rows(), other.cols()); |
223 | res.reserve(other.nonZeros()); |
224 | |
225 | for(Index col=0 ; col<other.cols() ; ++col) |
226 | { |
227 | // FIXME estimate number of non zeros |
228 | tempVector.init(.99/*float(other.col(col).nonZeros())/float(other.rows())*/); |
229 | tempVector.setZero(); |
230 | tempVector.restart(); |
231 | for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt) |
232 | { |
233 | tempVector.coeffRef(rhsIt.index()) = rhsIt.value(); |
234 | } |
235 | |
236 | for(Index i=IsLower?0:lhs.cols()-1; |
237 | IsLower?i<lhs.cols():i>=0; |
238 | i+=IsLower?1:-1) |
239 | { |
240 | tempVector.restart(); |
241 | Scalar& ci = tempVector.coeffRef(i); |
242 | if (ci!=Scalar(0)) |
243 | { |
244 | // find |
245 | typename Lhs::InnerIterator it(lhs, i); |
246 | if(!(Mode & UnitDiag)) |
247 | { |
248 | if (IsLower) |
249 | { |
250 | eigen_assert(it.index()==i); |
251 | ci /= it.value(); |
252 | } |
253 | else |
254 | ci /= lhs.coeff(i,i); |
255 | } |
256 | tempVector.restart(); |
257 | if (IsLower) |
258 | { |
259 | if (it.index()==i) |
260 | ++it; |
261 | for(; it; ++it) |
262 | tempVector.coeffRef(it.index()) -= ci * it.value(); |
263 | } |
264 | else |
265 | { |
266 | for(; it && it.index()<i; ++it) |
267 | tempVector.coeffRef(it.index()) -= ci * it.value(); |
268 | } |
269 | } |
270 | } |
271 | |
272 | |
273 | Index count = 0; |
274 | // FIXME compute a reference value to filter zeros |
275 | for (typename AmbiVector<Scalar,StorageIndex>::Iterator it(tempVector/*,1e-12*/); it; ++it) |
276 | { |
277 | ++ count; |
278 | // std::cerr << "fill " << it.index() << ", " << col << "\n"; |
279 | // std::cout << it.value() << " "; |
280 | // FIXME use insertBack |
281 | res.insert(it.index(), col) = it.value(); |
282 | } |
283 | // std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n"; |
284 | } |
285 | res.finalize(); |
286 | other = res.markAsRValue(); |
287 | } |
288 | }; |
289 | |
290 | } // end namespace internal |
291 | |
292 | #ifndef EIGEN_PARSED_BY_DOXYGEN |
293 | template<typename ExpressionType,unsigned int Mode> |
294 | template<typename OtherDerived> |
295 | void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const |
296 | { |
297 | eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows()); |
298 | eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); |
299 | |
300 | // enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit }; |
301 | |
302 | // typedef typename internal::conditional<copy, |
303 | // typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; |
304 | // OtherCopy otherCopy(other.derived()); |
305 | |
306 | internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(derived().nestedExpression(), other.derived()); |
307 | |
308 | // if (copy) |
309 | // other = otherCopy; |
310 | } |
311 | #endif |
312 | |
313 | } // end namespace Eigen |
314 | |
315 | #endif // EIGEN_SPARSETRIANGULARSOLVER_H |
316 | |