use core::cmp;
use crate::memmem::{prefilter::Pre, util};
#[derive(Clone, Copy, Debug)]
pub(crate) struct Forward(TwoWay);
#[derive(Clone, Copy, Debug)]
pub(crate) struct Reverse(TwoWay);
#[derive(Clone, Copy, Debug)]
struct TwoWay {
byteset: ApproximateByteSet,
critical_pos: usize,
shift: Shift,
}
impl Forward {
pub(crate) fn new(needle: &[u8]) -> Forward {
if needle.is_empty() {
return Forward(TwoWay::empty());
}
let byteset = ApproximateByteSet::new(needle);
let min_suffix = Suffix::forward(needle, SuffixKind::Minimal);
let max_suffix = Suffix::forward(needle, SuffixKind::Maximal);
let (period_lower_bound, critical_pos) =
if min_suffix.pos > max_suffix.pos {
(min_suffix.period, min_suffix.pos)
} else {
(max_suffix.period, max_suffix.pos)
};
let shift = Shift::forward(needle, period_lower_bound, critical_pos);
Forward(TwoWay { byteset, critical_pos, shift })
}
#[inline(always)]
pub(crate) fn find(
&self,
pre: Option<&mut Pre<'_>>,
haystack: &[u8],
needle: &[u8],
) -> Option<usize> {
debug_assert!(!needle.is_empty(), "needle should not be empty");
debug_assert!(needle.len() <= haystack.len(), "haystack too short");
match self.0.shift {
Shift::Small { period } => {
self.find_small_imp(pre, haystack, needle, period)
}
Shift::Large { shift } => {
self.find_large_imp(pre, haystack, needle, shift)
}
}
}
#[cfg(test)]
fn find_general(
&self,
pre: Option<&mut Pre<'_>>,
haystack: &[u8],
needle: &[u8],
) -> Option<usize> {
if needle.is_empty() {
Some(0)
} else if haystack.len() < needle.len() {
None
} else {
self.find(pre, haystack, needle)
}
}
#[inline(always)]
fn find_small_imp(
&self,
mut pre: Option<&mut Pre<'_>>,
haystack: &[u8],
needle: &[u8],
period: usize,
) -> Option<usize> {
let last_byte = needle.len() - 1;
let mut pos = 0;
let mut shift = 0;
while pos + needle.len() <= haystack.len() {
let mut i = cmp::max(self.0.critical_pos, shift);
if let Some(pre) = pre.as_mut() {
if pre.should_call() {
pos += pre.call(&haystack[pos..], needle)?;
shift = 0;
i = self.0.critical_pos;
if pos + needle.len() > haystack.len() {
return None;
}
}
}
if !self.0.byteset.contains(haystack[pos + last_byte]) {
pos += needle.len();
shift = 0;
continue;
}
while i < needle.len() && needle[i] == haystack[pos + i] {
i += 1;
}
if i < needle.len() {
pos += i - self.0.critical_pos + 1;
shift = 0;
} else {
let mut j = self.0.critical_pos;
while j > shift && needle[j] == haystack[pos + j] {
j -= 1;
}
if j <= shift && needle[shift] == haystack[pos + shift] {
return Some(pos);
}
pos += period;
shift = needle.len() - period;
}
}
None
}
#[inline(always)]
fn find_large_imp(
&self,
mut pre: Option<&mut Pre<'_>>,
haystack: &[u8],
needle: &[u8],
shift: usize,
) -> Option<usize> {
let last_byte = needle.len() - 1;
let mut pos = 0;
'outer: while pos + needle.len() <= haystack.len() {
if let Some(pre) = pre.as_mut() {
if pre.should_call() {
pos += pre.call(&haystack[pos..], needle)?;
if pos + needle.len() > haystack.len() {
return None;
}
}
}
if !self.0.byteset.contains(haystack[pos + last_byte]) {
pos += needle.len();
continue;
}
let mut i = self.0.critical_pos;
while i < needle.len() && needle[i] == haystack[pos + i] {
i += 1;
}
if i < needle.len() {
pos += i - self.0.critical_pos + 1;
} else {
for j in (0..self.0.critical_pos).rev() {
if needle[j] != haystack[pos + j] {
pos += shift;
continue 'outer;
}
}
return Some(pos);
}
}
None
}
}
impl Reverse {
pub(crate) fn new(needle: &[u8]) -> Reverse {
if needle.is_empty() {
return Reverse(TwoWay::empty());
}
let byteset = ApproximateByteSet::new(needle);
let min_suffix = Suffix::reverse(needle, SuffixKind::Minimal);
let max_suffix = Suffix::reverse(needle, SuffixKind::Maximal);
let (period_lower_bound, critical_pos) =
if min_suffix.pos < max_suffix.pos {
(min_suffix.period, min_suffix.pos)
} else {
(max_suffix.period, max_suffix.pos)
};
let shift = Shift::reverse(needle, period_lower_bound, critical_pos);
Reverse(TwoWay { byteset, critical_pos, shift })
}
#[inline(always)]
pub(crate) fn rfind(
&self,
haystack: &[u8],
needle: &[u8],
) -> Option<usize> {
debug_assert!(!needle.is_empty(), "needle should not be empty");
debug_assert!(needle.len() <= haystack.len(), "haystack too short");
match self.0.shift {
Shift::Small { period } => {
self.rfind_small_imp(haystack, needle, period)
}
Shift::Large { shift } => {
self.rfind_large_imp(haystack, needle, shift)
}
}
}
#[cfg(test)]
fn rfind_general(&self, haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() {
Some(haystack.len())
} else if haystack.len() < needle.len() {
None
} else {
self.rfind(haystack, needle)
}
}
#[inline(always)]
fn rfind_small_imp(
&self,
haystack: &[u8],
needle: &[u8],
period: usize,
) -> Option<usize> {
let nlen = needle.len();
let mut pos = haystack.len();
let mut shift = nlen;
while pos >= nlen {
if !self.0.byteset.contains(haystack[pos - nlen]) {
pos -= nlen;
shift = nlen;
continue;
}
let mut i = cmp::min(self.0.critical_pos, shift);
while i > 0 && needle[i - 1] == haystack[pos - nlen + i - 1] {
i -= 1;
}
if i > 0 || needle[0] != haystack[pos - nlen] {
pos -= self.0.critical_pos - i + 1;
shift = nlen;
} else {
let mut j = self.0.critical_pos;
while j < shift && needle[j] == haystack[pos - nlen + j] {
j += 1;
}
if j >= shift {
return Some(pos - nlen);
}
pos -= period;
shift = period;
}
}
None
}
#[inline(always)]
fn rfind_large_imp(
&self,
haystack: &[u8],
needle: &[u8],
shift: usize,
) -> Option<usize> {
let nlen = needle.len();
let mut pos = haystack.len();
while pos >= nlen {
if !self.0.byteset.contains(haystack[pos - nlen]) {
pos -= nlen;
continue;
}
let mut i = self.0.critical_pos;
while i > 0 && needle[i - 1] == haystack[pos - nlen + i - 1] {
i -= 1;
}
if i > 0 || needle[0] != haystack[pos - nlen] {
pos -= self.0.critical_pos - i + 1;
} else {
let mut j = self.0.critical_pos;
while j < nlen && needle[j] == haystack[pos - nlen + j] {
j += 1;
}
if j == nlen {
return Some(pos - nlen);
}
pos -= shift;
}
}
None
}
}
impl TwoWay {
fn empty() -> TwoWay {
TwoWay {
byteset: ApproximateByteSet::new(b""),
critical_pos: 0,
shift: Shift::Large { shift: 0 },
}
}
}
#[derive(Clone, Copy, Debug)]
enum Shift {
Small { period: usize },
Large { shift: usize },
}
impl Shift {
fn forward(
needle: &[u8],
period_lower_bound: usize,
critical_pos: usize,
) -> Shift {
let large = cmp::max(critical_pos, needle.len() - critical_pos);
if critical_pos * 2 >= needle.len() {
return Shift::Large { shift: large };
}
let (u, v) = needle.split_at(critical_pos);
if !util::is_suffix(&v[..period_lower_bound], u) {
return Shift::Large { shift: large };
}
Shift::Small { period: period_lower_bound }
}
fn reverse(
needle: &[u8],
period_lower_bound: usize,
critical_pos: usize,
) -> Shift {
let large = cmp::max(critical_pos, needle.len() - critical_pos);
if (needle.len() - critical_pos) * 2 >= needle.len() {
return Shift::Large { shift: large };
}
let (v, u) = needle.split_at(critical_pos);
if !util::is_prefix(&v[v.len() - period_lower_bound..], u) {
return Shift::Large { shift: large };
}
Shift::Small { period: period_lower_bound }
}
}
#[derive(Debug)]
struct Suffix {
pos: usize,
period: usize,
}
impl Suffix {
fn forward(needle: &[u8], kind: SuffixKind) -> Suffix {
debug_assert!(!needle.is_empty());
let mut suffix = Suffix { pos: 0, period: 1 };
let mut candidate_start = 1;
let mut offset = 0;
while candidate_start + offset < needle.len() {
let current = needle[suffix.pos + offset];
let candidate = needle[candidate_start + offset];
match kind.cmp(current, candidate) {
SuffixOrdering::Accept => {
suffix = Suffix { pos: candidate_start, period: 1 };
candidate_start += 1;
offset = 0;
}
SuffixOrdering::Skip => {
candidate_start += offset + 1;
offset = 0;
suffix.period = candidate_start - suffix.pos;
}
SuffixOrdering::Push => {
if offset + 1 == suffix.period {
candidate_start += suffix.period;
offset = 0;
} else {
offset += 1;
}
}
}
}
suffix
}
fn reverse(needle: &[u8], kind: SuffixKind) -> Suffix {
debug_assert!(!needle.is_empty());
let mut suffix = Suffix { pos: needle.len(), period: 1 };
if needle.len() == 1 {
return suffix;
}
let mut candidate_start = needle.len() - 1;
let mut offset = 0;
while offset < candidate_start {
let current = needle[suffix.pos - offset - 1];
let candidate = needle[candidate_start - offset - 1];
match kind.cmp(current, candidate) {
SuffixOrdering::Accept => {
suffix = Suffix { pos: candidate_start, period: 1 };
candidate_start -= 1;
offset = 0;
}
SuffixOrdering::Skip => {
candidate_start -= offset + 1;
offset = 0;
suffix.period = suffix.pos - candidate_start;
}
SuffixOrdering::Push => {
if offset + 1 == suffix.period {
candidate_start -= suffix.period;
offset = 0;
} else {
offset += 1;
}
}
}
}
suffix
}
}
#[derive(Clone, Copy, Debug)]
enum SuffixKind {
Minimal,
Maximal,
}
#[derive(Clone, Copy, Debug)]
enum SuffixOrdering {
Accept,
Skip,
Push,
}
impl SuffixKind {
fn cmp(self, current: u8, candidate: u8) -> SuffixOrdering {
use self::SuffixOrdering::*;
match self {
SuffixKind::Minimal if candidate < current => Accept,
SuffixKind::Minimal if candidate > current => Skip,
SuffixKind::Minimal => Push,
SuffixKind::Maximal if candidate > current => Accept,
SuffixKind::Maximal if candidate < current => Skip,
SuffixKind::Maximal => Push,
}
}
}
#[derive(Clone, Copy, Debug)]
struct ApproximateByteSet(u64);
impl ApproximateByteSet {
fn new(needle: &[u8]) -> ApproximateByteSet {
let mut bits = 0;
for &b in needle {
bits |= 1 << (b % 64);
}
ApproximateByteSet(bits)
}
#[inline(always)]
fn contains(&self, byte: u8) -> bool {
self.0 & (1 << (byte % 64)) != 0
}
}
#[cfg(all(test, feature = "std", not(miri)))]
mod tests {
use quickcheck::quickcheck;
use super::*;
define_memmem_quickcheck_tests!(
super::simpletests::twoway_find,
super::simpletests::twoway_rfind
);
fn get_suffix_forward(needle: &[u8], kind: SuffixKind) -> (&[u8], usize) {
let s = Suffix::forward(needle, kind);
(&needle[s.pos..], s.period)
}
fn get_suffix_reverse(needle: &[u8], kind: SuffixKind) -> (&[u8], usize) {
let s = Suffix::reverse(needle, kind);
(&needle[..s.pos], s.period)
}
fn suffixes(bytes: &[u8]) -> Vec<&[u8]> {
(0..bytes.len()).map(|i| &bytes[i..]).collect()
}
fn naive_maximal_suffix_forward(needle: &[u8]) -> &[u8] {
let mut sufs = suffixes(needle);
sufs.sort();
sufs.pop().unwrap()
}
fn naive_maximal_suffix_reverse(needle: &[u8]) -> Vec<u8> {
let mut reversed = needle.to_vec();
reversed.reverse();
let mut got = naive_maximal_suffix_forward(&reversed).to_vec();
got.reverse();
got
}
#[test]
fn suffix_forward() {
macro_rules! assert_suffix_min {
($given:expr, $expected:expr, $period:expr) => {
let (got_suffix, got_period) =
get_suffix_forward($given.as_bytes(), SuffixKind::Minimal);
let got_suffix = std::str::from_utf8(got_suffix).unwrap();
assert_eq!(($expected, $period), (got_suffix, got_period));
};
}
macro_rules! assert_suffix_max {
($given:expr, $expected:expr, $period:expr) => {
let (got_suffix, got_period) =
get_suffix_forward($given.as_bytes(), SuffixKind::Maximal);
let got_suffix = std::str::from_utf8(got_suffix).unwrap();
assert_eq!(($expected, $period), (got_suffix, got_period));
};
}
assert_suffix_min!("a", "a", 1);
assert_suffix_max!("a", "a", 1);
assert_suffix_min!("ab", "ab", 2);
assert_suffix_max!("ab", "b", 1);
assert_suffix_min!("ba", "a", 1);
assert_suffix_max!("ba", "ba", 2);
assert_suffix_min!("abc", "abc", 3);
assert_suffix_max!("abc", "c", 1);
assert_suffix_min!("acb", "acb", 3);
assert_suffix_max!("acb", "cb", 2);
assert_suffix_min!("cba", "a", 1);
assert_suffix_max!("cba", "cba", 3);
assert_suffix_min!("abcabc", "abcabc", 3);
assert_suffix_max!("abcabc", "cabc", 3);
assert_suffix_min!("abcabcabc", "abcabcabc", 3);
assert_suffix_max!("abcabcabc", "cabcabc", 3);
assert_suffix_min!("abczz", "abczz", 5);
assert_suffix_max!("abczz", "zz", 1);
assert_suffix_min!("zzabc", "abc", 3);
assert_suffix_max!("zzabc", "zzabc", 5);
assert_suffix_min!("aaa", "aaa", 1);
assert_suffix_max!("aaa", "aaa", 1);
assert_suffix_min!("foobar", "ar", 2);
assert_suffix_max!("foobar", "r", 1);
}
#[test]
fn suffix_reverse() {
macro_rules! assert_suffix_min {
($given:expr, $expected:expr, $period:expr) => {
let (got_suffix, got_period) =
get_suffix_reverse($given.as_bytes(), SuffixKind::Minimal);
let got_suffix = std::str::from_utf8(got_suffix).unwrap();
assert_eq!(($expected, $period), (got_suffix, got_period));
};
}
macro_rules! assert_suffix_max {
($given:expr, $expected:expr, $period:expr) => {
let (got_suffix, got_period) =
get_suffix_reverse($given.as_bytes(), SuffixKind::Maximal);
let got_suffix = std::str::from_utf8(got_suffix).unwrap();
assert_eq!(($expected, $period), (got_suffix, got_period));
};
}
assert_suffix_min!("a", "a", 1);
assert_suffix_max!("a", "a", 1);
assert_suffix_min!("ab", "a", 1);
assert_suffix_max!("ab", "ab", 2);
assert_suffix_min!("ba", "ba", 2);
assert_suffix_max!("ba", "b", 1);
assert_suffix_min!("abc", "a", 1);
assert_suffix_max!("abc", "abc", 3);
assert_suffix_min!("acb", "a", 1);
assert_suffix_max!("acb", "ac", 2);
assert_suffix_min!("cba", "cba", 3);
assert_suffix_max!("cba", "c", 1);
assert_suffix_min!("abcabc", "abca", 3);
assert_suffix_max!("abcabc", "abcabc", 3);
assert_suffix_min!("abcabcabc", "abcabca", 3);
assert_suffix_max!("abcabcabc", "abcabcabc", 3);
assert_suffix_min!("abczz", "a", 1);
assert_suffix_max!("abczz", "abczz", 5);
assert_suffix_min!("zzabc", "zza", 3);
assert_suffix_max!("zzabc", "zz", 1);
assert_suffix_min!("aaa", "aaa", 1);
assert_suffix_max!("aaa", "aaa", 1);
}
quickcheck! {
fn qc_suffix_forward_maximal(bytes: Vec<u8>) -> bool {
if bytes.is_empty() {
return true;
}
let (got, _) = get_suffix_forward(&bytes, SuffixKind::Maximal);
let expected = naive_maximal_suffix_forward(&bytes);
got == expected
}
fn qc_suffix_reverse_maximal(bytes: Vec<u8>) -> bool {
if bytes.is_empty() {
return true;
}
let (got, _) = get_suffix_reverse(&bytes, SuffixKind::Maximal);
let expected = naive_maximal_suffix_reverse(&bytes);
expected == got
}
}
}
#[cfg(test)]
mod simpletests {
use super::*;
pub(crate) fn twoway_find(
haystack: &[u8],
needle: &[u8],
) -> Option<usize> {
Forward::new(needle).find_general(None, haystack, needle)
}
pub(crate) fn twoway_rfind(
haystack: &[u8],
needle: &[u8],
) -> Option<usize> {
Reverse::new(needle).rfind_general(haystack, needle)
}
define_memmem_simple_tests!(twoway_find, twoway_rfind);
#[test]
fn regression_rev_small_period() {
let rfind = super::simpletests::twoway_rfind;
let haystack = "ababaz";
let needle = "abab";
assert_eq!(Some(0), rfind(haystack.as_bytes(), needle.as_bytes()));
}
}