use std::fmt;
use std::ops::{Drop,Deref,DerefMut};
use std::marker::Send;
use std::thread;
use std::thread::ThreadId;
const DEREF_ERROR: &'static str = "Dropped SendWrapper<T> variable from a thread different to the one it has been created with.";
const DROP_ERROR: &'static str = "Dereferenced SendWrapper<T> variable from a thread different to the one it has been created with.";
pub struct SendWrapper<T> {
data: *mut T,
thread_id: ThreadId,
}
impl<T> SendWrapper<T> {
pub fn new(data: T) -> SendWrapper<T> {
SendWrapper {
data: Box::into_raw(Box::new(data)),
thread_id: thread::current().id()
}
}
pub fn valid(&self) -> bool {
self.thread_id == thread::current().id()
}
pub fn take(self) -> T {
if !self.valid() {
panic!(DEREF_ERROR);
}
let result = unsafe { Box::from_raw(self.data) };
std::mem::forget(self);
*result
}
}
unsafe impl<T> Send for SendWrapper<T> { }
unsafe impl<T> Sync for SendWrapper<T> { }
impl<T> Deref for SendWrapper<T> {
type Target = T;
fn deref(&self) -> &T {
if !self.valid() {
panic!(DEREF_ERROR);
}
unsafe {
&*self.data
}
}
}
impl<T> DerefMut for SendWrapper<T> {
fn deref_mut(&mut self) -> &mut T {
if !self.valid() {
panic!(DEREF_ERROR);
}
unsafe {
&mut *self.data
}
}
}
impl<T> Drop for SendWrapper<T> {
fn drop(&mut self) {
if self.valid() {
unsafe {
let _dropper = Box::from_raw(self.data);
}
} else {
if !std::thread::panicking() {
panic!(DROP_ERROR);
}
}
}
}
impl<T: fmt::Debug> fmt::Debug for SendWrapper<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.valid() {
panic!(DEREF_ERROR);
}
f.debug_struct("SendWrapper")
.field("data", unsafe { &*self.data })
.field("thread_id", &self.thread_id)
.finish()
}
}
impl<T: Clone> Clone for SendWrapper<T> {
fn clone(&self) -> Self {
if !self.valid() {
panic!(DEREF_ERROR);
}
Self {
data: Box::into_raw(Box::new(unsafe { &*self.data }.clone())),
thread_id: self.thread_id,
}
}
}
#[cfg(test)]
mod tests {
use SendWrapper;
use std::thread;
use std::sync::mpsc::channel;
use std::ops::Deref;
use std::rc::Rc;
use std::sync::Arc;
#[test]
fn test_deref() {
let (sender, receiver) = channel();
let w = SendWrapper::new(Rc::new(42));
{
let _x = w.deref();
}
let t = thread::spawn(move || {
sender.send(w).unwrap();
});
let w2 = receiver.recv().unwrap();
{
let _x = w2.deref();
}
assert!(t.join().is_ok());
}
#[test]
fn test_deref_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let _x = w.deref();
});
let join_result = t.join();
assert!(join_result.is_err());
}
#[test]
fn test_drop_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let _x = w;
});
let join_result = t.join();
assert!(join_result.is_err());
}
#[test]
fn test_valid() {
let w = SendWrapper::new(Rc::new(42));
assert!(w.valid());
thread::spawn(move || {
assert!(!w.valid());
});
}
#[test]
fn test_take() {
let w = SendWrapper::new(Rc::new(42));
let inner: Rc<usize> = w.take();
assert_eq!(42, *inner);
}
#[test]
fn test_take_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let _ = w.take();
});
assert!(t.join().is_err());
}
#[test]
fn test_sync() {
let arc = Arc::new(SendWrapper::new(42));
thread::spawn(move || {
let _ = arc;
});
}
#[test]
fn test_debug() {
let w = SendWrapper::new(Rc::new(42));
let info = format!("{:?}", w);
assert!(info.contains("SendWrapper {"));
assert!(info.contains("data: 42,"));
assert!(info.contains("thread_id: ThreadId("));
}
#[test]
fn test_debug_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let _ = format!("{:?}", w);
});
assert!(t.join().is_err());
}
#[test]
fn test_clone() {
let w1 = SendWrapper::new(Rc::new(42));
let w2 = w1.clone();
assert_eq!(format!("{:?}", w1), format!("{:?}", w2));
}
#[test]
fn test_clone_panic() {
let w = SendWrapper::new(Rc::new(42));
let t = thread::spawn(move || {
let _ = w.clone();
});
assert!(t.join().is_err());
}
}