diff --git a/Eigen/Core b/Eigen/Core index 9f81658b0b57e6c7117b6f02f30455af14f5abf3..86608f3c8a0253cc2c0780a403f10e21d589be54 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -288,6 +288,9 @@ using std::ptrdiff_t; #if defined EIGEN_VECTORIZE_RVV10FP16 #include "src/Core/arch/RVV10/PacketMathFP16.h" #endif +#if defined EIGEN_VECTORIZE_RVV10BF16 +#include "src/Core/arch/RVV10/PacketMathBF16.h" +#endif #elif defined EIGEN_VECTORIZE_ZVECTOR #include "src/Core/arch/ZVector/PacketMath.h" #include "src/Core/arch/ZVector/MathFunctions.h" diff --git a/Eigen/src/Core/arch/RVV10/PacketMathBF16.h b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h new file mode 100644 index 0000000000000000000000000000000000000000..80502593cfdc0d5404359cbc7aaf312338c65ad1 --- /dev/null +++ b/Eigen/src/Core/arch/RVV10/PacketMathBF16.h @@ -0,0 +1,754 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2025 Chip Kerchner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_BF16_RVV10_H +#define EIGEN_PACKET_MATH_BF16_RVV10_H + +// IWYU pragma: private +#include "../../InternalHeaderCheck.h" + +namespace Eigen { +namespace internal { + +typedef eigen_packet_wrapper + Packet1Xbf; +typedef eigen_packet_wrapper + Packet2Xbf; + +#if EIGEN_RISCV64_DEFAULT_LMUL == 1 +typedef Packet1Xbf PacketXbf; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet1Xbf type; + typedef Packet1Xbf half; + + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = rvv_packet_size_selector::size, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasReduxp = 0, + HasSign = 0, + + HasCmp = 1, + HasDiv = 1, + HasRound = 0, + + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 0, + HasSqrt = 1, + HasTanh = 0, + HasErf = 0 + }; +}; + +#else +typedef Packet2Xbf PacketXbf; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet2Xbf type; + typedef Packet1Xbf half; + + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = rvv_packet_size_selector::size, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasReduxp = 0, + HasSign = 0, + + HasCmp = 1, + HasDiv = 1, + HasRound = 0, + + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 0, + HasSqrt = 1, + HasTanh = 0, + HasErf = 0 + }; +}; +#endif + +template <> +struct unpacket_traits : default_unpacket_traits { + typedef bfloat16 type; + typedef Packet1Xbf half; // Half not yet implemented + typedef Packet1Xs integer_packet; + typedef numext::uint8_t mask_t; + + enum { + size = rvv_packet_size_selector::size, + alignment = rvv_packet_alignment_selector::alignment, + vectorizable = true + }; +}; + +template <> +struct unpacket_traits : default_unpacket_traits { + typedef bfloat16 type; + typedef Packet1Xbf half; + typedef Packet2Xs integer_packet; + typedef numext::uint8_t mask_t; + + enum { + size = rvv_packet_size_selector::size, + alignment = rvv_packet_alignment_selector::alignment, + vectorizable = true + }; +}; + +/********************************* Packet1Xbf ************************************/ + +EIGEN_STRONG_INLINE Packet2Xf Bf16ToF32(const Packet1Xbf& a) { + return __riscv_vfwcvtbf16_f_f_v_f32m2(a, unpacket_traits::size); +} + +EIGEN_STRONG_INLINE Packet1Xbf F32ToBf16(const Packet2Xf& a) { + return __riscv_vfncvtbf16_f_f_w_bf16m1(a, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf ptrue(const Packet1Xbf& /*a*/) { + return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(static_cast(0xffffu), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pzero(const Packet1Xbf& /*a*/) { + return __riscv_vreinterpret_bf16m1( + __riscv_vmv_v_x_i16m1(numext::bit_cast(static_cast<__bf16>(0.0)), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pabs(const Packet1Xbf& a) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vx_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x7fffu), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pset1(const bfloat16& from) { + return __riscv_vreinterpret_bf16m1( + __riscv_vmv_v_x_i16m1(numext::bit_cast(from), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pset1frombits(numext::uint16_t from) { + return __riscv_vreinterpret_bf16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf plset(const bfloat16& a) { + return F32ToBf16(plset(static_cast(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf padd(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf psub(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pnegate(const Packet1Xbf& a) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vxor_vx_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), static_cast(0x8000u), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf psignbit(const Packet1Xbf& a) { + return __riscv_vreinterpret_v_i16m1_bf16m1(__riscv_vsra_vx_i16m1( + __riscv_vreinterpret_v_bf16m1_i16m1(a), 15, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pconj(const Packet1Xbf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmul(const Packet1Xbf& a, const Packet1Xbf& b) { + Packet2Xf c; + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(pzero(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pdiv(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmadd(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmsub(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(pnegate(c)), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pnmadd(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), pnegate(a), b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pnmsub(const Packet1Xbf& a, const Packet1Xbf& b, const Packet1Xbf& c) { + return pnegate(F32ToBf16(__riscv_vfwmaccbf16_vv_f32m2(Bf16ToF32(c), a, b, unpacket_traits::size))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmin(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmin(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmin(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmax(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmax(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pmax(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pcmp_le(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pcmp_lt(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pcmp_eq(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pcmp_lt_or_nan(const Packet1Xbf& a, const Packet1Xbf& b) { + return F32ToBf16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +// Logical Operations are not supported for bfloat16, so reinterpret casts +template <> +EIGEN_STRONG_INLINE Packet1Xbf pand(const Packet1Xbf& a, const Packet1Xbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf por(const Packet1Xbf& a, const Packet1Xbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vor_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pxor(const Packet1Xbf& a, const Packet1Xbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vxor_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), __riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pandnot(const Packet1Xbf& a, const Packet1Xbf& b) { + return __riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vand_vv_u16m1( + __riscv_vreinterpret_v_bf16m1_u16m1(a), + __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_bf16m1_u16m1(b), unpacket_traits::size), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pload(const bfloat16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return __riscv_vle16_v_bf16m1(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf ploadu(const bfloat16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return __riscv_vle16_v_bf16m1(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf ploaddup(const bfloat16* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vand_vx_u16m1(idx, static_cast(0xfffeu), unpacket_traits::size); + return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf ploadquad(const bfloat16* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, static_cast(0xfffcu), unpacket_traits::size), 1, + unpacket_traits::size); + return __riscv_vloxei16_v_bf16m1(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet1Xbf& from) { + EIGEN_DEBUG_ALIGNED_STORE __riscv_vse16_v_bf16m1(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet1Xbf& from) { + EIGEN_DEBUG_UNALIGNED_STORE __riscv_vse16_v_bf16m1(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet1Xbf pgather(const bfloat16* from, Index stride) { + return __riscv_vlse16_v_bf16m1(reinterpret_cast(from), stride * sizeof(bfloat16), + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(bfloat16* to, const Packet1Xbf& from, Index stride) { + __riscv_vsse16(reinterpret_cast<__bf16*>(to), stride * sizeof(bfloat16), from, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet1Xbf& a) { + return numext::bit_cast(__riscv_vmv_x_s_i16m1_i16(__riscv_vreinterpret_v_bf16m1_i16m1(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf psqrt(const Packet1Xbf& a) { + return F32ToBf16(psqrt(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf print(const Packet1Xbf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pfloor(const Packet1Xbf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf preverse(const Packet1Xbf& a) { + return __riscv_vreinterpret_v_i16m1_bf16m1(preverse(__riscv_vreinterpret_v_bf16m1_i16m1(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux(const Packet1Xbf& a) { + return static_cast(predux(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet1Xbf& a) { + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet1Xbf& a) { + return static_cast(predux_min(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet1Xbf& a) { + return static_cast(predux_max(Bf16ToF32(a))); +} + +template +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + bfloat16 buffer[unpacket_traits::size * N]; + int i = 0; + + for (i = 0; i < N; i++) { + __riscv_vsse16(reinterpret_cast<__bf16*>(&buffer[i]), N * sizeof(bfloat16), kernel.packet[i], + unpacket_traits::size); + } + + for (i = 0; i < N; i++) { + kernel.packet[i] = __riscv_vle16_v_bf16m1(reinterpret_cast<__bf16*>(&buffer[i * unpacket_traits::size]), + unpacket_traits::size); + } +} + +/********************************* Packet2Xbf ************************************/ + +EIGEN_STRONG_INLINE Packet4Xf Bf16ToF32(const Packet2Xbf& a) { + return __riscv_vfwcvtbf16_f_f_v_f32m4(a, unpacket_traits::size); +} + +EIGEN_STRONG_INLINE Packet2Xbf F32ToBf16(const Packet4Xf& a) { + return __riscv_vfncvtbf16_f_f_w_bf16m2(a, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ptrue(const Packet2Xbf& /*a*/) { + return __riscv_vreinterpret_bf16m2(__riscv_vmv_v_x_u16m2(static_cast(0xffffu), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pzero(const Packet2Xbf& /*a*/) { + return __riscv_vreinterpret_bf16m2( + __riscv_vmv_v_x_i16m2(numext::bit_cast(static_cast<__bf16>(0.0)), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pabs(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vand_vx_u16m2( + __riscv_vreinterpret_v_bf16m2_u16m2(a), static_cast(0x7fffu), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pset1(const bfloat16& from) { + return __riscv_vreinterpret_bf16m2( + __riscv_vmv_v_x_i16m2(numext::bit_cast(from), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pset1frombits(numext::uint16_t from) { + return __riscv_vreinterpret_bf16m2(__riscv_vmv_v_x_u16m2(from, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf plset(const bfloat16& a) { + return F32ToBf16(plset(static_cast(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf padd(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf psub(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pnegate(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vxor_vx_u16m2( + __riscv_vreinterpret_v_bf16m2_u16m2(a), static_cast(0x8000u), unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf psignbit(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_i16m2_bf16m2(__riscv_vsra_vx_i16m2( + __riscv_vreinterpret_v_bf16m2_i16m2(a), 15, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pconj(const Packet2Xbf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmul(const Packet2Xbf& a, const Packet2Xbf& b) { + Packet4Xf c; + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(pzero(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pdiv(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmadd(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(c), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmsub(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(pnegate(c)), a, b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pnmadd(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(c), pnegate(a), b, unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pnmsub(const Packet2Xbf& a, const Packet2Xbf& b, const Packet2Xbf& c) { + return pnegate(F32ToBf16(__riscv_vfwmaccbf16_vv_f32m4(Bf16ToF32(c), a, b, unpacket_traits::size))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmin(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmin(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmin(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmax(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmax(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pmax(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_le(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_lt(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_eq(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcmp_lt_or_nan(const Packet2Xbf& a, const Packet2Xbf& b) { + return F32ToBf16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +// Logical Operations are not supported for bflaot16, so reinterpret casts +template <> +EIGEN_STRONG_INLINE Packet2Xbf pand(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vand_vv_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vreinterpret_v_bf16m2_u16m2(b), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf por(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vor_vv_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vreinterpret_v_bf16m2_u16m2(b), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pxor(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vxor_vv_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vreinterpret_v_bf16m2_u16m2(b), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pandnot(const Packet2Xbf& a, const Packet2Xbf& b) { + return __riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vand_vv_u16m2( + __riscv_vreinterpret_v_bf16m2_u16m2(a), + __riscv_vnot_v_u16m2(__riscv_vreinterpret_v_bf16m2_u16m2(b), unpacket_traits::size), + unpacket_traits::size)); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pload(const bfloat16* from) { + EIGEN_DEBUG_ALIGNED_LOAD return __riscv_vle16_v_bf16m2(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ploadu(const bfloat16* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return __riscv_vle16_v_bf16m2(reinterpret_cast(from), + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ploaddup(const bfloat16* from) { + Packet2Xsu idx = __riscv_vid_v_u16m2(unpacket_traits::size); + idx = __riscv_vand_vx_u16m2(idx, static_cast(0xfffeu), unpacket_traits::size); + return __riscv_vloxei16_v_bf16m2(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf ploadquad(const bfloat16* from) { + Packet2Xsu idx = __riscv_vid_v_u16m2(unpacket_traits::size); + idx = __riscv_vsrl_vx_u16m2(__riscv_vand_vx_u16m2(idx, static_cast(0xfffcu), unpacket_traits::size), 1, + unpacket_traits::size); + return __riscv_vloxei16_v_bf16m2(reinterpret_cast(from), idx, unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet2Xbf& from) { + EIGEN_DEBUG_ALIGNED_STORE __riscv_vse16_v_bf16m2(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet2Xbf& from) { + EIGEN_DEBUG_UNALIGNED_STORE __riscv_vse16_v_bf16m2(reinterpret_cast<__bf16*>(to), from, + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet2Xbf pgather(const bfloat16* from, Index stride) { + return __riscv_vlse16_v_bf16m2(reinterpret_cast(from), stride * sizeof(bfloat16), + unpacket_traits::size); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(bfloat16* to, const Packet2Xbf& from, + Index stride) { + __riscv_vsse16(reinterpret_cast<__bf16*>(to), stride * sizeof(bfloat16), from, + unpacket_traits::size); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet2Xbf& a) { + return static_cast(__riscv_vmv_x_s_i16m2_i16(__riscv_vreinterpret_v_bf16m2_i16m2(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf psqrt(const Packet2Xbf& a) { + return F32ToBf16(psqrt(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf print(const Packet2Xbf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pfloor(const Packet2Xbf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf preverse(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_i16m2_bf16m2(preverse(__riscv_vreinterpret_v_bf16m2_i16m2(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux(const Packet2Xbf& a) { + return static_cast(predux(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet2Xbf& a) { + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet2Xbf& a) { + return static_cast(predux_min(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet2Xbf& a) { + return static_cast(predux_max(Bf16ToF32(a))); +} + +template +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + bfloat16 buffer[unpacket_traits::size * N]; + int i = 0; + + for (i = 0; i < N; i++) { + __riscv_vsse16(reinterpret_cast<__bf16*>(&buffer[i]), N * sizeof(bfloat16), kernel.packet[i], + unpacket_traits::size); + } + + for (i = 0; i < N; i++) { + kernel.packet[i] = + __riscv_vle16_v_bf16m2(reinterpret_cast<__bf16*>(&buffer[i * unpacket_traits::size]), + unpacket_traits::size); + } +} + +template +EIGEN_STRONG_INLINE +typename std::enable_if::value && (unpacket_traits::size % 8) == 0, + Packet1Xbf>::type +predux_half(const Packet2Xbf& a) { + return padd(__riscv_vget_v_bf16m2_bf16m1(a, 0), __riscv_vget_v_bf16m2_bf16m1(a, 1)); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xbf pcast(const Packet1Xs& a) { + return __riscv_vreinterpret_v_i16m1_bf16m1(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xbf pcast(const Packet2Xs& a) { + return __riscv_vreinterpret_v_i16m2_bf16m2(a); +} + +template <> +EIGEN_STRONG_INLINE Packet1Xs pcast(const Packet1Xbf& a) { + return __riscv_vreinterpret_v_bf16m1_i16m1(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2Xs pcast(const Packet2Xbf& a) { + return __riscv_vreinterpret_v_bf16m2_i16m2(a); +} + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_PACKET_MATH_BF16_RVV10_H diff --git a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h index d3cbf933a219325d1035136acac44f11414d462e..f3e5924c136ec94284024519c44a33054a90ab4d 100644 --- a/Eigen/src/Core/arch/RVV10/PacketMathFP16.h +++ b/Eigen/src/Core/arch/RVV10/PacketMathFP16.h @@ -110,7 +110,7 @@ template <> struct unpacket_traits { typedef Eigen::half type; typedef Packet1Xh half; // Half not yet implemented - typedef PacketXs integer_packet; + typedef Packet1Xs integer_packet; typedef numext::uint8_t mask_t; enum { @@ -138,351 +138,351 @@ struct unpacket_traits { }; }; -/********************************* PacketXh ************************************/ +/********************************* Packet1Xh ************************************/ template <> -EIGEN_STRONG_INLINE PacketXh ptrue(const PacketXh& /*a*/) { - return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(0xffffu, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xh ptrue(const Packet1Xh& /*a*/) { + return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(0xffffu, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pzero(const PacketXh& /*a*/) { - return __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pzero(const Packet1Xh& /*a*/) { + return __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pabs(const PacketXh& a) { - return __riscv_vfabs_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pabs(const Packet1Xh& a) { + return __riscv_vfabs_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pset1(const Eigen::half& from) { - return __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(from), unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pset1(const Eigen::half& from) { + return __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(from), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pset1frombits(numext::uint16_t from) { - return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet1Xh pset1frombits(numext::uint16_t from) { + return __riscv_vreinterpret_f16m1(__riscv_vmv_v_x_u16m1(from, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh plset(const Eigen::half& a) { - PacketXh idx = - __riscv_vfcvt_f_x_v_f16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(unpacket_traits::size)), - unpacket_traits::size); - return __riscv_vfadd_vf_f16m1(idx, numext::bit_cast<_Float16>(a), unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh plset(const Eigen::half& a) { + Packet1Xh idx = + __riscv_vfcvt_f_x_v_f16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(unpacket_traits::size)), + unpacket_traits::size); + return __riscv_vfadd_vf_f16m1(idx, numext::bit_cast<_Float16>(a), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh padd(const PacketXh& a, const PacketXh& b) { - return __riscv_vfadd_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh padd(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfadd_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh psub(const PacketXh& a, const PacketXh& b) { - return __riscv_vfsub_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh psub(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfsub_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pnegate(const PacketXh& a) { - return __riscv_vfneg_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pnegate(const Packet1Xh& a) { + return __riscv_vfneg_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pconj(const PacketXh& a) { +EIGEN_STRONG_INLINE Packet1Xh pconj(const Packet1Xh& a) { return a; } template <> -EIGEN_STRONG_INLINE PacketXh pmul(const PacketXh& a, const PacketXh& b) { - return __riscv_vfmul_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmul(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfmul_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pdiv(const PacketXh& a, const PacketXh& b) { - return __riscv_vfdiv_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pdiv(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfdiv_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmadd(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfmadd_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmadd(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfmadd_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmsub(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfmsub_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmsub(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfmsub_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pnmadd(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfnmsub_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pnmadd(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfnmsub_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pnmsub(const PacketXh& a, const PacketXh& b, const PacketXh& c) { - return __riscv_vfnmadd_vv_f16m1(a, b, c, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pnmsub(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c) { + return __riscv_vfnmadd_vv_f16m1(a, b, c, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pmin(const Packet1Xh& a, const Packet1Xh& b) { const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); - PacketXh nans = - __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); - PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); - PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); - mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); + Packet1Xh nans = + __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); + PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); + PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); + mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); - return __riscv_vfmin_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmin_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, const PacketXh& b) { - return pmin(a, b); +EIGEN_STRONG_INLINE Packet1Xh pmin(const Packet1Xh& a, const Packet1Xh& b) { + return pmin(a, b); } template <> -EIGEN_STRONG_INLINE PacketXh pmin(const PacketXh& a, const PacketXh& b) { - return __riscv_vfmin_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmin(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfmin_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmax(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pmax(const Packet1Xh& a, const Packet1Xh& b) { const Eigen::half nan = (std::numeric_limits::quiet_NaN)(); - PacketXh nans = - __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); - PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); - PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); - mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); + Packet1Xh nans = + __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(nan), unpacket_traits::size); + PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, a, unpacket_traits::size); + PacketMask16 mask2 = __riscv_vmfeq_vv_f16m1_b16(b, b, unpacket_traits::size); + mask = __riscv_vmand_mm_b16(mask, mask2, unpacket_traits::size); - return __riscv_vfmax_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); + return __riscv_vfmax_vv_f16m1_tumu(mask, nans, a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pmax(const PacketXh& a, const PacketXh& b) { - return pmax(a, b); +EIGEN_STRONG_INLINE Packet1Xh pmax(const Packet1Xh& a, const Packet1Xh& b) { + return pmax(a, b); } template <> -EIGEN_STRONG_INLINE PacketXh pmax(const PacketXh& a, const PacketXh& b) { - return __riscv_vfmax_vv_f16m1(a, b, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pmax(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vfmax_vv_f16m1(a, b, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_le(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmfle_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_le(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmfle_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_lt(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_lt(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_eq(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_eq(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmfeq_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(pzero(a), ptrue(a), mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pcmp_lt_or_nan(const PacketXh& a, const PacketXh& b) { - PacketMask16 mask = __riscv_vmfge_vv_f16m1_b16(a, b, unpacket_traits::size); - return __riscv_vfmerge_vfm_f16m1(ptrue(a), static_cast<_Float16>(0.0), mask, - unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcmp_lt_or_nan(const Packet1Xh& a, const Packet1Xh& b) { + PacketMask16 mask = __riscv_vmfge_vv_f16m1_b16(a, b, unpacket_traits::size); + return __riscv_vfmerge_vfm_f16m1(ptrue(a), static_cast<_Float16>(0.0), mask, + unpacket_traits::size); } // Logical Operations are not supported for half, so reinterpret casts template <> -EIGEN_STRONG_INLINE PacketXh pand(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pand(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vand_vv_u16m1( - __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh por(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh por(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vor_vv_u16m1( - __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pxor(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pxor(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vxor_vv_u16m1( - __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); + __riscv_vreinterpret_v_f16m1_u16m1(a), __riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pandnot(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet1Xh pandnot(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vreinterpret_v_u16m1_f16m1(__riscv_vand_vv_u16m1( __riscv_vreinterpret_v_f16m1_u16m1(a), - __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size), - unpacket_traits::size)); + __riscv_vnot_v_u16m1(__riscv_vreinterpret_v_f16m1_u16m1(b), unpacket_traits::size), + unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE PacketXh pload(const Eigen::half* from) { +EIGEN_STRONG_INLINE Packet1Xh pload(const Eigen::half* from) { EIGEN_DEBUG_ALIGNED_LOAD return __riscv_vle16_v_f16m1(reinterpret_cast(from), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh ploadu(const Eigen::half* from) { +EIGEN_STRONG_INLINE Packet1Xh ploadu(const Eigen::half* from) { EIGEN_DEBUG_UNALIGNED_LOAD return __riscv_vle16_v_f16m1(reinterpret_cast(from), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh ploaddup(const Eigen::half* from) { - PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vand_vx_u16m1(idx, 0xfffeu, unpacket_traits::size); - return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh ploaddup(const Eigen::half* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vand_vx_u16m1(idx, 0xfffeu, unpacket_traits::size); + return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh ploadquad(const Eigen::half* from) { - PacketXsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); - idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, 0xfffcu, unpacket_traits::size), 1, - unpacket_traits::size); - return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh ploadquad(const Eigen::half* from) { + Packet1Xsu idx = __riscv_vid_v_u16m1(unpacket_traits::size); + idx = __riscv_vsrl_vx_u16m1(__riscv_vand_vx_u16m1(idx, 0xfffcu, unpacket_traits::size), 1, + unpacket_traits::size); + return __riscv_vloxei16_v_f16m1(reinterpret_cast(from), idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const PacketXh& from) { +EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet1Xh& from) { EIGEN_DEBUG_ALIGNED_STORE __riscv_vse16_v_f16m1(reinterpret_cast<_Float16*>(to), from, - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const PacketXh& from) { +EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet1Xh& from) { EIGEN_DEBUG_UNALIGNED_STORE __riscv_vse16_v_f16m1(reinterpret_cast<_Float16*>(to), from, - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_DEVICE_FUNC inline PacketXh pgather(const Eigen::half* from, Index stride) { +EIGEN_DEVICE_FUNC inline Packet1Xh pgather(const Eigen::half* from, Index stride) { return __riscv_vlse16_v_f16m1(reinterpret_cast(from), stride * sizeof(Eigen::half), - unpacket_traits::size); + unpacket_traits::size); } template <> -EIGEN_DEVICE_FUNC inline void pscatter(Eigen::half* to, const PacketXh& from, Index stride) { - __riscv_vsse16(reinterpret_cast<_Float16*>(to), stride * sizeof(Eigen::half), from, unpacket_traits::size); +EIGEN_DEVICE_FUNC inline void pscatter(Eigen::half* to, const Packet1Xh& from, Index stride) { + __riscv_vsse16(reinterpret_cast<_Float16*>(to), stride * sizeof(Eigen::half), from, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE Eigen::half pfirst(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet1Xh& a) { return static_cast(__riscv_vfmv_f_s_f16m1_f16(a)); } template <> -EIGEN_STRONG_INLINE PacketXh psqrt(const PacketXh& a) { - return __riscv_vfsqrt_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh psqrt(const Packet1Xh& a) { + return __riscv_vfsqrt_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh print(const PacketXh& a) { - const PacketXh limit = pset1(static_cast(1 << 10)); - const PacketXh abs_a = pabs(a); +EIGEN_STRONG_INLINE Packet1Xh print(const Packet1Xh& a) { + const Packet1Xh limit = pset1(static_cast(1 << 10)); + const Packet1Xh abs_a = pabs(a); - PacketMask16 mask = __riscv_vmfne_vv_f16m1_b16(a, a, unpacket_traits::size); - const PacketXh x = __riscv_vfadd_vv_f16m1_tumu(mask, a, a, a, unpacket_traits::size); - const PacketXh new_x = __riscv_vfcvt_f_x_v_f16m1(__riscv_vfcvt_x_f_v_i16m1(a, unpacket_traits::size), - unpacket_traits::size); + PacketMask16 mask = __riscv_vmfne_vv_f16m1_b16(a, a, unpacket_traits::size); + const Packet1Xh x = __riscv_vfadd_vv_f16m1_tumu(mask, a, a, a, unpacket_traits::size); + const Packet1Xh new_x = __riscv_vfcvt_f_x_v_f16m1(__riscv_vfcvt_x_f_v_i16m1(a, unpacket_traits::size), + unpacket_traits::size); - mask = __riscv_vmflt_vv_f16m1_b16(abs_a, limit, unpacket_traits::size); - PacketXh signed_x = __riscv_vfsgnj_vv_f16m1(new_x, x, unpacket_traits::size); - return __riscv_vmerge_vvm_f16m1(x, signed_x, mask, unpacket_traits::size); + mask = __riscv_vmflt_vv_f16m1_b16(abs_a, limit, unpacket_traits::size); + Packet1Xh signed_x = __riscv_vfsgnj_vv_f16m1(new_x, x, unpacket_traits::size); + return __riscv_vmerge_vvm_f16m1(x, signed_x, mask, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh pfloor(const PacketXh& a) { - PacketXh tmp = print(a); +EIGEN_STRONG_INLINE Packet1Xh pfloor(const Packet1Xh& a) { + Packet1Xh tmp = print(a); // If greater, subtract one. - PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, tmp, unpacket_traits::size); - return __riscv_vfsub_vf_f16m1_tumu(mask, tmp, tmp, static_cast<_Float16>(1.0), unpacket_traits::size); + PacketMask16 mask = __riscv_vmflt_vv_f16m1_b16(a, tmp, unpacket_traits::size); + return __riscv_vfsub_vf_f16m1_tumu(mask, tmp, tmp, static_cast<_Float16>(1.0), unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh preverse(const PacketXh& a) { - PacketXsu idx = __riscv_vrsub_vx_u16m1(__riscv_vid_v_u16m1(unpacket_traits::size), - unpacket_traits::size - 1, unpacket_traits::size); - return __riscv_vrgather_vv_f16m1(a, idx, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh preverse(const Packet1Xh& a) { + Packet1Xsu idx = __riscv_vrsub_vx_u16m1(__riscv_vid_v_u16m1(unpacket_traits::size), + unpacket_traits::size - 1, unpacket_traits::size); + return __riscv_vrgather_vv_f16m1(a, idx, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE Eigen::half predux(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux(const Packet1Xh& a) { return static_cast(__riscv_vfmv_f(__riscv_vfredusum_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size), - unpacket_traits::size))); + a, __riscv_vfmv_v_f_f16m1(static_cast<_Float16>(0.0), unpacket_traits::size), + unpacket_traits::size))); } template <> -EIGEN_STRONG_INLINE Eigen::half predux_mul(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet1Xh& a) { // Multiply the vector by its reverse - PacketXh prod = __riscv_vfmul_vv_f16m1(preverse(a), a, unpacket_traits::size); - PacketXh half_prod; + Packet1Xh prod = __riscv_vfmul_vv_f16m1(preverse(a), a, unpacket_traits::size); + Packet1Xh half_prod; if (EIGEN_RISCV64_RVV_VL >= 1024) { - half_prod = __riscv_vslidedown_vx_f16m1(prod, 16, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 16, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); } if (EIGEN_RISCV64_RVV_VL >= 512) { - half_prod = __riscv_vslidedown_vx_f16m1(prod, 8, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 8, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); } if (EIGEN_RISCV64_RVV_VL >= 256) { - half_prod = __riscv_vslidedown_vx_f16m1(prod, 4, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 4, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); } // Last reduction - half_prod = __riscv_vslidedown_vx_f16m1(prod, 2, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 2, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); - half_prod = __riscv_vslidedown_vx_f16m1(prod, 1, unpacket_traits::size); - prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); + half_prod = __riscv_vslidedown_vx_f16m1(prod, 1, unpacket_traits::size); + prod = __riscv_vfmul_vv_f16m1(prod, half_prod, unpacket_traits::size); // The reduction is done to the first element. return pfirst(prod); } template <> -EIGEN_STRONG_INLINE Eigen::half predux_min(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet1Xh& a) { const Eigen::half max = (std::numeric_limits::max)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmin_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size), - unpacket_traits::size))); + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(max), unpacket_traits::size), + unpacket_traits::size))); } template <> -EIGEN_STRONG_INLINE Eigen::half predux_max(const PacketXh& a) { +EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet1Xh& a) { const Eigen::half min = (std::numeric_limits::min)(); return static_cast(__riscv_vfmv_f(__riscv_vfredmax_vs_f16m1_f16m1( - a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size), - unpacket_traits::size))); + a, __riscv_vfmv_v_f_f16m1(numext::bit_cast<_Float16>(min), unpacket_traits::size), + unpacket_traits::size))); } template -EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { - Eigen::half buffer[unpacket_traits::size * N]; +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + Eigen::half buffer[unpacket_traits::size * N]; int i = 0; for (i = 0; i < N; i++) { __riscv_vsse16(reinterpret_cast<_Float16*>(&buffer[i]), N * sizeof(Eigen::half), kernel.packet[i], - unpacket_traits::size); + unpacket_traits::size); } for (i = 0; i < N; i++) { - kernel.packet[i] = __riscv_vle16_v_f16m1(reinterpret_cast<_Float16*>(&buffer[i * unpacket_traits::size]), - unpacket_traits::size); + kernel.packet[i] = __riscv_vle16_v_f16m1(reinterpret_cast<_Float16*>(&buffer[i * unpacket_traits::size]), + unpacket_traits::size); } } -EIGEN_STRONG_INLINE Packet2Xf half2float(const PacketXh& a) { +EIGEN_STRONG_INLINE Packet2Xf half2float(const Packet1Xh& a) { return __riscv_vfwcvt_f_f_v_f32m2(a, unpacket_traits::size); } -EIGEN_STRONG_INLINE PacketXh float2half(const Packet2Xf& a) { - return __riscv_vfncvt_f_f_w_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh float2half(const Packet2Xf& a) { + return __riscv_vfncvt_f_f_w_f16m1(a, unpacket_traits::size); } /********************************* Packet2Xh ************************************/ @@ -774,8 +774,8 @@ EIGEN_STRONG_INLINE Eigen::half predux(const Packet2Xh& a) { template <> EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet2Xh& a) { - return predux_mul(__riscv_vfmul_vv_f16m1(__riscv_vget_v_f16m2_f16m1(a, 0), __riscv_vget_v_f16m2_f16m1(a, 1), - unpacket_traits::size)); + return predux_mul(__riscv_vfmul_vv_f16m1(__riscv_vget_v_f16m2_f16m1(a, 0), __riscv_vget_v_f16m2_f16m1(a, 1), + unpacket_traits::size)); } template <> @@ -822,22 +822,22 @@ EIGEN_STRONG_INLINE Packet2Xh float2half(const Packet4Xf& a) { template EIGEN_STRONG_INLINE typename std::enable_if::value && (unpacket_traits::size % 8) == 0, - PacketXh>::type + Packet1Xh>::type predux_half(const Packet2Xh& a) { return __riscv_vfadd_vv_f16m1(__riscv_vget_v_f16m2_f16m1(a, 0), __riscv_vget_v_f16m2_f16m1(a, 1), - unpacket_traits::size); + unpacket_traits::size); } -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, pcos) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, pexp) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, pexpm1) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, plog) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, plog1p) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, plog2) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, preciprocal) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, prsqrt) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, psin) -F16_PACKET_FUNCTION(Packet2Xf, PacketXh, ptanh) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, pcos) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, pexp) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, pexpm1) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, plog) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, plog1p) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, plog2) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, preciprocal) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, prsqrt) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, psin) +F16_PACKET_FUNCTION(Packet2Xf, Packet1Xh, ptanh) F16_PACKET_FUNCTION(Packet4Xf, Packet2Xh, pcos) F16_PACKET_FUNCTION(Packet4Xf, Packet2Xh, pexp) @@ -863,22 +863,22 @@ struct type_casting_traits { }; template <> -EIGEN_STRONG_INLINE PacketXh pcast(const PacketXs& a) { - return __riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xh pcast(const Packet1Xs& a) { + return __riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXs pcast(const PacketXh& a) { - return __riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size); +EIGEN_STRONG_INLINE Packet1Xs pcast(const Packet1Xh& a) { + return __riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size); } template <> -EIGEN_STRONG_INLINE PacketXh preinterpret(const PacketXs& a) { +EIGEN_STRONG_INLINE Packet1Xh preinterpret(const Packet1Xs& a) { return __riscv_vreinterpret_v_i16m1_f16m1(a); } template <> -EIGEN_STRONG_INLINE PacketXs preinterpret(const PacketXh& a) { +EIGEN_STRONG_INLINE Packet1Xs preinterpret(const Packet1Xh& a) { return __riscv_vreinterpret_v_f16m1_i16m1(a); } @@ -903,29 +903,29 @@ EIGEN_STRONG_INLINE Packet2Xs preinterpret(const Packet2Xh } template <> -EIGEN_STRONG_INLINE Packet4Xs pcast(const PacketXh& a, const PacketXh& b, const PacketXh& c, - const PacketXh& d) { - return __riscv_vcreate_v_i16m1_i16m4(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(c, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(d, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet4Xs pcast(const Packet1Xh& a, const Packet1Xh& b, const Packet1Xh& c, + const Packet1Xh& d) { + return __riscv_vcreate_v_i16m1_i16m4(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(c, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(d, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE Packet2Xh pcast(const PacketXs& a, const PacketXs& b) { - return __riscv_vcreate_v_f16m1_f16m2(__riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size), - __riscv_vfcvt_f_x_v_f16m1(b, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet2Xh pcast(const Packet1Xs& a, const Packet1Xs& b) { + return __riscv_vcreate_v_f16m1_f16m2(__riscv_vfcvt_f_x_v_f16m1(a, unpacket_traits::size), + __riscv_vfcvt_f_x_v_f16m1(b, unpacket_traits::size)); } template <> -EIGEN_STRONG_INLINE Packet2Xh pcast(const PacketXh& a, const PacketXh& b) { +EIGEN_STRONG_INLINE Packet2Xh pcast(const Packet1Xh& a, const Packet1Xh& b) { return __riscv_vcreate_v_f16m1_f16m2(a, b); } template <> -EIGEN_STRONG_INLINE Packet2Xs pcast(const PacketXh& a, const PacketXh& b) { - return __riscv_vcreate_v_i16m1_i16m2(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), - __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size)); +EIGEN_STRONG_INLINE Packet2Xs pcast(const Packet1Xh& a, const Packet1Xh& b) { + return __riscv_vcreate_v_i16m1_i16m2(__riscv_vfcvt_rtz_x_f_v_i16m1(a, unpacket_traits::size), + __riscv_vfcvt_rtz_x_f_v_i16m1(b, unpacket_traits::size)); } } // namespace internal diff --git a/Eigen/src/Core/util/ConfigureVectorization.h b/Eigen/src/Core/util/ConfigureVectorization.h index d41d05db1d7ce83f5c17b25ade22fc49dbbdda0e..f86d7a32228ac192a65c9bb287457471672340f5 100644 --- a/Eigen/src/Core/util/ConfigureVectorization.h +++ b/Eigen/src/Core/util/ConfigureVectorization.h @@ -467,6 +467,10 @@ extern "C" { #endif #endif +#if defined(__riscv_zvfbfwma) +#define EIGEN_VECTORIZE_RVV10BF16 +#endif + #endif // defined(EIGEN_ARCH_RISCV) #elif (defined __s390x__ && defined __VEC__)