From da2901a548949a6c1679b768529b24b93c0731d0 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 24 Nov 2020 09:49:47 -0800 Subject: [PATCH] Add `scalar_cast_product_op`, eliminate boolean product warnings. The new binary op multiplies two values and returns a value of known type. In cases where `scalar_product_op` exists, it will use that. Otherwise, it will use `operator*` and cast the result. This is used to replace instances of `alpha = alpha * a_lhs * a_rhs` and avoids multiplying bools (`-Wbool-in-int-context`). Modified: - `GeneralMatrixMatrix.h`, `generic_product_impl::scaleAndAddTo(...)` - `GeneralProduct.h`, `gemv_dense_selector::run()` - `ProductEvaluators.h`, `generic_product_impl::eval_dynamic(...)` --- Eigen/src/Core/GeneralProduct.h | 20 +++-- Eigen/src/Core/ProductEvaluators.h | 5 +- Eigen/src/Core/functors/BinaryFunctors.h | 84 ++++++++++++++++++- Eigen/src/Core/products/GeneralMatrixMatrix.h | 7 +- 4 files changed, 102 insertions(+), 14 deletions(-) diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index bf7ef54b5..486c2e301 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -217,19 +217,22 @@ template<> struct gemv_dense_selector typedef typename Rhs::Scalar RhsScalar; typedef typename Dest::Scalar ResScalar; typedef typename Dest::RealScalar RealScalar; - + typedef internal::blas_traits LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; typedef internal::blas_traits RhsBlasTraits; typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; - + typedef Map, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits::size)> MappedDest; ActualLhsType actualLhs = LhsBlasTraits::extract(lhs); ActualRhsType actualRhs = RhsBlasTraits::extract(rhs); - ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) - * RhsBlasTraits::extractScalarFactor(rhs); + internal::scalar_cast_product_op lr_mul; + internal::scalar_product_op a_mul; + ResScalar actualAlpha = a_mul(alpha, + lr_mul(LhsBlasTraits::extractScalarFactor(lhs), + RhsBlasTraits::extractScalarFactor(rhs))); // make sure Dest is a compile-time vector type (bug 1166) typedef typename conditional::type ActualDest; @@ -310,7 +313,7 @@ template<> struct gemv_dense_selector typedef typename Lhs::Scalar LhsScalar; typedef typename Rhs::Scalar RhsScalar; typedef typename Dest::Scalar ResScalar; - + typedef internal::blas_traits LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; typedef internal::blas_traits RhsBlasTraits; @@ -320,8 +323,11 @@ template<> struct gemv_dense_selector typename add_const::type actualLhs = LhsBlasTraits::extract(lhs); typename add_const::type actualRhs = RhsBlasTraits::extract(rhs); - ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) - * RhsBlasTraits::extractScalarFactor(rhs); + internal::scalar_cast_product_op lr_mul; + internal::scalar_product_op a_mul; + ResScalar actualAlpha = a_mul(alpha, + lr_mul(LhsBlasTraits::extractScalarFactor(lhs), + RhsBlasTraits::extractScalarFactor(rhs))); enum { // FIXME find a way to allow an inner stride on the result if packet_traits::size==1 diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 02b58438c..09cd2bce2 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -441,8 +441,9 @@ struct generic_product_impl }; // FIXME: in c++11 this should be auto, and extractScalarFactor should also return auto // this is important for real*complex_mat - Scalar actualAlpha = blas_traits::extractScalarFactor(lhs) - * blas_traits::extractScalarFactor(rhs); + internal::scalar_cast_product_op mul; + Scalar actualAlpha = mul(blas_traits::extractScalarFactor(lhs), + blas_traits::extractScalarFactor(rhs)); eval_dynamic_impl(dst, blas_traits::extract(lhs).template conjugateIf(), blas_traits::extract(rhs).template conjugateIf(), diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index f3509c4b9..c560d44bf 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -97,6 +97,84 @@ struct functor_traits > { template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool scalar_product_op::operator() (const bool& a, const bool& b) const { return a && b; } +/** \internal + * \brief Template functor to compute the product of two scalars and cast the result. + * + * This op will use scalar_product_op if it exists, otherwise + * will fall back to operator* and cast the result. + * + */ +template +struct scalar_cast_product_op; + +template +struct scalar_cast_product_op{ + typedef ResultScalar result_type; + typedef scalar_product_op ScalarProductOp; +#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_product_op) +#else + scalar_cast_product_op() { + EIGEN_SCALAR_BINARY_OP_PLUGIN + } +#endif + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + ResultScalar operator()(const LhsScalar& a, const RhsScalar& b) const { + return cast(ScalarProductOp()(a, b)); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const + { return ScalarProductOp().packetOp(a, b); } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultScalar predux(const Packet& a) const + { return internal::cast(ScalarProductOp().predux(a)); } +}; +template +struct functor_traits > { + typedef functor_traits > ScalarProductTraits; + enum { + Cost = ScalarProductTraits::Cost, + PacketAccess = EIGEN_SCALAR_BINARY_SUPPORTED(product, LhsScalar, RhsScalar) + && ScalarProductTraits::PacketAccess + }; +}; + +template +struct scalar_cast_product_op{ + typedef ResultScalar result_type; +#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_product_op) +#else + scalar_cast_product_op() { + EIGEN_SCALAR_BINARY_OP_PLUGIN + } +#endif + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + ResultScalar operator()(const LhsScalar& a, const RhsScalar& b) const { + // TODO: ideally this would use internal::cast, but without auto/decltype we + // cannot determine the return type of a * b. + return static_cast(a * b); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const + { return internal::pmul(a, b); } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultScalar predux(const Packet& a) const + { return internal::cast(internal::predux(a)); } +}; +template +struct functor_traits > { + enum { + Cost = (NumTraits::MulCost + NumTraits::MulCost)/2, // rough estimate! + PacketAccess = is_same::value && packet_traits::HasMul && packet_traits::HasMul + }; +}; /** \internal * \brief Template functor to compute the conjugate product of two scalars @@ -110,13 +188,13 @@ struct scalar_conj_product_op : binary_op_base enum { Conj = NumTraits::IsComplex }; - + typedef typename ScalarBinaryOpTraits::ReturnType result_type; - + EIGEN_EMPTY_STRUCT_CTOR(scalar_conj_product_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return conj_helper().pmul(a,b); } - + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const { return conj_helper().pmul(a,b); } diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 0d55bdf9e..09c7a139d 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -489,8 +489,11 @@ struct generic_product_impl typename internal::add_const_on_value_type::type lhs = LhsBlasTraits::extract(a_lhs); typename internal::add_const_on_value_type::type rhs = RhsBlasTraits::extract(a_rhs); - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) - * RhsBlasTraits::extractScalarFactor(a_rhs); + internal::scalar_cast_product_op lr_mul; + internal::scalar_product_op a_mul; + Scalar actualAlpha = a_mul(alpha, + lr_mul(LhsBlasTraits::extractScalarFactor(a_lhs), + RhsBlasTraits::extractScalarFactor(a_rhs))); typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; -- GitLab