diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h index 4877b6d8090cdd0ae4ebf8260c3d50c1242aab70..5b7bb262e8ed2c6688878b63b6498470670ba554 100644 --- a/Eigen/src/Core/arch/SVE/PacketMath.h +++ b/Eigen/src/Core/arch/SVE/PacketMath.h @@ -108,6 +108,12 @@ EIGEN_STRONG_INLINE PacketXi psub(const PacketXi& a, const PacketXi& b return svsub_s32_z(svptrue_b32(), a, b); } +template <> +EIGEN_STRONG_INLINE PacketXi pabsdiff(const PacketXi& a, const PacketXi& b) +{ + return svabd_s32_z(svptrue_b32(), a, b); +} + template <> EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) { @@ -153,7 +159,7 @@ EIGEN_STRONG_INLINE PacketXi pmax(const PacketXi& a, const PacketXi& b template <> EIGEN_STRONG_INLINE PacketXi pcmp_le(const PacketXi& a, const PacketXi& b) { - return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); + return svdup_n_s32_z(svcmple_s32(svptrue_b32(), a, b), 0xffffffffu); } template <> @@ -213,13 +219,13 @@ EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) template EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) { - return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N))); + return svreinterpret_s32_u32(svlsr_n_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), N)); } template EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) { - return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N)); + return svlsl_n_s32_z(svptrue_b32(), a, N); } template <> @@ -286,6 +292,17 @@ EIGEN_STRONG_INLINE numext::int32_t pfirst(const PacketXi& a) return svlasta_s32(svpfalse_b(), a); } +template <> +EIGEN_STRONG_INLINE PacketXi pselect(const PacketXi& mask, const PacketXi& a, const PacketXi& b) +{ +#if defined(__ARM_FEATURE_SVE2) + return svbsl(a, b, mask); +#else + PacketXi mask_inv = svnot_s32_z(svptrue_b32(), mask); + return svorr_s32_z(svptrue_b32(), svand_s32_z(svptrue_b32(), a, mask), svand_s32_z(svptrue_b32(), b, mask_inv)); +#endif +} + template <> EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) { @@ -424,6 +441,12 @@ struct unpacket_traits { }; }; +template <> +EIGEN_STRONG_INLINE void prefetch(const float* addr) +{ + svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); +} + template <> EIGEN_STRONG_INLINE PacketXf pset1(const float& from) { @@ -456,6 +479,18 @@ EIGEN_STRONG_INLINE PacketXf psub(const PacketXf& a, const PacketXf& b return svsub_f32_z(svptrue_b32(), a, b); } +template <> +EIGEN_STRONG_INLINE PacketXf paddsub(const PacketXf& a, const PacketXf& b) +{ + return svadd_f32_x(svptrue_b32(), a, svneg_f32_m(b, svdupq_n_b32(1,0,1,0), b)); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pabsdiff(const PacketXf& a, const PacketXf& b) +{ + return svabd_f32_z(svptrue_b32(), a, b); +} + template <> EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) { @@ -480,6 +515,12 @@ EIGEN_STRONG_INLINE PacketXf pdiv(const PacketXf& a, const PacketXf& b return svdiv_f32_z(svptrue_b32(), a, b); } +template <> +EIGEN_STRONG_INLINE PacketXf psqrt(const PacketXf& a) +{ + return svsqrt_f32_z(svptrue_b32(), a); +} + template <> EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) { @@ -527,7 +568,7 @@ EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, template <> EIGEN_STRONG_INLINE PacketXf pcmp_le(const PacketXf& a, const PacketXf& b) { - return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); + return svreinterpret_f32_u32(svdup_n_u32_z(svcmple_f32(svptrue_b32(), a, b), 0xffffffffu)); } template <> @@ -652,6 +693,22 @@ EIGEN_STRONG_INLINE float pfirst(const PacketXf& a) return svlasta_f32(svpfalse_b(), a); } + +template <> +EIGEN_STRONG_INLINE PacketXf pselect(const PacketXf& mask, const PacketXf& a, const PacketXf& b) +{ +#if defined(__ARM_FEATURE_SVE2) + return svreinterpret_f32(svbsl(svreinterpret_u32_f32(a), svreinterpret_u32_f32(b), svreinterpret_u32_f32(mask))); +#else + svuint32_t mask_ = svreinterpret_u32_f32(mask); + svuint32_t mask_inv_ = svnot_u32_z(svptrue_b32(), mask_); + svuint32_t a_ = svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), mask_); + svuint32_t b_ = svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(b), mask_inv_); + return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), a_, b_)); +#endif +} + + template <> EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) { @@ -727,6 +784,12 @@ EIGEN_STRONG_INLINE float predux_max(const PacketXf& a) return svmaxv_f32(svptrue_b32(), a); } +template <> +EIGEN_STRONG_INLINE bool predux_any(const PacketXf& a) +{ + return svptest_any(svptrue_b32(), svcmpne_n_f32(svptrue_b32(), a, 0.0f)); +} + template EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) {