#![allow(dead_code)]
extern crate odds;
extern crate memchr;
use std::cmp;
use std::ptr;
use TwoWaySearcher;
use itertools::ZipSlices;
const EQUAL_ANY: u8 = 0b0000;
const EQUAL_EACH: u8 = 0b1000;
const EQUAL_ORDERED: u8 = 0b1100;
#[inline(always)]
unsafe fn pcmpestri(text: &[u8], offset: usize, text_len: usize, needle: u64, needle_len: usize) -> u32 {
debug_assert!(text_len + offset <= text.len());
debug_assert!(needle_len <= 16);
let res: u32;
asm!("pcmpestri $1, [$2 + $3], $$0xc"
: "={ecx}"(res)
: "x"(needle), "r"(text.as_ptr()), "r"(offset), "{rax}"(needle_len), "{rdx}"(text_len) : "cc"
: "intel" );
res
}
#[inline(always)]
unsafe fn pcmpestri_16(text: *const u8, offset: usize, text_len: usize,
needle_1: u64, needle_2: u64, needle_len: usize) -> u32 {
let res: u32;
asm!("movlhps $1, $2
pcmpestri $1, [$3 + $4], $5"
: "={ecx}"(res)
: "x"(needle_1), "x"(needle_2), "r"(text), "r"(offset), "i"(EQUAL_ORDERED),
"{rax}"(needle_len), "{rdx}"(text_len) : "cc"
: "intel" );
res
}
#[inline(always)]
unsafe fn pcmpestrm_eq_each(text: *const u8, offset: usize, text_len: usize,
needle: *const u8, noffset: usize, needle_len: usize) -> u64 {
let res: u64;
asm!("movdqu xmm0, [$1 + $2]
pcmpestrm xmm0, [$3 + $4], $5"
: "={xmm0}"(res)
: "r"(needle), "r"(noffset), "r"(text), "r"(offset), "i"(EQUAL_EACH),
"{rax}"(needle_len), "{rdx}"(text_len) : "cc"
: "intel" );
res
}
#[test]
fn test_pcmpestri_1() {
let text = b"abc";
let n = b'c' as u64;
unsafe {
assert_eq!(pcmpestri(text, 0, text.len(), n, 1), 2);
assert_eq!(pcmpestri(text, 0, text.len(), 0, 1), 16);
}
}
#[inline(never)]
fn crit_period(pat: &[u8]) -> (usize, usize) {
let (i, p) = TwoWaySearcher::maximal_suffix(pat, false);
let (j, q) = TwoWaySearcher::maximal_suffix(pat, true);
if i >= j {
(i, p)
} else {
(j, q)
}
}
fn first_start_of_match(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
let patl = pat.len();
assert!(patl <= 16);
let (patw1, patw2) = pat128(pat);
first_start_of_match_inner(text, pat, patw1, patw2)
}
fn first_start_of_match_inner(text: &[u8], pat: &[u8], p1: u64, p2: u64) -> Option<(usize, usize)> {
let tp = text.as_ptr();
let tp_align_offset = tp as usize & 0xF;
let init_len;
let tp_aligned;
unsafe {
if tp_align_offset != 0 {
init_len = 16 - tp_align_offset;
tp_aligned = tp.offset(-(tp_align_offset as isize));
} else {
init_len = 0;
tp_aligned = tp;
};
}
let patl = pat.len();
debug_assert!(patl <= 16);
let mut offset = 0;
if init_len > 0 {
for start in 0..cmp::min(init_len, text.len()) {
if text[start] != pat[0] {
continue;
}
let mut mlen = 1;
for (a, b) in ZipSlices::new(&text[start + 1..], &pat[1..]) {
if a != b {
mlen = 0;
break;
}
mlen += 1;
}
return Some((start, mlen))
}
offset += 16;
}
while text.len() >= offset - tp_align_offset + patl {
unsafe {
let tlen = text.len() - (offset - tp_align_offset);
let ret = pcmpestri_16(tp_aligned, offset, tlen, p1, p2, patl) as usize;
if ret == 16 {
offset += 16;
} else {
let match_len = cmp::min(patl, 16 - ret);
return Some((offset - tp_align_offset + ret, match_len));
}
}
}
None
}
unsafe fn first_start_of_match_unaligned(text: &[u8], pat_len: usize, p1: u64, p2: u64) -> Option<(usize, usize)> {
let tp = text.as_ptr();
debug_assert!(pat_len <= 16);
debug_assert!(pat_len <= text.len());
let mut offset = 0;
while text.len() - pat_len >= offset {
let tlen = text.len() - offset;
let ret = pcmpestri_16(tp, offset, tlen, p1, p2, pat_len) as usize;
if ret == 16 {
offset += 16;
} else {
let match_len = cmp::min(pat_len, 16 - ret);
return Some((offset + ret, match_len));
}
}
None
}
#[test]
fn test_first_start_of_match() {
let text = b"abc";
let longer = "longer text and so on";
assert_eq!(first_start_of_match(text, b"d"), None);
assert_eq!(first_start_of_match(text, b"c"), Some((2, 1)));
assert_eq!(first_start_of_match(text, b"abc"), Some((0, 3)));
assert_eq!(first_start_of_match(text, b"T"), None);
assert_eq!(first_start_of_match(text, b"\0text"), None);
assert_eq!(first_start_of_match(text, b"\0"), None);
for wsz in 1..17 {
for window in longer.as_bytes().windows(wsz) {
let str_find = longer.find(::std::str::from_utf8(window).unwrap());
assert!(str_find.is_some());
let first_start = first_start_of_match(longer.as_bytes(), window);
assert!(first_start.is_some());
let (pos, len) = first_start.unwrap();
assert!(len <= wsz);
assert!(len == wsz && Some(pos) == str_find
|| pos <= str_find.unwrap());
}
}
}
fn find_2byte_pat(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
debug_assert!(text.len() >= pat.len());
debug_assert!(pat.len() == 2);
let mut off = 1;
while let Some(i) = memchr::memchr(pat[1], &text[off..]) {
match text.get(off + i - 1) {
None => break,
Some(&c) if c == pat[0] => return Some((off + i - 1, off + i + 1)),
_ => off += i + 1,
}
}
None
}
fn find_short_pat(text: &[u8], pat: &[u8]) -> Option<(usize, usize)> {
debug_assert!(pat.len() <= 8);
let (r1, _) = pat128(pat);
let safetext = &text[..cmp::max(text.len(), 16) - 16];
let mut pos = 0;
'search: loop {
if pos + pat.len() > safetext.len() {
break;
}
match unsafe { first_start_of_match_unaligned(&safetext[pos..], pat.len(), r1, 0) } {
None => break, Some((mpos, mlen)) => {
pos += mpos;
if mlen < pat.len() {
if pos > text.len() - pat.len() {
return None;
}
for (a, b) in ZipSlices::new(&text[pos + mlen..], &pat[mlen..]) {
if a != b {
pos += 1;
continue 'search;
}
}
}
return Some((pos, pos + pat.len()));
}
}
}
'tail: loop {
if pos > text.len() - pat.len() {
return None;
}
match first_start_of_match_inner(&text[pos..], pat, r1, 0) {
None => return None, Some((mpos, mlen)) => {
pos += mpos;
if mlen < pat.len() {
if pos > text.len() - pat.len() {
return None;
}
for (a, b) in ZipSlices::new(&text[pos + mlen..], &pat[mlen..]) {
if a != b {
pos += 1;
continue 'tail;
}
}
}
return Some((pos, pos + pat.len()));
}
}
}
}
pub fn find(text: &[u8], pattern: &[u8]) -> Option<(usize, usize)> {
let pat = pattern;
if pat.len() == 0 {
return Some((0, 0));
}
if text.len() < pat.len() {
return None;
}
if pat.len() == 1 {
return memchr::memchr(pat[0], text).map(|i| (i, i + 1));
} else if pat.len() <= 6 {
return find_short_pat(text, pat);
}
let (crit_pos, mut period) = crit_period(pat);
let mut memory;
if &pat[..crit_pos] == &pat[period.. period + crit_pos] {
memory = 0; } else {
memory = !0; period = cmp::max(crit_pos, pat.len() - crit_pos) + 1;
}
let (left, right) = pat.split_at(crit_pos);
let (right16, _right17) = right.split_at(cmp::min(16, right.len()));
assert!(right.len() != 0);
let (r1, r2) = pat128(right);
let safetext = &text[..cmp::max(text.len(), 16) - 16];
let mut pos = 0;
if memory == !0 {
'search: loop {
if pos + pat.len() > safetext.len() {
break;
}
let start = crit_pos;
match unsafe { first_start_of_match_unaligned(&safetext[pos + start..], right16.len(), r1, r2) } {
None => break, Some((mpos, mlen)) => {
pos += mpos;
let mut pfxlen = mlen;
if pfxlen < right.len() {
pfxlen += shared_prefix(&text[pos + start + mlen..], &right[mlen..]);
}
if pfxlen != right.len() {
pos += pfxlen + 1;
continue 'search;
} else {
}
}
}
if left != &text[pos..pos + left.len()] {
pos += period;
continue 'search;
}
return Some((pos, pos + pat.len()));
}
} else {
'search_memory: loop {
if pos + pat.len() > safetext.len() {
break;
}
let mut pfxlen = if memory == 0 {
let start = crit_pos;
match unsafe { first_start_of_match_unaligned(&safetext[pos + start..], right16.len(), r1, r2) } {
None => break, Some((mpos, mlen)) => {
pos += mpos;
mlen
}
}
} else {
memory - crit_pos
};
if pfxlen < right.len() {
pfxlen += shared_prefix(&text[pos + crit_pos + pfxlen..], &right[pfxlen..]);
}
if pfxlen != right.len() {
pos += pfxlen + 1;
memory = 0;
continue 'search_memory;
} else {
}
if memory <= left.len() && &left[memory..] != &text[pos + memory..pos + left.len()] {
pos += period;
memory = pat.len() - period;
continue 'search_memory;
}
return Some((pos, pos + pat.len()));
}
}
'tail: loop {
if pos > text.len() - pat.len() {
return None;
}
let start = crit_pos;
match first_start_of_match_inner(&text[pos + start..], right16, r1, r2) {
None => return None, Some((mpos, mlen)) => {
pos += mpos;
let mut pfxlen = mlen;
if pfxlen < right.len() {
pfxlen += shared_prefix(&text[pos + start + mlen..], &right[mlen..]);
}
if pfxlen != right.len() {
pos += pfxlen + 1;
continue 'tail;
} else {
}
}
}
if left != &text[pos..pos + left.len()] {
pos += period;
continue 'tail;
}
return Some((pos, pos + pat.len()));
}
}
#[test]
fn test_find() {
let text = b"abc";
assert_eq!(find(text, b"d"), None);
assert_eq!(find(text, b"c"), Some((2, 3)));
let longer = "longer text and so on, a bit more";
for wsz in 1..longer.len() {
for window in longer.as_bytes().windows(wsz) {
let str_find = longer.find(::std::str::from_utf8(window).unwrap());
assert!(str_find.is_some());
assert_eq!(find(longer.as_bytes(), window).map(|(a, _)| a), str_find);
}
}
let pat = b"ger text and so on";
assert!(pat.len() > 16);
assert_eq!(Some((3, 3 + pat.len())), find(longer.as_bytes(), pat));
let text = "cbabababcbabababab";
let n = "abababab";
assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()).map(|(a, _)| a));
let text = "cbababababababababababababababab";
let n = "abababab";
assert_eq!(text.find(n), find(text.as_bytes(), n.as_bytes()).map(|(a, _)| a));
}
fn pat128(pat: &[u8]) -> (u64, u64) {
let (mut p1, mut p2) = (0, 0);
unsafe {
let patl = pat.len();
ptr::copy_nonoverlapping(&pat[0],
&mut p1 as *mut _ as *mut _,
cmp::min(8, patl));
if patl > 8 {
ptr::copy_nonoverlapping(&pat[8],
&mut p2 as *mut _ as *mut _,
cmp::min(16, patl) - 8);
}
}
(p1, p2)
}
pub fn shared_prefix(text: &[u8], pat: &[u8]) -> usize {
let tp = text.as_ptr();
let tlen = text.len();
let pp = pat.as_ptr();
let plen = pat.len();
let len = cmp::min(tlen, plen);
unsafe {
let initial_part = len.saturating_sub(16);
let mut prefix_len = 0;
let mut offset = 0;
while offset < initial_part {
let initial_tail = initial_part - offset;
let mask = pcmpestrm_eq_each(tp, offset, initial_tail, pp, offset, initial_tail);
if mask != 0xffff {
let first_bit_set = (mask ^ 0xffff).trailing_zeros() as usize;
prefix_len += first_bit_set;
return prefix_len;
} else {
prefix_len += cmp::min(initial_tail, 16);
}
offset += 16;
}
let text_suffix = odds::slice_unchecked(text, prefix_len, len);
let pat_suffix = odds::slice_unchecked(pat, prefix_len, len);
for (&a, &b) in ZipSlices::new(text_suffix, pat_suffix) {
if a != b {
break;
}
prefix_len += 1;
}
prefix_len
}
}
#[test]
fn test_prefixlen() {
let text_long = b"0123456789abcdefeffect";
let text_long2 = b"9123456789abcdefeffect";
let text_long3 = b"0123456789abcdefgffect";
let plen = shared_prefix(text_long, text_long);
assert_eq!(plen, text_long.len());
let plen = shared_prefix(b"abcd", b"abc");
assert_eq!(plen, 3);
let plen = shared_prefix(b"abcd", b"abcf");
assert_eq!(plen, 3);
assert_eq!(0, shared_prefix(text_long, text_long2));
assert_eq!(0, shared_prefix(text_long, &text_long[1..]));
assert_eq!(16, shared_prefix(text_long, text_long3));
for i in 0..text_long.len() + 1 {
assert_eq!(text_long.len() - i, shared_prefix(&text_long[i..], &text_long[i..]));
}
let l1 = [7u8; 1024];
let mut l2 = [7u8; 1024];
let off = 1000;
l2[off] = 0;
for i in 0..off {
let plen = shared_prefix(&l1[i..], &l2[i..]);
assert_eq!(plen, off - i);
}
}