use core::{ fmt, mem, str };
use core::ptr::{ self, NonNull };
use core::ops::{ Deref, DerefMut };
use core::cell::Cell;
use memsec::{ memzero, malloc, malloc_sized, free, mprotect, Prot };
pub struct SecKey<T: ?Sized> {
ptr: NonNull<T>,
count: Cell<usize>
}
impl<T> SecKey<T> {
pub fn new(mut t: T) -> Result<SecKey<T>, T> {
unsafe {
match Self::from_ptr(&t) {
Some(output) => {
memzero(&mut t as *mut T as *mut u8, mem::size_of::<T>());
mem::forget(t);
Ok(output)
},
None => Err(t)
}
}
}
#[inline]
pub unsafe fn from_ptr(t: *const T) -> Option<SecKey<T>> {
Self::with(move |memptr| ptr::copy_nonoverlapping(t, memptr, 1))
}
pub unsafe fn with<F>(f: F) -> Option<SecKey<T>>
where F: FnOnce(*mut T)
{
let memptr = malloc()?;
f(memptr.as_ptr());
mprotect(memptr, Prot::NoAccess);
Some(SecKey {
ptr: memptr,
count: Cell::new(0)
})
}
}
impl<T: Copy> SecKey<T> {
pub fn from_ref(t: &T) -> Option<SecKey<T>> {
unsafe { Self::from_ptr(t) }
}
}
impl<T: Default> SecKey<T> {
pub fn with_default<F>(f: F) -> Option<SecKey<T>>
where F: FnOnce(&mut T)
{
unsafe {
Self::with(|p| {
ptr::write(p, T::default());
f(&mut *p);
})
}
}
}
impl SecKey<[u8]> {
pub fn from_bytes(src: &mut [u8]) -> Option<SecKey<[u8]>> {
unsafe {
let mut memptr = malloc_sized(src.len())?;
ptr::copy_nonoverlapping(
src.as_ptr(),
memptr.as_mut().as_mut_ptr(),
src.len()
);
mprotect(memptr, Prot::NoAccess);
memzero(src.as_mut_ptr(), src.len());
Some(SecKey {
ptr: memptr,
count: Cell::new(0),
})
}
}
}
impl SecKey<str> {
pub fn from_str(src: &mut str) -> Option<SecKey<str>> {
unsafe {
let src = src.as_bytes_mut();
let mut memptr = malloc_sized(src.len())?;
ptr::copy_nonoverlapping(
src.as_ptr(),
memptr.as_mut().as_mut_ptr(),
src.len()
);
let strptr = NonNull::new_unchecked(
str::from_utf8_unchecked_mut(memptr.as_mut()) as *mut str
);
mprotect(strptr, Prot::NoAccess);
memzero(src.as_mut_ptr(), src.len());
Some(SecKey {
ptr: strptr,
count: Cell::new(0),
})
}
}
}
impl<T: ?Sized> SecKey<T> {
#[inline]
unsafe fn lock(&self) {
let count = self.count.get();
self.count.set(count - 1);
if count <= 1 {
mprotect(self.ptr, Prot::NoAccess);
}
}
#[inline]
pub fn read(&self) -> SecReadGuard<T> {
let count = self.count.get();
self.count.set(count + 1);
if count == 0 {
unsafe { mprotect(self.ptr, Prot::ReadOnly) };
}
SecReadGuard(self)
}
#[inline]
pub fn write(&mut self) -> SecWriteGuard<T> {
let count = self.count.get();
self.count.set(count + 1);
if count == 0 {
unsafe { mprotect(self.ptr, Prot::ReadWrite) };
}
SecWriteGuard(self)
}
}
impl<T: ?Sized> fmt::Debug for SecKey<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("SecKey")
.field(&format_args!("{:p}", self.ptr))
.field(&self.count)
.finish()
}
}
impl<T: ?Sized> fmt::Pointer for SecKey<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:p}", self.ptr)
}
}
impl<T: ?Sized> Drop for SecKey<T> {
fn drop(&mut self) {
unsafe {
mprotect(self.ptr, Prot::ReadWrite);
ptr::drop_in_place(self.ptr.as_ptr());
free(self.ptr);
}
}
}
pub struct SecReadGuard<'a, T: 'a + ?Sized>(&'a SecKey<T>);
impl<'a, T: 'a + ?Sized> Deref for SecReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { self.0.ptr.as_ref() }
}
}
impl<'a, T: 'a + ?Sized> Drop for SecReadGuard<'a, T> {
fn drop(&mut self) {
unsafe { self.0.lock() }
}
}
pub struct SecWriteGuard<'a, T: 'a + ?Sized>(&'a mut SecKey<T>);
impl<'a, T: 'a + ?Sized> Deref for SecWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { self.0.ptr.as_ref() }
}
}
impl<'a, T: 'a + ?Sized> DerefMut for SecWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { self.0.ptr.as_mut() }
}
}
impl<'a, T: 'a + ?Sized> Drop for SecWriteGuard<'a, T> {
fn drop(&mut self) {
unsafe { self.0.lock() }
}
}