use std::io::Result as IoResult;
use std::io::Read;
use std::io::Error as IoError;
use std::io::ErrorKind;
use std::fmt;
use std::error::Error;
pub struct Decoder<R> {
source: R,
remaining_chunks_size: Option<usize>,
}
impl<R> Decoder<R> where R: Read {
pub fn new(source: R) -> Decoder<R> {
Decoder {
source: source,
remaining_chunks_size: None,
}
}
}
impl<R> Read for Decoder<R> where R: Read {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
if self.remaining_chunks_size.is_none() {
let mut chunk_size = Vec::new();
loop {
let byte = match self.source.by_ref().bytes().next() {
Some(b) => try!(b),
None => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
};
if byte == b'\r' {
break;
}
chunk_size.push(byte);
}
match self.source.by_ref().bytes().next() {
Some(Ok(b'\n')) => (),
_ => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError)),
}
let chunk_size = match String::from_utf8(chunk_size) {
Ok(c) => c,
Err(_) => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError))
};
let chunk_size = match usize::from_str_radix(&chunk_size, 16) {
Ok(c) => c,
Err(_) => return Err(IoError::new(ErrorKind::InvalidInput, DecoderError))
};
if chunk_size == 0 {
if try!(self.source.by_ref().bytes().next().unwrap_or(Ok(0))) != b'\r' {
return Err(IoError::new(ErrorKind::InvalidInput, DecoderError));
}
if try!(self.source.by_ref().bytes().next().unwrap_or(Ok(0))) != b'\n' {
return Err(IoError::new(ErrorKind::InvalidInput, DecoderError));
}
return Ok(0);
}
self.remaining_chunks_size = Some(chunk_size);
return self.read(buf);
}
assert!(self.remaining_chunks_size.is_some());
if buf.len() < *self.remaining_chunks_size.as_ref().unwrap() {
let read = try!(self.source.read(buf));
*self.remaining_chunks_size.as_mut().unwrap() -= read;
return Ok(read);
}
assert!(buf.len() >= *self.remaining_chunks_size.as_ref().unwrap());
let remaining_chunks_size = *self.remaining_chunks_size.as_ref().unwrap();
let buf = &mut buf[.. remaining_chunks_size];
let read = try!(self.source.read(buf));
*self.remaining_chunks_size.as_mut().unwrap() -= read;
if read == remaining_chunks_size {
self.remaining_chunks_size = None;
if try!(self.source.by_ref().bytes().next().unwrap_or(Ok(0))) != b'\r' {
return Err(IoError::new(ErrorKind::InvalidInput, DecoderError));
}
if try!(self.source.by_ref().bytes().next().unwrap_or(Ok(0))) != b'\n' {
return Err(IoError::new(ErrorKind::InvalidInput, DecoderError));
}
}
return Ok(read);
}
}
#[derive(Debug, Copy, Clone)]
struct DecoderError;
impl fmt::Display for DecoderError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(fmt, "Error while decoding chunks")
}
}
impl Error for DecoderError {
fn description(&self) -> &str {
"Error while decoding chunks"
}
}
#[cfg(test)]
mod test {
use super::Decoder;
use std::io;
use std::io::Read;
#[test]
fn test_valid_chunk_decode() {
let source = io::Cursor::new("3\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n\r\n".to_string().into_bytes());
let mut decoded = Decoder::new(source);
let mut string = String::new();
decoded.read_to_string(&mut string).unwrap();
assert_eq!(string, "hello world!!!");
}
#[test]
#[should_panic]
fn invalid_input1() {
let source = io::Cursor::new("2\r\nhel\r\nb\r\nlo world!!!\r\n0\r\n".to_string().into_bytes());
let mut decoded = Decoder::new(source);
let mut string = String::new();
decoded.read_to_string(&mut string).unwrap();
}
#[test]
#[should_panic]
fn invalid_input2() {
let source = io::Cursor::new("3\rhel\r\nb\r\nlo world!!!\r\n0\r\n".to_string().into_bytes());
let mut decoded = Decoder::new(source);
let mut string = String::new();
decoded.read_to_string(&mut string).unwrap();
}
}