[go: up one dir, main page]

ucd-trie 0.1.7

A trie for storing Unicode codepoint sets and maps.
Documentation
use std::borrow::Borrow;
use std::collections::HashMap;
use std::error;
use std::fmt;
use std::io;
use std::result;

use super::{TrieSetSlice, CHUNK_SIZE};

// This implementation was pretty much cribbed from raphlinus' contribution
// to the standard library: https://github.com/rust-lang/rust/pull/33098/files
//
// The fundamental principle guiding this implementation is to take advantage
// of the fact that similar Unicode codepoints are often grouped together, and
// that most boolean Unicode properties are quite sparse over the entire space
// of Unicode codepoints.
//
// To do this, we represent sets using something like a trie (which gives us
// prefix compression). The "final" states of the trie are embedded in leaves
// or "chunks," where each chunk is a 64 bit integer. Each bit position of the
// integer corresponds to whether a particular codepoint is in the set or not.
// These chunks are not just a compact representation of the final states of
// the trie, but are also a form of suffix compression. In particular, if
// multiple ranges of 64 contiguous codepoints map have the same set membership
// ordering, then they all map to the exact same chunk in the trie.
//
// We organize this structure by partitioning the space of Unicode codepoints
// into three disjoint sets. The first set corresponds to codepoints
// [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000).
// These partitions conveniently correspond to the space of 1 or 2 byte UTF-8
// encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded
// codepoints, respectively.
//
// Each partition has its own tree with its own root. The first partition is
// the simplest, since the tree is completely flat. In particular, to determine
// the set membership of a Unicode codepoint (that is less than `0x800`), we
// do the following (where `cp` is the codepoint we're testing):
//
//     let chunk_address = cp >> 6;
//     let chunk_bit = cp & 0b111111;
//     let chunk = tree1[cp >> 6];
//     let is_member = 1 == ((chunk >> chunk_bit) & 1);
//
// We do something similar for the second partition:
//
//     // we subtract 0x20 since (0x800 >> 6) == 0x20.
//     let child_address = (cp >> 6) - 0x20;
//     let chunk_address = tree2_level1[child_address];
//     let chunk_bit = cp & 0b111111;
//     let chunk = tree2_level2[chunk_address];
//     let is_member = 1 == ((chunk >> chunk_bit) & 1);
//
// And so on for the third partition.
//
// Note that as a special case, if the second or third partitions are empty,
// then the trie will store empty slices for those levels. The `contains`
// check knows to return `false` in those cases.

const CHUNKS: usize = 0x110000 / CHUNK_SIZE;

/// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`.
pub type Result<T> = result::Result<T, Error>;

/// An error that can occur during construction of a trie.
#[derive(Clone, Debug)]
pub enum Error {
    /// This error is returned when an invalid codepoint is given to
    /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that
    /// is greater than `0x10FFFF`.
    InvalidCodepoint(u32),
    /// This error is returned when a set of Unicode codepoints could not be
    /// sufficiently compressed into the trie provided by this crate. There is
    /// no work-around for this error at this time.
    GaveUp,
}

impl error::Error for Error {}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match *self {
            Error::InvalidCodepoint(cp) => write!(
                f,
                "could not construct trie set containing an \
                 invalid Unicode codepoint: 0x{:X}",
                cp
            ),
            Error::GaveUp => {
                write!(f, "could not compress codepoint set into a trie")
            }
        }
    }
}

impl From<Error> for io::Error {
    fn from(err: Error) -> io::Error {
        io::Error::new(io::ErrorKind::Other, err)
    }
}

/// An owned trie set.
#[derive(Clone)]
pub struct TrieSetOwned {
    tree1_level1: Vec<u64>,
    tree2_level1: Vec<u8>,
    tree2_level2: Vec<u64>,
    tree3_level1: Vec<u8>,
    tree3_level2: Vec<u8>,
    tree3_level3: Vec<u64>,
}

impl fmt::Debug for TrieSetOwned {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "TrieSetOwned(...)")
    }
}

impl TrieSetOwned {
    fn new(all: &[bool]) -> Result<TrieSetOwned> {
        let mut bitvectors = Vec::with_capacity(CHUNKS);
        for i in 0..CHUNKS {
            let mut bitvector = 0u64;
            for j in 0..CHUNK_SIZE {
                if all[i * CHUNK_SIZE + j] {
                    bitvector |= 1 << j;
                }
            }
            bitvectors.push(bitvector);
        }

        let tree1_level1 =
            bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect();

        let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves(
            &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE],
        )?;
        if tree2_level2.len() == 1 && tree2_level2[0] == 0 {
            tree2_level1.clear();
            tree2_level2.clear();
        }

        let (mid, mut tree3_level3) = compress_postfix_leaves(
            &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE],
        )?;
        let (mut tree3_level1, mut tree3_level2) =
            compress_postfix_mid(&mid, 64)?;
        if tree3_level3.len() == 1 && tree3_level3[0] == 0 {
            tree3_level1.clear();
            tree3_level2.clear();
            tree3_level3.clear();
        }

        Ok(TrieSetOwned {
            tree1_level1,
            tree2_level1,
            tree2_level2,
            tree3_level1,
            tree3_level2,
            tree3_level3,
        })
    }

    /// Create a new trie set from a set of Unicode scalar values.
    ///
    /// This returns an error if a set could not be sufficiently compressed to
    /// fit into a trie.
    pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned>
    where
        I: IntoIterator<Item = C>,
        C: Borrow<char>,
    {
        let mut all = vec![false; 0x110000];
        for s in scalars {
            all[*s.borrow() as usize] = true;
        }
        TrieSetOwned::new(&all)
    }

    /// Create a new trie set from a set of Unicode scalar values.
    ///
    /// This returns an error if a set could not be sufficiently compressed to
    /// fit into a trie. This also returns an error if any of the given
    /// codepoints are greater than `0x10FFFF`.
    pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned>
    where
        I: IntoIterator<Item = C>,
        C: Borrow<u32>,
    {
        let mut all = vec![false; 0x110000];
        for cp in codepoints {
            let cp = *cp.borrow();
            if cp > 0x10FFFF {
                return Err(Error::InvalidCodepoint(cp));
            }
            all[cp as usize] = true;
        }
        TrieSetOwned::new(&all)
    }

    /// Return this set as a slice.
    #[inline(always)]
    pub fn as_slice(&self) -> TrieSetSlice<'_> {
        TrieSetSlice {
            tree1_level1: &self.tree1_level1,
            tree2_level1: &self.tree2_level1,
            tree2_level2: &self.tree2_level2,
            tree3_level1: &self.tree3_level1,
            tree3_level2: &self.tree3_level2,
            tree3_level3: &self.tree3_level3,
        }
    }

    /// Returns true if and only if the given Unicode scalar value is in this
    /// set.
    pub fn contains_char(&self, c: char) -> bool {
        self.as_slice().contains_char(c)
    }

    /// Returns true if and only if the given codepoint is in this set.
    ///
    /// If the given value exceeds the codepoint range (i.e., it's greater
    /// than `0x10FFFF`), then this returns false.
    pub fn contains_u32(&self, cp: u32) -> bool {
        self.as_slice().contains_u32(cp)
    }
}

fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> {
    let mut root = vec![];
    let mut children = vec![];
    let mut bychild = HashMap::new();
    for &chunk in chunks {
        if !bychild.contains_key(&chunk) {
            let start = bychild.len();
            if start > ::std::u8::MAX as usize {
                return Err(Error::GaveUp);
            }
            bychild.insert(chunk, start as u8);
            children.push(chunk);
        }
        root.push(bychild[&chunk]);
    }
    Ok((root, children))
}

fn compress_postfix_mid(
    chunks: &[u8],
    chunk_size: usize,
) -> Result<(Vec<u8>, Vec<u8>)> {
    let mut root = vec![];
    let mut children = vec![];
    let mut bychild = HashMap::new();
    for i in 0..(chunks.len() / chunk_size) {
        let chunk = &chunks[i * chunk_size..(i + 1) * chunk_size];
        if !bychild.contains_key(chunk) {
            let start = bychild.len();
            if start > ::std::u8::MAX as usize {
                return Err(Error::GaveUp);
            }
            bychild.insert(chunk, start as u8);
            children.extend(chunk);
        }
        root.push(bychild[chunk]);
    }
    Ok((root, children))
}

#[cfg(test)]
mod tests {
    use super::TrieSetOwned;
    use crate::general_category;
    use std::collections::HashSet;

    fn mk(scalars: &[char]) -> TrieSetOwned {
        TrieSetOwned::from_scalars(scalars).unwrap()
    }

    fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> {
        let mut set = vec![];
        for &(start, end) in ranges {
            for cp in start..end + 1 {
                set.push(cp);
            }
        }
        set
    }

    #[test]
    fn set1() {
        let set = mk(&['a']);
        assert!(set.contains_char('a'));
        assert!(!set.contains_char('b'));
        assert!(!set.contains_char('β'));
        assert!(!set.contains_char(''));
        assert!(!set.contains_char('😼'));
    }

    #[test]
    fn set_combined() {
        let set = mk(&['a', 'b', 'β', '', '😼']);
        assert!(set.contains_char('a'));
        assert!(set.contains_char('b'));
        assert!(set.contains_char('β'));
        assert!(set.contains_char(''));
        assert!(set.contains_char('😼'));

        assert!(!set.contains_char('c'));
        assert!(!set.contains_char('θ'));
        assert!(!set.contains_char(''));
        assert!(!set.contains_char('🐲'));
    }

    // Basic tests on all of the general category sets. We check that
    // membership is correct on every Unicode codepoint... because we can.

    macro_rules! category_test {
        ($name:ident, $ranges:ident) => {
            #[test]
            fn $name() {
                let set = ranges_to_set(general_category::$ranges);
                let hashset: HashSet<u32> = set.iter().cloned().collect();
                let trie = TrieSetOwned::from_codepoints(&set).unwrap();
                for cp in 0..0x110000 {
                    assert!(trie.contains_u32(cp) == hashset.contains(&cp));
                }
                // Test that an invalid codepoint is treated correctly.
                assert!(!trie.contains_u32(0x110000));
                assert!(!hashset.contains(&0x110000));
            }
        };
    }

    category_test!(gencat_cased_letter, CASED_LETTER);
    category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION);
    category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION);
    category_test!(gencat_control, CONTROL);
    category_test!(gencat_currency_symbol, CURRENCY_SYMBOL);
    category_test!(gencat_dash_punctuation, DASH_PUNCTUATION);
    category_test!(gencat_decimal_number, DECIMAL_NUMBER);
    category_test!(gencat_enclosing_mark, ENCLOSING_MARK);
    category_test!(gencat_final_punctuation, FINAL_PUNCTUATION);
    category_test!(gencat_format, FORMAT);
    category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION);
    category_test!(gencat_letter, LETTER);
    category_test!(gencat_letter_number, LETTER_NUMBER);
    category_test!(gencat_line_separator, LINE_SEPARATOR);
    category_test!(gencat_lowercase_letter, LOWERCASE_LETTER);
    category_test!(gencat_math_symbol, MATH_SYMBOL);
    category_test!(gencat_mark, MARK);
    category_test!(gencat_modifier_letter, MODIFIER_LETTER);
    category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL);
    category_test!(gencat_nonspacing_mark, NONSPACING_MARK);
    category_test!(gencat_number, NUMBER);
    category_test!(gencat_open_punctuation, OPEN_PUNCTUATION);
    category_test!(gencat_other, OTHER);
    category_test!(gencat_other_letter, OTHER_LETTER);
    category_test!(gencat_other_number, OTHER_NUMBER);
    category_test!(gencat_other_punctuation, OTHER_PUNCTUATION);
    category_test!(gencat_other_symbol, OTHER_SYMBOL);
    category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR);
    category_test!(gencat_private_use, PRIVATE_USE);
    category_test!(gencat_punctuation, PUNCTUATION);
    category_test!(gencat_separator, SEPARATOR);
    category_test!(gencat_space_separator, SPACE_SEPARATOR);
    category_test!(gencat_spacing_mark, SPACING_MARK);
    category_test!(gencat_surrogate, SURROGATE);
    category_test!(gencat_symbol, SYMBOL);
    category_test!(gencat_titlecase_letter, TITLECASE_LETTER);
    category_test!(gencat_unassigned, UNASSIGNED);
    category_test!(gencat_uppercase_letter, UPPERCASE_LETTER);
}