use std::{ fmt, mem, ptr };
use std::ops::{ Deref, DerefMut };
use std::cell::Cell;
use memsec::{ memzero, malloc, free, mprotect, Prot };
#[cfg(feature = "place")] use std::ops::{ Place, Placer, InPlace };
pub struct SecKey<T> {
ptr: *mut T,
count: Cell<usize>
}
#[cfg(feature = "place")]
pub struct SecHeap;
#[cfg(feature = "place")]
pub struct SecPtr<T>(*mut T);
#[cfg(feature = "place")]
impl<T: Sized> Placer<T> for SecHeap {
type Place = SecPtr<T>;
fn make_place(self) -> Self::Place {
SecPtr(unsafe {
malloc(mem::size_of::<T>())
.unwrap_or_else(|| panic!("memsec::malloc fail: {}", mem::size_of::<T>()))
})
}
}
#[cfg(feature = "place")]
impl<T> Place<T> for SecPtr<T> {
fn pointer(&mut self) -> *mut T {
self.0
}
}
#[cfg(feature = "place")]
impl<T> InPlace<T> for SecPtr<T> {
type Owner = SecKey<T>;
unsafe fn finalize(self) -> Self::Owner {
mprotect(self.0, Prot::NoAccess);
SecKey {
ptr: self.0,
count: Cell::new(0)
}
}
}
impl<T> Default for SecKey<T> where T: Default {
fn default() -> Self {
SecKey::new(T::default())
.unwrap_or_else(|_| panic!("memsec::malloc fail: {}", mem::size_of::<T>()))
}
}
impl<T> SecKey<T> where T: Sized {
pub fn new(mut t: T) -> Result<SecKey<T>, T> {
unsafe {
match Self::from_raw(&t) {
Some(output) => {
memzero(&mut t, mem::size_of::<T>());
mem::forget(t);
Ok(output)
},
None => Err(t)
}
}
}
pub unsafe fn from_raw(t: *const T) -> Option<SecKey<T>> {
let memptr: *mut T = match malloc(mem::size_of::<T>()) {
Some(memptr) => memptr,
None => return None
};
ptr::copy_nonoverlapping(t, memptr, 1);
mprotect(memptr, Prot::NoAccess);
Some(SecKey {
ptr: memptr,
count: Cell::new(0)
})
}
}
impl<T> SecKey<T> {
fn read_unlock(&self) {
let count = self.count.get();
self.count.set(count + 1);
if count == 0 {
unsafe { mprotect(self.ptr, Prot::ReadOnly) };
}
}
fn write_unlock(&self) {
let count = self.count.get();
self.count.set(count + 1);
if count == 0 {
unsafe { mprotect(self.ptr, Prot::ReadWrite) };
}
}
fn lock(&self) {
let count = self.count.get();
self.count.set(count - 1);
if count <= 1 {
unsafe { mprotect(self.ptr, Prot::NoAccess) };
}
}
#[inline]
pub fn read(&self) -> SecReadGuard<T> {
self.read_unlock();
SecReadGuard(self)
}
#[inline]
pub fn write(&mut self) -> SecWriteGuard<T> {
self.write_unlock();
SecWriteGuard(self)
}
}
impl<T> fmt::Debug for SecKey<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "** sec key ({}) **", self.count.get())
}
}
impl<T> fmt::Pointer for SecKey<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:p}", self.ptr)
}
}
impl<T> Drop for SecKey<T> {
fn drop(&mut self) {
unsafe {
mprotect(self.ptr, Prot::ReadWrite);
ptr::drop_in_place(self.ptr);
free(self.ptr);
}
}
}
pub struct SecReadGuard<'a, T: 'a>(&'a SecKey<T>);
impl<'a, T: 'a> Deref for SecReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.ptr }
}
}
impl<'a, T: 'a> Drop for SecReadGuard<'a, T> {
fn drop(&mut self) {
self.0.lock();
}
}
pub struct SecWriteGuard<'a, T: 'a>(&'a mut SecKey<T>);
impl<'a, T: 'a> Deref for SecWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.0.ptr }
}
}
impl<'a, T: 'a> DerefMut for SecWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.0.ptr }
}
}
impl<'a, T: 'a> Drop for SecWriteGuard<'a, T> {
fn drop(&mut self) {
self.0.lock();
}
}