1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_BLASUTIL_H
11#define EIGEN_BLASUTIL_H
12
13// This file contains many lightweight helper classes used to
14// implement and control fast level 2 and level 3 BLAS-like routines.
15
16namespace Eigen {
17
18namespace internal {
19
20// forward declarations
21template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
22struct gebp_kernel;
23
24template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
25struct gemm_pack_rhs;
26
27template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
28struct gemm_pack_lhs;
29
30template<
31 typename Index,
32 typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
33 typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
34 int ResStorageOrder>
35struct general_matrix_matrix_product;
36
37template<typename Index,
38 typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
39 typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
40struct general_matrix_vector_product;
41
42
43template<bool Conjugate> struct conj_if;
44
45template<> struct conj_if<true> {
46 template<typename T>
47 inline T operator()(const T& x) const { return numext::conj(x); }
48 template<typename T>
49 inline T pconj(const T& x) const { return internal::pconj(x); }
50};
51
52template<> struct conj_if<false> {
53 template<typename T>
54 inline const T& operator()(const T& x) const { return x; }
55 template<typename T>
56 inline const T& pconj(const T& x) const { return x; }
57};
58
59// Generic implementation for custom complex types.
60template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs>
61struct conj_helper
62{
63 typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar;
64
65 EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const
66 { return padd(c, pmul(x,y)); }
67
68 EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const
69 { return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); }
70};
71
72template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
73{
74 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
75 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
76};
77
78template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
79{
80 typedef std::complex<RealScalar> Scalar;
81 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
82 { return c + pmul(x,y); }
83
84 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
85 { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
86};
87
88template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
89{
90 typedef std::complex<RealScalar> Scalar;
91 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
92 { return c + pmul(x,y); }
93
94 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
95 { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
96};
97
98template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
99{
100 typedef std::complex<RealScalar> Scalar;
101 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
102 { return c + pmul(x,y); }
103
104 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
105 { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
106};
107
108template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
109{
110 typedef std::complex<RealScalar> Scalar;
111 EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
112 { return padd(c, pmul(x,y)); }
113 EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
114 { return conj_if<Conj>()(x)*y; }
115};
116
117template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
118{
119 typedef std::complex<RealScalar> Scalar;
120 EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
121 { return padd(c, pmul(x,y)); }
122 EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
123 { return x*conj_if<Conj>()(y); }
124};
125
126template<typename From,typename To> struct get_factor {
127 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
128};
129
130template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
131 EIGEN_DEVICE_FUNC
132 static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
133};
134
135
136template<typename Scalar, typename Index>
137class BlasVectorMapper {
138 public:
139 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
140
141 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
142 return m_data[i];
143 }
144 template <typename Packet, int AlignmentType>
145 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
146 return ploadt<Packet, AlignmentType>(m_data + i);
147 }
148
149 template <typename Packet>
150 EIGEN_DEVICE_FUNC bool aligned(Index i) const {
151 return (UIntPtr(m_data+i)%sizeof(Packet))==0;
152 }
153
154 protected:
155 Scalar* m_data;
156};
157
158template<typename Scalar, typename Index, int AlignmentType>
159class BlasLinearMapper {
160 public:
161 typedef typename packet_traits<Scalar>::type Packet;
162 typedef typename packet_traits<Scalar>::half HalfPacket;
163
164 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
165
166 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
167 internal::prefetch(&operator()(i));
168 }
169
170 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
171 return m_data[i];
172 }
173
174 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
175 return ploadt<Packet, AlignmentType>(m_data + i);
176 }
177
178 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
179 return ploadt<HalfPacket, AlignmentType>(m_data + i);
180 }
181
182 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
183 pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
184 }
185
186 protected:
187 Scalar *m_data;
188};
189
190// Lightweight helper class to access matrix coefficients.
191template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
192class blas_data_mapper {
193 public:
194 typedef typename packet_traits<Scalar>::type Packet;
195 typedef typename packet_traits<Scalar>::half HalfPacket;
196
197 typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
198 typedef BlasVectorMapper<Scalar, Index> VectorMapper;
199
200 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
201
202 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
203 getSubMapper(Index i, Index j) const {
204 return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
205 }
206
207 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
208 return LinearMapper(&operator()(i, j));
209 }
210
211 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
212 return VectorMapper(&operator()(i, j));
213 }
214
215
216 EIGEN_DEVICE_FUNC
217 EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
218 return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
219 }
220
221 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
222 return ploadt<Packet, AlignmentType>(&operator()(i, j));
223 }
224
225 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
226 return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
227 }
228
229 template<typename SubPacket>
230 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
231 pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
232 }
233
234 template<typename SubPacket>
235 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
236 return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
237 }
238
239 EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
240 EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
241
242 EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
243 if (UIntPtr(m_data)%sizeof(Scalar)) {
244 return -1;
245 }
246 return internal::first_default_aligned(m_data, size);
247 }
248
249 protected:
250 Scalar* EIGEN_RESTRICT m_data;
251 const Index m_stride;
252};
253
254// lightweight helper class to access matrix coefficients (const version)
255template<typename Scalar, typename Index, int StorageOrder>
256class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
257 public:
258 EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
259
260 EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
261 return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
262 }
263};
264
265
266/* Helper class to analyze the factors of a Product expression.
267 * In particular it allows to pop out operator-, scalar multiples,
268 * and conjugate */
269template<typename XprType> struct blas_traits
270{
271 typedef typename traits<XprType>::Scalar Scalar;
272 typedef const XprType& ExtractType;
273 typedef XprType _ExtractType;
274 enum {
275 IsComplex = NumTraits<Scalar>::IsComplex,
276 IsTransposed = false,
277 NeedToConjugate = false,
278 HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
279 && ( bool(XprType::IsVectorAtCompileTime)
280 || int(inner_stride_at_compile_time<XprType>::ret) == 1)
281 ) ? 1 : 0
282 };
283 typedef typename conditional<bool(HasUsableDirectAccess),
284 ExtractType,
285 typename _ExtractType::PlainObject
286 >::type DirectLinearAccessType;
287 static inline ExtractType extract(const XprType& x) { return x; }
288 static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
289};
290
291// pop conjugate
292template<typename Scalar, typename NestedXpr>
293struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
294 : blas_traits<NestedXpr>
295{
296 typedef blas_traits<NestedXpr> Base;
297 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
298 typedef typename Base::ExtractType ExtractType;
299
300 enum {
301 IsComplex = NumTraits<Scalar>::IsComplex,
302 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
303 };
304 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
305 static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
306};
307
308// pop scalar multiple
309template<typename Scalar, typename NestedXpr, typename Plain>
310struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
311 : blas_traits<NestedXpr>
312{
313 typedef blas_traits<NestedXpr> Base;
314 typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
315 typedef typename Base::ExtractType ExtractType;
316 static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
317 static inline Scalar extractScalarFactor(const XprType& x)
318 { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
319};
320template<typename Scalar, typename NestedXpr, typename Plain>
321struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
322 : blas_traits<NestedXpr>
323{
324 typedef blas_traits<NestedXpr> Base;
325 typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
326 typedef typename Base::ExtractType ExtractType;
327 static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
328 static inline Scalar extractScalarFactor(const XprType& x)
329 { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
330};
331template<typename Scalar, typename Plain1, typename Plain2>
332struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>,
333 const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > >
334 : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> >
335{};
336
337// pop opposite
338template<typename Scalar, typename NestedXpr>
339struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
340 : blas_traits<NestedXpr>
341{
342 typedef blas_traits<NestedXpr> Base;
343 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
344 typedef typename Base::ExtractType ExtractType;
345 static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
346 static inline Scalar extractScalarFactor(const XprType& x)
347 { return - Base::extractScalarFactor(x.nestedExpression()); }
348};
349
350// pop/push transpose
351template<typename NestedXpr>
352struct blas_traits<Transpose<NestedXpr> >
353 : blas_traits<NestedXpr>
354{
355 typedef typename NestedXpr::Scalar Scalar;
356 typedef blas_traits<NestedXpr> Base;
357 typedef Transpose<NestedXpr> XprType;
358 typedef Transpose<const typename Base::_ExtractType> ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
359 typedef Transpose<const typename Base::_ExtractType> _ExtractType;
360 typedef typename conditional<bool(Base::HasUsableDirectAccess),
361 ExtractType,
362 typename ExtractType::PlainObject
363 >::type DirectLinearAccessType;
364 enum {
365 IsTransposed = Base::IsTransposed ? 0 : 1
366 };
367 static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
368 static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
369};
370
371template<typename T>
372struct blas_traits<const T>
373 : blas_traits<T>
374{};
375
376template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
377struct extract_data_selector {
378 static const typename T::Scalar* run(const T& m)
379 {
380 return blas_traits<T>::extract(m).data();
381 }
382};
383
384template<typename T>
385struct extract_data_selector<T,false> {
386 static typename T::Scalar* run(const T&) { return 0; }
387};
388
389template<typename T> const typename T::Scalar* extract_data(const T& m)
390{
391 return extract_data_selector<T>::run(m);
392}
393
394} // end namespace internal
395
396} // end namespace Eigen
397
398#endif // EIGEN_BLASUTIL_H
399