use crate::{Archive, Archived, Offset, RelPtr};
use bytecheck::{CheckBytes, Unreachable};
use core::{fmt, mem};
use std::error;
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct Interval {
pub start: usize,
pub end: usize,
}
impl Interval {
pub fn overlaps(&self, other: &Self) -> bool {
self.start < other.end && other.start < self.end
}
}
#[derive(Debug)]
pub enum ArchiveMemoryError {
OutOfBounds {
base: usize,
offset: isize,
archive_len: usize,
},
Overrun {
pos: usize,
size: usize,
archive_len: usize,
},
Unaligned {
pos: usize,
align: usize,
},
ClaimOverlap {
previous: Interval,
current: Interval,
},
}
impl fmt::Display for ArchiveMemoryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ArchiveMemoryError::OutOfBounds {
base,
offset,
archive_len,
} => write!(
f,
"out of bounds pointer: base {} offset {} in archive len {}",
base, offset, archive_len
),
ArchiveMemoryError::Overrun {
pos,
size,
archive_len,
} => write!(
f,
"archive overrun: pos {} size {} in archive len {}",
pos, size, archive_len
),
ArchiveMemoryError::Unaligned { pos, align } => write!(
f,
"unaligned pointer: pos {} unaligned for alignment {}",
pos, align
),
ArchiveMemoryError::ClaimOverlap { previous, current } => write!(
f,
"memory claim overlap: current [{}..{}] overlaps previous [{}..{}]",
current.start, current.end, previous.start, previous.end
),
}
}
}
impl error::Error for ArchiveMemoryError {}
#[derive(Debug)]
pub enum CheckArchiveError<T> {
MemoryError(ArchiveMemoryError),
CheckBytes(T),
}
impl<T> From<ArchiveMemoryError> for CheckArchiveError<T> {
fn from(e: ArchiveMemoryError) -> Self {
Self::MemoryError(e)
}
}
impl<T: fmt::Display> fmt::Display for CheckArchiveError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CheckArchiveError::MemoryError(e) => write!(f, "archive memory error: {}", e),
CheckArchiveError::CheckBytes(e) => write!(f, "check bytes error: {}", e),
}
}
}
impl<T: fmt::Debug + fmt::Display> error::Error for CheckArchiveError<T> {}
pub struct ArchiveContext {
begin: *const u8,
len: usize,
intervals: Vec<Interval>,
}
impl ArchiveContext {
pub fn new(bytes: &[u8]) -> Self {
const DEFAULT_INTERVALS_CAPACITY: usize = 64;
Self {
begin: bytes.as_ptr(),
len: bytes.len(),
intervals: Vec::with_capacity(DEFAULT_INTERVALS_CAPACITY),
}
}
pub unsafe fn claim<T: CheckBytes<ArchiveContext>>(
&mut self,
rel_ptr: &RelPtr,
count: usize,
) -> Result<*const u8, ArchiveMemoryError> {
let base = (rel_ptr as *const RelPtr).cast::<u8>();
let offset = rel_ptr.offset();
self.claim_bytes(
base,
offset,
count * mem::size_of::<T>(),
mem::align_of::<T>(),
)
}
pub unsafe fn claim_bytes(
&mut self,
base: *const u8,
offset: isize,
count: usize,
align: usize,
) -> Result<*const u8, ArchiveMemoryError> {
let base_pos = base.offset_from(self.begin);
if offset < -base_pos || offset > self.len as isize - base_pos {
Err(ArchiveMemoryError::OutOfBounds {
base: base_pos as usize,
offset,
archive_len: self.len,
})
} else {
let target_pos = (base_pos + offset) as usize;
if target_pos & (align - 1) != 0 {
Err(ArchiveMemoryError::Unaligned {
pos: target_pos,
align,
})
} else if count != 0 {
if self.len - target_pos < count {
Err(ArchiveMemoryError::Overrun {
pos: target_pos,
size: count,
archive_len: self.len,
})
} else {
let interval = Interval {
start: target_pos,
end: target_pos + count,
};
match self.intervals.binary_search(&interval) {
Ok(index) => Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index],
current: interval,
}),
Err(index) => {
if index < self.intervals.len() {
if self.intervals[index].overlaps(&interval) {
return Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index],
current: interval,
});
} else if self.intervals[index].start == interval.end {
self.intervals[index].start = interval.start;
return Ok(base.offset(offset));
}
}
if index > 0 {
if self.intervals[index - 1].overlaps(&interval) {
return Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index - 1],
current: interval,
});
} else if self.intervals[index - 1].end == interval.start {
self.intervals[index - 1].end = interval.end;
return Ok(base.offset(offset));
}
}
self.intervals.insert(index, interval);
Ok(base.offset(offset))
}
}
}
} else {
Ok(base.offset(offset))
}
}
}
}
pub fn check_archive<'a, T: Archive>(
buf: &[u8],
pos: usize,
) -> Result<&'a T::Archived, CheckArchiveError<<T::Archived as CheckBytes<ArchiveContext>>::Error>>
where
T::Archived: CheckBytes<ArchiveContext>,
{
let mut context = ArchiveContext::new(buf);
unsafe {
let bytes = context.claim_bytes(
buf.as_ptr(),
pos as isize,
mem::size_of::<T::Archived>(),
mem::align_of::<T::Archived>(),
)?;
Archived::<T>::check_bytes(bytes, &mut context).map_err(CheckArchiveError::CheckBytes)?;
Ok(&*bytes.cast())
}
}
impl CheckBytes<ArchiveContext> for RelPtr {
type Error = Unreachable;
unsafe fn check_bytes<'a>(
bytes: *const u8,
context: &mut ArchiveContext,
) -> Result<&'a Self, Self::Error> {
Offset::check_bytes(bytes, context)?;
Ok(&*bytes.cast())
}
}