From 80a75eb53fd1a1a09cd190a28765870013e695e8 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 5 Dec 2025 18:25:47 +0000 Subject: [PATCH 1/9] Fix naming of predux_half for RVV when LMUL > 1 --- Eigen/src/Core/arch/RVV10/PacketMath2.h | 20 ++++++++++---------- Eigen/src/Core/arch/RVV10/PacketMathFP16.h | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Eigen/src/Core/arch/RVV10/PacketMath2.h b/Eigen/src/Core/arch/RVV10/PacketMath2.h index 1fda51131..e230ba16b 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMath2.h +++ b/Eigen/src/Core/arch/RVV10/PacketMath2.h @@ -266,7 +266,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet2Xi>::type -predux_half_dowto4(const Packet4Xi& a) { +predux_half(const Packet4Xi& a) { return __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(a, 0), __riscv_vget_v_i32m4_i32m2(a, 1), unpacket_traits::size); } @@ -275,7 +275,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet1Xi>::type -predux_half_dowto4(const Packet2Xi& a) { +predux_half(const Packet2Xi& a) { return __riscv_vadd_vv_i32m1(__riscv_vget_v_i32m2_i32m1(a, 0), __riscv_vget_v_i32m2_i32m1(a, 1), unpacket_traits::size); } @@ -611,7 +611,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet2Xf>::type -predux_half_dowto4(const Packet4Xf& a) { +predux_half(const Packet4Xf& a) { return __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(a, 0), __riscv_vget_v_f32m4_f32m2(a, 1), unpacket_traits::size); } @@ -620,7 +620,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet1Xf>::type -predux_half_dowto4(const Packet2Xf& a) { +predux_half(const Packet2Xf& a) { return __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(a, 0), __riscv_vget_v_f32m2_f32m1(a, 1), unpacket_traits::size); } @@ -876,7 +876,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet2Xl>::type -predux_half_dowto4(const Packet4Xl& a) { +predux_half(const Packet4Xl& a) { return __riscv_vadd_vv_i64m2(__riscv_vget_v_i64m4_i64m2(a, 0), __riscv_vget_v_i64m4_i64m2(a, 1), unpacket_traits::size); } @@ -885,7 +885,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet1Xl>::type -predux_half_dowto4(const Packet2Xl& a) { +predux_half(const Packet2Xl& a) { return __riscv_vadd_vv_i64m1(__riscv_vget_v_i64m2_i64m1(a, 0), __riscv_vget_v_i64m2_i64m1(a, 1), unpacket_traits::size); } @@ -1222,7 +1222,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet2Xd>::type -predux_half_dowto4(const Packet4Xd& a) { +predux_half(const Packet4Xd& a) { return __riscv_vfadd_vv_f64m2(__riscv_vget_v_f64m4_f64m2(a, 0), __riscv_vget_v_f64m4_f64m2(a, 1), unpacket_traits::size); } @@ -1231,7 +1231,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet1Xd>::type -predux_half_dowto4(const Packet2Xd& a) { +predux_half(const Packet2Xd& a) { return __riscv_vfadd_vv_f64m1(__riscv_vget_v_f64m2_f64m1(a, 0), __riscv_vget_v_f64m2_f64m1(a, 1), unpacket_traits::size); } @@ -1486,7 +1486,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet2Xs>::type -predux_half_dowto4(const Packet4Xs& a) { +predux_half(const Packet4Xs& a) { return __riscv_vadd_vv_i16m2(__riscv_vget_v_i16m4_i16m2(a, 0), __riscv_vget_v_i16m4_i16m2(a, 1), unpacket_traits::size); } @@ -1495,7 +1495,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, Packet1Xs>::type -predux_half_dowto4(const Packet2Xs& a) { +predux_half(const Packet2Xs& a) { return __riscv_vadd_vv_i16m1(__riscv_vget_v_i16m2_i16m1(a, 0), __riscv_vget_v_i16m2_i16m1(a, 1), unpacket_traits::size); } diff --git a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h index fbda19138..848e0ca0a 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h @@ -811,7 +811,7 @@ template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, PacketXh>::type -predux_half_dowto4(const Packet2Xh& a) { +predux_half(const Packet2Xh& a) { return __riscv_vfadd_vv_f16m1(__riscv_vget_v_f16m2_f16m1(a, 0), __riscv_vget_v_f16m2_f16m1(a, 1), unpacket_traits::size); } -- GitLab From 7a5e4bd4e049180a54648f6c00d7753ac2730554 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 10 Dec 2025 03:36:26 +0000 Subject: [PATCH 2/9] Fix FP16 for RVV so that it will compile for gcc. --- Eigen/src/Core/arch/RVV10/PacketMathFP16.h | 74 +++++++++++++--------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h index 848e0ca0a..d3cbf933a 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h @@ -16,8 +16,10 @@ namespace Eigen { namespace internal { -typedef vfloat16m1_t Packet1Xh __attribute__((riscv_rvv_vector_bits(EIGEN_RISCV64_RVV_VL))); -typedef vfloat16m2_t Packet2Xh __attribute__((riscv_rvv_vector_bits(EIGEN_RISCV64_RVV_VL * 2))); +typedef eigen_packet_wrapper + Packet1Xh; +typedef eigen_packet_wrapper + Packet2Xh; #if EIGEN_RISCV64_DEFAULT_LMUL == 1 typedef Packet1Xh PacketXh; @@ -145,7 +147,7 @@ EIGEN_STRONG_INLINE PacketXh ptrue(const PacketXh& /*a*/) { template <> EIGEN_STRONG_INLINE PacketXh pzero(const PacketXh& /*a*/) { - return __riscv_vfmv_v_f_f16m1(static_cast(0.0), unpacket_traits::size); + return __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size); } template <> @@ -155,7 +157,7 @@ EIGEN_STRONG_INLINE PacketXh pabs(const PacketXh& a) { template <> EIGEN_STRONG_INLINE PacketXh pset1(const Eigen::half& from) { - return __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(from), unpacket_traits::size); + return __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(from), unpacket_traits::size); } template <> @@ -166,8 +168,9 @@ EIGEN_STRONG_INLINE PacketXh pset1frombits(numext::uint16_t from) { template <> EIGEN_STRONG_INLINE PacketXh plset(const Eigen::half& a) { PacketXh idx = - __riscv_vfcvt_f_x_v_f16m1(__riscv_vid_v_i16m1(unpacket_traits::size), unpacket_traits::size); - return __riscv_vfadd_vf_f16m1(idx, a, unpacket_traits::size); + __riscv_vfcvt_f_x_v_f16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(unpacket_traits::size)), + unpacket_traits::size); + return __riscv_vfadd_vf_f16m1(idx, numext::bit_cast<_Float16>(a), unpacket_traits::size); } template <> @@ -222,13 +225,14 @@ EIGEN_STRONG_INLINE PacketXh pnmsub(const PacketXh& a, const PacketXh& b, const template <> EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, const PacketXh& b) { + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); PacketXh nans = - __riscv_vfmv_v_f_f16m1((std::numeric_limits::quiet_NaN)(), unpacket_traits::size); + __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); - return __riscv_vfmin_vv_f16m1_tum(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmin_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); } template <> @@ -243,13 +247,14 @@ EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, template <> EIGEN_STRONG_INLINE PacketXh pmax(const PacketXh& a, const PacketXh& b) { + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); PacketXh nans = - __riscv_vfmv_v_f_f16m1((std::numeric_limits::quiet_NaN)(), unpacket_traits::size); + __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); - return __riscv_vfmax_vv_f16m1_tum(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmax_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); } template <> @@ -283,7 +288,7 @@ EIGEN_STRONG_INLINE PacketXh pcmp_eq(const PacketXh& a, const PacketXh template <> EIGEN_STRONG_INLINE PacketXh pcmp_lt_or_nan(const PacketXh& a, const PacketXh& b) { PacketMask16 mask = __riscv_vmfge_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vfmerge_vfm_f16m1(ptrue(a), static_cast(0.0), mask, + return __riscv_vfmerge_vfm_f16m1(ptrue(a), static_cast<_Float16>(0.0), mask, unpacket_traits::size); } @@ -380,7 +385,7 @@ EIGEN_STRONG_INLINE PacketXh print(const PacketXh& a) { const PacketXh abs_a = pabs(a); PacketMask16 mask = __riscv_vmfne_vv_f16m1_b16(a, a, unpacket_traits::size); - const PacketXh x = __riscv_vfadd_vv_f16m1_tum(mask, a, a, a, unpacket_traits::size); + const PacketXh x = __riscv_vfadd_vv_f16m1_tumu(mask, a, a, a, unpacket_traits::size); const PacketXh new_x = __riscv_vfcvt_f_x_v_f16m1(__riscv_vfcvt_x_f_v_i16m1(a, unpacket_traits::size), unpacket_traits::size); @@ -394,7 +399,7 @@ EIGEN_STRONG_INLINE PacketXh pfloor(const PacketXh& a) { PacketXh tmp = print(a); // If greater, subtract one. PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, tmp, unpacket_traits::size); - return __riscv_vfsub_vf_f16m1_tum(mask, tmp, tmp, static_cast(1.0), unpacket_traits::size); + return __riscv_vfsub_vf_f16m1_tumu(mask, tmp, tmp, static_cast<_Float16>(1.0), unpacket_traits::size); } template <> @@ -407,7 +412,7 @@ EIGEN_STRONG_INLINE PacketXh preverse(const PacketXh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux(const PacketXh& a) { return static_cast(__riscv_vfmv_f(__riscv_vfredusum_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(static_cast(0.0), unpacket_traits::size), + a, __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size), unpacket_traits::size))); } @@ -442,15 +447,17 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul(const PacketXh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux_min(const PacketXh& a) { + const Eigen::half max = (std::numeric_limits::max)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1((std::numeric_limits::max)(), unpacket_traits::size), + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size), unpacket_traits::size))); } template <> EIGEN_STRONG_INLINE Eigen::half predux_max(const PacketXh& a) { + const Eigen::half min = (std::numeric_limits::min)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(-(std::numeric_limits::max)(), unpacket_traits::size), + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size), unpacket_traits::size))); } @@ -487,7 +494,7 @@ EIGEN_STRONG_INLINE Packet2Xh ptrue(const Packet2Xh& /*a*/) { template <> EIGEN_STRONG_INLINE Packet2Xh pzero(const Packet2Xh& /*a*/) { - return __riscv_vfmv_v_f_f16m2(static_cast(0.0), unpacket_traits::size); + return __riscv_vfmv_v_f_f16m2(static_cast<_Float16>(0.0), unpacket_traits::size); } template <> @@ -497,7 +504,7 @@ EIGEN_STRONG_INLINE Packet2Xh pabs(const Packet2Xh& a) { template <> EIGEN_STRONG_INLINE Packet2Xh pset1(const Eigen::half& from) { - return __riscv_vfmv_v_f_f16m2(static_cast<_Float16>(from), unpacket_traits::size); + return __riscv_vfmv_v_f_f16m2(numext::bit_cast<_Float16>(from), unpacket_traits::size); } template <> @@ -507,9 +514,10 @@ EIGEN_STRONG_INLINE Packet2Xh pset1frombits(numext::uint16_t from) { template <> EIGEN_STRONG_INLINE Packet2Xh plset(const Eigen::half& a) { - Packet2Xh idx = __riscv_vfcvt_f_x_v_f16m2(__riscv_vid_v_i16m2(unpacket_traits::size), - unpacket_traits::size); - return __riscv_vfadd_vf_f16m2(idx, a, unpacket_traits::size); + Packet2Xh idx = __riscv_vfcvt_f_x_v_f16m2( + __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(unpacket_traits::size)), + unpacket_traits::size); + return __riscv_vfadd_vf_f16m2(idx, numext::bit_cast<_Float16>(a), unpacket_traits::size); } template <> @@ -564,13 +572,14 @@ EIGEN_STRONG_INLINE Packet2Xh pnmsub(const Packet2Xh& a, const Packet2Xh& b, con template <> EIGEN_STRONG_INLINE Packet2Xh pmin(const Packet2Xh& a, const Packet2Xh& b) { + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); Packet2Xh nans = - __riscv_vfmv_v_f_f16m2((std::numeric_limits::quiet_NaN)(), unpacket_traits::size); + __riscv_vfmv_v_f_f16m2(numext::bit_cast<_Float16>(nan), unpacket_traits::size); PacketMask8 mask = __riscv_vmfeq_vv_f16m2_b8(a, a, unpacket_traits::size); PacketMask8 mask2 = __riscv_vmfeq_vv_f16m2_b8(b, b, unpacket_traits::size); mask = __riscv_vmand_mm_b8(mask, mask2, unpacket_traits::size); - return __riscv_vfmin_vv_f16m2_tum(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmin_vv_f16m2_tumu(mask, nans, a, b, unpacket_traits::size); } template <> @@ -585,13 +594,14 @@ EIGEN_STRONG_INLINE Packet2Xh pmin(const Packet2Xh& template <> EIGEN_STRONG_INLINE Packet2Xh pmax(const Packet2Xh& a, const Packet2Xh& b) { + const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); Packet2Xh nans = - __riscv_vfmv_v_f_f16m2((std::numeric_limits::quiet_NaN)(), unpacket_traits::size); + __riscv_vfmv_v_f_f16m2(numext::bit_cast<_Float16>(nan), unpacket_traits::size); PacketMask8 mask = __riscv_vmfeq_vv_f16m2_b8(a, a, unpacket_traits::size); PacketMask8 mask2 = __riscv_vmfeq_vv_f16m2_b8(b, b, unpacket_traits::size); mask = __riscv_vmand_mm_b8(mask, mask2, unpacket_traits::size); - return __riscv_vfmax_vv_f16m2_tum(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmax_vv_f16m2_tumu(mask, nans, a, b, unpacket_traits::size); } template <> @@ -628,7 +638,7 @@ EIGEN_STRONG_INLINE Packet2Xh pcmp_eq(const Packet2Xh& a, const Packe template <> EIGEN_STRONG_INLINE Packet2Xh pcmp_lt_or_nan(const Packet2Xh& a, const Packet2Xh& b) { PacketMask8 mask = __riscv_vmfge_vv_f16m2_b8(a, b, unpacket_traits::size); - return __riscv_vfmerge_vfm_f16m2(ptrue(a), static_cast(0.0), mask, + return __riscv_vfmerge_vfm_f16m2(ptrue(a), static_cast<_Float16>(0.0), mask, unpacket_traits::size); } @@ -730,7 +740,7 @@ EIGEN_STRONG_INLINE Packet2Xh print(const Packet2Xh& a) { const Packet2Xh abs_a = pabs(a); PacketMask8 mask = __riscv_vmfne_vv_f16m2_b8(a, a, unpacket_traits::size); - const Packet2Xh x = __riscv_vfadd_vv_f16m2_tum(mask, a, a, a, unpacket_traits::size); + const Packet2Xh x = __riscv_vfadd_vv_f16m2_tumu(mask, a, a, a, unpacket_traits::size); const Packet2Xh new_x = __riscv_vfcvt_f_x_v_f16m2( __riscv_vfcvt_x_f_v_i16m2(a, unpacket_traits::size), unpacket_traits::size); @@ -744,7 +754,7 @@ EIGEN_STRONG_INLINE Packet2Xh pfloor(const Packet2Xh& a) { Packet2Xh tmp = print(a); // If greater, subtract one. PacketMask8 mask = __riscv_vmflt_vv_f16m2_b8(a, tmp, unpacket_traits::size); - return __riscv_vfsub_vf_f16m2_tum(mask, tmp, tmp, static_cast(1.0), unpacket_traits::size); + return __riscv_vfsub_vf_f16m2_tumu(mask, tmp, tmp, static_cast<_Float16>(1.0), unpacket_traits::size); } template <> @@ -758,7 +768,7 @@ EIGEN_STRONG_INLINE Packet2Xh preverse(const Packet2Xh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux(const Packet2Xh& a) { return static_cast(__riscv_vfmv_f(__riscv_vfredusum_vs_f16m2_f16m1( - a, __riscv_vfmv_v_f_f16m1(static_cast(0.0), unpacket_traits::size / 4), + a, __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size / 2), unpacket_traits::size))); } @@ -770,15 +780,17 @@ EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet2Xh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet2Xh& a) { + const Eigen::half max = (std::numeric_limits::max)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m2_f16m1( - a, __riscv_vfmv_v_f_f16m1((std::numeric_limits::max)(), unpacket_traits::size / 4), + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size / 2), unpacket_traits::size))); } template <> EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet2Xh& a) { + const Eigen::half min = (std::numeric_limits::min)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m2_f16m1( - a, __riscv_vfmv_v_f_f16m1(-(std::numeric_limits::max)(), unpacket_traits::size / 4), + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size / 2), unpacket_traits::size))); } -- GitLab From 790a4636e1d2bc864732648d43f03a3d7a0de41e Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 12 Dec 2025 13:18:25 +0000 Subject: [PATCH 3/9] Add basic support for packetmath for BF16 RVV. --- Eigen/Core | 3 + Eigen/src/Core/arch/RVV10/PacketMathBF16.h | 724 +++++++++++++++++++ Eigen/src/Core/util/ConfigureVectorization.h | 4 + 3 files changed, 731 insertions(+) create mode 100644 Eigen/src/Core/arch/RVV10/PacketMathBF16.h diff --git a/Eigen/Core b/Eigen/Core index 9f81658b0..86608f3c8 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -288,6 +288,9 @@ using std::ptrdiff_t; #if defined EIGEN_VECTORIZE_RVV10FP16 #include "src/Core/arch/RVV10/PacketMathFP16.h" #endif +#if defined EIGEN_VECTORIZE_RVV10BF16 +#include "src/Core/arch/RVV10/PacketMathBF16.h" +#endif #elif defined EIGEN_VECTORIZE_ZVECTOR #include "src/Core/arch/ZVector/PacketMath.h" #include "src/Core/arch/ZVector/MathFunctions.h" diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h new file mode 100644 index 000000000..6ee52b5f1 --- /dev/null +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -0,0 +1,724 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2025 Chip Kerchner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_BF16_RVV10_H +#define EIGEN_PACKET_MATH_BF16_RVV10_H + +// IWYU pragma: private +#include "../../InternalHeaderCheck.h" + +namespace Eigen { +namespace internal { + +typedef eigen_packet_wrapper + Packet1Xbf; +typedef eigen_packet_wrapper + Packet2Xbf; + +#if EIGEN_RISCV64_DEFAULT_LMUL == 1 +typedef Packet1Xbf PacketXbf; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet1Xbf type; + typedef Packet1Xbf half; + + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = rvv_packet_size_selector::size, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasReduxp = 0, + + HasCmp = 1, + HasDiv = 1, + HasRound = 1, + + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 0, + HasSqrt = 1, + HasTanh = 0, + HasErf = 0 + }; +}; + +#else +typedef Packet2Xbf PacketXbf; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet2Xbf type; + typedef Packet1Xbf half; + + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = rvv_packet_size_selector::size, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasReduxp = 0, + + HasCmp = 1, + HasDiv = 1, + HasRound = 1, + + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 0, + HasSqrt = 1, + HasTanh = 0, + HasErf = 0 + }; +}; +#endif + +template <> +struct unpacket_traits { + typedef bfloat16 type; + typedef Packet1Xbf half; // Half not yet implemented + typedef PacketXs integer_packet; + typedef numext::uint8_t mask_t; + + enum { + size = rvv_packet_size_selector::size, + alignment = rvv_packet_alignment_selector::alignment, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template <> +struct unpacket_traits { + typedef bfloat16 type; + typedef Packet1Xbf half; + typedef Packet2Xs integer_packet; + typedef numext::uint8_t mask_t; + + enum { + size = rvv_packet_size_selector::size, + alignment = rvv_packet_alignment_selector::alignment, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +/********************************* PacketXbf ************************************/ + +EIGEN_STRONG_INLINE Packet2Xf Bf16ToF32(const PacketXbf& a) { + return __riscv_vfwcvtbf16_f_f_v_f32m2(a, unpacket_traits::size); +} + +EIGEN_STRONG_INLINE PacketXbf F32ToBf16(const Packet2Xf& a) { + return __riscv_vfncvtbf16_f_f_w_bf16m1(a, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf ptrue(const PacketXbf& /*a*/) { + return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(0xffffu, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pzero(const PacketXbf& /*a*/) { + return __riscv_vreinterpret_bf16m1( + __riscv_vmv_v_x_i16m1(numext::bit_cast(static_cast<__bf16>(0.0)), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pabs(const PacketXbf& a) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vx_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), 0x7fffu, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pset1(const bfloat16& from) { + return __riscv_vreinterpret_bf16m1( + __riscv_vmv_v_x_i16m1(numext::bit_cast(from), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pset1frombits(numext::uint16_t from) { + return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf plset(const bfloat16& a) { + return F32ToBf16(plset(static_cast(a))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf padd(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf psub(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pnegate(const PacketXbf& a) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vxor_vx_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), 0x8000u, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pconj(const PacketXbf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmul(const PacketXbf& a, const PacketXbf& b) { + Packet2Xf c; + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(pzero(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pdiv(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmadd(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmsub(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(pnegate(c)), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pnmadd(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), pnegate(a), b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pnmsub(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { + return pnegate(F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmin(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmin(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmin(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pcmp_le(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pcmp_lt(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pcmp_eq(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pcmp_lt_or_nan(const PacketXbf& a, const PacketXbf& b) { + return F32ToBf16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +// Logical Operations are not supported for bfloat16, so reinterpret casts +template <> +EIGEN_STRONG_INLINE PacketXbf pand(const PacketXbf& a, const PacketXbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf por(const PacketXbf& a, const PacketXbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vor_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pxor(const PacketXbf& a, const PacketXbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vxor_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pandnot(const PacketXbf& a, const PacketXbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), + __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pload(const bfloat16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return __riscv_vle16_v_bf16m1(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf ploadu(const bfloat16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return __riscv_vle16_v_bf16m1(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf ploaddup(const bfloat16* from) { + PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vand_vx_u16m1(idx, 0xfffeu, unpacket_traits::size); + return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf ploadquad(const bfloat16* from) { + PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, 0xfffcu, unpacket_traits::size), 1, + unpacket_traits::size); + return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstore(bfloat16* to, const PacketXbf& from) { + EIGEN_DEBUG_ALIGNED_STORE __riscv_vse16_v_bf16m1(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const PacketXbf& from) { + EIGEN_DEBUG_UNALIGNED_STORE __riscv_vse16_v_bf16m1(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline PacketXbf pgather(const bfloat16* from, Index stride) { + return __riscv_vlse16_v_bf16m1(reinterpret_cast(from), stride * sizeof(bfloat16), + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(bfloat16* to, const PacketXbf& from, Index stride) { + __riscv_vsse16(reinterpret_cast<__bf16*>(to), stride * sizeof(bfloat16), from, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 pfirst(const PacketXbf& a) { + return numext::bit_cast(__riscv_vmv_x_s_i16m1_i16(__riscv_vreinterpret_v_bf16m1_i16m1(a))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf psqrt(const PacketXbf& a) { + return F32ToBf16(psqrt(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf print(const PacketXbf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf pfloor(const PacketXbf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE PacketXbf preverse(const PacketXbf& a) { + return __riscv_vreinterpret_v_i16m1_bf16m1(preverse(__riscv_vreinterpret_v_bf16m1_i16m1(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux(const PacketXbf& a) { + return static_cast(predux(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_mul(const PacketXbf& a) { + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_min(const PacketXbf& a) { + return static_cast(predux_min(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_max(const PacketXbf& a) { + return static_cast(predux_max(Bf16ToF32(a))); +} + +template +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + bfloat16 buffer[unpacket_traits::size * N]; + int i = 0; + + for (i = 0; i < N; i++) { + __riscv_vsse16(reinterpret_cast<__bf16*>(&buffer[i]), N * sizeof(bfloat16), kernel.packet[i], + unpacket_traits::size); + } + + for (i = 0; i < N; i++) { + kernel.packet[i] = __riscv_vle16_v_bf16m1(reinterpret_cast<__bf16*>(&buffer[i * unpacket_traits::size]), + unpacket_traits::size); + } +} + +/********************************* Packet2Xbf ************************************/ + +EIGEN_STRONG_INLINE Packet4Xf Bf16ToF32(const Packet2Xbf& a) { + return __riscv_vfwcvtbf16_f_f_v_f32m4(a, unpacket_traits::size); +} + +EIGEN_STRONG_INLINE Packet2Xbf F32ToBf16(const Packet4Xf& a) { + return __riscv_vfncvtbf16_f_f_w_bf16m2(a, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ptrue(const Packet2Xbf& /*a*/) { + return __riscv_vreinterpret_bf16m2(__riscv_vmv_v_x_u16m2(0xffffu, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pzero(const Packet2Xbf& /*a*/) { + return __riscv_vreinterpret_bf16m2( + __riscv_vmv_v_x_i16m2(numext::bit_cast(static_cast<__bf16>(0.0)), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pabs(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vand_vx_u16m2( + __riscv_vreinterpret_v_bf16m2_u16m2(a), 0x7fffu, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pset1(const bfloat16& from) { + return __riscv_vreinterpret_bf16m2( + __riscv_vmv_v_x_i16m2(numext::bit_cast(from), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pset1frombits(numext::uint16_t from) { + return __riscv_vreinterpret_bf16m2(__riscv_vmv_v_x_u16m2(from, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf plset(const bfloat16& a) { + return F32ToBf16(plset(static_cast(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf padd(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf psub(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pnegate(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vxor_vx_u16m2( + __riscv_vreinterpret_v_bf16m2_u16m2(a), 0x8000u, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pconj(const Packet2Xbf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmul(const Packet2Xbf& a, const Packet2Xbf& b) { + Packet4Xf c; + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(pzero(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pdiv(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmadd(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmsub(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(pnegate(c)), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pnmadd(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(c), pnegate(a), b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pnmsub(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return pnegate(F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(c), a, b, unpacket_traits::size))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmin(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmin(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmin(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmax(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmax(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmax(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_le(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_lt(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_eq(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_lt_or_nan(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +// Logical Operations are not supported for bflaot16, so reinterpret casts +template <> +EIGEN_STRONG_INLINE Packet2Xbf pand(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vand_vv_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vreinterpret_v_bf16m2_u16m2(b), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf por(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vor_vv_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vreinterpret_v_bf16m2_u16m2(b), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pxor(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vxor_vv_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vreinterpret_v_bf16m2_u16m2(b), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pandnot(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vand_vv_u16m2( + __riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vnot_v_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(b), unpacket_traits::size), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pload(const bfloat16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return __riscv_vle16_v_bf16m2(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ploadu(const bfloat16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return __riscv_vle16_v_bf16m2(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ploaddup(const bfloat16* from) { + Packet2Xsu idx = __riscv_vid_v_u16m2(unpacket_traits::size); + idx = __riscv_vand_vx_u16m2(idx, 0xfffeu, unpacket_traits::size); + return __riscv_vloxei16_v_bf16m2(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ploadquad(const bfloat16* from) { + Packet2Xsu idx = __riscv_vid_v_u16m2(unpacket_traits::size); + idx = __riscv_vsrl_vx_u16m2(__riscv_vand_vx_u16m2(idx, 0xfffcu, unpacket_traits::size), 1, + unpacket_traits::size); + return __riscv_vloxei16_v_bf16m2(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet2Xbf& from) { + EIGEN_DEBUG_ALIGNED_STORE __riscv_vse16_v_bf16m2(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet2Xbf& from) { + EIGEN_DEBUG_UNALIGNED_STORE __riscv_vse16_v_bf16m2(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet2Xbf pgather(const bfloat16* from, Index stride) { + return __riscv_vlse16_v_bf16m2(reinterpret_cast(from), stride * sizeof(bfloat16), + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(bfloat16* to, const Packet2Xbf& from, + Index stride) { + __riscv_vsse16(reinterpret_cast<__bf16*>(to), stride * sizeof(bfloat16), from, + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet2Xbf& a) { + return static_cast(__riscv_vmv_x_s_i16m2_i16(__riscv_vreinterpret_v_bf16m2_i16m2(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf psqrt(const Packet2Xbf& a) { + return F32ToBf16(psqrt(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf print(const Packet2Xbf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pfloor(const Packet2Xbf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf preverse(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_i16m2_bf16m2(preverse(__riscv_vreinterpret_v_bf16m2_i16m2(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux(const Packet2Xbf& a) { + return static_cast(predux(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet2Xbf& a) { + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet2Xbf& a) { + return static_cast(predux_min(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet2Xbf& a) { + return static_cast(predux_max(Bf16ToF32(a))); +} + +template +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + bfloat16 buffer[unpacket_traits::size * N]; + int i = 0; + + for (i = 0; i < N; i++) { + __riscv_vsse16(reinterpret_cast<__bf16*>(&buffer[i]), N * sizeof(bfloat16), kernel.packet[i], + unpacket_traits::size); + } + + for (i = 0; i < N; i++) { + kernel.packet[i] = + __riscv_vle16_v_bf16m2(reinterpret_cast<__bf16*>(&buffer[i * unpacket_traits::size]), + unpacket_traits::size); + } +} + +template +EIGEN_STRONG_INLINE +typename std::enable_if::value && (unpacket_traits::size % 8) == 0, + PacketXbf>::type +predux_half(const Packet2Xbf& a) { + return padd(__riscv_vget_v_bf16m2_bf16m1(a, 0), __riscv_vget_v_bf16m2_bf16m1(a, 1)); +} + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_PACKET_MATH_BF16_RVV10_H diff --git a/Eigen/src/Core/util/ConfigureVectorization.h b/Eigen/src/Core/util/ConfigureVectorization.h index d41d05db1..f86d7a322 100644 --- a/Eigen/src/Core/util/ConfigureVectorization.h +++ b/Eigen/src/Core/util/ConfigureVectorization.h @@ -467,6 +467,10 @@ extern "C" { #endif #endif +#if defined(__riscv_zvfbfwma) +#define EIGEN_VECTORIZE_RVV10BF16 +#endif + #endif // defined(EIGEN_ARCH_RISCV) #elif (defined __s390x__ && defined __VEC__) -- GitLab From 269ab61161fe003fcf23986aa38317064d3d01d7 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 12 Dec 2025 15:14:59 +0000 Subject: [PATCH 4/9] Fix packetmath issues with BF16 RVV. --- Eigen/src/Core/arch/RVV10/PacketMathBF16.h | 28 ++++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h index 6ee52b5f1..63259fe2d 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -48,10 +48,11 @@ struct packet_traits : default_packet_traits { HasSetLinear = 0, HasBlend = 0, HasReduxp = 0, + HasSign = 0, HasCmp = 1, HasDiv = 1, - HasRound = 1, + HasRound = 0, HasSin = 0, HasCos = 0, @@ -90,10 +91,11 @@ struct packet_traits : default_packet_traits { HasSetLinear = 0, HasBlend = 0, HasReduxp = 0, + HasSign = 0, HasCmp = 1, HasDiv = 1, - HasRound = 1, + HasRound = 0, HasSin = 0, HasCos = 0, @@ -150,7 +152,7 @@ EIGEN_STRONG_INLINE PacketXbf F32ToBf16(const Packet2Xf& a) { template <> EIGEN_STRONG_INLINE PacketXbf ptrue(const PacketXbf& /*a*/) { - return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(0xffffu, unpacket_traits::size)); + return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(static_cast(0xffffu), unpacket_traits::size)); } template <> @@ -162,7 +164,7 @@ EIGEN_STRONG_INLINE PacketXbf pzero(const PacketXbf& /*a*/) { template <> EIGEN_STRONG_INLINE PacketXbf pabs(const PacketXbf& a) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vx_u16m1( - __riscv_vreinterpret_v_bf16m1_u16m1(a), 0x7fffu, unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x7fffu), unpacket_traits::size)); } template <> @@ -194,7 +196,7 @@ EIGEN_STRONG_INLINE PacketXbf psub(const PacketXbf& a, const PacketXb template <> EIGEN_STRONG_INLINE PacketXbf pnegate(const PacketXbf& a) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vxor_vx_u16m1( - __riscv_vreinterpret_v_bf16m1_u16m1(a), 0x8000u, unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x8000u), unpacket_traits::size)); } template <> @@ -260,7 +262,7 @@ EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, template <> EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, const PacketXbf& b) { - return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); } template <> @@ -325,14 +327,14 @@ EIGEN_STRONG_INLINE PacketXbf ploadu(const bfloat16* from) { template <> EIGEN_STRONG_INLINE PacketXbf ploaddup(const bfloat16* from) { PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vand_vx_u16m1(idx, 0xfffeu, unpacket_traits::size); + idx = __riscv_vand_vx_u16m1(idx, static_cast(0xfffeu), unpacket_traits::size); return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); } template <> EIGEN_STRONG_INLINE PacketXbf ploadquad(const bfloat16* from) { PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, 0xfffcu, unpacket_traits::size), 1, + idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, static_cast(0xfffcu), unpacket_traits::size), 1, unpacket_traits::size); return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); } @@ -433,7 +435,7 @@ EIGEN_STRONG_INLINE Packet2Xbf F32ToBf16(const Packet4Xf& a) { template <> EIGEN_STRONG_INLINE Packet2Xbf ptrue(const Packet2Xbf& /*a*/) { - return __riscv_vreinterpret_bf16m2(__riscv_vmv_v_x_u16m2(0xffffu, unpacket_traits::size)); + return __riscv_vreinterpret_bf16m2(__riscv_vmv_v_x_u16m2(static_cast(0xffffu), unpacket_traits::size)); } template <> @@ -445,7 +447,7 @@ EIGEN_STRONG_INLINE Packet2Xbf pzero(const Packet2Xbf& /*a*/) { template <> EIGEN_STRONG_INLINE Packet2Xbf pabs(const Packet2Xbf& a) { return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vand_vx_u16m2( - __riscv_vreinterpret_v_bf16m2_u16m2(a), 0x7fffu, unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m2_u16m2(a), static_cast(0x7fffu), unpacket_traits::size)); } template <> @@ -477,7 +479,7 @@ EIGEN_STRONG_INLINE Packet2Xbf psub(const Packet2Xbf& a, const Packe template <> EIGEN_STRONG_INLINE Packet2Xbf pnegate(const Packet2Xbf& a) { return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vxor_vx_u16m2( - __riscv_vreinterpret_v_bf16m2_u16m2(a), 0x8000u, unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m2_u16m2(a), static_cast(0x8000u), unpacket_traits::size)); } template <> @@ -611,14 +613,14 @@ EIGEN_STRONG_INLINE Packet2Xbf ploadu(const bfloat16* from) { template <> EIGEN_STRONG_INLINE Packet2Xbf ploaddup(const bfloat16* from) { Packet2Xsu idx = __riscv_vid_v_u16m2(unpacket_traits::size); - idx = __riscv_vand_vx_u16m2(idx, 0xfffeu, unpacket_traits::size); + idx = __riscv_vand_vx_u16m2(idx, static_cast(0xfffeu), unpacket_traits::size); return __riscv_vloxei16_v_bf16m2(reinterpret_cast(from), idx, unpacket_traits::size); } template <> EIGEN_STRONG_INLINE Packet2Xbf ploadquad(const bfloat16* from) { Packet2Xsu idx = __riscv_vid_v_u16m2(unpacket_traits::size); - idx = __riscv_vsrl_vx_u16m2(__riscv_vand_vx_u16m2(idx, 0xfffcu, unpacket_traits::size), 1, + idx = __riscv_vsrl_vx_u16m2(__riscv_vand_vx_u16m2(idx, static_cast(0xfffcu), unpacket_traits::size), 1, unpacket_traits::size); return __riscv_vloxei16_v_bf16m2(reinterpret_cast(from), idx, unpacket_traits::size); } -- GitLab From afd496549722f7bc4670a2f68e79a6038f3023fa Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 12 Dec 2025 19:16:32 +0000 Subject: [PATCH 5/9] Add psignbit - fixes array_cwise_22 failure. --- Eigen/src/Core/arch/RVV10/PacketMathBF16.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h index 63259fe2d..68168d3d8 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -199,6 +199,12 @@ EIGEN_STRONG_INLINE PacketXbf pnegate(const PacketXbf& a) { __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x8000u), unpacket_traits::size)); } +template <> +EIGEN_STRONG_INLINE PacketXbf psignbit(const PacketXbf& a) { + return __riscv_vreinterpret_v_i16m1_bf16m1(__riscv_vsra_vx_i16m1( + __riscv_vreinterpret_v_bf16m1_i16m1(a), 15, unpacket_traits::size)); +} + template <> EIGEN_STRONG_INLINE PacketXbf pconj(const PacketXbf& a) { return a; @@ -482,6 +488,12 @@ EIGEN_STRONG_INLINE Packet2Xbf pnegate(const Packet2Xbf& a) { __riscv_vreinterpret_v_bf16m2_u16m2(a), static_cast(0x8000u), unpacket_traits::size)); } +template <> +EIGEN_STRONG_INLINE Packet2Xbf psignbit(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_i16m2_bf16m2(__riscv_vsra_vx_i16m2( + __riscv_vreinterpret_v_bf16m2_i16m2(a), 15, unpacket_traits::size)); +} + template <> EIGEN_STRONG_INLINE Packet2Xbf pconj(const Packet2Xbf& a) { return a; -- GitLab From 695ae6d9dad4d938e128bb0f9e450387e4db9d29 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 12 Dec 2025 21:07:45 +0000 Subject: [PATCH 6/9] Added pcasts for BF16 RVV. --- Eigen/src/Core/arch/RVV10/PacketMathBF16.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h index 68168d3d8..cb332a7e2 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -732,6 +732,26 @@ predux_half(const Packet2Xbf& a) { return padd(__riscv_vget_v_bf16m2_bf16m1(a, 0), __riscv_vget_v_bf16m2_bf16m1(a, 1)); } +template <> +EIGEN_STRONG_INLINE Packet1Xbf pcast(const Packet1Xs& a) { + return __riscv_vreinterpret_v_i16m1_bf16m1(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcast(const Packet2Xs& a) { + return __riscv_vreinterpret_v_i16m2_bf16m2(a); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xs pcast(const Packet1Xbf& a) { + return __riscv_vreinterpret_v_bf16m1_i16m1(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xs pcast(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_bf16m2_i16m2(a); +} + } // namespace internal } // namespace Eigen -- GitLab From ca773e36bcf12d4cc18c075e4af7b0ae8777b82e Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 15 Dec 2025 15:34:36 +0000 Subject: [PATCH 7/9] Use default_unpacket_traits. --- Eigen/src/Core/arch/RVV10/PacketMathBF16.h | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h index cb332a7e2..9ce501c6f 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -109,7 +109,7 @@ struct packet_traits : default_packet_traits { #endif template <> -struct unpacket_traits { +struct unpacket_traits : default_unpacket_traits { typedef bfloat16 type; typedef Packet1Xbf half; // Half not yet implemented typedef PacketXs integer_packet; @@ -118,14 +118,12 @@ struct unpacket_traits { enum { size = rvv_packet_size_selector::size, alignment = rvv_packet_alignment_selector::alignment, - vectorizable = true, - masked_load_available = false, - masked_store_available = false + vectorizable = true }; }; template <> -struct unpacket_traits { +struct unpacket_traits : default_unpacket_traits { typedef bfloat16 type; typedef Packet1Xbf half; typedef Packet2Xs integer_packet; @@ -134,9 +132,7 @@ struct unpacket_traits { enum { size = rvv_packet_size_selector::size, alignment = rvv_packet_alignment_selector::alignment, - vectorizable = true, - masked_load_available = false, - masked_store_available = false + vectorizable = true }; }; -- GitLab From 7e9990fcba94bd62d47c353cbabc1de20490627b Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 15 Dec 2025 20:54:19 +0000 Subject: [PATCH 8/9] Change PacketXbf to Packet1Xbf. --- Eigen/src/Core/arch/RVV10/PacketMathBF16.h | 182 ++++++++++----------- 1 file changed, 91 insertions(+), 91 deletions(-) diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h index 9ce501c6f..80502593c 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -112,7 +112,7 @@ template <> struct unpacket_traits : default_unpacket_traits { typedef bfloat16 type; typedef Packet1Xbf half; // Half not yet implemented - typedef PacketXs integer_packet; + typedef Packet1Xs integer_packet; typedef numext::uint8_t mask_t; enum { @@ -136,292 +136,292 @@ struct unpacket_traits : default_unpacket_traits { }; }; -/********************************* PacketXbf ************************************/ +/********************************* Packet1Xbf ************************************/ -EIGEN_STRONG_INLINE Packet2Xf Bf16ToF32(const PacketXbf& a) { - return __riscv_vfwcvtbf16_f_f_v_f32m2(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet2Xf Bf16ToF32(const Packet1Xbf& a) { + return __riscv_vfwcvtbf16_f_f_v_f32m2(a, unpacket_traits::size); } -EIGEN_STRONG_INLINE PacketXbf F32ToBf16(const Packet2Xf& a) { +EIGEN_STRONG_INLINE Packet1Xbf F32ToBf16(const Packet2Xf& a) { return __riscv_vfncvtbf16_f_f_w_bf16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXbf ptrue(const PacketXbf& /*a*/) { - return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(static_cast(0xffffu), unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xbf ptrue(const Packet1Xbf& /*a*/) { + return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(static_cast(0xffffu), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pzero(const PacketXbf& /*a*/) { +EIGEN_STRONG_INLINE Packet1Xbf pzero(const Packet1Xbf& /*a*/) { return __riscv_vreinterpret_bf16m1( - __riscv_vmv_v_x_i16m1(numext::bit_cast(static_cast<__bf16>(0.0)), unpacket_traits::size)); + __riscv_vmv_v_x_i16m1(numext::bit_cast(static_cast<__bf16>(0.0)), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pabs(const PacketXbf& a) { +EIGEN_STRONG_INLINE Packet1Xbf pabs(const Packet1Xbf& a) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vx_u16m1( - __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x7fffu), unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x7fffu), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pset1(const bfloat16& from) { +EIGEN_STRONG_INLINE Packet1Xbf pset1(const bfloat16& from) { return __riscv_vreinterpret_bf16m1( - __riscv_vmv_v_x_i16m1(numext::bit_cast(from), unpacket_traits::size)); + __riscv_vmv_v_x_i16m1(numext::bit_cast(from), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pset1frombits(numext::uint16_t from) { - return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xbf pset1frombits(numext::uint16_t from) { + return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf plset(const bfloat16& a) { +EIGEN_STRONG_INLINE Packet1Xbf plset(const bfloat16& a) { return F32ToBf16(plset(static_cast(a))); } template <> -EIGEN_STRONG_INLINE PacketXbf padd(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf padd(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf psub(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf psub(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pnegate(const PacketXbf& a) { +EIGEN_STRONG_INLINE Packet1Xbf pnegate(const Packet1Xbf& a) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vxor_vx_u16m1( - __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x8000u), unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x8000u), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf psignbit(const PacketXbf& a) { +EIGEN_STRONG_INLINE Packet1Xbf psignbit(const Packet1Xbf& a) { return __riscv_vreinterpret_v_i16m1_bf16m1(__riscv_vsra_vx_i16m1( - __riscv_vreinterpret_v_bf16m1_i16m1(a), 15, unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_i16m1(a), 15, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pconj(const PacketXbf& a) { +EIGEN_STRONG_INLINE Packet1Xbf pconj(const Packet1Xbf& a) { return a; } template <> -EIGEN_STRONG_INLINE PacketXbf pmul(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pmul(const Packet1Xbf& a, const Packet1Xbf& b) { Packet2Xf c; - return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(pzero(c), a, b, unpacket_traits::size)); + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(pzero(c), a, b, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pdiv(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pdiv(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pmadd(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { - return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xbf pmadd(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pmsub(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { - return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(pnegate(c)), a, b, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xbf pmsub(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(pnegate(c)), a, b, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pnmadd(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { - return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), pnegate(a), b, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xbf pnmadd(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), pnegate(a), b, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pnmsub(const PacketXbf& a, const PacketXbf& b, const PacketXbf& c) { - return pnegate(F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size))); +EIGEN_STRONG_INLINE Packet1Xbf pnmsub(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return pnegate(F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size))); } template <> -EIGEN_STRONG_INLINE PacketXbf pmin(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pmin(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pmin(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pmin(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pmin(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pmin(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pmax(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pmax(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pmax(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pmax(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pcmp_le(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pcmp_le(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pcmp_lt(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pcmp_lt(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pcmp_eq(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pcmp_eq(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); } template <> -EIGEN_STRONG_INLINE PacketXbf pcmp_lt_or_nan(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pcmp_lt_or_nan(const Packet1Xbf& a, const Packet1Xbf& b) { return F32ToBf16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); } // Logical Operations are not supported for bfloat16, so reinterpret casts template <> -EIGEN_STRONG_INLINE PacketXbf pand(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pand(const Packet1Xbf& a, const Packet1Xbf& b) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vv_u16m1( - __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf por(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf por(const Packet1Xbf& a, const Packet1Xbf& b) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vor_vv_u16m1( - __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pxor(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pxor(const Packet1Xbf& a, const Packet1Xbf& b) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vxor_vv_u16m1( - __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pandnot(const PacketXbf& a, const PacketXbf& b) { +EIGEN_STRONG_INLINE Packet1Xbf pandnot(const Packet1Xbf& a, const Packet1Xbf& b) { return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vv_u16m1( __riscv_vreinterpret_v_bf16m1_u16m1(a), - __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size), - unpacket_traits::size)); + __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size), + unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXbf pload(const bfloat16* from) { +EIGEN_STRONG_INLINE Packet1Xbf pload(const bfloat16* from) { EIGEN_DEBUG_ALIGNED_LOAD return __riscv_vle16_v_bf16m1(reinterpret_cast(from), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXbf ploadu(const bfloat16* from) { +EIGEN_STRONG_INLINE Packet1Xbf ploadu(const bfloat16* from) { EIGEN_DEBUG_UNALIGNED_LOAD return __riscv_vle16_v_bf16m1(reinterpret_cast(from), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXbf ploaddup(const bfloat16* from) { - PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vand_vx_u16m1(idx, static_cast(0xfffeu), unpacket_traits::size); - return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xbf ploaddup(const bfloat16* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vand_vx_u16m1(idx, static_cast(0xfffeu), unpacket_traits::size); + return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXbf ploadquad(const bfloat16* from) { - PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, static_cast(0xfffcu), unpacket_traits::size), 1, - unpacket_traits::size); - return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xbf ploadquad(const bfloat16* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, static_cast(0xfffcu), unpacket_traits::size), 1, + unpacket_traits::size); + return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE void pstore(bfloat16* to, const PacketXbf& from) { +EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet1Xbf& from) { EIGEN_DEBUG_ALIGNED_STORE __riscv_vse16_v_bf16m1(reinterpret_cast<__bf16*>(to), from, - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const PacketXbf& from) { +EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet1Xbf& from) { EIGEN_DEBUG_UNALIGNED_STORE __riscv_vse16_v_bf16m1(reinterpret_cast<__bf16*>(to), from, - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_DEVICE_FUNC inline PacketXbf pgather(const bfloat16* from, Index stride) { +EIGEN_DEVICE_FUNC inline Packet1Xbf pgather(const bfloat16* from, Index stride) { return __riscv_vlse16_v_bf16m1(reinterpret_cast(from), stride * sizeof(bfloat16), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_DEVICE_FUNC inline void pscatter(bfloat16* to, const PacketXbf& from, Index stride) { - __riscv_vsse16(reinterpret_cast<__bf16*>(to), stride * sizeof(bfloat16), from, unpacket_traits::size); +EIGEN_DEVICE_FUNC inline void pscatter(bfloat16* to, const Packet1Xbf& from, Index stride) { + __riscv_vsse16(reinterpret_cast<__bf16*>(to), stride * sizeof(bfloat16), from, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE bfloat16 pfirst(const PacketXbf& a) { +EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet1Xbf& a) { return numext::bit_cast(__riscv_vmv_x_s_i16m1_i16(__riscv_vreinterpret_v_bf16m1_i16m1(a))); } template <> -EIGEN_STRONG_INLINE PacketXbf psqrt(const PacketXbf& a) { +EIGEN_STRONG_INLINE Packet1Xbf psqrt(const Packet1Xbf& a) { return F32ToBf16(psqrt(Bf16ToF32(a))); } template <> -EIGEN_STRONG_INLINE PacketXbf print(const PacketXbf& a) { +EIGEN_STRONG_INLINE Packet1Xbf print(const Packet1Xbf& a) { return F32ToBf16(print(Bf16ToF32(a))); } template <> -EIGEN_STRONG_INLINE PacketXbf pfloor(const PacketXbf& a) { +EIGEN_STRONG_INLINE Packet1Xbf pfloor(const Packet1Xbf& a) { return F32ToBf16(pfloor(Bf16ToF32(a))); } template <> -EIGEN_STRONG_INLINE PacketXbf preverse(const PacketXbf& a) { - return __riscv_vreinterpret_v_i16m1_bf16m1(preverse(__riscv_vreinterpret_v_bf16m1_i16m1(a))); +EIGEN_STRONG_INLINE Packet1Xbf preverse(const Packet1Xbf& a) { + return __riscv_vreinterpret_v_i16m1_bf16m1(preverse(__riscv_vreinterpret_v_bf16m1_i16m1(a))); } template <> -EIGEN_STRONG_INLINE bfloat16 predux(const PacketXbf& a) { +EIGEN_STRONG_INLINE bfloat16 predux(const Packet1Xbf& a) { return static_cast(predux(Bf16ToF32(a))); } template <> -EIGEN_STRONG_INLINE bfloat16 predux_mul(const PacketXbf& a) { +EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet1Xbf& a) { return static_cast(predux_mul(Bf16ToF32(a))); } template <> -EIGEN_STRONG_INLINE bfloat16 predux_min(const PacketXbf& a) { +EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet1Xbf& a) { return static_cast(predux_min(Bf16ToF32(a))); } template <> -EIGEN_STRONG_INLINE bfloat16 predux_max(const PacketXbf& a) { +EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet1Xbf& a) { return static_cast(predux_max(Bf16ToF32(a))); } template -EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { - bfloat16 buffer[unpacket_traits::size * N]; +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + bfloat16 buffer[unpacket_traits::size * N]; int i = 0; for (i = 0; i < N; i++) { __riscv_vsse16(reinterpret_cast<__bf16*>(&buffer[i]), N * sizeof(bfloat16), kernel.packet[i], - unpacket_traits::size); + unpacket_traits::size); } for (i = 0; i < N; i++) { - kernel.packet[i] = __riscv_vle16_v_bf16m1(reinterpret_cast<__bf16*>(&buffer[i * unpacket_traits::size]), - unpacket_traits::size); + kernel.packet[i] = __riscv_vle16_v_bf16m1(reinterpret_cast<__bf16*>(&buffer[i * unpacket_traits::size]), + unpacket_traits::size); } } @@ -723,9 +723,9 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, - PacketXbf>::type + Packet1Xbf>::type predux_half(const Packet2Xbf& a) { - return padd(__riscv_vget_v_bf16m2_bf16m1(a, 0), __riscv_vget_v_bf16m2_bf16m1(a, 1)); + return padd(__riscv_vget_v_bf16m2_bf16m1(a, 0), __riscv_vget_v_bf16m2_bf16m1(a, 1)); } template <> -- GitLab From 91b1cf4be7e13719d37dc5984a325512ec6f4f52 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 15 Dec 2025 20:55:48 +0000 Subject: [PATCH 9/9] Change PacketXh to Packet1Xh. --- Eigen/src/Core/arch/RVV10/PacketMathFP16.h | 370 ++++++++++----------- 1 file changed, 185 insertions(+), 185 deletions(-) diff --git a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h index d3cbf933a..f3e5924c1 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h @@ -110,7 +110,7 @@ template <> struct unpacket_traits { typedef Eigen::half type; typedef Packet1Xh half; // Half not yet implemented - typedef PacketXs integer_packet; + typedef Packet1Xs integer_packet; typedef numext::uint8_t mask_t; enum { @@ -138,351 +138,351 @@ struct unpacket_traits { }; }; -/********************************* PacketXh ************************************/ +/********************************* Packet1Xh ************************************/ template <> -EIGEN_STRONG_INLINE PacketXh ptrue(const PacketXh& /*a*/) { - return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(0xffffu, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xh ptrue(const Packet1Xh& /*a*/) { + return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(0xffffu, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pzero(const PacketXh& /*a*/) { - return __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pzero(const Packet1Xh& /*a*/) { + return __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pabs(const PacketXh& a) { - return __riscv_vfabs_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pabs(const Packet1Xh& a) { + return __riscv_vfabs_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pset1(const Eigen::half& from) { - return __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(from), unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pset1(const Eigen::half& from) { + return __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(from), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pset1frombits(numext::uint16_t from) { - return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xh pset1frombits(numext::uint16_t from) { + return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh plset(const Eigen::half& a) { - PacketXh idx = - __riscv_vfcvt_f_x_v_f16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(unpacket_traits::size)), - unpacket_traits::size); - return __riscv_vfadd_vf_f16m1(idx, numext::bit_cast<_Float16>(a), unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh plset(const Eigen::half& a) { + Packet1Xh idx = + __riscv_vfcvt_f_x_v_f16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(unpacket_traits::size)), + unpacket_traits::size); + return __riscv_vfadd_vf_f16m1(idx, numext::bit_cast<_Float16>(a), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh padd(const PacketXh& a, const PacketXh& b) { - return __riscv_vfadd_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh padd(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfadd_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh psub(const PacketXh& a, const PacketXh& b) { - return __riscv_vfsub_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh psub(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfsub_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pnegate(const PacketXh& a) { - return __riscv_vfneg_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pnegate(const Packet1Xh& a) { + return __riscv_vfneg_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pconj(const PacketXh& a) { +EIGEN_STRONG_INLINE Packet1Xh pconj(const Packet1Xh& a) { return a; } template <> -EIGEN_STRONG_INLINE PacketXh pmul(const PacketXh& a, const PacketXh& b) { - return __riscv_vfmul_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmul(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfmul_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pdiv(const PacketXh& a, const PacketXh& b) { - return __riscv_vfdiv_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pdiv(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfdiv_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmadd(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfmadd_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmadd(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfmadd_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmsub(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfmsub_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmsub(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfmsub_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pnmadd(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfnmsub_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pnmadd(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfnmsub_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pnmsub(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfnmadd_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pnmsub(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfnmadd_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pmin(const Packet1Xh& a, const Packet1Xh& b) { const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); - PacketXh nans = - __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); - PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); - PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); - mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); + Packet1Xh nans = + __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); + PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); + PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); + mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); - return __riscv_vfmin_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmin_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, const PacketXh& b) { - return pmin(a, b); +EIGEN_STRONG_INLINE Packet1Xh pmin(const Packet1Xh& a, const Packet1Xh& b) { + return pmin(a, b); } template <> -EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, const PacketXh& b) { - return __riscv_vfmin_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmin(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfmin_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmax(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pmax(const Packet1Xh& a, const Packet1Xh& b) { const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); - PacketXh nans = - __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); - PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); - PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); - mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); + Packet1Xh nans = + __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); + PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); + PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); + mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); - return __riscv_vfmax_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmax_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmax(const PacketXh& a, const PacketXh& b) { - return pmax(a, b); +EIGEN_STRONG_INLINE Packet1Xh pmax(const Packet1Xh& a, const Packet1Xh& b) { + return pmax(a, b); } template <> -EIGEN_STRONG_INLINE PacketXh pmax(const PacketXh& a, const PacketXh& b) { - return __riscv_vfmax_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmax(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfmax_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_le(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmfle_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_le(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmfle_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_lt(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_lt(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_eq(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_eq(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_lt_or_nan(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmfge_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vfmerge_vfm_f16m1(ptrue(a), static_cast<_Float16>(0.0), mask, - unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_lt_or_nan(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmfge_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vfmerge_vfm_f16m1(ptrue(a), static_cast<_Float16>(0.0), mask, + unpacket_traits::size); } // Logical Operations are not supported for half, so reinterpret casts template <> -EIGEN_STRONG_INLINE PacketXh pand(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pand(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vand_vv_u16m1( - __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh por(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh por(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vor_vv_u16m1( - __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pxor(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pxor(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vxor_vv_u16m1( - __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pandnot(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pandnot(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vand_vv_u16m1( __riscv_vreinterpret_v_f16m1_u16m1(a), - __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size), - unpacket_traits::size)); + __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size), + unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pload(const Eigen::half* from) { +EIGEN_STRONG_INLINE Packet1Xh pload(const Eigen::half* from) { EIGEN_DEBUG_ALIGNED_LOAD return __riscv_vle16_v_f16m1(reinterpret_cast(from), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh ploadu(const Eigen::half* from) { +EIGEN_STRONG_INLINE Packet1Xh ploadu(const Eigen::half* from) { EIGEN_DEBUG_UNALIGNED_LOAD return __riscv_vle16_v_f16m1(reinterpret_cast(from), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh ploaddup(const Eigen::half* from) { - PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vand_vx_u16m1(idx, 0xfffeu, unpacket_traits::size); - return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh ploaddup(const Eigen::half* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vand_vx_u16m1(idx, 0xfffeu, unpacket_traits::size); + return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh ploadquad(const Eigen::half* from) { - PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, 0xfffcu, unpacket_traits::size), 1, - unpacket_traits::size); - return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh ploadquad(const Eigen::half* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, 0xfffcu, unpacket_traits::size), 1, + unpacket_traits::size); + return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const PacketXh& from) { +EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet1Xh& from) { EIGEN_DEBUG_ALIGNED_STORE __riscv_vse16_v_f16m1(reinterpret_cast<_Float16*>(to), from, - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const PacketXh& from) { +EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet1Xh& from) { EIGEN_DEBUG_UNALIGNED_STORE __riscv_vse16_v_f16m1(reinterpret_cast<_Float16*>(to), from, - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_DEVICE_FUNC inline PacketXh pgather(const Eigen::half* from, Index stride) { +EIGEN_DEVICE_FUNC inline Packet1Xh pgather(const Eigen::half* from, Index stride) { return __riscv_vlse16_v_f16m1(reinterpret_cast(from), stride * sizeof(Eigen::half), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_DEVICE_FUNC inline void pscatter(Eigen::half* to, const PacketXh& from, Index stride) { - __riscv_vsse16(reinterpret_cast<_Float16*>(to), stride * sizeof(Eigen::half), from, unpacket_traits::size); +EIGEN_DEVICE_FUNC inline void pscatter(Eigen::half* to, const Packet1Xh& from, Index stride) { + __riscv_vsse16(reinterpret_cast<_Float16*>(to), stride * sizeof(Eigen::half), from, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE Eigen::half pfirst(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet1Xh& a) { return static_cast(__riscv_vfmv_f_s_f16m1_f16(a)); } template <> -EIGEN_STRONG_INLINE PacketXh psqrt(const PacketXh& a) { - return __riscv_vfsqrt_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh psqrt(const Packet1Xh& a) { + return __riscv_vfsqrt_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh print(const PacketXh& a) { - const PacketXh limit = pset1(static_cast(1 << 10)); - const PacketXh abs_a = pabs(a); +EIGEN_STRONG_INLINE Packet1Xh print(const Packet1Xh& a) { + const Packet1Xh limit = pset1(static_cast(1 << 10)); + const Packet1Xh abs_a = pabs(a); - PacketMask16 mask = __riscv_vmfne_vv_f16m1_b16(a, a, unpacket_traits::size); - const PacketXh x = __riscv_vfadd_vv_f16m1_tumu(mask, a, a, a, unpacket_traits::size); - const PacketXh new_x = __riscv_vfcvt_f_x_v_f16m1(__riscv_vfcvt_x_f_v_i16m1(a, unpacket_traits::size), - unpacket_traits::size); + PacketMask16 mask = __riscv_vmfne_vv_f16m1_b16(a, a, unpacket_traits::size); + const Packet1Xh x = __riscv_vfadd_vv_f16m1_tumu(mask, a, a, a, unpacket_traits::size); + const Packet1Xh new_x = __riscv_vfcvt_f_x_v_f16m1(__riscv_vfcvt_x_f_v_i16m1(a, unpacket_traits::size), + unpacket_traits::size); - mask = __riscv_vmflt_vv_f16m1_b16(abs_a, limit, unpacket_traits::size); - PacketXh signed_x = __riscv_vfsgnj_vv_f16m1(new_x, x, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(x, signed_x, mask, unpacket_traits::size); + mask = __riscv_vmflt_vv_f16m1_b16(abs_a, limit, unpacket_traits::size); + Packet1Xh signed_x = __riscv_vfsgnj_vv_f16m1(new_x, x, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(x, signed_x, mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pfloor(const PacketXh& a) { - PacketXh tmp = print(a); +EIGEN_STRONG_INLINE Packet1Xh pfloor(const Packet1Xh& a) { + Packet1Xh tmp = print(a); // If greater, subtract one. - PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, tmp, unpacket_traits::size); - return __riscv_vfsub_vf_f16m1_tumu(mask, tmp, tmp, static_cast<_Float16>(1.0), unpacket_traits::size); + PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, tmp, unpacket_traits::size); + return __riscv_vfsub_vf_f16m1_tumu(mask, tmp, tmp, static_cast<_Float16>(1.0), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh preverse(const PacketXh& a) { - PacketXsu idx = __riscv_vrsub_vx_u16m1(__riscv_vid_v_u16m1(unpacket_traits::size), - unpacket_traits::size - 1, unpacket_traits::size); - return __riscv_vrgather_vv_f16m1(a, idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh preverse(const Packet1Xh& a) { + Packet1Xsu idx = __riscv_vrsub_vx_u16m1(__riscv_vid_v_u16m1(unpacket_traits::size), + unpacket_traits::size - 1, unpacket_traits::size); + return __riscv_vrgather_vv_f16m1(a, idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE Eigen::half predux(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux(const Packet1Xh& a) { return static_cast(__riscv_vfmv_f(__riscv_vfredusum_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size), - unpacket_traits::size))); + a, __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size), + unpacket_traits::size))); } template <> -EIGEN_STRONG_INLINE Eigen::half predux_mul(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet1Xh& a) { // Multiply the vector by its reverse - PacketXh prod = __riscv_vfmul_vv_f16m1(preverse(a), a, unpacket_traits::size); - PacketXh half_prod; + Packet1Xh prod = __riscv_vfmul_vv_f16m1(preverse(a), a, unpacket_traits::size); + Packet1Xh half_prod; if (EIGEN_RISCV64_RVV_VL >= 1024) { - half_prod = __riscv_vslidedown_vx_f16m1(prod, 16, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 16, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); } if (EIGEN_RISCV64_RVV_VL >= 512) { - half_prod = __riscv_vslidedown_vx_f16m1(prod, 8, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 8, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); } if (EIGEN_RISCV64_RVV_VL >= 256) { - half_prod = __riscv_vslidedown_vx_f16m1(prod, 4, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 4, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); } // Last reduction - half_prod = __riscv_vslidedown_vx_f16m1(prod, 2, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 2, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); - half_prod = __riscv_vslidedown_vx_f16m1(prod, 1, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 1, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); // The reduction is done to the first element. return pfirst(prod); } template <> -EIGEN_STRONG_INLINE Eigen::half predux_min(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet1Xh& a) { const Eigen::half max = (std::numeric_limits::max)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size), - unpacket_traits::size))); + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size), + unpacket_traits::size))); } template <> -EIGEN_STRONG_INLINE Eigen::half predux_max(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet1Xh& a) { const Eigen::half min = (std::numeric_limits::min)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size), - unpacket_traits::size))); + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size), + unpacket_traits::size))); } template -EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { - Eigen::half buffer[unpacket_traits::size * N]; +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + Eigen::half buffer[unpacket_traits::size * N]; int i = 0; for (i = 0; i < N; i++) { __riscv_vsse16(reinterpret_cast<_Float16*>(&buffer[i]), N * sizeof(Eigen::half), kernel.packet[i], - unpacket_traits::size); + unpacket_traits::size); } for (i = 0; i < N; i++) { - kernel.packet[i] = __riscv_vle16_v_f16m1(reinterpret_cast<_Float16*>(&buffer[i * unpacket_traits::size]), - unpacket_traits::size); + kernel.packet[i] = __riscv_vle16_v_f16m1(reinterpret_cast<_Float16*>(&buffer[i * unpacket_traits::size]), + unpacket_traits::size); } } -EIGEN_STRONG_INLINE Packet2Xf half2float(const PacketXh& a) { +EIGEN_STRONG_INLINE Packet2Xf half2float(const Packet1Xh& a) { return __riscv_vfwcvt_f_f_v_f32m2(a, unpacket_traits::size); } -EIGEN_STRONG_INLINE PacketXh float2half(const Packet2Xf& a) { - return __riscv_vfncvt_f_f_w_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh float2half(const Packet2Xf& a) { + return __riscv_vfncvt_f_f_w_f16m1(a, unpacket_traits::size); } /********************************* Packet2Xh ************************************/ @@ -774,8 +774,8 @@ EIGEN_STRONG_INLINE Eigen::half predux(const Packet2Xh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet2Xh& a) { - return predux_mul(__riscv_vfmul_vv_f16m1(__riscv_vget_v_f16m2_f16m1(a, 0), __riscv_vget_v_f16m2_f16m1(a, 1), - unpacket_traits::size)); + return predux_mul(__riscv_vfmul_vv_f16m1(__riscv_vget_v_f16m2_f16m1(a, 0), __riscv_vget_v_f16m2_f16m1(a, 1), + unpacket_traits::size)); } template <> @@ -822,22 +822,22 @@ EIGEN_STRONG_INLINE Packet2Xh float2half(const Packet4Xf& a) { template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, - PacketXh>::type + Packet1Xh>::type predux_half(const Packet2Xh& a) { return __riscv_vfadd_vv_f16m1(__riscv_vget_v_f16m2_f16m1(a, 0), __riscv_vget_v_f16m2_f16m1(a, 1), - unpacket_traits::size); + unpacket_traits::size); } -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, pcos) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, pexp) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, pexpm1) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, plog) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, plog1p) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, plog2) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, preciprocal) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, prsqrt) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, psin) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, ptanh) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, pcos) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, pexp) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, pexpm1) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, plog) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, plog1p) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, plog2) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, preciprocal) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, prsqrt) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, psin) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, ptanh) F16_PACKET_FUNCTION(Packet4Xf, Packet2Xh, pcos) F16_PACKET_FUNCTION(Packet4Xf, Packet2Xh, pexp) @@ -863,22 +863,22 @@ struct type_casting_traits { }; template <> -EIGEN_STRONG_INLINE PacketXh pcast(const PacketXs& a) { - return __riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcast(const Packet1Xs& a) { + return __riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXs pcast(const PacketXh& a) { - return __riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xs pcast(const Packet1Xh& a) { + return __riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh preinterpret(const PacketXs& a) { +EIGEN_STRONG_INLINE Packet1Xh preinterpret(const Packet1Xs& a) { return __riscv_vreinterpret_v_i16m1_f16m1(a); } template <> -EIGEN_STRONG_INLINE PacketXs preinterpret(const PacketXh& a) { +EIGEN_STRONG_INLINE Packet1Xs preinterpret(const Packet1Xh& a) { return __riscv_vreinterpret_v_f16m1_i16m1(a); } @@ -903,29 +903,29 @@ EIGEN_STRONG_INLINE Packet2Xs preinterpret(const Packet2Xh } template <> -EIGEN_STRONG_INLINE Packet4Xs pcast(const PacketXh& a, const PacketXh& b, const PacketXh& c, - const PacketXh& d) { - return __riscv_vcreate_v_i16m1_i16m4(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(c, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(d, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet4Xs pcast(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c, + const Packet1Xh& d) { + return __riscv_vcreate_v_i16m1_i16m4(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(c, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(d, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE Packet2Xh pcast(const PacketXs& a, const PacketXs& b) { - return __riscv_vcreate_v_f16m1_f16m2(__riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size), - __riscv_vfcvt_f_x_v_f16m1(b, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet2Xh pcast(const Packet1Xs& a, const Packet1Xs& b) { + return __riscv_vcreate_v_f16m1_f16m2(__riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size), + __riscv_vfcvt_f_x_v_f16m1(b, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE Packet2Xh pcast(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet2Xh pcast(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vcreate_v_f16m1_f16m2(a, b); } template <> -EIGEN_STRONG_INLINE Packet2Xs pcast(const PacketXh& a, const PacketXh& b) { - return __riscv_vcreate_v_i16m1_i16m2(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet2Xs pcast(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vcreate_v_i16m1_i16m2(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size)); } } // namespace internal -- GitLab