1 | // This file is part of Eigen, a lightweight C++ template library |
2 | // for linear algebra. |
3 | // |
4 | // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr> |
5 | // |
6 | // This Source Code Form is subject to the terms of the Mozilla |
7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
9 | |
10 | #ifndef EIGEN_BLASUTIL_H |
11 | #define EIGEN_BLASUTIL_H |
12 | |
13 | // This file contains many lightweight helper classes used to |
14 | // implement and control fast level 2 and level 3 BLAS-like routines. |
15 | |
16 | namespace Eigen { |
17 | |
18 | namespace internal { |
19 | |
20 | // forward declarations |
21 | template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> |
22 | struct gebp_kernel; |
23 | |
24 | template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false> |
25 | struct gemm_pack_rhs; |
26 | |
27 | template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false> |
28 | struct gemm_pack_lhs; |
29 | |
30 | template< |
31 | typename Index, |
32 | typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, |
33 | typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, |
34 | int ResStorageOrder> |
35 | struct general_matrix_matrix_product; |
36 | |
37 | template<typename Index, |
38 | typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs, |
39 | typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized> |
40 | struct general_matrix_vector_product; |
41 | |
42 | |
43 | template<bool Conjugate> struct conj_if; |
44 | |
45 | template<> struct conj_if<true> { |
46 | template<typename T> |
47 | inline T operator()(const T& x) const { return numext::conj(x); } |
48 | template<typename T> |
49 | inline T pconj(const T& x) const { return internal::pconj(x); } |
50 | }; |
51 | |
52 | template<> struct conj_if<false> { |
53 | template<typename T> |
54 | inline const T& operator()(const T& x) const { return x; } |
55 | template<typename T> |
56 | inline const T& pconj(const T& x) const { return x; } |
57 | }; |
58 | |
59 | // Generic implementation for custom complex types. |
60 | template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs> |
61 | struct conj_helper |
62 | { |
63 | typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar; |
64 | |
65 | EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const |
66 | { return padd(c, pmul(x,y)); } |
67 | |
68 | EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const |
69 | { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); } |
70 | }; |
71 | |
72 | template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false> |
73 | { |
74 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); } |
75 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); } |
76 | }; |
77 | |
78 | template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true> |
79 | { |
80 | typedef std::complex<RealScalar> Scalar; |
81 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const |
82 | { return c + pmul(x,y); } |
83 | |
84 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const |
85 | { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); } |
86 | }; |
87 | |
88 | template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false> |
89 | { |
90 | typedef std::complex<RealScalar> Scalar; |
91 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const |
92 | { return c + pmul(x,y); } |
93 | |
94 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const |
95 | { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); } |
96 | }; |
97 | |
98 | template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true> |
99 | { |
100 | typedef std::complex<RealScalar> Scalar; |
101 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const |
102 | { return c + pmul(x,y); } |
103 | |
104 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const |
105 | { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); } |
106 | }; |
107 | |
108 | template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false> |
109 | { |
110 | typedef std::complex<RealScalar> Scalar; |
111 | EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const |
112 | { return padd(c, pmul(x,y)); } |
113 | EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const |
114 | { return conj_if<Conj>()(x)*y; } |
115 | }; |
116 | |
117 | template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj> |
118 | { |
119 | typedef std::complex<RealScalar> Scalar; |
120 | EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const |
121 | { return padd(c, pmul(x,y)); } |
122 | EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const |
123 | { return x*conj_if<Conj>()(y); } |
124 | }; |
125 | |
126 | template<typename From,typename To> struct get_factor { |
127 | EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); } |
128 | }; |
129 | |
130 | template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> { |
131 | EIGEN_DEVICE_FUNC |
132 | static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); } |
133 | }; |
134 | |
135 | |
136 | template<typename Scalar, typename Index> |
137 | class BlasVectorMapper { |
138 | public: |
139 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {} |
140 | |
141 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { |
142 | return m_data[i]; |
143 | } |
144 | template <typename Packet, int AlignmentType> |
145 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const { |
146 | return ploadt<Packet, AlignmentType>(m_data + i); |
147 | } |
148 | |
149 | template <typename Packet> |
150 | EIGEN_DEVICE_FUNC bool aligned(Index i) const { |
151 | return (UIntPtr(m_data+i)%sizeof(Packet))==0; |
152 | } |
153 | |
154 | protected: |
155 | Scalar* m_data; |
156 | }; |
157 | |
158 | template<typename Scalar, typename Index, int AlignmentType> |
159 | class BlasLinearMapper { |
160 | public: |
161 | typedef typename packet_traits<Scalar>::type Packet; |
162 | typedef typename packet_traits<Scalar>::half HalfPacket; |
163 | |
164 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {} |
165 | |
166 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { |
167 | internal::prefetch(&operator()(i)); |
168 | } |
169 | |
170 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { |
171 | return m_data[i]; |
172 | } |
173 | |
174 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { |
175 | return ploadt<Packet, AlignmentType>(m_data + i); |
176 | } |
177 | |
178 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { |
179 | return ploadt<HalfPacket, AlignmentType>(m_data + i); |
180 | } |
181 | |
182 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const { |
183 | pstoret<Scalar, Packet, AlignmentType>(m_data + i, p); |
184 | } |
185 | |
186 | protected: |
187 | Scalar *m_data; |
188 | }; |
189 | |
190 | // Lightweight helper class to access matrix coefficients. |
191 | template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned> |
192 | class blas_data_mapper { |
193 | public: |
194 | typedef typename packet_traits<Scalar>::type Packet; |
195 | typedef typename packet_traits<Scalar>::half HalfPacket; |
196 | |
197 | typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper; |
198 | typedef BlasVectorMapper<Scalar, Index> VectorMapper; |
199 | |
200 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {} |
201 | |
202 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType> |
203 | getSubMapper(Index i, Index j) const { |
204 | return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride); |
205 | } |
206 | |
207 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { |
208 | return LinearMapper(&operator()(i, j)); |
209 | } |
210 | |
211 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { |
212 | return VectorMapper(&operator()(i, j)); |
213 | } |
214 | |
215 | |
216 | EIGEN_DEVICE_FUNC |
217 | EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const { |
218 | return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; |
219 | } |
220 | |
221 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { |
222 | return ploadt<Packet, AlignmentType>(&operator()(i, j)); |
223 | } |
224 | |
225 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { |
226 | return ploadt<HalfPacket, AlignmentType>(&operator()(i, j)); |
227 | } |
228 | |
229 | template<typename SubPacket> |
230 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const { |
231 | pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride); |
232 | } |
233 | |
234 | template<typename SubPacket> |
235 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const { |
236 | return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride); |
237 | } |
238 | |
239 | EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; } |
240 | EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; } |
241 | |
242 | EIGEN_DEVICE_FUNC Index firstAligned(Index size) const { |
243 | if (UIntPtr(m_data)%sizeof(Scalar)) { |
244 | return -1; |
245 | } |
246 | return internal::first_default_aligned(m_data, size); |
247 | } |
248 | |
249 | protected: |
250 | Scalar* EIGEN_RESTRICT m_data; |
251 | const Index m_stride; |
252 | }; |
253 | |
254 | // lightweight helper class to access matrix coefficients (const version) |
255 | template<typename Scalar, typename Index, int StorageOrder> |
256 | class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> { |
257 | public: |
258 | EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {} |
259 | |
260 | EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const { |
261 | return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride); |
262 | } |
263 | }; |
264 | |
265 | |
266 | /* Helper class to analyze the factors of a Product expression. |
267 | * In particular it allows to pop out operator-, scalar multiples, |
268 | * and conjugate */ |
269 | template<typename XprType> struct blas_traits |
270 | { |
271 | typedef typename traits<XprType>::Scalar Scalar; |
272 | typedef const XprType& ; |
273 | typedef XprType ; |
274 | enum { |
275 | IsComplex = NumTraits<Scalar>::IsComplex, |
276 | IsTransposed = false, |
277 | NeedToConjugate = false, |
278 | HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit) |
279 | && ( bool(XprType::IsVectorAtCompileTime) |
280 | || int(inner_stride_at_compile_time<XprType>::ret) == 1) |
281 | ) ? 1 : 0 |
282 | }; |
283 | typedef typename conditional<bool(HasUsableDirectAccess), |
284 | ExtractType, |
285 | typename _ExtractType::PlainObject |
286 | >::type DirectLinearAccessType; |
287 | static inline ExtractType (const XprType& x) { return x; } |
288 | static inline const Scalar (const XprType&) { return Scalar(1); } |
289 | }; |
290 | |
291 | // pop conjugate |
292 | template<typename Scalar, typename NestedXpr> |
293 | struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> > |
294 | : blas_traits<NestedXpr> |
295 | { |
296 | typedef blas_traits<NestedXpr> Base; |
297 | typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType; |
298 | typedef typename Base::ExtractType ; |
299 | |
300 | enum { |
301 | IsComplex = NumTraits<Scalar>::IsComplex, |
302 | NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex |
303 | }; |
304 | static inline ExtractType (const XprType& x) { return Base::extract(x.nestedExpression()); } |
305 | static inline Scalar (const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); } |
306 | }; |
307 | |
308 | // pop scalar multiple |
309 | template<typename Scalar, typename NestedXpr, typename Plain> |
310 | struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> > |
311 | : blas_traits<NestedXpr> |
312 | { |
313 | typedef blas_traits<NestedXpr> Base; |
314 | typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType; |
315 | typedef typename Base::ExtractType ExtractType; |
316 | static inline ExtractType (const XprType& x) { return Base::extract(x.rhs()); } |
317 | static inline Scalar extractScalarFactor(const XprType& x) |
318 | { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); } |
319 | }; |
320 | template<typename Scalar, typename NestedXpr, typename Plain> |
321 | struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > > |
322 | : blas_traits<NestedXpr> |
323 | { |
324 | typedef blas_traits<NestedXpr> Base; |
325 | typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType; |
326 | typedef typename Base::ExtractType ExtractType; |
327 | static inline ExtractType (const XprType& x) { return Base::extract(x.lhs()); } |
328 | static inline Scalar extractScalarFactor(const XprType& x) |
329 | { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; } |
330 | }; |
331 | template<typename Scalar, typename Plain1, typename Plain2> |
332 | struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>, |
333 | const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > > |
334 | : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> > |
335 | {}; |
336 | |
337 | // pop opposite |
338 | template<typename Scalar, typename NestedXpr> |
339 | struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> > |
340 | : blas_traits<NestedXpr> |
341 | { |
342 | typedef blas_traits<NestedXpr> Base; |
343 | typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType; |
344 | typedef typename Base::ExtractType ; |
345 | static inline ExtractType (const XprType& x) { return Base::extract(x.nestedExpression()); } |
346 | static inline Scalar (const XprType& x) |
347 | { return - Base::extractScalarFactor(x.nestedExpression()); } |
348 | }; |
349 | |
350 | // pop/push transpose |
351 | template<typename NestedXpr> |
352 | struct blas_traits<Transpose<NestedXpr> > |
353 | : blas_traits<NestedXpr> |
354 | { |
355 | typedef typename NestedXpr::Scalar Scalar; |
356 | typedef blas_traits<NestedXpr> Base; |
357 | typedef Transpose<NestedXpr> XprType; |
358 | typedef Transpose<const typename Base::_ExtractType> ; // const to get rid of a compile error; anyway blas traits are only used on the RHS |
359 | typedef Transpose<const typename Base::_ExtractType> ; |
360 | typedef typename conditional<bool(Base::HasUsableDirectAccess), |
361 | ExtractType, |
362 | typename ExtractType::PlainObject |
363 | >::type DirectLinearAccessType; |
364 | enum { |
365 | IsTransposed = Base::IsTransposed ? 0 : 1 |
366 | }; |
367 | static inline ExtractType (const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); } |
368 | static inline Scalar (const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); } |
369 | }; |
370 | |
371 | template<typename T> |
372 | struct blas_traits<const T> |
373 | : blas_traits<T> |
374 | {}; |
375 | |
376 | template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess> |
377 | struct { |
378 | static const typename T::Scalar* (const T& m) |
379 | { |
380 | return blas_traits<T>::extract(m).data(); |
381 | } |
382 | }; |
383 | |
384 | template<typename T> |
385 | struct <T,false> { |
386 | static typename T::Scalar* (const T&) { return 0; } |
387 | }; |
388 | |
389 | template<typename T> const typename T::Scalar* (const T& m) |
390 | { |
391 | return extract_data_selector<T>::run(m); |
392 | } |
393 | |
394 | } // end namespace internal |
395 | |
396 | } // end namespace Eigen |
397 | |
398 | #endif // EIGEN_BLASUTIL_H |
399 | |