use core::arch::asm;
use core::mem::size_of;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use crate::with_dit;
#[must_use]
#[inline(always)]
fn cmpeq_epi8(a: __m128i, b: __m128i) -> __m128i {
let mut c;
if cfg!(target_feature = "avx") {
unsafe {
asm!("vpcmpeqb {c}, {a}, {b}",
c = lateout(xmm_reg) c,
a = in(xmm_reg) a,
b = in(xmm_reg) b,
options(pure, nomem, preserves_flags, nostack));
}
} else {
unsafe {
asm!("pcmpeqb {a}, {b}",
a = inlateout(xmm_reg) a => c,
b = in(xmm_reg) b,
options(pure, nomem, preserves_flags, nostack));
}
}
c
}
#[must_use]
#[inline(always)]
fn and_si128(a: __m128i, b: __m128i) -> __m128i {
let mut c;
if cfg!(target_feature = "avx") {
unsafe {
asm!("vpand {c}, {a}, {b}",
c = lateout(xmm_reg) c,
a = in(xmm_reg) a,
b = in(xmm_reg) b,
options(pure, nomem, preserves_flags, nostack));
}
} else {
unsafe {
asm!("pand {a}, {b}",
a = inlateout(xmm_reg) a => c,
b = in(xmm_reg) b,
options(pure, nomem, preserves_flags, nostack));
}
}
c
}
#[must_use]
#[inline(always)]
fn movemask_epi8(a: __m128i) -> u32 {
let mut mask;
if cfg!(target_feature = "avx") {
unsafe {
asm!("vpmovmskb {mask:e}, {a}",
mask = lateout(reg) mask,
a = in(xmm_reg) a,
options(pure, nomem, preserves_flags, nostack));
}
} else {
unsafe {
asm!("pmovmskb {mask:e}, {a}",
mask = lateout(reg) mask,
a = in(xmm_reg) a,
options(pure, nomem, preserves_flags, nostack));
}
}
mask
}
#[must_use]
#[inline(always)]
unsafe fn constant_time_eq_sse2(mut a: *const u8, mut b: *const u8, mut n: usize) -> bool {
const LANES: usize = size_of::<__m128i>();
let tmp = if n >= LANES * 2 {
let mut mask0;
let mut mask1;
unsafe {
let tmpa0 = _mm_loadu_si128(a as *const __m128i);
let tmpb0 = _mm_loadu_si128(b as *const __m128i);
let tmpa1 = _mm_loadu_si128(a.add(LANES) as *const __m128i);
let tmpb1 = _mm_loadu_si128(b.add(LANES) as *const __m128i);
a = a.add(LANES * 2);
b = b.add(LANES * 2);
n -= LANES * 2;
mask0 = cmpeq_epi8(tmpa0, tmpb0);
mask1 = cmpeq_epi8(tmpa1, tmpb1);
}
while n >= LANES * 2 {
unsafe {
let tmpa0 = _mm_loadu_si128(a as *const __m128i);
let tmpb0 = _mm_loadu_si128(b as *const __m128i);
let tmpa1 = _mm_loadu_si128(a.add(LANES) as *const __m128i);
let tmpb1 = _mm_loadu_si128(b.add(LANES) as *const __m128i);
a = a.add(LANES * 2);
b = b.add(LANES * 2);
n -= LANES * 2;
let tmp0 = cmpeq_epi8(tmpa0, tmpb0);
let tmp1 = cmpeq_epi8(tmpa1, tmpb1);
mask0 = and_si128(mask0, tmp0);
mask1 = and_si128(mask1, tmp1);
}
}
if n >= LANES {
unsafe {
let tmpa = _mm_loadu_si128(a as *const __m128i);
let tmpb = _mm_loadu_si128(b as *const __m128i);
a = a.add(LANES);
b = b.add(LANES);
n -= LANES;
let tmp = cmpeq_epi8(tmpa, tmpb);
mask0 = and_si128(mask0, tmp);
}
}
let mask = and_si128(mask0, mask1);
movemask_epi8(mask) ^ 0xFFFF
} else if n >= LANES {
let mask = unsafe {
let tmpa = _mm_loadu_si128(a as *const __m128i);
let tmpb = _mm_loadu_si128(b as *const __m128i);
a = a.add(LANES);
b = b.add(LANES);
n -= LANES;
cmpeq_epi8(tmpa, tmpb)
};
movemask_epi8(mask) ^ 0xFFFF
} else {
0
};
unsafe { crate::generic::constant_time_eq_impl(a, b, n, tmp.into()) }
}
#[must_use]
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
with_dit(|| {
a.len() == b.len() && unsafe { constant_time_eq_sse2(a.as_ptr(), b.as_ptr(), a.len()) }
})
}
#[must_use]
pub fn constant_time_eq_n<const N: usize>(a: &[u8; N], b: &[u8; N]) -> bool {
with_dit(|| {
unsafe { constant_time_eq_sse2(a.as_ptr(), b.as_ptr(), N) }
})
}