use std::cmp;
use std::fmt;
use std::io;
use crate::bytes;
use crate::compress::Encoder;
use crate::crc32::CheckSummer;
use crate::decompress::{decompress_len, Decoder};
use crate::error::Error;
use crate::frame::{
compress_frame, ChunkType, CHUNK_HEADER_AND_CRC_SIZE,
MAX_COMPRESS_BLOCK_SIZE, STREAM_BODY, STREAM_IDENTIFIER,
};
use crate::MAX_BLOCK_SIZE;
const MAX_READ_FRAME_ENCODER_BLOCK_SIZE: usize = STREAM_IDENTIFIER.len()
+ CHUNK_HEADER_AND_CRC_SIZE
+ MAX_COMPRESS_BLOCK_SIZE;
pub struct FrameDecoder<R: io::Read> {
r: R,
dec: Decoder,
checksummer: CheckSummer,
src: Vec<u8>,
dst: Vec<u8>,
dsts: usize,
dste: usize,
read_stream_ident: bool,
}
impl<R: io::Read> FrameDecoder<R> {
pub fn new(rdr: R) -> FrameDecoder<R> {
FrameDecoder {
r: rdr,
dec: Decoder::new(),
checksummer: CheckSummer::new(),
src: vec![0; MAX_COMPRESS_BLOCK_SIZE],
dst: vec![0; MAX_BLOCK_SIZE],
dsts: 0,
dste: 0,
read_stream_ident: false,
}
}
pub fn get_ref(&self) -> &R {
&self.r
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.r
}
}
impl<R: io::Read> io::Read for FrameDecoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
macro_rules! fail {
($err:expr) => {
return Err(io::Error::from($err));
};
}
loop {
if self.dsts < self.dste {
let len = cmp::min(self.dste - self.dsts, buf.len());
let dste = self.dsts.checked_add(len).unwrap();
buf[0..len].copy_from_slice(&self.dst[self.dsts..dste]);
self.dsts = dste;
return Ok(len);
}
if !read_exact_eof(&mut self.r, &mut self.src[0..4])? {
return Ok(0);
}
let ty = ChunkType::from_u8(self.src[0]);
if !self.read_stream_ident {
if ty != Ok(ChunkType::Stream) {
fail!(Error::StreamHeader { byte: self.src[0] });
}
self.read_stream_ident = true;
}
let len64 = bytes::read_u24_le(&self.src[1..]) as u64;
if len64 > self.src.len() as u64 {
fail!(Error::UnsupportedChunkLength {
len: len64,
header: false,
});
}
let len = len64 as usize;
match ty {
Err(b) if 0x02 <= b && b <= 0x7F => {
fail!(Error::UnsupportedChunkType { byte: b });
}
Err(b) if 0x80 <= b && b <= 0xFD => {
self.r.read_exact(&mut self.src[0..len])?;
}
Err(b) => {
unreachable!("BUG: unhandled chunk type: {}", b);
}
Ok(ChunkType::Padding) => {
self.r.read_exact(&mut self.src[0..len])?;
}
Ok(ChunkType::Stream) => {
if len != STREAM_BODY.len() {
fail!(Error::UnsupportedChunkLength {
len: len64,
header: true,
})
}
self.r.read_exact(&mut self.src[0..len])?;
if &self.src[0..len] != STREAM_BODY {
fail!(Error::StreamHeaderMismatch {
bytes: self.src[0..len].to_vec(),
});
}
}
Ok(ChunkType::Uncompressed) => {
let expected_sum = bytes::io_read_u32_le(&mut self.r)?;
let n = len - 4;
if n > self.dst.len() {
fail!(Error::UnsupportedChunkLength {
len: n as u64,
header: false,
});
}
self.r.read_exact(&mut self.dst[0..n])?;
let got_sum =
self.checksummer.crc32c_masked(&self.dst[0..n]);
if expected_sum != got_sum {
fail!(Error::Checksum {
expected: expected_sum,
got: got_sum,
});
}
self.dsts = 0;
self.dste = n;
}
Ok(ChunkType::Compressed) => {
let expected_sum = bytes::io_read_u32_le(&mut self.r)?;
let sn = len - 4;
if sn > self.src.len() {
fail!(Error::UnsupportedChunkLength {
len: len64,
header: false,
});
}
self.r.read_exact(&mut self.src[0..sn])?;
let dn = decompress_len(&self.src)?;
if dn > self.dst.len() {
fail!(Error::UnsupportedChunkLength {
len: dn as u64,
header: false,
});
}
self.dec
.decompress(&self.src[0..sn], &mut self.dst[0..dn])?;
let got_sum =
self.checksummer.crc32c_masked(&self.dst[0..dn]);
if expected_sum != got_sum {
fail!(Error::Checksum {
expected: expected_sum,
got: got_sum,
});
}
self.dsts = 0;
self.dste = dn;
}
}
}
}
}
impl<R: fmt::Debug + io::Read> fmt::Debug for FrameDecoder<R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("FrameDecoder")
.field("r", &self.r)
.field("dec", &self.dec)
.field("checksummer", &self.checksummer)
.field("src", &"[...]")
.field("dst", &"[...]")
.field("dsts", &self.dsts)
.field("dste", &self.dste)
.field("read_stream_ident", &self.read_stream_ident)
.finish()
}
}
pub struct FrameEncoder<R: io::Read> {
inner: Inner<R>,
dst: Vec<u8>,
dsts: usize,
dste: usize,
}
struct Inner<R: io::Read> {
r: R,
enc: Encoder,
checksummer: CheckSummer,
src: Vec<u8>,
wrote_stream_ident: bool,
}
impl<R: io::Read> FrameEncoder<R> {
pub fn new(rdr: R) -> FrameEncoder<R> {
FrameEncoder {
inner: Inner {
r: rdr,
enc: Encoder::new(),
checksummer: CheckSummer::new(),
src: vec![0; MAX_BLOCK_SIZE],
wrote_stream_ident: false,
},
dst: vec![0; MAX_READ_FRAME_ENCODER_BLOCK_SIZE],
dsts: 0,
dste: 0,
}
}
pub fn get_ref(&self) -> &R {
&self.inner.r
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner.r
}
fn read_from_dst(&mut self, buf: &mut [u8]) -> usize {
let available_bytes = self.dste - self.dsts;
let count = cmp::min(available_bytes, buf.len());
buf[..count].copy_from_slice(&self.dst[self.dsts..self.dsts + count]);
self.dsts += count;
count
}
}
impl<R: io::Read> io::Read for FrameEncoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let count = self.read_from_dst(buf);
if count > 0 {
Ok(count)
} else if buf.len() >= MAX_READ_FRAME_ENCODER_BLOCK_SIZE {
self.inner.read_frame(buf)
} else {
let count = self.inner.read_frame(&mut self.dst)?;
self.dsts = 0;
self.dste = count;
Ok(self.read_from_dst(buf))
}
}
}
impl<R: io::Read> Inner<R> {
fn read_frame(&mut self, dst: &mut [u8]) -> io::Result<usize> {
debug_assert!(dst.len() >= MAX_READ_FRAME_ENCODER_BLOCK_SIZE);
let nread = self.r.read(&mut self.src)?;
if nread == 0 {
return Ok(0);
}
let mut dst_write_start = 0;
if !self.wrote_stream_ident {
dst[0..STREAM_IDENTIFIER.len()].copy_from_slice(STREAM_IDENTIFIER);
dst_write_start += STREAM_IDENTIFIER.len();
self.wrote_stream_ident = true;
}
let (chunk_header, remaining_dst) =
dst[dst_write_start..].split_at_mut(CHUNK_HEADER_AND_CRC_SIZE);
dst_write_start += CHUNK_HEADER_AND_CRC_SIZE;
let frame_data = compress_frame(
&mut self.enc,
self.checksummer,
&self.src[..nread],
chunk_header,
remaining_dst,
true,
)?;
Ok(dst_write_start + frame_data.len())
}
}
impl<R: fmt::Debug + io::Read> fmt::Debug for FrameEncoder<R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("FrameEncoder")
.field("inner", &self.inner)
.field("dst", &"[...]")
.field("dsts", &self.dsts)
.field("dste", &self.dste)
.finish()
}
}
impl<R: fmt::Debug + io::Read> fmt::Debug for Inner<R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Inner")
.field("r", &self.r)
.field("enc", &self.enc)
.field("checksummer", &self.checksummer)
.field("src", &"[...]")
.field("wrote_stream_ident", &self.wrote_stream_ident)
.finish()
}
}
fn read_exact_eof<R: io::Read>(
rdr: &mut R,
buf: &mut [u8],
) -> io::Result<bool> {
use std::io::ErrorKind::UnexpectedEof;
match rdr.read_exact(buf) {
Ok(()) => Ok(true),
Err(ref err) if err.kind() == UnexpectedEof => Ok(false),
Err(err) => Err(err),
}
}