use std::io;
use automaton::Automaton;
use buffer::Buffer;
use dfa::{self, DFA};
use error::Result;
use nfa::{self, NFA};
use packed;
use prefilter::PrefilterState;
use state_id::StateID;
use Match;
#[derive(Clone, Debug)]
pub struct AhoCorasick<S: StateID = usize> {
imp: Imp<S>,
match_kind: MatchKind,
}
impl AhoCorasick {
pub fn new<I, P>(patterns: I) -> AhoCorasick
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
AhoCorasickBuilder::new().build(patterns)
}
pub fn new_auto_configured<B>(patterns: &[B]) -> AhoCorasick
where
B: AsRef<[u8]>,
{
AhoCorasickBuilder::new().auto_configure(patterns).build(patterns)
}
}
impl<S: StateID> AhoCorasick<S> {
pub fn is_match<B: AsRef<[u8]>>(&self, haystack: B) -> bool {
self.earliest_find(haystack).is_some()
}
pub fn earliest_find<B: AsRef<[u8]>>(&self, haystack: B) -> Option<Match> {
let mut prestate = PrefilterState::new(self.max_pattern_len());
let mut start = self.imp.start_state();
self.imp.earliest_find_at(
&mut prestate,
haystack.as_ref(),
0,
&mut start,
)
}
pub fn find<B: AsRef<[u8]>>(&self, haystack: B) -> Option<Match> {
let mut prestate = PrefilterState::new(self.max_pattern_len());
self.imp.find_at_no_state(&mut prestate, haystack.as_ref(), 0)
}
pub fn find_iter<'a, 'b, B: ?Sized + AsRef<[u8]>>(
&'a self,
haystack: &'b B,
) -> FindIter<'a, 'b, S> {
FindIter::new(self, haystack.as_ref())
}
pub fn find_overlapping_iter<'a, 'b, B: ?Sized + AsRef<[u8]>>(
&'a self,
haystack: &'b B,
) -> FindOverlappingIter<'a, 'b, S> {
FindOverlappingIter::new(self, haystack.as_ref())
}
pub fn replace_all<B>(&self, haystack: &str, replace_with: &[B]) -> String
where
B: AsRef<str>,
{
assert_eq!(
replace_with.len(),
self.pattern_count(),
"replace_all requires a replacement for every pattern \
in the automaton"
);
let mut dst = String::with_capacity(haystack.len());
self.replace_all_with(haystack, &mut dst, |mat, _, dst| {
dst.push_str(replace_with[mat.pattern()].as_ref());
true
});
dst
}
pub fn replace_all_bytes<B>(
&self,
haystack: &[u8],
replace_with: &[B],
) -> Vec<u8>
where
B: AsRef<[u8]>,
{
assert_eq!(
replace_with.len(),
self.pattern_count(),
"replace_all_bytes requires a replacement for every pattern \
in the automaton"
);
let mut dst = Vec::with_capacity(haystack.len());
self.replace_all_with_bytes(haystack, &mut dst, |mat, _, dst| {
dst.extend(replace_with[mat.pattern()].as_ref());
true
});
dst
}
pub fn replace_all_with<F>(
&self,
haystack: &str,
dst: &mut String,
mut replace_with: F,
) where
F: FnMut(&Match, &str, &mut String) -> bool,
{
let mut last_match = 0;
for mat in self.find_iter(haystack) {
dst.push_str(&haystack[last_match..mat.start()]);
last_match = mat.end();
replace_with(&mat, &haystack[mat.start()..mat.end()], dst);
}
dst.push_str(&haystack[last_match..]);
}
pub fn replace_all_with_bytes<F>(
&self,
haystack: &[u8],
dst: &mut Vec<u8>,
mut replace_with: F,
) where
F: FnMut(&Match, &[u8], &mut Vec<u8>) -> bool,
{
let mut last_match = 0;
for mat in self.find_iter(haystack) {
dst.extend(&haystack[last_match..mat.start()]);
last_match = mat.end();
replace_with(&mat, &haystack[mat.start()..mat.end()], dst);
}
dst.extend(&haystack[last_match..]);
}
pub fn stream_find_iter<'a, R: io::Read>(
&'a self,
rdr: R,
) -> StreamFindIter<'a, R, S> {
StreamFindIter::new(self, rdr)
}
pub fn stream_replace_all<R, W, B>(
&self,
rdr: R,
wtr: W,
replace_with: &[B],
) -> io::Result<()>
where
R: io::Read,
W: io::Write,
B: AsRef<[u8]>,
{
assert_eq!(
replace_with.len(),
self.pattern_count(),
"stream_replace_all requires a replacement for every pattern \
in the automaton"
);
self.stream_replace_all_with(rdr, wtr, |mat, _, wtr| {
wtr.write_all(replace_with[mat.pattern()].as_ref())
})
}
pub fn stream_replace_all_with<R, W, F>(
&self,
rdr: R,
mut wtr: W,
mut replace_with: F,
) -> io::Result<()>
where
R: io::Read,
W: io::Write,
F: FnMut(&Match, &[u8], &mut W) -> io::Result<()>,
{
let mut it = StreamChunkIter::new(self, rdr);
while let Some(result) = it.next() {
let chunk = result?;
match chunk {
StreamChunk::NonMatch { bytes, .. } => {
wtr.write_all(bytes)?;
}
StreamChunk::Match { bytes, mat } => {
replace_with(&mat, bytes, &mut wtr)?;
}
}
}
Ok(())
}
pub fn match_kind(&self) -> &MatchKind {
self.imp.match_kind()
}
pub fn max_pattern_len(&self) -> usize {
self.imp.max_pattern_len()
}
pub fn pattern_count(&self) -> usize {
self.imp.pattern_count()
}
pub fn supports_overlapping(&self) -> bool {
self.match_kind.supports_overlapping()
}
pub fn supports_stream(&self) -> bool {
self.match_kind.supports_stream()
}
pub fn heap_bytes(&self) -> usize {
match self.imp {
Imp::NFA(ref nfa) => nfa.heap_bytes(),
Imp::DFA(ref dfa) => dfa.heap_bytes(),
}
}
}
#[derive(Clone, Debug)]
enum Imp<S: StateID> {
NFA(NFA<S>),
DFA(DFA<S>),
}
impl<S: StateID> Imp<S> {
fn match_kind(&self) -> &MatchKind {
match *self {
Imp::NFA(ref nfa) => nfa.match_kind(),
Imp::DFA(ref dfa) => dfa.match_kind(),
}
}
fn start_state(&self) -> S {
match *self {
Imp::NFA(ref nfa) => nfa.start_state(),
Imp::DFA(ref dfa) => dfa.start_state(),
}
}
fn max_pattern_len(&self) -> usize {
match *self {
Imp::NFA(ref nfa) => nfa.max_pattern_len(),
Imp::DFA(ref dfa) => dfa.max_pattern_len(),
}
}
fn pattern_count(&self) -> usize {
match *self {
Imp::NFA(ref nfa) => nfa.pattern_count(),
Imp::DFA(ref dfa) => dfa.pattern_count(),
}
}
#[inline(always)]
fn overlapping_find_at(
&self,
prestate: &mut PrefilterState,
haystack: &[u8],
at: usize,
state_id: &mut S,
match_index: &mut usize,
) -> Option<Match> {
match *self {
Imp::NFA(ref nfa) => nfa.overlapping_find_at(
prestate,
haystack,
at,
state_id,
match_index,
),
Imp::DFA(ref dfa) => dfa.overlapping_find_at(
prestate,
haystack,
at,
state_id,
match_index,
),
}
}
#[inline(always)]
fn earliest_find_at(
&self,
prestate: &mut PrefilterState,
haystack: &[u8],
at: usize,
state_id: &mut S,
) -> Option<Match> {
match *self {
Imp::NFA(ref nfa) => {
nfa.earliest_find_at(prestate, haystack, at, state_id)
}
Imp::DFA(ref dfa) => {
dfa.earliest_find_at(prestate, haystack, at, state_id)
}
}
}
#[inline(always)]
fn find_at_no_state(
&self,
prestate: &mut PrefilterState,
haystack: &[u8],
at: usize,
) -> Option<Match> {
match *self {
Imp::NFA(ref nfa) => nfa.find_at_no_state(prestate, haystack, at),
Imp::DFA(ref dfa) => dfa.find_at_no_state(prestate, haystack, at),
}
}
}
#[derive(Debug)]
pub struct FindIter<'a, 'b, S: 'a + StateID> {
fsm: &'a Imp<S>,
prestate: PrefilterState,
haystack: &'b [u8],
pos: usize,
}
impl<'a, 'b, S: StateID> FindIter<'a, 'b, S> {
fn new(ac: &'a AhoCorasick<S>, haystack: &'b [u8]) -> FindIter<'a, 'b, S> {
let prestate = PrefilterState::new(ac.max_pattern_len());
FindIter { fsm: &ac.imp, prestate, haystack, pos: 0 }
}
}
impl<'a, 'b, S: StateID> Iterator for FindIter<'a, 'b, S> {
type Item = Match;
fn next(&mut self) -> Option<Match> {
if self.pos > self.haystack.len() {
return None;
}
let result = self.fsm.find_at_no_state(
&mut self.prestate,
self.haystack,
self.pos,
);
let mat = match result {
None => return None,
Some(mat) => mat,
};
if mat.end() == self.pos {
self.pos += 1;
} else {
self.pos = mat.end();
}
Some(mat)
}
}
#[derive(Debug)]
pub struct FindOverlappingIter<'a, 'b, S: 'a + StateID> {
fsm: &'a Imp<S>,
prestate: PrefilterState,
haystack: &'b [u8],
pos: usize,
last_match_end: usize,
state_id: S,
match_index: usize,
}
impl<'a, 'b, S: StateID> FindOverlappingIter<'a, 'b, S> {
fn new(
ac: &'a AhoCorasick<S>,
haystack: &'b [u8],
) -> FindOverlappingIter<'a, 'b, S> {
assert!(
ac.supports_overlapping(),
"automaton does not support overlapping searches"
);
let prestate = PrefilterState::new(ac.max_pattern_len());
FindOverlappingIter {
fsm: &ac.imp,
prestate,
haystack,
pos: 0,
last_match_end: 0,
state_id: ac.imp.start_state(),
match_index: 0,
}
}
}
impl<'a, 'b, S: StateID> Iterator for FindOverlappingIter<'a, 'b, S> {
type Item = Match;
fn next(&mut self) -> Option<Match> {
let result = self.fsm.overlapping_find_at(
&mut self.prestate,
self.haystack,
self.pos,
&mut self.state_id,
&mut self.match_index,
);
match result {
None => return None,
Some(m) => {
self.pos = m.end();
Some(m)
}
}
}
}
#[derive(Debug)]
pub struct StreamFindIter<'a, R, S: 'a + StateID> {
it: StreamChunkIter<'a, R, S>,
}
impl<'a, R: io::Read, S: StateID> StreamFindIter<'a, R, S> {
fn new(ac: &'a AhoCorasick<S>, rdr: R) -> StreamFindIter<'a, R, S> {
StreamFindIter { it: StreamChunkIter::new(ac, rdr) }
}
}
impl<'a, R: io::Read, S: StateID> Iterator for StreamFindIter<'a, R, S> {
type Item = io::Result<Match>;
fn next(&mut self) -> Option<io::Result<Match>> {
loop {
match self.it.next() {
None => return None,
Some(Err(err)) => return Some(Err(err)),
Some(Ok(StreamChunk::NonMatch { .. })) => {}
Some(Ok(StreamChunk::Match { mat, .. })) => {
return Some(Ok(mat));
}
}
}
}
}
#[derive(Debug)]
struct StreamChunkIter<'a, R, S: 'a + StateID> {
fsm: &'a Imp<S>,
prestate: PrefilterState,
rdr: R,
buf: Buffer,
state_id: S,
search_pos: usize,
absolute_pos: usize,
report_pos: usize,
pending_match: Option<Match>,
has_empty_match_at_end: bool,
}
#[derive(Debug)]
enum StreamChunk<'r> {
NonMatch { bytes: &'r [u8], start: usize },
Match { bytes: &'r [u8], mat: Match },
}
impl<'a, R: io::Read, S: StateID> StreamChunkIter<'a, R, S> {
fn new(ac: &'a AhoCorasick<S>, rdr: R) -> StreamChunkIter<'a, R, S> {
assert!(
ac.supports_stream(),
"stream searching is only supported for Standard match semantics"
);
let prestate = PrefilterState::new(ac.max_pattern_len());
let buf = Buffer::new(ac.imp.max_pattern_len());
let state_id = ac.imp.start_state();
StreamChunkIter {
fsm: &ac.imp,
prestate,
rdr,
buf,
state_id,
absolute_pos: 0,
report_pos: 0,
search_pos: 0,
pending_match: None,
has_empty_match_at_end: ac.is_match(""),
}
}
fn next<'r>(&'r mut self) -> Option<io::Result<StreamChunk<'r>>> {
loop {
if let Some(mut mat) = self.pending_match.take() {
let bytes = &self.buf.buffer()[mat.start()..mat.end()];
self.report_pos = mat.end();
mat = mat.increment(self.absolute_pos);
return Some(Ok(StreamChunk::Match { bytes, mat }));
}
if self.search_pos >= self.buf.len() {
if let Some(end) = self.unreported() {
let bytes = &self.buf.buffer()[self.report_pos..end];
let start = self.absolute_pos + self.report_pos;
self.report_pos = end;
return Some(Ok(StreamChunk::NonMatch { bytes, start }));
}
if self.buf.len() >= self.buf.min_buffer_len() {
self.report_pos -=
self.buf.len() - self.buf.min_buffer_len();
self.absolute_pos +=
self.search_pos - self.buf.min_buffer_len();
self.search_pos = self.buf.min_buffer_len();
self.buf.roll();
}
match self.buf.fill(&mut self.rdr) {
Err(err) => return Some(Err(err)),
Ok(false) => {
if self.report_pos < self.buf.len() {
let bytes = &self.buf.buffer()[self.report_pos..];
let start = self.absolute_pos + self.report_pos;
self.report_pos = self.buf.len();
let chunk = StreamChunk::NonMatch { bytes, start };
return Some(Ok(chunk));
} else {
if !self.has_empty_match_at_end {
return None;
}
self.has_empty_match_at_end = false;
}
}
Ok(true) => {}
}
}
let result = self.fsm.earliest_find_at(
&mut self.prestate,
self.buf.buffer(),
self.search_pos,
&mut self.state_id,
);
match result {
None => {
self.search_pos = self.buf.len();
}
Some(mat) => {
self.state_id = self.fsm.start_state();
if mat.end() == self.search_pos {
self.search_pos += 1;
} else {
self.search_pos = mat.end();
}
self.pending_match = Some(mat.clone());
if self.report_pos < mat.start() {
let bytes =
&self.buf.buffer()[self.report_pos..mat.start()];
let start = self.absolute_pos + self.report_pos;
self.report_pos = mat.start();
let chunk = StreamChunk::NonMatch { bytes, start };
return Some(Ok(chunk));
}
}
}
}
}
fn unreported(&self) -> Option<usize> {
let end = self.search_pos.saturating_sub(self.buf.min_buffer_len());
if self.report_pos < end {
Some(end)
} else {
None
}
}
}
#[derive(Clone, Debug)]
pub struct AhoCorasickBuilder {
nfa_builder: nfa::Builder,
dfa_builder: dfa::Builder,
dfa: bool,
}
impl Default for AhoCorasickBuilder {
fn default() -> AhoCorasickBuilder {
AhoCorasickBuilder::new()
}
}
impl AhoCorasickBuilder {
pub fn new() -> AhoCorasickBuilder {
AhoCorasickBuilder {
nfa_builder: nfa::Builder::new(),
dfa_builder: dfa::Builder::new(),
dfa: false,
}
}
pub fn build<I, P>(&self, patterns: I) -> AhoCorasick
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
self.build_with_size::<usize, I, P>(patterns)
.expect("usize state ID type should always work")
}
pub fn build_with_size<S, I, P>(
&self,
patterns: I,
) -> Result<AhoCorasick<S>>
where
S: StateID,
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
let nfa = self.nfa_builder.build(patterns)?;
let match_kind = nfa.match_kind().clone();
let imp = if self.dfa {
let dfa = self.dfa_builder.build(&nfa)?;
Imp::DFA(dfa)
} else {
Imp::NFA(nfa)
};
Ok(AhoCorasick { imp, match_kind })
}
pub fn auto_configure<B: AsRef<[u8]>>(
&mut self,
patterns: &[B],
) -> &mut AhoCorasickBuilder {
if patterns.len() <= 100 {
self.dfa(true).byte_classes(false);
} else if patterns.len() <= 5000 {
self.dfa(true);
}
self
}
pub fn match_kind(&mut self, kind: MatchKind) -> &mut AhoCorasickBuilder {
self.nfa_builder.match_kind(kind);
self
}
pub fn anchored(&mut self, yes: bool) -> &mut AhoCorasickBuilder {
self.nfa_builder.anchored(yes);
self
}
pub fn ascii_case_insensitive(
&mut self,
yes: bool,
) -> &mut AhoCorasickBuilder {
self.nfa_builder.ascii_case_insensitive(yes);
self
}
pub fn dense_depth(&mut self, depth: usize) -> &mut AhoCorasickBuilder {
self.nfa_builder.dense_depth(depth);
self
}
pub fn dfa(&mut self, yes: bool) -> &mut AhoCorasickBuilder {
self.dfa = yes;
self
}
pub fn prefilter(&mut self, yes: bool) -> &mut AhoCorasickBuilder {
self.nfa_builder.prefilter(yes);
self
}
pub fn byte_classes(&mut self, yes: bool) -> &mut AhoCorasickBuilder {
self.dfa_builder.byte_classes(yes);
self
}
pub fn premultiply(&mut self, yes: bool) -> &mut AhoCorasickBuilder {
self.dfa_builder.premultiply(yes);
self
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MatchKind {
Standard,
LeftmostFirst,
LeftmostLongest,
#[doc(hidden)]
__Nonexhaustive,
}
impl Default for MatchKind {
fn default() -> MatchKind {
MatchKind::Standard
}
}
impl MatchKind {
fn supports_overlapping(&self) -> bool {
self.is_standard()
}
fn supports_stream(&self) -> bool {
self.is_standard()
}
pub(crate) fn is_standard(&self) -> bool {
*self == MatchKind::Standard
}
pub(crate) fn is_leftmost(&self) -> bool {
*self == MatchKind::LeftmostFirst
|| *self == MatchKind::LeftmostLongest
}
pub(crate) fn is_leftmost_first(&self) -> bool {
*self == MatchKind::LeftmostFirst
}
pub(crate) fn as_packed(&self) -> Option<packed::MatchKind> {
match *self {
MatchKind::Standard => None,
MatchKind::LeftmostFirst => Some(packed::MatchKind::LeftmostFirst),
MatchKind::LeftmostLongest => {
Some(packed::MatchKind::LeftmostLongest)
}
MatchKind::__Nonexhaustive => unreachable!(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn oibits() {
use std::panic::{RefUnwindSafe, UnwindSafe};
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
fn assert_unwind_safe<T: RefUnwindSafe + UnwindSafe>() {}
assert_send::<AhoCorasick>();
assert_sync::<AhoCorasick>();
assert_unwind_safe::<AhoCorasick>();
assert_send::<AhoCorasickBuilder>();
assert_sync::<AhoCorasickBuilder>();
assert_unwind_safe::<AhoCorasickBuilder>();
}
}