1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_COMPLEX_CUDA_H
11#define EIGEN_COMPLEX_CUDA_H
12
13// clang-format off
14
15namespace Eigen {
16
17namespace internal {
18
19#if defined(__CUDACC__) && defined(EIGEN_USE_GPU)
20
21// Many std::complex methods such as operator+, operator-, operator* and
22// operator/ are not constexpr. Due to this, clang does not treat them as device
23// functions and thus Eigen functors making use of these operators fail to
24// compile. Here, we manually specialize these functors for complex types when
25// building for CUDA to avoid non-constexpr methods.
26
27// Sum
28template<typename T> struct scalar_sum_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
29 typedef typename std::complex<T> result_type;
30
31 EIGEN_EMPTY_STRUCT_CTOR(scalar_sum_op)
32 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
33 return std::complex<T>(numext::real(a) + numext::real(b),
34 numext::imag(a) + numext::imag(b));
35 }
36};
37
38template<typename T> struct scalar_sum_op<std::complex<T>, std::complex<T> > : scalar_sum_op<const std::complex<T>, const std::complex<T> > {};
39
40
41// Difference
42template<typename T> struct scalar_difference_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
43 typedef typename std::complex<T> result_type;
44
45 EIGEN_EMPTY_STRUCT_CTOR(scalar_difference_op)
46 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
47 return std::complex<T>(numext::real(a) - numext::real(b),
48 numext::imag(a) - numext::imag(b));
49 }
50};
51
52template<typename T> struct scalar_difference_op<std::complex<T>, std::complex<T> > : scalar_difference_op<const std::complex<T>, const std::complex<T> > {};
53
54
55// Product
56template<typename T> struct scalar_product_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
57 enum {
58 Vectorizable = packet_traits<std::complex<T>>::HasMul
59 };
60 typedef typename std::complex<T> result_type;
61
62 EIGEN_EMPTY_STRUCT_CTOR(scalar_product_op)
63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
64 const T a_real = numext::real(a);
65 const T a_imag = numext::imag(a);
66 const T b_real = numext::real(b);
67 const T b_imag = numext::imag(b);
68 return std::complex<T>(a_real * b_real - a_imag * b_imag,
69 a_real * b_imag + a_imag * b_real);
70 }
71};
72
73template<typename T> struct scalar_product_op<std::complex<T>, std::complex<T> > : scalar_product_op<const std::complex<T>, const std::complex<T> > {};
74
75
76// Quotient
77template<typename T> struct scalar_quotient_op<const std::complex<T>, const std::complex<T> > : binary_op_base<const std::complex<T>, const std::complex<T> > {
78 enum {
79 Vectorizable = packet_traits<std::complex<T>>::HasDiv
80 };
81 typedef typename std::complex<T> result_type;
82
83 EIGEN_EMPTY_STRUCT_CTOR(scalar_quotient_op)
84 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<T> operator() (const std::complex<T>& a, const std::complex<T>& b) const {
85 const T a_real = numext::real(a);
86 const T a_imag = numext::imag(a);
87 const T b_real = numext::real(b);
88 const T b_imag = numext::imag(b);
89 const T norm = T(1) / (b_real * b_real + b_imag * b_imag);
90 return std::complex<T>((a_real * b_real + a_imag * b_imag) * norm,
91 (a_imag * b_real - a_real * b_imag) * norm);
92 }
93};
94
95template<typename T> struct scalar_quotient_op<std::complex<T>, std::complex<T> > : scalar_quotient_op<const std::complex<T>, const std::complex<T> > {};
96
97#endif
98
99} // end namespace internal
100
101} // end namespace Eigen
102
103#endif // EIGEN_COMPLEX_CUDA_H
104