use std::io;
use std::io::prelude::*;
#[cfg(feature = "tokio")]
use futures::Poll;
#[cfg(feature = "tokio")]
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
#[cfg(feature = "parallel")]
use crate::stream::MtStreamBuilder;
use crate::stream::{Action, Check, Status, Stream};
pub struct XzEncoder<W: Write> {
data: Stream,
obj: Option<W>,
buf: Vec<u8>,
}
pub struct XzDecoder<W: Write> {
data: Stream,
obj: Option<W>,
buf: Vec<u8>,
}
impl<W: Write> XzEncoder<W> {
#[inline]
pub fn new(obj: W, level: u32) -> XzEncoder<W> {
let stream = Stream::new_easy_encoder(level, Check::Crc64).unwrap();
XzEncoder::new_stream(obj, stream)
}
#[cfg(feature = "parallel")]
pub fn new_parallel(obj: W, level: u32) -> XzEncoder<W> {
let stream = MtStreamBuilder::new()
.preset(level)
.check(Check::Crc64)
.threads(num_cpus::get() as u32)
.encoder()
.unwrap();
Self::new_stream(obj, stream)
}
#[inline]
pub fn new_stream(obj: W, stream: Stream) -> XzEncoder<W> {
XzEncoder {
data: stream,
obj: Some(obj),
buf: Vec::with_capacity(32 * 1024),
}
}
#[inline]
pub fn get_ref(&self) -> &W {
self.obj.as_ref().unwrap()
}
#[inline]
pub fn get_mut(&mut self) -> &mut W {
self.obj.as_mut().unwrap()
}
fn dump(&mut self) -> io::Result<()> {
while !self.buf.is_empty() {
let n = self.obj.as_mut().unwrap().write(&self.buf)?;
self.buf.drain(..n);
}
Ok(())
}
#[inline]
pub fn try_finish(&mut self) -> io::Result<()> {
loop {
self.dump()?;
let res = self.data.process_vec(&[], &mut self.buf, Action::Finish)?;
if res == Status::StreamEnd {
break;
}
}
self.dump()
}
#[inline]
pub fn finish(mut self) -> io::Result<W> {
self.try_finish()?;
Ok(self.obj.take().unwrap())
}
#[inline]
pub fn total_out(&self) -> u64 {
self.data.total_out()
}
#[inline]
pub fn total_in(&self) -> u64 {
self.data.total_in()
}
}
impl<W: Write> Write for XzEncoder<W> {
#[inline]
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
loop {
self.dump()?;
let total_in = self.total_in();
self.data.process_vec(data, &mut self.buf, Action::Run)?;
let written = (self.total_in() - total_in) as usize;
if written > 0 || data.is_empty() {
return Ok(written);
}
}
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
loop {
self.dump()?;
let status = self
.data
.process_vec(&[], &mut self.buf, Action::FullFlush)?;
if status == Status::StreamEnd {
break;
}
}
self.obj.as_mut().unwrap().flush()
}
}
#[cfg(feature = "tokio")]
impl<W: AsyncWrite> AsyncWrite for XzEncoder<W> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
try_nb!(self.try_finish());
self.get_mut().shutdown()
}
}
impl<W: Read + Write> Read for XzEncoder<W> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.get_mut().read(buf)
}
}
#[cfg(feature = "tokio")]
impl<W: AsyncRead + AsyncWrite> AsyncRead for XzEncoder<W> {}
impl<W: Write> Drop for XzEncoder<W> {
#[inline]
fn drop(&mut self) {
if self.obj.is_some() {
let _ = self.try_finish();
}
}
}
impl<W: Write> XzDecoder<W> {
#[inline]
pub fn new(obj: W) -> XzDecoder<W> {
let stream = Stream::new_stream_decoder(u64::MAX, 0).unwrap();
XzDecoder::new_stream(obj, stream)
}
#[cfg(feature = "parallel")]
pub fn new_parallel(obj: W) -> Self {
let stream = MtStreamBuilder::new()
.memlimit_stop(u64::MAX)
.threads(num_cpus::get() as u32)
.decoder()
.unwrap();
Self::new_stream(obj, stream)
}
#[inline]
pub fn new_multi_decoder(obj: W) -> XzDecoder<W> {
let stream = Stream::new_stream_decoder(u64::MAX, liblzma_sys::LZMA_CONCATENATED).unwrap();
XzDecoder::new_stream(obj, stream)
}
#[inline]
pub fn new_stream(obj: W, stream: Stream) -> XzDecoder<W> {
XzDecoder {
data: stream,
obj: Some(obj),
buf: Vec::with_capacity(32 * 1024),
}
}
#[inline]
pub fn get_ref(&self) -> &W {
self.obj.as_ref().unwrap()
}
#[inline]
pub fn get_mut(&mut self) -> &mut W {
self.obj.as_mut().unwrap()
}
fn dump(&mut self) -> io::Result<()> {
if !self.buf.is_empty() {
self.obj.as_mut().unwrap().write_all(&self.buf)?;
self.buf.clear();
}
Ok(())
}
fn try_finish(&mut self) -> io::Result<()> {
loop {
self.dump()?;
let res = self.data.process_vec(&[], &mut self.buf, Action::Finish)?;
if self.buf.is_empty() && res == Status::MemNeeded {
let msg = "xz compressed stream is truncated or otherwise corrupt";
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, msg));
}
if res == Status::StreamEnd {
break;
}
}
self.dump()
}
#[inline]
pub fn finish(&mut self) -> io::Result<W> {
self.try_finish()?;
Ok(self.obj.take().unwrap())
}
#[inline]
pub fn total_out(&self) -> u64 {
self.data.total_out()
}
#[inline]
pub fn total_in(&self) -> u64 {
self.data.total_in()
}
}
impl<W: Write> Write for XzDecoder<W> {
#[inline]
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
loop {
self.dump()?;
let before = self.total_in();
let res = self.data.process_vec(data, &mut self.buf, Action::Run)?;
let written = (self.total_in() - before) as usize;
if written > 0 || data.is_empty() || res == Status::StreamEnd {
return Ok(written);
}
}
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
self.dump()?;
self.obj.as_mut().unwrap().flush()
}
}
#[cfg(feature = "tokio")]
impl<W: AsyncWrite> AsyncWrite for XzDecoder<W> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
try_nb!(self.try_finish());
self.get_mut().shutdown()
}
}
impl<W: Read + Write> Read for XzDecoder<W> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.get_mut().read(buf)
}
}
#[cfg(feature = "tokio")]
impl<W: AsyncRead + AsyncWrite> AsyncRead for XzDecoder<W> {}
impl<W: Write> Drop for XzDecoder<W> {
#[inline]
fn drop(&mut self) {
if self.obj.is_some() {
let _ = self.try_finish();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stream::LzmaOptions;
use quickcheck::quickcheck;
use std::iter::repeat;
#[test]
fn smoke() {
let d = XzDecoder::new(Vec::new());
let mut c = XzEncoder::new(d, 6);
c.write_all(b"12834").unwrap();
let s = repeat("12345").take(100000).collect::<String>();
c.write_all(s.as_bytes()).unwrap();
let data = c.finish().unwrap().finish().unwrap();
assert_eq!(&data[0..5], b"12834");
assert_eq!(data.len(), 500005);
assert_eq!(format!("12834{}", s).as_bytes(), &*data);
}
#[test]
fn write_empty() {
let d = XzDecoder::new(Vec::new());
let mut c = XzEncoder::new(d, 6);
c.write(b"").unwrap();
let data = c.finish().unwrap().finish().unwrap();
assert_eq!(&data[..], b"");
}
#[test]
fn qc_lzma1() {
quickcheck(test as fn(_) -> _);
fn test(v: Vec<u8>) -> bool {
let stream = Stream::new_lzma_decoder(u64::MAX).unwrap();
let w = XzDecoder::new_stream(Vec::new(), stream);
let options = LzmaOptions::new_preset(6).unwrap();
let stream = Stream::new_lzma_encoder(&options).unwrap();
let mut w = XzEncoder::new_stream(w, stream);
w.write_all(&v).unwrap();
v == w.finish().unwrap().finish().unwrap()
}
}
#[test]
fn qc() {
quickcheck(test as fn(_) -> _);
fn test(v: Vec<u8>) -> bool {
let w = XzDecoder::new(Vec::new());
let mut w = XzEncoder::new(w, 6);
w.write_all(&v).unwrap();
v == w.finish().unwrap().finish().unwrap()
}
}
#[cfg(feature = "parallel")]
#[test]
fn qc_parallel_encode() {
quickcheck(test as fn(_) -> _);
fn test(v: Vec<u8>) -> bool {
let w = XzDecoder::new(Vec::new());
let mut w = XzEncoder::new_parallel(w, 6);
w.write_all(&v).unwrap();
v == w.finish().unwrap().finish().unwrap()
}
}
#[cfg(feature = "parallel")]
#[test]
fn qc_parallel_decode() {
quickcheck(test as fn(_) -> _);
fn test(v: Vec<u8>) -> bool {
let w = XzDecoder::new_parallel(Vec::new());
let mut w = XzEncoder::new(w, 6);
w.write_all(&v).unwrap();
v == w.finish().unwrap().finish().unwrap()
}
}
}