| 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr> |
| 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #ifndef EIGEN_BLASUTIL_H |
| 11 | #define EIGEN_BLASUTIL_H |
| 12 | |
| 13 | // This file contains many lightweight helper classes used to |
| 14 | // implement and control fast level 2 and level 3 BLAS-like routines. |
| 15 | |
| 16 | namespace Eigen { |
| 17 | |
| 18 | namespace internal { |
| 19 | |
| 20 | // forward declarations |
| 21 | template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> |
| 22 | struct gebp_kernel; |
| 23 | |
| 24 | template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false> |
| 25 | struct gemm_pack_rhs; |
| 26 | |
| 27 | template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false> |
| 28 | struct gemm_pack_lhs; |
| 29 | |
| 30 | template< |
| 31 | typename Index, |
| 32 | typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, |
| 33 | typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, |
| 34 | int ResStorageOrder> |
| 35 | struct general_matrix_matrix_product; |
| 36 | |
| 37 | template<typename Index, |
| 38 | typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs, |
| 39 | typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized> |
| 40 | struct general_matrix_vector_product; |
| 41 | |
| 42 | |
| 43 | template<bool Conjugate> struct conj_if; |
| 44 | |
| 45 | template<> struct conj_if<true> { |
| 46 | template<typename T> |
| 47 | inline T operator()(const T& x) const { return numext::conj(x); } |
| 48 | template<typename T> |
| 49 | inline T pconj(const T& x) const { return internal::pconj(x); } |
| 50 | }; |
| 51 | |
| 52 | template<> struct conj_if<false> { |
| 53 | template<typename T> |
| 54 | inline const T& operator()(const T& x) const { return x; } |
| 55 | template<typename T> |
| 56 | inline const T& pconj(const T& x) const { return x; } |
| 57 | }; |
| 58 | |
| 59 | // Generic implementation for custom complex types. |
| 60 | template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs> |
| 61 | struct conj_helper |
| 62 | { |
| 63 | typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar; |
| 64 | |
| 65 | EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const |
| 66 | { return padd(c, pmul(x,y)); } |
| 67 | |
| 68 | EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const |
| 69 | { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); } |
| 70 | }; |
| 71 | |
| 72 | template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false> |
| 73 | { |
| 74 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); } |
| 75 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); } |
| 76 | }; |
| 77 | |
| 78 | template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true> |
| 79 | { |
| 80 | typedef std::complex<RealScalar> Scalar; |
| 81 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const |
| 82 | { return c + pmul(x,y); } |
| 83 | |
| 84 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const |
| 85 | { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); } |
| 86 | }; |
| 87 | |
| 88 | template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false> |
| 89 | { |
| 90 | typedef std::complex<RealScalar> Scalar; |
| 91 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const |
| 92 | { return c + pmul(x,y); } |
| 93 | |
| 94 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const |
| 95 | { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); } |
| 96 | }; |
| 97 | |
| 98 | template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true> |
| 99 | { |
| 100 | typedef std::complex<RealScalar> Scalar; |
| 101 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const |
| 102 | { return c + pmul(x,y); } |
| 103 | |
| 104 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const |
| 105 | { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); } |
| 106 | }; |
| 107 | |
| 108 | template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false> |
| 109 | { |
| 110 | typedef std::complex<RealScalar> Scalar; |
| 111 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const |
| 112 | { return padd(c, pmul(x,y)); } |
| 113 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const |
| 114 | { return conj_if<Conj>()(x)*y; } |
| 115 | }; |
| 116 | |
| 117 | template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj> |
| 118 | { |
| 119 | typedef std::complex<RealScalar> Scalar; |
| 120 | EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const |
| 121 | { return padd(c, pmul(x,y)); } |
| 122 | EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const |
| 123 | { return x*conj_if<Conj>()(y); } |
| 124 | }; |
| 125 | |
| 126 | template<typename From,typename To> struct get_factor { |
| 127 | EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); } |
| 128 | }; |
| 129 | |
| 130 | template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> { |
| 131 | EIGEN_DEVICE_FUNC |
| 132 | static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); } |
| 133 | }; |
| 134 | |
| 135 | |
| 136 | template<typename Scalar, typename Index> |
| 137 | class BlasVectorMapper { |
| 138 | public: |
| 139 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {} |
| 140 | |
| 141 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { |
| 142 | return m_data[i]; |
| 143 | } |
| 144 | template <typename Packet, int AlignmentType> |
| 145 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const { |
| 146 | return ploadt<Packet, AlignmentType>(m_data + i); |
| 147 | } |
| 148 | |
| 149 | template <typename Packet> |
| 150 | EIGEN_DEVICE_FUNC bool aligned(Index i) const { |
| 151 | return (UIntPtr(m_data+i)%sizeof(Packet))==0; |
| 152 | } |
| 153 | |
| 154 | protected: |
| 155 | Scalar* m_data; |
| 156 | }; |
| 157 | |
| 158 | template<typename Scalar, typename Index, int AlignmentType> |
| 159 | class BlasLinearMapper { |
| 160 | public: |
| 161 | typedef typename packet_traits<Scalar>::type Packet; |
| 162 | typedef typename packet_traits<Scalar>::half HalfPacket; |
| 163 | |
| 164 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {} |
| 165 | |
| 166 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { |
| 167 | internal::prefetch(&operator()(i)); |
| 168 | } |
| 169 | |
| 170 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { |
| 171 | return m_data[i]; |
| 172 | } |
| 173 | |
| 174 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { |
| 175 | return ploadt<Packet, AlignmentType>(m_data + i); |
| 176 | } |
| 177 | |
| 178 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { |
| 179 | return ploadt<HalfPacket, AlignmentType>(m_data + i); |
| 180 | } |
| 181 | |
| 182 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const { |
| 183 | pstoret<Scalar, Packet, AlignmentType>(m_data + i, p); |
| 184 | } |
| 185 | |
| 186 | protected: |
| 187 | Scalar *m_data; |
| 188 | }; |
| 189 | |
| 190 | // Lightweight helper class to access matrix coefficients. |
| 191 | template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned> |
| 192 | class blas_data_mapper { |
| 193 | public: |
| 194 | typedef typename packet_traits<Scalar>::type Packet; |
| 195 | typedef typename packet_traits<Scalar>::half HalfPacket; |
| 196 | |
| 197 | typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper; |
| 198 | typedef BlasVectorMapper<Scalar, Index> VectorMapper; |
| 199 | |
| 200 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {} |
| 201 | |
| 202 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType> |
| 203 | getSubMapper(Index i, Index j) const { |
| 204 | return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride); |
| 205 | } |
| 206 | |
| 207 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { |
| 208 | return LinearMapper(&operator()(i, j)); |
| 209 | } |
| 210 | |
| 211 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { |
| 212 | return VectorMapper(&operator()(i, j)); |
| 213 | } |
| 214 | |
| 215 | |
| 216 | EIGEN_DEVICE_FUNC |
| 217 | EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const { |
| 218 | return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; |
| 219 | } |
| 220 | |
| 221 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { |
| 222 | return ploadt<Packet, AlignmentType>(&operator()(i, j)); |
| 223 | } |
| 224 | |
| 225 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { |
| 226 | return ploadt<HalfPacket, AlignmentType>(&operator()(i, j)); |
| 227 | } |
| 228 | |
| 229 | template<typename SubPacket> |
| 230 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const { |
| 231 | pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride); |
| 232 | } |
| 233 | |
| 234 | template<typename SubPacket> |
| 235 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const { |
| 236 | return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride); |
| 237 | } |
| 238 | |
| 239 | EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; } |
| 240 | EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; } |
| 241 | |
| 242 | EIGEN_DEVICE_FUNC Index firstAligned(Index size) const { |
| 243 | if (UIntPtr(m_data)%sizeof(Scalar)) { |
| 244 | return -1; |
| 245 | } |
| 246 | return internal::first_default_aligned(m_data, size); |
| 247 | } |
| 248 | |
| 249 | protected: |
| 250 | Scalar* EIGEN_RESTRICT m_data; |
| 251 | const Index m_stride; |
| 252 | }; |
| 253 | |
| 254 | // lightweight helper class to access matrix coefficients (const version) |
| 255 | template<typename Scalar, typename Index, int StorageOrder> |
| 256 | class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> { |
| 257 | public: |
| 258 | EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {} |
| 259 | |
| 260 | EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const { |
| 261 | return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride); |
| 262 | } |
| 263 | }; |
| 264 | |
| 265 | |
| 266 | /* Helper class to analyze the factors of a Product expression. |
| 267 | * In particular it allows to pop out operator-, scalar multiples, |
| 268 | * and conjugate */ |
| 269 | template<typename XprType> struct blas_traits |
| 270 | { |
| 271 | typedef typename traits<XprType>::Scalar Scalar; |
| 272 | typedef const XprType& ; |
| 273 | typedef XprType ; |
| 274 | enum { |
| 275 | IsComplex = NumTraits<Scalar>::IsComplex, |
| 276 | IsTransposed = false, |
| 277 | NeedToConjugate = false, |
| 278 | HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit) |
| 279 | && ( bool(XprType::IsVectorAtCompileTime) |
| 280 | || int(inner_stride_at_compile_time<XprType>::ret) == 1) |
| 281 | ) ? 1 : 0 |
| 282 | }; |
| 283 | typedef typename conditional<bool(HasUsableDirectAccess), |
| 284 | ExtractType, |
| 285 | typename _ExtractType::PlainObject |
| 286 | >::type DirectLinearAccessType; |
| 287 | static inline ExtractType (const XprType& x) { return x; } |
| 288 | static inline const Scalar (const XprType&) { return Scalar(1); } |
| 289 | }; |
| 290 | |
| 291 | // pop conjugate |
| 292 | template<typename Scalar, typename NestedXpr> |
| 293 | struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> > |
| 294 | : blas_traits<NestedXpr> |
| 295 | { |
| 296 | typedef blas_traits<NestedXpr> Base; |
| 297 | typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType; |
| 298 | typedef typename Base::ExtractType ; |
| 299 | |
| 300 | enum { |
| 301 | IsComplex = NumTraits<Scalar>::IsComplex, |
| 302 | NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex |
| 303 | }; |
| 304 | static inline ExtractType (const XprType& x) { return Base::extract(x.nestedExpression()); } |
| 305 | static inline Scalar (const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); } |
| 306 | }; |
| 307 | |
| 308 | // pop scalar multiple |
| 309 | template<typename Scalar, typename NestedXpr, typename Plain> |
| 310 | struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> > |
| 311 | : blas_traits<NestedXpr> |
| 312 | { |
| 313 | typedef blas_traits<NestedXpr> Base; |
| 314 | typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType; |
| 315 | typedef typename Base::ExtractType ExtractType; |
| 316 | static inline ExtractType (const XprType& x) { return Base::extract(x.rhs()); } |
| 317 | static inline Scalar extractScalarFactor(const XprType& x) |
| 318 | { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); } |
| 319 | }; |
| 320 | template<typename Scalar, typename NestedXpr, typename Plain> |
| 321 | struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > > |
| 322 | : blas_traits<NestedXpr> |
| 323 | { |
| 324 | typedef blas_traits<NestedXpr> Base; |
| 325 | typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType; |
| 326 | typedef typename Base::ExtractType ExtractType; |
| 327 | static inline ExtractType (const XprType& x) { return Base::extract(x.lhs()); } |
| 328 | static inline Scalar extractScalarFactor(const XprType& x) |
| 329 | { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; } |
| 330 | }; |
| 331 | template<typename Scalar, typename Plain1, typename Plain2> |
| 332 | struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>, |
| 333 | const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > > |
| 334 | : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> > |
| 335 | {}; |
| 336 | |
| 337 | // pop opposite |
| 338 | template<typename Scalar, typename NestedXpr> |
| 339 | struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> > |
| 340 | : blas_traits<NestedXpr> |
| 341 | { |
| 342 | typedef blas_traits<NestedXpr> Base; |
| 343 | typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType; |
| 344 | typedef typename Base::ExtractType ; |
| 345 | static inline ExtractType (const XprType& x) { return Base::extract(x.nestedExpression()); } |
| 346 | static inline Scalar (const XprType& x) |
| 347 | { return - Base::extractScalarFactor(x.nestedExpression()); } |
| 348 | }; |
| 349 | |
| 350 | // pop/push transpose |
| 351 | template<typename NestedXpr> |
| 352 | struct blas_traits<Transpose<NestedXpr> > |
| 353 | : blas_traits<NestedXpr> |
| 354 | { |
| 355 | typedef typename NestedXpr::Scalar Scalar; |
| 356 | typedef blas_traits<NestedXpr> Base; |
| 357 | typedef Transpose<NestedXpr> XprType; |
| 358 | typedef Transpose<const typename Base::_ExtractType> ; // const to get rid of a compile error; anyway blas traits are only used on the RHS |
| 359 | typedef Transpose<const typename Base::_ExtractType> ; |
| 360 | typedef typename conditional<bool(Base::HasUsableDirectAccess), |
| 361 | ExtractType, |
| 362 | typename ExtractType::PlainObject |
| 363 | >::type DirectLinearAccessType; |
| 364 | enum { |
| 365 | IsTransposed = Base::IsTransposed ? 0 : 1 |
| 366 | }; |
| 367 | static inline ExtractType (const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); } |
| 368 | static inline Scalar (const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); } |
| 369 | }; |
| 370 | |
| 371 | template<typename T> |
| 372 | struct blas_traits<const T> |
| 373 | : blas_traits<T> |
| 374 | {}; |
| 375 | |
| 376 | template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess> |
| 377 | struct { |
| 378 | static const typename T::Scalar* (const T& m) |
| 379 | { |
| 380 | return blas_traits<T>::extract(m).data(); |
| 381 | } |
| 382 | }; |
| 383 | |
| 384 | template<typename T> |
| 385 | struct <T,false> { |
| 386 | static typename T::Scalar* (const T&) { return 0; } |
| 387 | }; |
| 388 | |
| 389 | template<typename T> const typename T::Scalar* (const T& m) |
| 390 | { |
| 391 | return extract_data_selector<T>::run(m); |
| 392 | } |
| 393 | |
| 394 | } // end namespace internal |
| 395 | |
| 396 | } // end namespace Eigen |
| 397 | |
| 398 | #endif // EIGEN_BLASUTIL_H |
| 399 | |