[go: up one dir, main page]

xml/
util.rs

1use std::fmt;
2use std::io::{self, Read};
3use std::str::{self, FromStr};
4
5#[derive(Debug)]
6pub(crate) enum CharReadError {
7    UnexpectedEof,
8    Utf8(str::Utf8Error),
9    Io(io::Error),
10}
11
12impl From<str::Utf8Error> for CharReadError {
13    #[cold]
14    fn from(e: str::Utf8Error) -> Self {
15        Self::Utf8(e)
16    }
17}
18
19impl From<io::Error> for CharReadError {
20    #[cold]
21    fn from(e: io::Error) -> Self {
22        Self::Io(e)
23    }
24}
25
26impl fmt::Display for CharReadError {
27    #[cold]
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        use self::CharReadError::{Io, UnexpectedEof, Utf8};
30        match *self {
31            UnexpectedEof => write!(f, "unexpected end of stream"),
32            Utf8(ref e) => write!(f, "UTF-8 decoding error: {e}"),
33            Io(ref e) => write!(f, "I/O error: {e}"),
34        }
35    }
36}
37
38/// Character encoding used for parsing
39#[derive(Debug, Copy, Clone, Eq, PartialEq)]
40#[non_exhaustive]
41pub enum Encoding {
42    /// Explicitly UTF-8 only
43    Utf8,
44    /// UTF-8 fallback, but can be any 8-bit encoding
45    Default,
46    /// ISO-8859-1
47    Latin1,
48    /// US-ASCII
49    Ascii,
50    /// Big-Endian
51    Utf16Be,
52    /// Little-Endian
53    Utf16Le,
54    /// Unknown endianness yet, will be sniffed
55    Utf16,
56    /// Not determined yet, may be sniffed to be anything
57    Unknown,
58}
59
60// Rustc inlines eq_ignore_ascii_case and creates kilobytes of code!
61#[inline(never)]
62fn icmp(lower: &str, varcase: &str) -> bool {
63    lower.bytes().zip(varcase.bytes()).all(|(l, v)| l == v.to_ascii_lowercase())
64}
65
66impl FromStr for Encoding {
67    type Err = &'static str;
68
69    fn from_str(val: &str) -> Result<Self, Self::Err> {
70        if ["utf-8", "utf8"].into_iter().any(move |label| icmp(label, val)) {
71            Ok(Self::Utf8)
72        } else if ["iso-8859-1", "latin1"].into_iter().any(move |label| icmp(label, val)) {
73            Ok(Self::Latin1)
74        } else if ["utf-16", "utf16"].into_iter().any(move |label| icmp(label, val)) {
75            Ok(Self::Utf16)
76        } else if ["ascii", "us-ascii"].into_iter().any(move |label| icmp(label, val)) {
77            Ok(Self::Ascii)
78        } else {
79            Err("unknown encoding name")
80        }
81    }
82}
83
84impl fmt::Display for Encoding {
85    #[cold]
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.write_str(match self {
88            Self::Utf8 |
89            Self::Default => "UTF-8",
90            Self::Latin1 => "ISO-8859-1",
91            Self::Ascii => "US-ASCII",
92            Self::Utf16Be |
93            Self::Utf16Le |
94            Self::Utf16 => "UTF-16",
95            Self::Unknown => "(unknown)",
96        })
97    }
98}
99
100pub(crate) struct CharReader {
101    pub encoding: Encoding,
102}
103
104impl CharReader {
105    pub const fn new() -> Self {
106        Self { encoding: Encoding::Unknown }
107    }
108
109    #[inline]
110    pub fn next_char_from<R: Read>(&mut self, source: &mut R) -> Result<Option<char>, CharReadError> {
111        let mut bytes = source.bytes();
112        const MAX_CODEPOINT_LEN: usize = 4;
113
114        let mut buf = [0u8; MAX_CODEPOINT_LEN];
115        let mut pos = 0;
116        while pos < MAX_CODEPOINT_LEN {
117            let next = match bytes.next() {
118                Some(Ok(b)) => b,
119                Some(Err(e)) => return Err(e.into()),
120                None if pos == 0 => return Ok(None),
121                None => return Err(CharReadError::UnexpectedEof),
122            };
123
124            match self.encoding {
125                Encoding::Utf8 | Encoding::Default => {
126                    // fast path for ASCII subset
127                    if pos == 0 && next.is_ascii() {
128                        return Ok(Some(next.into()));
129                    }
130
131                    buf[pos] = next;
132                    pos += 1;
133
134                    match str::from_utf8(&buf[..pos]) {
135                        Ok(s) => return Ok(s.chars().next()), // always Some(..)
136                        Err(_) if pos < MAX_CODEPOINT_LEN => continue,
137                        Err(e) => return Err(e.into()),
138                    }
139                },
140                Encoding::Latin1 => {
141                    return Ok(Some(next.into()));
142                },
143                Encoding::Ascii => {
144                    return if next.is_ascii() {
145                        Ok(Some(next.into()))
146                    } else {
147                        Err(CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, "char is not ASCII")))
148                    };
149                },
150                Encoding::Unknown | Encoding::Utf16 => {
151                    buf[pos] = next;
152                    pos += 1;
153                    if let Some(value) = self.sniff_bom(&buf[..pos], &mut pos) {
154                        return value;
155                    }
156                },
157                Encoding::Utf16Be => {
158                    buf[pos] = next;
159                    pos += 1;
160                    if pos == 2 {
161                        if let Some(Ok(c)) = char::decode_utf16([u16::from_be_bytes(buf[..2].try_into().unwrap())]).next() {
162                            return Ok(Some(c));
163                        }
164                    } else if pos == 4 {
165                        return Self::surrogate([u16::from_be_bytes(buf[..2].try_into().unwrap()), u16::from_be_bytes(buf[2..4].try_into().unwrap())]);
166                    }
167                },
168                Encoding::Utf16Le => {
169                    buf[pos] = next;
170                    pos += 1;
171                    if pos == 2 {
172                        if let Some(Ok(c)) = char::decode_utf16([u16::from_le_bytes(buf[..2].try_into().unwrap())]).next() {
173                            return Ok(Some(c));
174                        }
175                    } else if pos == 4 {
176                        return Self::surrogate([u16::from_le_bytes(buf[..2].try_into().unwrap()), u16::from_le_bytes(buf[2..4].try_into().unwrap())]);
177                    }
178                },
179            }
180        }
181        Err(CharReadError::Io(io::ErrorKind::InvalidData.into()))
182    }
183
184    #[cold]
185    fn sniff_bom(&mut self, buf: &[u8], pos: &mut usize) -> Option<Result<Option<char>, CharReadError>> {
186        // sniff BOM
187        if buf.len() <= 3 && [0xEF, 0xBB, 0xBF].starts_with(buf) {
188            if buf.len() == 3 && self.encoding != Encoding::Utf16 {
189                *pos = 0;
190                self.encoding = Encoding::Utf8;
191            }
192        } else if buf.len() <= 2 && [0xFE, 0xFF].starts_with(buf) {
193            if buf.len() == 2 {
194                *pos = 0;
195                self.encoding = Encoding::Utf16Be;
196            }
197        } else if buf.len() <= 2 && [0xFF, 0xFE].starts_with(buf) {
198            if buf.len() == 2 {
199                *pos = 0;
200                self.encoding = Encoding::Utf16Le;
201            }
202        } else if buf.len() == 1 && self.encoding == Encoding::Utf16 {
203            // sniff ASCII char in UTF-16
204            self.encoding = if buf[0] == 0 { Encoding::Utf16Be } else { Encoding::Utf16Le };
205        } else {
206            // UTF-8 is the default, but XML decl can change it to other 8-bit encoding
207            self.encoding = Encoding::Default;
208            if buf.len() == 1 && buf[0].is_ascii() {
209                return Some(Ok(Some(buf[0].into())));
210            }
211        }
212        None
213    }
214
215    fn surrogate(buf: [u16; 2]) -> Result<Option<char>, CharReadError> {
216        char::decode_utf16(buf).next().transpose()
217            .map_err(|e| CharReadError::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::{CharReadError, CharReader, Encoding};
224
225    #[test]
226    fn test_next_char_from() {
227        use std::io;
228
229        let mut bytes: &[u8] = b"correct";    // correct ASCII
230        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('c'));
231
232        let mut bytes: &[u8] = b"\xEF\xBB\xBF\xE2\x80\xA2!";  // BOM
233        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('•'));
234
235        let mut bytes: &[u8] = b"\xEF\xBB\xBFx123";  // BOM
236        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('x'));
237
238        let mut bytes: &[u8] = b"\xEF\xBB\xBF";  // Nothing after BOM
239        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
240
241        let mut bytes: &[u8] = b"\xEF\xBB";  // Nothing after BO
242        assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
243
244        let mut bytes: &[u8] = b"\xEF\xBB\x42";  // Nothing after BO
245        assert!(CharReader::new().next_char_from(&mut bytes).is_err());
246
247        let mut bytes: &[u8] = b"\xFE\xFF\x00\x42";  // UTF-16
248        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
249
250        let mut bytes: &[u8] = b"\xFF\xFE\x42\x00";  // UTF-16
251        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('B'));
252
253        let mut bytes: &[u8] = b"\xFF\xFE";  // UTF-16
254        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
255
256        let mut bytes: &[u8] = b"\xFF\xFE\x00";  // UTF-16
257        assert!(matches!(CharReader::new().next_char_from(&mut bytes), Err(CharReadError::UnexpectedEof)));
258
259        let mut bytes: &[u8] = "правильно".as_bytes();  // correct BMP
260        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('п'));
261
262        let mut bytes: &[u8] = "правильно".as_bytes();
263        assert_eq!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).unwrap(), Some('킿'));
264
265        let mut bytes: &[u8] = "правильно".as_bytes();
266        assert_eq!(CharReader { encoding: Encoding::Utf16Le }.next_char_from(&mut bytes).unwrap(), Some('뿐'));
267
268        let mut bytes: &[u8] = b"\xD8\xD8\x80";
269        assert!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).is_err());
270
271        let mut bytes: &[u8] = b"\x00\x42";
272        assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
273
274        let mut bytes: &[u8] = b"\x42\x00";
275        assert_eq!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).unwrap(), Some('B'));
276
277        let mut bytes: &[u8] = &[0xEF, 0xBB, 0xBF, 0xFF, 0xFF];
278        assert!(CharReader { encoding: Encoding::Utf16 }.next_char_from(&mut bytes).is_err());
279
280        let mut bytes: &[u8] = b"\x00";
281        assert!(CharReader { encoding: Encoding::Utf16Be }.next_char_from(&mut bytes).is_err());
282
283        let mut bytes: &[u8] = "😊".as_bytes();          // correct non-BMP
284        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), Some('😊'));
285
286        let mut bytes: &[u8] = b"";                     // empty
287        assert_eq!(CharReader::new().next_char_from(&mut bytes).unwrap(), None);
288
289        let mut bytes: &[u8] = b"\xf0\x9f\x98";         // incomplete code point
290        match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
291            super::CharReadError::UnexpectedEof => {},
292            e => panic!("Unexpected result: {e:?}")
293        }
294
295        let mut bytes: &[u8] = b"\xff\x9f\x98\x32";     // invalid code point
296        match CharReader::new().next_char_from(&mut bytes).unwrap_err() {
297            super::CharReadError::Utf8(_) => {},
298            e => panic!("Unexpected result: {e:?}"),
299        }
300
301        // error during read
302        struct ErrorReader;
303        impl io::Read for ErrorReader {
304            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
305                Err(io::Error::new(io::ErrorKind::Other, "test error"))
306            }
307        }
308
309        let mut r = ErrorReader;
310        match CharReader::new().next_char_from(&mut r).unwrap_err() {
311            super::CharReadError::Io(ref e) if e.kind() == io::ErrorKind::Other &&
312                                               e.to_string().contains("test error") => {},
313            e => panic!("Unexpected result: {e:?}")
314        }
315    }
316}