use std::ptr;
use crate::bytes;
use crate::error::{Error, Result};
use crate::tag;
use crate::MAX_INPUT_SIZE;
const TAG_LOOKUP_TABLE: TagLookupTable = TagLookupTable(tag::TAG_LOOKUP_TABLE);
const WORD_MASK: [usize; 5] = [0, 0xFF, 0xFFFF, 0xFFFFFF, 0xFFFFFFFF];
pub fn decompress_len(input: &[u8]) -> Result<usize> {
if input.is_empty() {
return Ok(0);
}
Ok(Header::read(input)?.decompress_len)
}
#[derive(Clone, Debug, Default)]
pub struct Decoder {
_dummy: (),
}
impl Decoder {
pub fn new() -> Decoder {
Decoder { _dummy: () }
}
pub fn decompress(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<usize> {
if input.is_empty() {
return Err(Error::Empty);
}
let hdr = Header::read(input)?;
if hdr.decompress_len > output.len() {
return Err(Error::BufferTooSmall {
given: output.len() as u64,
min: hdr.decompress_len as u64,
});
}
let dst = &mut output[..hdr.decompress_len];
let mut dec =
Decompress { src: &input[hdr.len..], s: 0, dst: dst, d: 0 };
dec.decompress()?;
Ok(dec.dst.len())
}
pub fn decompress_vec(&mut self, input: &[u8]) -> Result<Vec<u8>> {
let mut buf = vec![0; decompress_len(input)?];
let n = self.decompress(input, &mut buf)?;
buf.truncate(n);
Ok(buf)
}
}
struct Decompress<'s, 'd> {
src: &'s [u8],
s: usize,
dst: &'d mut [u8],
d: usize,
}
impl<'s, 'd> Decompress<'s, 'd> {
fn decompress(&mut self) -> Result<()> {
while self.s < self.src.len() {
let byte = self.src[self.s];
self.s += 1;
if byte & 0b000000_11 == 0 {
let len = (byte >> 2) as usize + 1;
self.read_literal(len)?;
} else {
self.read_copy(byte)?;
}
}
if self.d != self.dst.len() {
return Err(Error::HeaderMismatch {
expected_len: self.dst.len() as u64,
got_len: self.d as u64,
});
}
Ok(())
}
#[inline(always)]
fn read_literal(&mut self, len: usize) -> Result<()> {
debug_assert!(len <= 64);
let mut len = len as u64;
if len <= 16
&& self.s + 16 <= self.src.len()
&& self.d + 16 <= self.dst.len()
{
unsafe {
let srcp = self.src.as_ptr().add(self.s);
let dstp = self.dst.as_mut_ptr().add(self.d);
ptr::copy_nonoverlapping(srcp, dstp, 16);
}
self.d += len as usize;
self.s += len as usize;
return Ok(());
}
if len >= 61 {
if self.s as u64 + 4 > self.src.len() as u64 {
return Err(Error::Literal {
len: 4,
src_len: (self.src.len() - self.s) as u64,
dst_len: (self.dst.len() - self.d) as u64,
});
}
let byte_count = len as usize - 60;
len = bytes::read_u32_le(&self.src[self.s..]) as u64;
len = (len & (WORD_MASK[byte_count] as u64)) + 1;
self.s += byte_count;
}
if ((self.src.len() - self.s) as u64) < len
|| ((self.dst.len() - self.d) as u64) < len
{
return Err(Error::Literal {
len: len,
src_len: (self.src.len() - self.s) as u64,
dst_len: (self.dst.len() - self.d) as u64,
});
}
unsafe {
let srcp = self.src.as_ptr().add(self.s);
let dstp = self.dst.as_mut_ptr().add(self.d);
ptr::copy_nonoverlapping(srcp, dstp, len as usize);
}
self.s += len as usize;
self.d += len as usize;
Ok(())
}
#[inline(always)]
fn read_copy(&mut self, tag_byte: u8) -> Result<()> {
let entry = TAG_LOOKUP_TABLE.entry(tag_byte);
let offset = entry.offset(self.src, self.s)?;
let len = entry.len();
self.s += entry.num_tag_bytes();
if self.d <= offset.wrapping_sub(1) {
return Err(Error::Offset {
offset: offset as u64,
dst_pos: self.d as u64,
});
}
let end = self.d + len;
if offset >= 8 && len <= 16 && self.d + 16 <= self.dst.len() {
unsafe {
let dstp = self.dst.as_mut_ptr().add(self.d);
let srcp = dstp.sub(offset);
ptr::copy_nonoverlapping(srcp, dstp, 8);
ptr::copy_nonoverlapping(srcp.add(8), dstp.add(8), 8);
}
} else if end + 24 <= self.dst.len() {
unsafe {
let dest_len = self.dst.len();
let mut dstp = self.dst.as_mut_ptr().add(self.d);
let mut srcp = dstp.sub(offset);
loop {
debug_assert!(dstp >= srcp);
let diff = (dstp as usize) - (srcp as usize);
if diff >= 16 {
break;
}
debug_assert!(self.d + 16 <= dest_len);
ptr::copy(srcp, dstp, 16);
self.d += diff as usize;
dstp = dstp.add(diff);
}
while self.d < end {
ptr::copy_nonoverlapping(srcp, dstp, 16);
srcp = srcp.add(16);
dstp = dstp.add(16);
self.d += 16;
}
}
} else {
if end > self.dst.len() {
return Err(Error::CopyWrite {
len: len as u64,
dst_len: (self.dst.len() - self.d) as u64,
});
}
while self.d != end {
self.dst[self.d] = self.dst[self.d - offset];
self.d += 1;
}
}
self.d = end;
Ok(())
}
}
#[derive(Debug)]
struct Header {
len: usize,
decompress_len: usize,
}
impl Header {
#[inline(always)]
fn read(input: &[u8]) -> Result<Header> {
let (decompress_len, header_len) = bytes::read_varu64(input);
if header_len == 0 {
return Err(Error::Header);
}
if decompress_len > MAX_INPUT_SIZE {
return Err(Error::TooBig {
given: decompress_len as u64,
max: MAX_INPUT_SIZE,
});
}
Ok(Header { len: header_len, decompress_len: decompress_len as usize })
}
}
struct TagLookupTable([u16; 256]);
impl TagLookupTable {
#[inline(always)]
fn entry(&self, byte: u8) -> TagEntry {
TagEntry(self.0[byte as usize] as usize)
}
}
struct TagEntry(usize);
impl TagEntry {
fn num_tag_bytes(&self) -> usize {
self.0 >> 11
}
fn len(&self) -> usize {
self.0 & 0xFF
}
fn offset(&self, src: &[u8], s: usize) -> Result<usize> {
let num_tag_bytes = self.num_tag_bytes();
let trailer =
if s + 4 <= src.len() {
unsafe {
let p = src.as_ptr().add(s);
bytes::loadu_u32_le(p) as usize & WORD_MASK[num_tag_bytes]
}
} else if num_tag_bytes == 1 {
if s >= src.len() {
return Err(Error::CopyRead {
len: 1,
src_len: (src.len() - s) as u64,
});
}
src[s] as usize
} else if num_tag_bytes == 2 {
if s + 1 >= src.len() {
return Err(Error::CopyRead {
len: 2,
src_len: (src.len() - s) as u64,
});
}
bytes::read_u16_le(&src[s..]) as usize
} else {
return Err(Error::CopyRead {
len: num_tag_bytes as u64,
src_len: (src.len() - s) as u64,
});
};
Ok((self.0 & 0b0000_0111_0000_0000) | trailer)
}
}