From 6579e36eb4f19e4a01d89277f927da8391a74c04 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 25 Mar 2025 08:26:23 -0700 Subject: [PATCH] Allow Tensor trace to be passed to a TensorRef. --- unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h | 6 ++++++ unsupported/test/cxx11_tensor_ref.cpp | 3 +++ 2 files changed, 9 insertions(+) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h index c1499852c..5357a482d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h @@ -27,6 +27,10 @@ struct traits > : public traits { typedef std::remove_reference_t Nested_; static constexpr int NumDimensions = XprTraits::NumDimensions - array_size::value; static constexpr int Layout = XprTraits::Layout; + enum { + // Trace is read-only. + Flags = traits::Flags & ~LvalueBit + }; }; template @@ -203,6 +207,8 @@ struct TensorEvaluator, Device> { return true; } + EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return nullptr; } + EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { diff --git a/unsupported/test/cxx11_tensor_ref.cpp b/unsupported/test/cxx11_tensor_ref.cpp index cf097499d..5e796b47b 100644 --- a/unsupported/test/cxx11_tensor_ref.cpp +++ b/unsupported/test/cxx11_tensor_ref.cpp @@ -41,6 +41,9 @@ static void test_simple_lvalue_ref() { for (int i = 0; i < 6; ++i) { VERIFY_IS_EQUAL(input(i), -i * 2); } + + TensorRef> ref5(input.trace()); + VERIFY_IS_EQUAL(ref5[0], input[0]); } static void test_simple_rvalue_ref() { -- GitLab