diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h index 4877b6d8090cdd0ae4ebf8260c3d50c1242aab70..841a69606ab06d2bcb6574f0cc1a121458847a04 100644 --- a/Eigen/src/Core/arch/SVE/PacketMath.h +++ b/Eigen/src/Core/arch/SVE/PacketMath.h @@ -33,8 +33,9 @@ struct sve_packet_size_selector { enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) }; }; -/********************************* int32 **************************************/ +/********************************* int **************************************/ typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); +typedef svuint32_t PacketXui __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); template <> struct packet_traits : default_packet_traits { @@ -63,6 +64,33 @@ struct packet_traits : default_packet_traits { }; }; +template <> +struct packet_traits : default_packet_traits { + typedef PacketXui type; + typedef PacketXui half; // Half not implemented yet + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = sve_packet_size_selector::size, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasReduxp = 0 // Not implemented in SVE + }; +}; + template <> struct unpacket_traits { typedef numext::int32_t type; @@ -77,235 +105,304 @@ struct unpacket_traits { }; template <> -EIGEN_STRONG_INLINE void prefetch(const numext::int32_t* addr) -{ - svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); -} +struct unpacket_traits { + typedef numext::uint32_t type; + typedef PacketXui half; // Half not yet implemented + enum { + size = sve_packet_size_selector::size, + alignment = Aligned64, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; -template <> -EIGEN_STRONG_INLINE PacketXi pset1(const numext::int32_t& from) -{ - return svdup_n_s32(from); -} +template<> EIGEN_STRONG_INLINE void prefetch(const numext::int32_t* addr) +{ svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); } -template <> -EIGEN_STRONG_INLINE PacketXi plset(const numext::int32_t& a) +template<> EIGEN_STRONG_INLINE void prefetch(const numext::uint32_t* addr) +{ svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); } + +template<> EIGEN_STRONG_INLINE PacketXi pset1(const numext::int32_t& from) +{ return svdup_n_s32(from); } + +template<> EIGEN_STRONG_INLINE PacketXui pset1(const numext::uint32_t& from) +{ return svdup_n_u32(from); } + +template<> EIGEN_STRONG_INLINE PacketXi plset(const numext::int32_t& a) { numext::int32_t c[packet_traits::size]; for (int i = 0; i < packet_traits::size; i++) c[i] = i; return svadd_s32_z(svptrue_b32(), pset1(a), svld1_s32(svptrue_b32(), c)); } -template <> -EIGEN_STRONG_INLINE PacketXi padd(const PacketXi& a, const PacketXi& b) +template<> EIGEN_STRONG_INLINE PacketXui plset(const numext::uint32_t& a) { - return svadd_s32_z(svptrue_b32(), a, b); + numext::uint32_t c[packet_traits::size]; + for (uint i = 0; i < packet_traits::size; i++) c[i] = i; + return svadd_u32_z(svptrue_b32(), pset1(a), svld1_u32(svptrue_b32(), c)); } -template <> -EIGEN_STRONG_INLINE PacketXi psub(const PacketXi& a, const PacketXi& b) -{ - return svsub_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXi padd(const PacketXi& a, const PacketXi& b) +{ return svadd_s32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) -{ - return svneg_s32_z(svptrue_b32(), a); -} +template<> EIGEN_STRONG_INLINE PacketXui padd(const PacketXui& a, const PacketXui& b) +{ return svadd_u32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) -{ - return a; -} +template<> EIGEN_STRONG_INLINE PacketXi psub(const PacketXi& a, const PacketXi& b) +{ return svsub_s32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pmul(const PacketXi& a, const PacketXi& b) -{ - return svmul_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXui psub(const PacketXui& a, const PacketXui& b) +{ return svsub_u32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pdiv(const PacketXi& a, const PacketXi& b) -{ - return svdiv_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXi pabsdiff(const PacketXi& a, const PacketXi& b) +{ return svabd_s32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) -{ - return svmla_s32_z(svptrue_b32(), c, a, b); -} +template<> EIGEN_STRONG_INLINE PacketXui pabsdiff(const PacketXui& a, const PacketXui& b) +{ return svabd_u32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pmin(const PacketXi& a, const PacketXi& b) -{ - return svmin_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) +{ return svneg_s32_z(svptrue_b32(), a); } -template <> -EIGEN_STRONG_INLINE PacketXi pmax(const PacketXi& a, const PacketXi& b) -{ - return svmax_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) { return a; } -template <> -EIGEN_STRONG_INLINE PacketXi pcmp_le(const PacketXi& a, const PacketXi& b) -{ - return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); -} +template<> EIGEN_STRONG_INLINE PacketXui pconj(const PacketXui& a) { return a; } -template <> -EIGEN_STRONG_INLINE PacketXi pcmp_lt(const PacketXi& a, const PacketXi& b) -{ - return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); -} +template<> EIGEN_STRONG_INLINE PacketXi pmul(const PacketXi& a, const PacketXi& b) +{ return svmul_s32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pcmp_eq(const PacketXi& a, const PacketXi& b) -{ - return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu); -} +template<> EIGEN_STRONG_INLINE PacketXui pmul(const PacketXui& a, const PacketXui& b) +{ return svmul_u32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi ptrue(const PacketXi& /*a*/) -{ - return svdup_n_s32_z(svptrue_b32(), 0xffffffffu); -} +template<> EIGEN_STRONG_INLINE PacketXi pdiv(const PacketXi& a, const PacketXi& b) +{ return svdiv_s32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pzero(const PacketXi& /*a*/) -{ - return svdup_n_s32_z(svptrue_b32(), 0); -} +template<> EIGEN_STRONG_INLINE PacketXui pdiv(const PacketXui& a, const PacketXui& b) +{ return svdiv_u32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pand(const PacketXi& a, const PacketXi& b) -{ - return svand_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) +{ return svmla_s32_z(svptrue_b32(), c, a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi por(const PacketXi& a, const PacketXi& b) -{ - return svorr_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXui pmadd(const PacketXui& a, const PacketXui& b, const PacketXui& c) +{ return svmla_u32_z(svptrue_b32(), c, a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pxor(const PacketXi& a, const PacketXi& b) -{ - return sveor_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXi pmin(const PacketXi& a, const PacketXi& b) +{ return svmin_s32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXi pandnot(const PacketXi& a, const PacketXi& b) -{ - return svbic_s32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXui pmin(const PacketXui& a, const PacketXui& b) +{ return svmin_u32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXi pmax(const PacketXi& a, const PacketXi& b) +{ return svmax_s32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXui pmax(const PacketXui& a, const PacketXui& b) +{ return svmax_u32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXi pcmp_le(const PacketXi& a, const PacketXi& b) +{ return svdup_n_s32_z(svcmple_s32(svptrue_b32(), a, b), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXui pcmp_le(const PacketXui& a, const PacketXui& b) +{ return svdup_n_u32_z(svcmple_u32(svptrue_b32(), a, b), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXi pcmp_lt(const PacketXi& a, const PacketXi& b) +{ return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXui pcmp_lt(const PacketXui& a, const PacketXui& b) +{ return svdup_n_u32_z(svcmplt_u32(svptrue_b32(), a, b), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXi pcmp_eq(const PacketXi& a, const PacketXi& b) +{ return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXui pcmp_eq(const PacketXui& a, const PacketXui& b) +{ return svdup_n_u32_z(svcmpeq_u32(svptrue_b32(), a, b), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXi ptrue(const PacketXi& /*a*/) +{ return svdup_n_s32_z(svptrue_b32(), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXui ptrue(const PacketXui& /*a*/) +{ return svdup_n_u32_z(svptrue_b32(), 0xffffffffu); } + +template<> EIGEN_STRONG_INLINE PacketXi pzero(const PacketXi& /*a*/) +{ return svdup_n_s32_z(svptrue_b32(), 0); } + +template<> EIGEN_STRONG_INLINE PacketXui pzero(const PacketXui& /*a*/) +{ return svdup_n_u32_z(svptrue_b32(), 0); } + +template<> EIGEN_STRONG_INLINE PacketXi pand(const PacketXi& a, const PacketXi& b) +{ return svand_s32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXui pand(const PacketXui& a, const PacketXui& b) +{ return svand_u32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXi por(const PacketXi& a, const PacketXi& b) +{ return svorr_s32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXui por(const PacketXui& a, const PacketXui& b) +{ return svorr_u32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXi pxor(const PacketXi& a, const PacketXi& b) +{ return sveor_s32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXui pxor(const PacketXui& a, const PacketXui& b) +{ return sveor_u32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXi pandnot(const PacketXi& a, const PacketXi& b) +{ return svbic_s32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXui pandnot(const PacketXui& a, const PacketXui& b) +{ return svbic_u32_z(svptrue_b32(), a, b); } template EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) -{ - return svasrd_n_s32_z(svptrue_b32(), a, N); -} +{ return svasrd_n_s32_z(svptrue_b32(), a, N); } + +template +EIGEN_STRONG_INLINE PacketXui parithmetic_shift_right(PacketXui a) +{ return svlsr_n_u32_z(svptrue_b32(), a, N); } template EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) -{ - return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N))); -} +{ return svreinterpret_s32_u32(svlsr_n_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), N)); } + +template +EIGEN_STRONG_INLINE PacketXui plogical_shift_right(PacketXui a) +{ return svlsr_n_u32_z(svptrue_b32(), a, N); } template EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) -{ - return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N)); -} +{ return svlsl_n_s32_z(svptrue_b32(), a, N); } -template <> -EIGEN_STRONG_INLINE PacketXi pload(const numext::int32_t* from) +template +EIGEN_STRONG_INLINE PacketXui plogical_shift_left(PacketXui a) +{ return svlsl_n_u32_z(svptrue_b32(), a, N); } + +template<> EIGEN_STRONG_INLINE PacketXi pload(const numext::int32_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from); } + +template<> EIGEN_STRONG_INLINE PacketXui pload(const numext::uint32_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return svld1_u32(svptrue_b32(), from); } + +template<> EIGEN_STRONG_INLINE PacketXi ploadu(const numext::int32_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from); } + +template<> EIGEN_STRONG_INLINE PacketXui ploadu(const numext::uint32_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return svld1_u32(svptrue_b32(), from); } + +template<> EIGEN_STRONG_INLINE PacketXi ploaddup(const numext::int32_t* from) { - EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from); + svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} + return svld1_gather_u32index_s32(svptrue_b32(), from, indices); } -template <> -EIGEN_STRONG_INLINE PacketXi ploadu(const numext::int32_t* from) +template<> EIGEN_STRONG_INLINE PacketXui ploaddup(const numext::uint32_t* from) { - EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from); + svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} + return svld1_gather_u32index_u32(svptrue_b32(), from, indices); } -template <> -EIGEN_STRONG_INLINE PacketXi ploaddup(const numext::int32_t* from) +template<> EIGEN_STRONG_INLINE PacketXi ploadquad(const numext::int32_t* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} return svld1_gather_u32index_s32(svptrue_b32(), from, indices); } -template <> -EIGEN_STRONG_INLINE PacketXi ploadquad(const numext::int32_t* from) +template<> EIGEN_STRONG_INLINE PacketXui ploadquad(const numext::uint32_t* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} - return svld1_gather_u32index_s32(svptrue_b32(), from, indices); + return svld1_gather_u32index_u32(svptrue_b32(), from, indices); } -template <> -EIGEN_STRONG_INLINE void pstore(numext::int32_t* to, const PacketXi& from) +template<> EIGEN_STRONG_INLINE void pstore(numext::int32_t* to, const PacketXi& from) +{ EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from); } + +template<> EIGEN_STRONG_INLINE void pstore(numext::uint32_t* to, const PacketXui& from) +{ EIGEN_DEBUG_ALIGNED_STORE svst1_u32(svptrue_b32(), to, from); } + +template<> EIGEN_STRONG_INLINE void pstoreu(numext::int32_t* to, const PacketXi& from) +{ EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from); } + +template<> EIGEN_STRONG_INLINE void pstoreu(numext::uint32_t* to, const PacketXui& from) +{ EIGEN_DEBUG_UNALIGNED_STORE svst1_u32(svptrue_b32(), to, from); } + +template<> EIGEN_DEVICE_FUNC inline PacketXi pgather(const numext::int32_t* from, Index stride) { - EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from); + // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} + svint32_t indices = svindex_s32(0, stride); + return svld1_gather_s32index_s32(svptrue_b32(), from, indices); } -template <> -EIGEN_STRONG_INLINE void pstoreu(numext::int32_t* to, const PacketXi& from) +template<> EIGEN_DEVICE_FUNC inline PacketXui pgather(const numext::uint32_t* from, Index stride) { - EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from); + // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} + svuint32_t indices = svindex_u32(0, stride); + return svld1_gather_u32index_u32(svptrue_b32(), from, indices); } -template <> -EIGEN_DEVICE_FUNC inline PacketXi pgather(const numext::int32_t* from, Index stride) +template<> EIGEN_DEVICE_FUNC inline void pscatter(numext::int32_t* to, const PacketXi& from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} svint32_t indices = svindex_s32(0, stride); - return svld1_gather_s32index_s32(svptrue_b32(), from, indices); + svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from); } -template <> -EIGEN_DEVICE_FUNC inline void pscatter(numext::int32_t* to, const PacketXi& from, Index stride) +template<> EIGEN_DEVICE_FUNC inline void pscatter(numext::uint32_t* to, const PacketXui& from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} - svint32_t indices = svindex_s32(0, stride); - svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from); + svuint32_t indices = svindex_u32(0, stride); + svst1_scatter_u32index_u32(svptrue_b32(), to, indices, from); } -template <> -EIGEN_STRONG_INLINE numext::int32_t pfirst(const PacketXi& a) +template<> EIGEN_STRONG_INLINE numext::int32_t pfirst(const PacketXi& a) { // svlasta returns the first element if all predicate bits are 0 return svlasta_s32(svpfalse_b(), a); } -template <> -EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) +template<> EIGEN_STRONG_INLINE numext::uint32_t pfirst(const PacketXui& a) { - return svrev_s32(a); + // svlasta returns the first element if all predicate bits are 0 + return svlasta_u32(svpfalse_b(), a); } -template <> -EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) +template<> EIGEN_STRONG_INLINE PacketXi pselect(const PacketXi& mask, const PacketXi& a, const PacketXi& b) { - return svabs_s32_z(svptrue_b32(), a); +#if __ARM_FEATURE_SVE2 + return svbsl(a, b, mask); +#else + PacketXi mask_inv = svnot_s32_z(svptrue_b32(), mask); + return svorr_s32_z(svptrue_b32(), svand_s32_z(svptrue_b32(), a, mask), svand_s32_z(svptrue_b32(), b, mask_inv)); +#endif } -template <> -EIGEN_STRONG_INLINE numext::int32_t predux(const PacketXi& a) +template<> EIGEN_STRONG_INLINE PacketXui pselect(const PacketXui& mask, const PacketXui& a, const PacketXui& b) { - return static_cast(svaddv_s32(svptrue_b32(), a)); +#if __ARM_FEATURE_SVE2 + return svbsl(a, b, mask); +#else + PacketXui mask_inv = svnot_u32_z(svptrue_b32(), mask); + return svorr_u32_z(svptrue_b32(), svand_u32_z(svptrue_b32(), a, mask), svand_u32_z(svptrue_b32(), b, mask_inv)); +#endif } -template <> -EIGEN_STRONG_INLINE numext::int32_t predux_mul(const PacketXi& a) +template<> EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) { return svrev_s32(a); } + +template<> EIGEN_STRONG_INLINE PacketXui preverse(const PacketXui& a) { return svrev_u32(a); } + +template<> EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) { return svabs_s32_z(svptrue_b32(), a); } + +template<> EIGEN_STRONG_INLINE numext::int32_t predux(const PacketXi& a) +{ return static_cast(svaddv_s32(svptrue_b32(), a)); } + +template<> EIGEN_STRONG_INLINE numext::uint32_t predux(const PacketXui& a) +{ return static_cast(svaddv_u32(svptrue_b32(), a)); } + +template<> EIGEN_STRONG_INLINE numext::int32_t predux_mul(const PacketXi& a) { EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); @@ -339,18 +436,52 @@ EIGEN_STRONG_INLINE numext::int32_t predux_mul(const PacketXi& a) return pfirst(prod); } -template <> -EIGEN_STRONG_INLINE numext::int32_t predux_min(const PacketXi& a) +template<> EIGEN_STRONG_INLINE numext::uint32_t predux_mul(const PacketXui& a) { - return svminv_s32(svptrue_b32(), a); -} + EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), + EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); -template <> -EIGEN_STRONG_INLINE numext::int32_t predux_max(const PacketXi& a) -{ - return svmaxv_s32(svptrue_b32(), a); + // Multiply the vector by its reverse + svuint32_t prod = svmul_u32_z(svptrue_b32(), a, svrev_u32(a)); + svuint32_t half_prod; + + // Extract the high half of the vector. Depending on the VL more reductions need to be done + if (EIGEN_ARM64_SVE_VL >= 2048) { + half_prod = svtbl_u32(prod, svindex_u32(32, 1)); + prod = svmul_u32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 1024) { + half_prod = svtbl_u32(prod, svindex_u32(16, 1)); + prod = svmul_u32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 512) { + half_prod = svtbl_u32(prod, svindex_u32(8, 1)); + prod = svmul_u32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 256) { + half_prod = svtbl_u32(prod, svindex_u32(4, 1)); + prod = svmul_u32_z(svptrue_b32(), prod, half_prod); + } + // Last reduction + half_prod = svtbl_u32(prod, svindex_u32(2, 1)); + prod = svmul_u32_z(svptrue_b32(), prod, half_prod); + + // The reduction is done to the first element. + return pfirst(prod); } +template<> EIGEN_STRONG_INLINE numext::int32_t predux_min(const PacketXi& a) +{ return svminv_s32(svptrue_b32(), a); } + +template<> EIGEN_STRONG_INLINE numext::uint32_t predux_min(const PacketXui& a) +{ return svminv_u32(svptrue_b32(), a); } + +template<> EIGEN_STRONG_INLINE numext::int32_t predux_max(const PacketXi& a) +{ return svmaxv_s32(svptrue_b32(), a); } + +template<> EIGEN_STRONG_INLINE numext::uint32_t predux_max(const PacketXui& a) +{ return svmaxv_u32(svptrue_b32(), a); } + template EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { int buffer[packet_traits::size * N] = {0}; @@ -366,6 +497,21 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { } } +template +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + int buffer[packet_traits::size * N] = {0}; + int i = 0; + + PacketXui stride_index = svindex_u32(0, N); + + for (i = 0; i < N; i++) { + svst1_scatter_u32index_u32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]); + } + for (i = 0; i < N; i++) { + kernel.packet[i] = svld1_u32(svptrue_b32(), buffer + i * packet_traits::size); + } +} + /********************************* float32 ************************************/ typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); @@ -424,192 +570,119 @@ struct unpacket_traits { }; }; -template <> -EIGEN_STRONG_INLINE PacketXf pset1(const float& from) -{ - return svdup_n_f32(from); -} +template<> EIGEN_STRONG_INLINE void prefetch(const float* addr) +{ svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); } -template <> -EIGEN_STRONG_INLINE PacketXf pset1frombits(numext::uint32_t from) -{ - return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from)); -} +template<> EIGEN_STRONG_INLINE PacketXf pset1(const float& from) +{ return svdup_n_f32(from); } -template <> -EIGEN_STRONG_INLINE PacketXf plset(const float& a) +template<> EIGEN_STRONG_INLINE PacketXf pset1frombits(numext::uint32_t from) +{ return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from)); } + +template<> EIGEN_STRONG_INLINE PacketXf plset(const float& a) { float c[packet_traits::size]; for (int i = 0; i < packet_traits::size; i++) c[i] = i; return svadd_f32_z(svptrue_b32(), pset1(a), svld1_f32(svptrue_b32(), c)); } -template <> -EIGEN_STRONG_INLINE PacketXf padd(const PacketXf& a, const PacketXf& b) -{ - return svadd_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf padd(const PacketXf& a, const PacketXf& b) +{ return svadd_f32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf psub(const PacketXf& a, const PacketXf& b) -{ - return svsub_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf psub(const PacketXf& a, const PacketXf& b) +{ return svsub_f32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) -{ - return svneg_f32_z(svptrue_b32(), a); -} +template<> EIGEN_STRONG_INLINE PacketXf paddsub(const PacketXf& a, const PacketXf& b) +{ return svadd_f32_x(svptrue_b32(), a, svneg_f32_m(b, svdupq_n_b32(1,0,1,0), b)); } -template <> -EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) -{ - return a; -} +template<> EIGEN_STRONG_INLINE PacketXf pabsdiff(const PacketXf& a, const PacketXf& b) +{ return svabd_f32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf pmul(const PacketXf& a, const PacketXf& b) -{ - return svmul_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) +{ return svneg_f32_z(svptrue_b32(), a); } -template <> -EIGEN_STRONG_INLINE PacketXf pdiv(const PacketXf& a, const PacketXf& b) -{ - return svdiv_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) { return a; } -template <> -EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) -{ - return svmla_f32_z(svptrue_b32(), c, a, b); -} +template <> EIGEN_STRONG_INLINE PacketXf pmul(const PacketXf& a, const PacketXf& b) +{ return svmul_f32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) -{ - return svmin_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf pdiv(const PacketXf& a, const PacketXf& b) +{ return svdiv_f32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) -{ - return pmin(a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf psqrt(const PacketXf& a) +{ return svsqrt_f32_z(svptrue_b32(), a); } -template <> -EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) -{ - return svminnm_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) +{ return svmla_f32_z(svptrue_b32(), c, a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) -{ - return svmax_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) +{ return svmin_f32_z(svptrue_b32(), a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) -{ - return pmax(a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) +{ return pmin(a, b); } -template <> -EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) -{ - return svmaxnm_f32_z(svptrue_b32(), a, b); -} +template<> EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) +{ return svminnm_f32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) +{ return svmax_f32_z(svptrue_b32(), a, b); } + +template<> EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) +{ return pmax(a, b); } + +template<> EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) +{ return svmaxnm_f32_z(svptrue_b32(), a, b); } // Float comparisons in SVE return svbool (predicate). Use svdup to set active // lanes to 1 (0xffffffffu) and inactive lanes to 0. -template <> -EIGEN_STRONG_INLINE PacketXf pcmp_le(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); -} +template<> EIGEN_STRONG_INLINE PacketXf pcmp_le(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(svdup_n_u32_z(svcmple_f32(svptrue_b32(), a, b), 0xffffffffu)); } -template <> -EIGEN_STRONG_INLINE PacketXf pcmp_lt(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); -} +template<> EIGEN_STRONG_INLINE PacketXf pcmp_lt(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); } -template <> -EIGEN_STRONG_INLINE PacketXf pcmp_eq(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu)); -} +template<> EIGEN_STRONG_INLINE PacketXf pcmp_eq(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu)); } // Do a predicate inverse (svnot_b_z) on the predicate resulted from the // greater/equal comparison (svcmpge_f32). Then fill a float vector with the // active elements. -template <> -EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu)); -} +template<> EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu)); } -template <> -EIGEN_STRONG_INLINE PacketXf pfloor(const PacketXf& a) -{ - return svrintm_f32_z(svptrue_b32(), a); -} +template<> EIGEN_STRONG_INLINE PacketXf pfloor(const PacketXf& a) +{ return svrintm_f32_z(svptrue_b32(), a); } -template <> -EIGEN_STRONG_INLINE PacketXf ptrue(const PacketXf& /*a*/) -{ - return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu)); -} +template<> EIGEN_STRONG_INLINE PacketXf ptrue(const PacketXf& /*a*/) +{ return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu)); } // Logical Operations are not supported for float, so reinterpret casts -template <> -EIGEN_STRONG_INLINE PacketXf pand(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); -} +template<> EIGEN_STRONG_INLINE PacketXf pand(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } -template <> -EIGEN_STRONG_INLINE PacketXf por(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); -} +template<> EIGEN_STRONG_INLINE PacketXf por(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } -template <> -EIGEN_STRONG_INLINE PacketXf pxor(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); -} +template<> EIGEN_STRONG_INLINE PacketXf pxor(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } -template <> -EIGEN_STRONG_INLINE PacketXf pandnot(const PacketXf& a, const PacketXf& b) -{ - return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); -} +template<> EIGEN_STRONG_INLINE PacketXf pandnot(const PacketXf& a, const PacketXf& b) +{ return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); } -template <> -EIGEN_STRONG_INLINE PacketXf pload(const float* from) -{ - EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from); -} +template<> EIGEN_STRONG_INLINE PacketXf pload(const float* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from); } -template <> -EIGEN_STRONG_INLINE PacketXf ploadu(const float* from) -{ - EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from); -} +template<> EIGEN_STRONG_INLINE PacketXf ploadu(const float* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from); } -template <> -EIGEN_STRONG_INLINE PacketXf ploaddup(const float* from) +template<> EIGEN_STRONG_INLINE PacketXf ploaddup(const float* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} return svld1_gather_u32index_f32(svptrue_b32(), from, indices); } -template <> -EIGEN_STRONG_INLINE PacketXf ploadquad(const float* from) +template<> EIGEN_STRONG_INLINE PacketXf ploadquad(const float* from) { svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} @@ -617,72 +690,65 @@ EIGEN_STRONG_INLINE PacketXf ploadquad(const float* from) return svld1_gather_u32index_f32(svptrue_b32(), from, indices); } -template <> -EIGEN_STRONG_INLINE void pstore(float* to, const PacketXf& from) -{ - EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from); -} +template<> EIGEN_STRONG_INLINE void pstore(float* to, const PacketXf& from) +{ EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from); } -template <> -EIGEN_STRONG_INLINE void pstoreu(float* to, const PacketXf& from) -{ - EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from); -} +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const PacketXf& from) +{ EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from); } -template <> -EIGEN_DEVICE_FUNC inline PacketXf pgather(const float* from, Index stride) +template<> EIGEN_DEVICE_FUNC inline PacketXf pgather(const float* from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} svint32_t indices = svindex_s32(0, stride); return svld1_gather_s32index_f32(svptrue_b32(), from, indices); } -template <> -EIGEN_DEVICE_FUNC inline void pscatter(float* to, const PacketXf& from, Index stride) +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const PacketXf& from, Index stride) { // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} svint32_t indices = svindex_s32(0, stride); svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from); } -template <> -EIGEN_STRONG_INLINE float pfirst(const PacketXf& a) +template<> EIGEN_STRONG_INLINE float pfirst(const PacketXf& a) { // svlasta returns the first element if all predicate bits are 0 return svlasta_f32(svpfalse_b(), a); } -template <> -EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) -{ - return svrev_f32(a); -} -template <> -EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) +template<> EIGEN_STRONG_INLINE PacketXf pselect(const PacketXf& mask, const PacketXf& a, const PacketXf& b) { - return svabs_f32_z(svptrue_b32(), a); +#if __ARM_FEATURE_SVE2 + return svreinterpret_f32(svbsl(svreinterpret_u32_f32(a), svreinterpret_u32_f32(b), svreinterpret_u32_f32(mask))); +#else + svuint32_t mask_ = svreinterpret_u32_f32(mask); + svuint32_t mask_inv_ = svnot_u32_z(svptrue_b32(), mask_); + svuint32_t a_ = svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), mask_); + svuint32_t b_ = svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(b), mask_inv_); + return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), a_, b_)); +#endif } + +template<> EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) { return svrev_f32(a); } + +template <> EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) +{ return svabs_f32_z(svptrue_b32(), a); } + // TODO(tellenbach): Should this go into MathFunctions.h? If so, change for // all vector extensions and the generic version. template <> EIGEN_STRONG_INLINE PacketXf pfrexp(const PacketXf& a, PacketXf& exponent) -{ - return pfrexp_generic(a, exponent); -} +{ return pfrexp_generic(a, exponent); } -template <> -EIGEN_STRONG_INLINE float predux(const PacketXf& a) -{ - return svaddv_f32(svptrue_b32(), a); -} +template<> EIGEN_STRONG_INLINE float predux(const PacketXf& a) +{ return svaddv_f32(svptrue_b32(), a); } // Other reduction functions: // mul // Only works for SVE Vls multiple of 128 -template <> -EIGEN_STRONG_INLINE float predux_mul(const PacketXf& a) +template<> EIGEN_STRONG_INLINE float predux_mul(const PacketXf& a) { EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); @@ -715,17 +781,14 @@ EIGEN_STRONG_INLINE float predux_mul(const PacketXf& a) return pfirst(prod); } -template <> -EIGEN_STRONG_INLINE float predux_min(const PacketXf& a) -{ - return svminv_f32(svptrue_b32(), a); -} +template<> EIGEN_STRONG_INLINE float predux_min(const PacketXf& a) +{ return svminv_f32(svptrue_b32(), a); } -template <> -EIGEN_STRONG_INLINE float predux_max(const PacketXf& a) -{ - return svmaxv_f32(svptrue_b32(), a); -} +template<> EIGEN_STRONG_INLINE float predux_max(const PacketXf& a) +{ return svmaxv_f32(svptrue_b32(), a); } + +template<> EIGEN_STRONG_INLINE bool predux_any(const PacketXf& a) +{ return svptest_any(svptrue_b32(), svcmpne_n_f32(svptrue_b32(), a, 0.0f)); } template EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) @@ -746,9 +809,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) template<> EIGEN_STRONG_INLINE PacketXf pldexp(const PacketXf& a, const PacketXf& exponent) -{ - return pldexp_generic(a, exponent); -} +{ return pldexp_generic(a, exponent); } } // namespace internal } // namespace Eigen