From 26e5beb8cb6c7b697a6d60a142d5c864948450ae Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 26 Aug 2021 13:05:23 -0700 Subject: [PATCH] Device-compatible Tuple implementation. An analogue of `std::tuple` that works on device. Context: I've tried `std::tuple` in various versions of NVCC and clang, and although code seems to compile, it often fails to run - generating "illegal memory access" errors, or "illegal instruction" errors. This replacement does work on device. --- Eigen/src/Core/arch/GPU/Tuple.h | 302 ++++++++++++++++++++++++++++++++ Eigen/src/Core/util/Meta.h | 54 ++++++ test/CMakeLists.txt | 1 + test/tuple_test.cpp | 123 +++++++++++++ 4 files changed, 480 insertions(+) create mode 100644 Eigen/src/Core/arch/GPU/Tuple.h create mode 100644 test/tuple_test.cpp diff --git a/Eigen/src/Core/arch/GPU/Tuple.h b/Eigen/src/Core/arch/GPU/Tuple.h new file mode 100644 index 000000000..d381cd886 --- /dev/null +++ b/Eigen/src/Core/arch/GPU/Tuple.h @@ -0,0 +1,302 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2021 The Eigen Team +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TUPLE_GPU +#define EIGEN_TUPLE_GPU + +#include +#include + +// This is a replacement of std::tuple that can be used in device code. + +namespace Eigen { +namespace internal { +namespace tuple_impl { + +// Internal tuple implementation. +template +class TupleImpl; + +// Generic recursive tuple. +template +class TupleImpl { + public: + // Tuple may contain Eigen types. + EIGEN_MAKE_ALIGNED_OPERATOR_NEW + + // Default constructor, enable if all types are default-constructible. + template::value + && reduce_all::value...>::value + >::type> + EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC + TupleImpl() : head_{}, tail_{} {} + + // Element constructor. + template 1 || std::is_convertible::value) + >::type> + EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC + TupleImpl(U1&& arg1, Us&&... args) + : head_(std::forward(arg1)), tail_(std::forward(args)...) {} + + // The first stored value. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + T1& head() { + return head_; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + const T1& head() const { + return head_; + } + + // The tail values. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + TupleImpl& tail() { + return tail_; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + const TupleImpl& tail() const { + return tail_; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void swap(TupleImpl& other) { + using numext::swap; + swap(head_, other.head_); + swap(tail_, other.tail_); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TupleImpl& operator=(const TupleImpl& other) { + head_ = other.head_; + tail_ = other.tail_; + return *this; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TupleImpl& operator=(TupleImpl&& other) { + head_ = std::move(other.head_); + tail_ = std::move(other.tail_); + return *this; + } + + private: + // Allow related tuples to reference head_/tail_. + template + friend class TupleImpl; + + T1 head_; + TupleImpl tail_; +}; + +// Empty tuple specialization. +template<> +class TupleImpl {}; + +template +struct is_tuple : std::false_type {}; + +template +struct is_tuple< TupleImpl > : std::true_type {}; + +// Gets an element from a tuple. +template +struct tuple_get_impl { + using TupleType = TupleImpl; + using ReturnType = typename tuple_get_impl::ReturnType; + + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + ReturnType& run(TupleType& tuple) { + return tuple_get_impl::run(tuple.tail()); + } + + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + const ReturnType& run(const TupleType& tuple) { + return tuple_get_impl::run(tuple.tail()); + } +}; + +// Base case, getting the head element. +template +struct tuple_get_impl<0, T1, Ts...> { + using TupleType = TupleImpl; + using ReturnType = T1; + + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + T1& run(TupleType& tuple) { + return tuple.head(); + } + + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + const T1& run(const TupleType& tuple) { + return tuple.head(); + } +}; + +// Concatenates N Tuples. +template +struct tuple_cat_impl; + +template +struct tuple_cat_impl, TupleImpl, Tuples...> { + using TupleType1 = TupleImpl; + using TupleType2 = TupleImpl; + using MergedTupleType = TupleImpl; + + using ReturnType = typename tuple_cat_impl::ReturnType; + + // Uses the index sequences to extract and merge elements from tuple1 and tuple2, + // then recursively calls again. + template + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + ReturnType run(Tuple1&& tuple1, index_sequence, + Tuple2&& tuple2, index_sequence, + MoreTuples&&... tuples) { + return tuple_cat_impl::run( + MergedTupleType(tuple_get_impl::run(std::forward(tuple1))..., + tuple_get_impl::run(std::forward(tuple2))...), + std::forward(tuples)...); + } + + // Concatenates the first two tuples. + template + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + ReturnType run(Tuple1&& tuple1, Tuple2&& tuple2, MoreTuples&&... tuples) { + return run(std::forward(tuple1), make_index_sequence{}, + std::forward(tuple2), make_index_sequence{}, + std::forward(tuples)...); + } +}; + +// Base case with a single tuple. +template +struct tuple_cat_impl<1, TupleImpl > { + using ReturnType = TupleImpl; + + template + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + ReturnType run(Tuple1&& tuple1) { + return tuple1; + } +}; + +// Special case of no tuples. +template<> +struct tuple_cat_impl<0> { + using ReturnType = TupleImpl<0>; + static EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + ReturnType run() {return ReturnType{}; } +}; + +// For use in make_tuple, unwraps a reference_wrapper. +template +struct unwrap_reference_wrapper { using type = T; }; + +template +struct unwrap_reference_wrapper > { using type = T&; }; + +// For use in make_tuple, decays a type and unwraps a reference_wrapper. +template +struct unwrap_decay { + using type = typename unwrap_reference_wrapper::type>::type; +}; + +/** + * Alternative to std::tuple that can be used on device. + */ +template +using tuple = TupleImpl; + +/** + * Utility for determining a tuple's size. + */ +template +struct tuple_size; + +template +struct tuple_size< tuple > : std::integral_constant {}; + +/** + * Gets an element of a tuple. + * \tparam Idx index of the element. + * \tparam Types ... tuple element types. + * \param tuple the tuple. + * \return a reference to the desired element. + */ +template +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +const typename tuple_get_impl::ReturnType& +get(const tuple& tuple) { + return tuple_get_impl::run(tuple); +} + +template +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +typename tuple_get_impl::ReturnType& +get(tuple& tuple) { + return tuple_get_impl::run(tuple); +} + +/** + * Concatenate multiple tuples. + * \param tuples ... list of tuples. + * \return concatenated tuple. + */ +template::type>::value...>::value>::type> +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +typename tuple_cat_impl::type...>::ReturnType +tuple_cat(Tuples&&... tuples) { + return tuple_cat_impl::type...>::run(std::forward(tuples)...); +} + +/** + * Tie arguments together into a tuple. + */ +template > +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +ReturnType tie(Args&... args) EIGEN_NOEXCEPT { + return ReturnType{args...}; +} + +/** + * Create a tuple of l-values with the supplied arguments. + */ +template ::type...> > +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +ReturnType make_tuple(Args&&... args) { + return ReturnType{std::forward(args)...}; +} + +/** + * Forward a set of arguments as a tuple. + */ +template +EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +tuple forward_as_tuple(Args&&... args) { + return tuple(std::forward(args)...); +} + +} // namespace tuple_impl +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_TUPLE_GPU diff --git a/Eigen/src/Core/util/Meta.h b/Eigen/src/Core/util/Meta.h index 81ae2a32d..084e37e30 100755 --- a/Eigen/src/Core/util/Meta.h +++ b/Eigen/src/Core/util/Meta.h @@ -648,6 +648,60 @@ struct invoke_result { }; #endif +// C++14 integer/index_sequence. +#if defined(__cpp_lib_integer_sequence) && __cpp_lib_integer_sequence >= 201304L && EIGEN_MAX_CPP_VER >= 14 + +using std::integer_sequence; +using std::make_integer_sequence; + +using std::index_sequence; +using std::make_index_sequence; + +#else + +template +struct integer_sequence { + static EIGEN_CONSTEXPR size_t size() EIGEN_NOEXCEPT { return sizeof...(Ints); } +}; + +template +struct append_integer; + +template +struct append_integer, N> { + using type = integer_sequence; +}; + +template +struct generate_integer_sequence { + using type = typename append_integer::type, N-1>::type; +}; + +template +struct generate_integer_sequence { + using type = integer_sequence; +}; + +template +using make_integer_sequence = typename generate_integer_sequence::type; + +template +using index_sequence = integer_sequence; + +template +using make_index_sequence = make_integer_sequence; + +#endif + +// Reduces a sequence of bools to true if all are true, false otherwise. +template +using reduce_all = std::is_same, integer_sequence >; + +// Reduces a sequence of bools to true if any are true, false if all false. +template +using reduce_any = std::integral_constant, integer_sequence >::value>; + struct meta_yes { char a[1]; }; struct meta_no { char a[2]; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 769e883f1..675742c04 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -289,6 +289,7 @@ ei_add_test(random_matrix) ei_add_test(initializer_list_construction) ei_add_test(diagonal_matrix_variadic_ctor) ei_add_test(serializer) +ei_add_test(tuple_test) add_executable(bug1213 bug1213.cpp bug1213_main.cpp) diff --git a/test/tuple_test.cpp b/test/tuple_test.cpp new file mode 100644 index 000000000..b40c457b2 --- /dev/null +++ b/test/tuple_test.cpp @@ -0,0 +1,123 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2021 The Eigen Team +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include +#include + +using namespace Eigen::internal; +using Eigen::internal::tuple_impl::tuple; + +void basic_tuple_test() { + // Construction. + tuple<> tuple0 {}; + tuple tuple1 {1}; + tuple tuple2 {3, 5.0f}; + tuple tuple3 {7, 11.0f, 13.0}; + // Default construction. + tuple<> tuple0default; + EIGEN_UNUSED_VARIABLE(tuple0default) + tuple tuple1default; + EIGEN_UNUSED_VARIABLE(tuple1default) + tuple tuple2default; + EIGEN_UNUSED_VARIABLE(tuple2default) + tuple tuple3default; + EIGEN_UNUSED_VARIABLE(tuple3default) + + // Assignment. + tuple<> tuple0b = tuple0; + EIGEN_UNUSED_VARIABLE(tuple0b) + decltype(tuple1) tuple1b = tuple1; + EIGEN_UNUSED_VARIABLE(tuple1b) + decltype(tuple2) tuple2b = tuple2; + EIGEN_UNUSED_VARIABLE(tuple2b) + decltype(tuple3) tuple3b = tuple3; + EIGEN_UNUSED_VARIABLE(tuple3b) + + // get. + VERIFY_IS_EQUAL(tuple_impl::get<0>(tuple3), 7); + VERIFY_IS_EQUAL(tuple_impl::get<1>(tuple3), 11.0f); + VERIFY_IS_EQUAL(tuple_impl::get<2>(tuple3), 13.0); + + // tuple_impl::tuple_size. + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 0); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 1); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 2); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 3); + + // tuple_impl::tuple_cat. + auto tuple2cat3 = tuple_impl::tuple_cat(tuple2, tuple3); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 5); + VERIFY_IS_EQUAL(tuple_impl::get<1>(tuple2cat3), 5.0f); + VERIFY_IS_EQUAL(tuple_impl::get<3>(tuple2cat3), 11.0f); + auto tuple3cat0 = tuple_impl::tuple_cat(tuple3, tuple0); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 3); + auto singlecat = tuple_impl::tuple_cat(tuple3); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 3); + auto emptycat = tuple_impl::tuple_cat(); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 0); + auto tuple0cat1cat2cat3 = tuple_impl::tuple_cat(tuple0, tuple1, tuple2, tuple3); + VERIFY_IS_EQUAL(tuple_impl::tuple_size::value, 6); + + // make_tuple. + // The tuple types should uses values for the second and fourth parameters. + double tmp = 20; + auto tuple_make = tuple_impl::make_tuple(int(10), tmp, float(20.0f), tuple0); + VERIFY( (std::is_same > >::value) ); + VERIFY_IS_EQUAL(tuple_impl::get<1>(tuple_make), tmp); + + // forward_as_tuple. + // The tuple types should uses references for the second and fourth parameters. + auto tuple_forward = tuple_impl::forward_as_tuple(int(10), tmp, float(20.0f), tuple0); + VERIFY( (std::is_same& > >::value) ); + VERIFY_IS_EQUAL(tuple_impl::get<1>(tuple_forward), tmp); + + // tie. + auto tuple_tie = tuple_impl::tie(tuple0, tuple1, tuple2, tuple3); + VERIFY( (std::is_same >::value) ); + VERIFY_IS_EQUAL( (tuple_impl::get<1>(tuple_impl::get<2>(tuple_tie))), 5.0 ); + // Modify value and ensure tuple2 is updated. + tuple_impl::get<1>(tuple_impl::get<2>(tuple_tie)) = 10.0; + VERIFY_IS_EQUAL( (tuple_impl::get<1>(tuple2)), 10.0 ); + + // Assignment. + int x = -1; + float y = -1; + double z = -1; + tuple_impl::tie(x, y, z) = tuple3; + VERIFY_IS_EQUAL(x, tuple_impl::get<0>(tuple3)); + VERIFY_IS_EQUAL(y, tuple_impl::get<1>(tuple3)); + VERIFY_IS_EQUAL(z, tuple_impl::get<2>(tuple3)); + tuple tuple3c(-2, -2, -2); + tuple3c = std::move(tuple3b); + VERIFY_IS_EQUAL(tuple_impl::get<0>(tuple3c), tuple_impl::get<0>(tuple3)); + VERIFY_IS_EQUAL(tuple_impl::get<1>(tuple3c), tuple_impl::get<1>(tuple3)); + VERIFY_IS_EQUAL(tuple_impl::get<2>(tuple3c), tuple_impl::get<2>(tuple3)); +} + +void eigen_tuple_test() { + tuple tuple; + tuple_impl::get<0>(tuple).setRandom(); + tuple_impl::get<1>(tuple).setRandom(10, 10); + + auto tuple_tie = tuple_impl::tie(tuple_impl::get<0>(tuple), tuple_impl::get<1>(tuple)); + tuple_impl::get<1>(tuple_tie).setIdentity(); + VERIFY(tuple_impl::get<1>(tuple).isIdentity()); +} + +EIGEN_DECLARE_TEST(tuple) +{ + CALL_SUBTEST(basic_tuple_test()); + CALL_SUBTEST(eigen_tuple_test()); +} -- GitLab