| 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr> |
| 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #ifndef EIGEN_SPARSEDENSEPRODUCT_H |
| 11 | #define EIGEN_SPARSEDENSEPRODUCT_H |
| 12 | |
| 13 | namespace Eigen { |
| 14 | |
| 15 | namespace internal { |
| 16 | |
| 17 | template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; |
| 18 | template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; |
| 19 | |
| 20 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, |
| 21 | typename AlphaType, |
| 22 | int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, |
| 23 | bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> |
| 24 | struct sparse_time_dense_product_impl; |
| 25 | |
| 26 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
| 27 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true> |
| 28 | { |
| 29 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
| 30 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
| 31 | typedef typename internal::remove_all<DenseResType>::type Res; |
| 32 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
| 33 | typedef evaluator<Lhs> LhsEval; |
| 34 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) |
| 35 | { |
| 36 | LhsEval lhsEval(lhs); |
| 37 | |
| 38 | Index n = lhs.outerSize(); |
| 39 | #ifdef EIGEN_HAS_OPENMP |
| 40 | Eigen::initParallel(); |
| 41 | Index threads = Eigen::nbThreads(); |
| 42 | #endif |
| 43 | |
| 44 | for(Index c=0; c<rhs.cols(); ++c) |
| 45 | { |
| 46 | #ifdef EIGEN_HAS_OPENMP |
| 47 | // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems. |
| 48 | // It basically represents the minimal amount of work to be done to be worth it. |
| 49 | if(threads>1 && lhsEval.nonZerosEstimate() > 20000) |
| 50 | { |
| 51 | #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads) |
| 52 | for(Index i=0; i<n; ++i) |
| 53 | processRow(lhsEval,rhs,res,alpha,i,c); |
| 54 | } |
| 55 | else |
| 56 | #endif |
| 57 | { |
| 58 | for(Index i=0; i<n; ++i) |
| 59 | processRow(lhsEval,rhs,res,alpha,i,c); |
| 60 | } |
| 61 | } |
| 62 | } |
| 63 | |
| 64 | static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha, Index i, Index col) |
| 65 | { |
| 66 | typename Res::Scalar tmp(0); |
| 67 | for(LhsInnerIterator it(lhsEval,i); it ;++it) |
| 68 | tmp += it.value() * rhs.coeff(it.index(),col); |
| 69 | res.coeffRef(i,col) += alpha * tmp; |
| 70 | } |
| 71 | |
| 72 | }; |
| 73 | |
| 74 | // FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format? |
| 75 | // -> let's disable it for now as it is conflicting with generic scalar*matrix and matrix*scalar operators |
| 76 | // template<typename T1, typename T2/*, int _Options, typename _StrideType*/> |
| 77 | // struct ScalarBinaryOpTraits<T1, Ref<T2/*, _Options, _StrideType*/> > |
| 78 | // { |
| 79 | // enum { |
| 80 | // Defined = 1 |
| 81 | // }; |
| 82 | // typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType; |
| 83 | // }; |
| 84 | |
| 85 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType> |
| 86 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true> |
| 87 | { |
| 88 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
| 89 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
| 90 | typedef typename internal::remove_all<DenseResType>::type Res; |
| 91 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
| 92 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) |
| 93 | { |
| 94 | evaluator<Lhs> lhsEval(lhs); |
| 95 | for(Index c=0; c<rhs.cols(); ++c) |
| 96 | { |
| 97 | for(Index j=0; j<lhs.outerSize(); ++j) |
| 98 | { |
| 99 | // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); |
| 100 | typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c)); |
| 101 | for(LhsInnerIterator it(lhsEval,j); it ;++it) |
| 102 | res.coeffRef(it.index(),c) += it.value() * rhs_j; |
| 103 | } |
| 104 | } |
| 105 | } |
| 106 | }; |
| 107 | |
| 108 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
| 109 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false> |
| 110 | { |
| 111 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
| 112 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
| 113 | typedef typename internal::remove_all<DenseResType>::type Res; |
| 114 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
| 115 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) |
| 116 | { |
| 117 | evaluator<Lhs> lhsEval(lhs); |
| 118 | for(Index j=0; j<lhs.outerSize(); ++j) |
| 119 | { |
| 120 | typename Res::RowXpr res_j(res.row(j)); |
| 121 | for(LhsInnerIterator it(lhsEval,j); it ;++it) |
| 122 | res_j += (alpha*it.value()) * rhs.row(it.index()); |
| 123 | } |
| 124 | } |
| 125 | }; |
| 126 | |
| 127 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> |
| 128 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false> |
| 129 | { |
| 130 | typedef typename internal::remove_all<SparseLhsType>::type Lhs; |
| 131 | typedef typename internal::remove_all<DenseRhsType>::type Rhs; |
| 132 | typedef typename internal::remove_all<DenseResType>::type Res; |
| 133 | typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; |
| 134 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) |
| 135 | { |
| 136 | evaluator<Lhs> lhsEval(lhs); |
| 137 | for(Index j=0; j<lhs.outerSize(); ++j) |
| 138 | { |
| 139 | typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); |
| 140 | for(LhsInnerIterator it(lhsEval,j); it ;++it) |
| 141 | res.row(it.index()) += (alpha*it.value()) * rhs_j; |
| 142 | } |
| 143 | } |
| 144 | }; |
| 145 | |
| 146 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> |
| 147 | inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) |
| 148 | { |
| 149 | sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha); |
| 150 | } |
| 151 | |
| 152 | } // end namespace internal |
| 153 | |
| 154 | namespace internal { |
| 155 | |
| 156 | template<typename Lhs, typename Rhs, int ProductType> |
| 157 | struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> |
| 158 | : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> > |
| 159 | { |
| 160 | typedef typename Product<Lhs,Rhs>::Scalar Scalar; |
| 161 | |
| 162 | template<typename Dest> |
| 163 | static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) |
| 164 | { |
| 165 | typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested; |
| 166 | typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested; |
| 167 | LhsNested lhsNested(lhs); |
| 168 | RhsNested rhsNested(rhs); |
| 169 | internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha); |
| 170 | } |
| 171 | }; |
| 172 | |
| 173 | template<typename Lhs, typename Rhs, int ProductType> |
| 174 | struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType> |
| 175 | : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> |
| 176 | {}; |
| 177 | |
| 178 | template<typename Lhs, typename Rhs, int ProductType> |
| 179 | struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> |
| 180 | : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> > |
| 181 | { |
| 182 | typedef typename Product<Lhs,Rhs>::Scalar Scalar; |
| 183 | |
| 184 | template<typename Dst> |
| 185 | static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) |
| 186 | { |
| 187 | typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested; |
| 188 | typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested; |
| 189 | LhsNested lhsNested(lhs); |
| 190 | RhsNested rhsNested(rhs); |
| 191 | |
| 192 | // transpose everything |
| 193 | Transpose<Dst> dstT(dst); |
| 194 | internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha); |
| 195 | } |
| 196 | }; |
| 197 | |
| 198 | template<typename Lhs, typename Rhs, int ProductType> |
| 199 | struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType> |
| 200 | : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> |
| 201 | {}; |
| 202 | |
| 203 | template<typename LhsT, typename RhsT, bool NeedToTranspose> |
| 204 | struct sparse_dense_outer_product_evaluator |
| 205 | { |
| 206 | protected: |
| 207 | typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1; |
| 208 | typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs; |
| 209 | typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType; |
| 210 | |
| 211 | // if the actual left-hand side is a dense vector, |
| 212 | // then build a sparse-view so that we can seamlessly iterate over it. |
| 213 | typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, |
| 214 | Lhs1, SparseView<Lhs1> >::type ActualLhs; |
| 215 | typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, |
| 216 | Lhs1 const&, SparseView<Lhs1> >::type LhsArg; |
| 217 | |
| 218 | typedef evaluator<ActualLhs> LhsEval; |
| 219 | typedef evaluator<ActualRhs> RhsEval; |
| 220 | typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator; |
| 221 | typedef typename ProdXprType::Scalar Scalar; |
| 222 | |
| 223 | public: |
| 224 | enum { |
| 225 | Flags = NeedToTranspose ? RowMajorBit : 0, |
| 226 | CoeffReadCost = HugeCost |
| 227 | }; |
| 228 | |
| 229 | class InnerIterator : public LhsIterator |
| 230 | { |
| 231 | public: |
| 232 | InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) |
| 233 | : LhsIterator(xprEval.m_lhsXprImpl, 0), |
| 234 | m_outer(outer), |
| 235 | m_empty(false), |
| 236 | m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() )) |
| 237 | {} |
| 238 | |
| 239 | EIGEN_STRONG_INLINE Index outer() const { return m_outer; } |
| 240 | EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); } |
| 241 | EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; } |
| 242 | |
| 243 | EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } |
| 244 | EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } |
| 245 | |
| 246 | protected: |
| 247 | Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const |
| 248 | { |
| 249 | return rhs.coeff(outer); |
| 250 | } |
| 251 | |
| 252 | Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) |
| 253 | { |
| 254 | typename RhsEval::InnerIterator it(rhs, outer); |
| 255 | if (it && it.index()==0 && it.value()!=Scalar(0)) |
| 256 | return it.value(); |
| 257 | m_empty = true; |
| 258 | return Scalar(0); |
| 259 | } |
| 260 | |
| 261 | Index m_outer; |
| 262 | bool m_empty; |
| 263 | Scalar m_factor; |
| 264 | }; |
| 265 | |
| 266 | sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs) |
| 267 | : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) |
| 268 | { |
| 269 | EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); |
| 270 | } |
| 271 | |
| 272 | // transpose case |
| 273 | sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs) |
| 274 | : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) |
| 275 | { |
| 276 | EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); |
| 277 | } |
| 278 | |
| 279 | protected: |
| 280 | const LhsArg m_lhs; |
| 281 | evaluator<ActualLhs> m_lhsXprImpl; |
| 282 | evaluator<ActualRhs> m_rhsXprImpl; |
| 283 | }; |
| 284 | |
| 285 | // sparse * dense outer product |
| 286 | template<typename Lhs, typename Rhs> |
| 287 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape> |
| 288 | : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> |
| 289 | { |
| 290 | typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; |
| 291 | |
| 292 | typedef Product<Lhs, Rhs> XprType; |
| 293 | typedef typename XprType::PlainObject PlainObject; |
| 294 | |
| 295 | explicit product_evaluator(const XprType& xpr) |
| 296 | : Base(xpr.lhs(), xpr.rhs()) |
| 297 | {} |
| 298 | |
| 299 | }; |
| 300 | |
| 301 | template<typename Lhs, typename Rhs> |
| 302 | struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape> |
| 303 | : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> |
| 304 | { |
| 305 | typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; |
| 306 | |
| 307 | typedef Product<Lhs, Rhs> XprType; |
| 308 | typedef typename XprType::PlainObject PlainObject; |
| 309 | |
| 310 | explicit product_evaluator(const XprType& xpr) |
| 311 | : Base(xpr.lhs(), xpr.rhs()) |
| 312 | {} |
| 313 | |
| 314 | }; |
| 315 | |
| 316 | } // end namespace internal |
| 317 | |
| 318 | } // end namespace Eigen |
| 319 | |
| 320 | #endif // EIGEN_SPARSEDENSEPRODUCT_H |
| 321 | |