1 | // This file is part of Eigen, a lightweight C++ template library |
2 | // for linear algebra. |
3 | // |
4 | // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr> |
5 | // |
6 | // This Source Code Form is subject to the terms of the Mozilla |
7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
9 | |
10 | #ifndef EIGEN_SPARSEDENSEPRODUCT_H |
11 | #define EIGEN_SPARSEDENSEPRODUCT_H |
12 | |
13 | namespace Eigen { |
14 | |
15 | namespace internal { |
16 | |
17 | template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; |
18 | template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; |
19 | |
20 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, |
21 | typename AlphaType, |
22 | int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, |
23 | bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> |
24 | struct sparse_time_dense_product_impl; |
25 | |
26 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
27 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true> |
28 | { |
29 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
30 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
31 | typedef typename internal::remove_all<DenseResType>::type Res; |
32 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
33 | typedef evaluator<Lhs> LhsEval; |
34 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) |
35 | { |
36 | LhsEval lhsEval(lhs); |
37 | |
38 | Index n = lhs.outerSize(); |
39 | #ifdef EIGEN_HAS_OPENMP |
40 | Eigen::initParallel(); |
41 | Index threads = Eigen::nbThreads(); |
42 | #endif |
43 | |
44 | for(Index c=0; c<rhs.cols(); ++c) |
45 | { |
46 | #ifdef EIGEN_HAS_OPENMP |
47 | // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems. |
48 | // It basically represents the minimal amount of work to be done to be worth it. |
49 | if(threads>1 && lhsEval.nonZerosEstimate() > 20000) |
50 | { |
51 | #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads) |
52 | for(Index i=0; i<n; ++i) |
53 | processRow(lhsEval,rhs,res,alpha,i,c); |
54 | } |
55 | else |
56 | #endif |
57 | { |
58 | for(Index i=0; i<n; ++i) |
59 | processRow(lhsEval,rhs,res,alpha,i,c); |
60 | } |
61 | } |
62 | } |
63 | |
64 | static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha, Index i, Index col) |
65 | { |
66 | typename Res::Scalar tmp(0); |
67 | for(LhsInnerIterator it(lhsEval,i); it ;++it) |
68 | tmp += it.value() * rhs.coeff(it.index(),col); |
69 | res.coeffRef(i,col) += alpha * tmp; |
70 | } |
71 | |
72 | }; |
73 | |
74 | // FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format? |
75 | // -> let's disable it for now as it is conflicting with generic scalar*matrix and matrix*scalar operators |
76 | // template<typename T1, typename T2/*, int _Options, typename _StrideType*/> |
77 | // struct ScalarBinaryOpTraits<T1, Ref<T2/*, _Options, _StrideType*/> > |
78 | // { |
79 | // enum { |
80 | // Defined = 1 |
81 | // }; |
82 | // typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType; |
83 | // }; |
84 | |
85 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType> |
86 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true> |
87 | { |
88 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
89 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
90 | typedef typename internal::remove_all<DenseResType>::type Res; |
91 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
92 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) |
93 | { |
94 | evaluator<Lhs> lhsEval(lhs); |
95 | for(Index c=0; c<rhs.cols(); ++c) |
96 | { |
97 | for(Index j=0; j<lhs.outerSize(); ++j) |
98 | { |
99 | // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); |
100 | typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c)); |
101 | for(LhsInnerIterator it(lhsEval,j); it ;++it) |
102 | res.coeffRef(it.index(),c) += it.value() * rhs_j; |
103 | } |
104 | } |
105 | } |
106 | }; |
107 | |
108 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
109 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false> |
110 | { |
111 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
112 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
113 | typedef typename internal::remove_all<DenseResType>::type Res; |
114 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
115 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) |
116 | { |
117 | evaluator<Lhs> lhsEval(lhs); |
118 | for(Index j=0; j<lhs.outerSize(); ++j) |
119 | { |
120 | typename Res::RowXpr res_j(res.row(j)); |
121 | for(LhsInnerIterator it(lhsEval,j); it ;++it) |
122 | res_j += (alpha*it.value()) * rhs.row(it.index()); |
123 | } |
124 | } |
125 | }; |
126 | |
127 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
128 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false> |
129 | { |
130 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
131 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
132 | typedef typename internal::remove_all<DenseResType>::type Res; |
133 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
134 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) |
135 | { |
136 | evaluator<Lhs> lhsEval(lhs); |
137 | for(Index j=0; j<lhs.outerSize(); ++j) |
138 | { |
139 | typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); |
140 | for(LhsInnerIterator it(lhsEval,j); it ;++it) |
141 | res.row(it.index()) += (alpha*it.value()) * rhs_j; |
142 | } |
143 | } |
144 | }; |
145 | |
146 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> |
147 | inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) |
148 | { |
149 | sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha); |
150 | } |
151 | |
152 | } // end namespace internal |
153 | |
154 | namespace internal { |
155 | |
156 | template<typename Lhs, typename Rhs, int ProductType> |
157 | struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> |
158 | : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> > |
159 | { |
160 | typedef typename Product<Lhs,Rhs>::Scalar Scalar; |
161 | |
162 | template<typename Dest> |
163 | static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) |
164 | { |
165 | typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested; |
166 | typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested; |
167 | LhsNested lhsNested(lhs); |
168 | RhsNested rhsNested(rhs); |
169 | internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha); |
170 | } |
171 | }; |
172 | |
173 | template<typename Lhs, typename Rhs, int ProductType> |
174 | struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType> |
175 | : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> |
176 | {}; |
177 | |
178 | template<typename Lhs, typename Rhs, int ProductType> |
179 | struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> |
180 | : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> > |
181 | { |
182 | typedef typename Product<Lhs,Rhs>::Scalar Scalar; |
183 | |
184 | template<typename Dst> |
185 | static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) |
186 | { |
187 | typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested; |
188 | typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested; |
189 | LhsNested lhsNested(lhs); |
190 | RhsNested rhsNested(rhs); |
191 | |
192 | // transpose everything |
193 | Transpose<Dst> dstT(dst); |
194 | internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha); |
195 | } |
196 | }; |
197 | |
198 | template<typename Lhs, typename Rhs, int ProductType> |
199 | struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType> |
200 | : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> |
201 | {}; |
202 | |
203 | template<typename LhsT, typename RhsT, bool NeedToTranspose> |
204 | struct sparse_dense_outer_product_evaluator |
205 | { |
206 | protected: |
207 | typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1; |
208 | typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs; |
209 | typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType; |
210 | |
211 | // if the actual left-hand side is a dense vector, |
212 | // then build a sparse-view so that we can seamlessly iterate over it. |
213 | typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, |
214 | Lhs1, SparseView<Lhs1> >::type ActualLhs; |
215 | typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, |
216 | Lhs1 const&, SparseView<Lhs1> >::type LhsArg; |
217 | |
218 | typedef evaluator<ActualLhs> LhsEval; |
219 | typedef evaluator<ActualRhs> RhsEval; |
220 | typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator; |
221 | typedef typename ProdXprType::Scalar Scalar; |
222 | |
223 | public: |
224 | enum { |
225 | Flags = NeedToTranspose ? RowMajorBit : 0, |
226 | CoeffReadCost = HugeCost |
227 | }; |
228 | |
229 | class InnerIterator : public LhsIterator |
230 | { |
231 | public: |
232 | InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) |
233 | : LhsIterator(xprEval.m_lhsXprImpl, 0), |
234 | m_outer(outer), |
235 | m_empty(false), |
236 | m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() )) |
237 | {} |
238 | |
239 | EIGEN_STRONG_INLINE Index outer() const { return m_outer; } |
240 | EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); } |
241 | EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; } |
242 | |
243 | EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } |
244 | EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } |
245 | |
246 | protected: |
247 | Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const |
248 | { |
249 | return rhs.coeff(outer); |
250 | } |
251 | |
252 | Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) |
253 | { |
254 | typename RhsEval::InnerIterator it(rhs, outer); |
255 | if (it && it.index()==0 && it.value()!=Scalar(0)) |
256 | return it.value(); |
257 | m_empty = true; |
258 | return Scalar(0); |
259 | } |
260 | |
261 | Index m_outer; |
262 | bool m_empty; |
263 | Scalar m_factor; |
264 | }; |
265 | |
266 | sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs) |
267 | : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) |
268 | { |
269 | EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); |
270 | } |
271 | |
272 | // transpose case |
273 | sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs) |
274 | : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) |
275 | { |
276 | EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); |
277 | } |
278 | |
279 | protected: |
280 | const LhsArg m_lhs; |
281 | evaluator<ActualLhs> m_lhsXprImpl; |
282 | evaluator<ActualRhs> m_rhsXprImpl; |
283 | }; |
284 | |
285 | // sparse * dense outer product |
286 | template<typename Lhs, typename Rhs> |
287 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape> |
288 | : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> |
289 | { |
290 | typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; |
291 | |
292 | typedef Product<Lhs, Rhs> XprType; |
293 | typedef typename XprType::PlainObject PlainObject; |
294 | |
295 | explicit product_evaluator(const XprType& xpr) |
296 | : Base(xpr.lhs(), xpr.rhs()) |
297 | {} |
298 | |
299 | }; |
300 | |
301 | template<typename Lhs, typename Rhs> |
302 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape> |
303 | : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> |
304 | { |
305 | typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; |
306 | |
307 | typedef Product<Lhs, Rhs> XprType; |
308 | typedef typename XprType::PlainObject PlainObject; |
309 | |
310 | explicit product_evaluator(const XprType& xpr) |
311 | : Base(xpr.lhs(), xpr.rhs()) |
312 | {} |
313 | |
314 | }; |
315 | |
316 | } // end namespace internal |
317 | |
318 | } // end namespace Eigen |
319 | |
320 | #endif // EIGEN_SPARSEDENSEPRODUCT_H |
321 | |