use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
use crate::util::WakeList;
use std::fmt;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};
use std::usize;
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
next: u64,
}
pub mod error {
use std::fmt;
#[derive(Debug)]
pub struct SendError<T>(pub T);
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "channel closed")
}
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum RecvError {
Closed,
Lagged(u64),
}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RecvError::Closed => write!(f, "channel closed"),
RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
}
}
}
impl std::error::Error for RecvError {}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum TryRecvError {
Empty,
Closed,
Lagged(u64),
}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(f, "channel empty"),
TryRecvError::Closed => write!(f, "channel closed"),
TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
}
}
}
impl std::error::Error for TryRecvError {}
}
use self::error::{RecvError, SendError, TryRecvError};
struct Shared<T> {
buffer: Box<[RwLock<Slot<T>>]>,
mask: usize,
tail: Mutex<Tail>,
num_tx: AtomicUsize,
}
struct Tail {
pos: u64,
rx_cnt: usize,
closed: bool,
waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
}
struct Slot<T> {
rem: AtomicUsize,
pos: u64,
val: UnsafeCell<Option<T>>,
}
struct Waiter {
queued: bool,
waker: Option<Waker>,
pointers: linked_list::Pointers<Waiter>,
_p: PhantomPinned,
}
impl Waiter {
fn new() -> Self {
Self {
queued: false,
waker: None,
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
}
}
}
generate_addr_of_methods! {
impl<> Waiter {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
&self.pointers
}
}
}
struct RecvGuard<'a, T> {
slot: RwLockReadGuard<'a, Slot<T>>,
}
struct Recv<'a, T> {
receiver: &'a mut Receiver<T>,
waiter: UnsafeCell<Waiter>,
}
unsafe impl<'a, T: Send> Send for Recv<'a, T> {}
unsafe impl<'a, T: Send> Sync for Recv<'a, T> {}
const MAX_RECEIVERS: usize = usize::MAX >> 2;
#[track_caller]
pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let tx = unsafe { Sender::new_with_receiver_count(1, capacity) };
let rx = Receiver {
shared: tx.shared.clone(),
next: 0,
};
(tx, rx)
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Sync for Sender<T> {}
unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}
impl<T> Sender<T> {
#[track_caller]
pub fn new(capacity: usize) -> Self {
unsafe { Self::new_with_receiver_count(0, capacity) }
}
#[track_caller]
unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self {
assert!(capacity > 0, "broadcast channel capacity cannot be zero");
assert!(
capacity <= usize::MAX >> 1,
"broadcast channel capacity exceeded `usize::MAX / 2`"
);
capacity = capacity.next_power_of_two();
let mut buffer = Vec::with_capacity(capacity);
for i in 0..capacity {
buffer.push(RwLock::new(Slot {
rem: AtomicUsize::new(0),
pos: (i as u64).wrapping_sub(capacity as u64),
val: UnsafeCell::new(None),
}));
}
let shared = Arc::new(Shared {
buffer: buffer.into_boxed_slice(),
mask: capacity - 1,
tail: Mutex::new(Tail {
pos: 0,
rx_cnt: receiver_count,
closed: false,
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
});
Sender { shared }
}
pub fn send(&self, value: T) -> Result<usize, SendError<T>> {
let mut tail = self.shared.tail.lock();
if tail.rx_cnt == 0 {
return Err(SendError(value));
}
let pos = tail.pos;
let rem = tail.rx_cnt;
let idx = (pos & self.shared.mask as u64) as usize;
tail.pos = tail.pos.wrapping_add(1);
let mut slot = self.shared.buffer[idx].write().unwrap();
slot.pos = pos;
slot.rem.with_mut(|v| *v = rem);
slot.val = UnsafeCell::new(Some(value));
drop(slot);
self.shared.notify_rx(tail);
Ok(rem)
}
pub fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
new_receiver(shared)
}
pub fn len(&self) -> usize {
let tail = self.shared.tail.lock();
let base_idx = (tail.pos & self.shared.mask as u64) as usize;
let mut low = 0;
let mut high = self.shared.buffer.len();
while low < high {
let mid = low + (high - low) / 2;
let idx = base_idx.wrapping_add(mid) & self.shared.mask;
if self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 {
low = mid + 1;
} else {
high = mid;
}
}
self.shared.buffer.len() - low
}
pub fn is_empty(&self) -> bool {
let tail = self.shared.tail.lock();
let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize;
self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0
}
pub fn receiver_count(&self) -> usize {
let tail = self.shared.tail.lock();
tail.rx_cnt
}
pub fn same_channel(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.shared, &other.shared)
}
fn close_channel(&self) {
let mut tail = self.shared.tail.lock();
tail.closed = true;
self.shared.notify_rx(tail);
}
}
fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {
let mut tail = shared.tail.lock();
assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers");
tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
let next = tail.pos;
drop(tail);
Receiver { shared, next }
}
struct WaitersList<'a, T> {
list: GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
is_empty: bool,
shared: &'a Shared<T>,
}
impl<'a, T> Drop for WaitersList<'a, T> {
fn drop(&mut self) {
if !self.is_empty {
let _lock_guard = self.shared.tail.lock();
while self.list.pop_back().is_some() {}
}
}
}
impl<'a, T> WaitersList<'a, T> {
fn new(
unguarded_list: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
guard: Pin<&'a Waiter>,
shared: &'a Shared<T>,
) -> Self {
let guard_ptr = NonNull::from(guard.get_ref());
let list = unguarded_list.into_guarded(guard_ptr);
WaitersList {
list,
is_empty: false,
shared,
}
}
fn pop_back_locked(&mut self, _tail: &mut Tail) -> Option<NonNull<Waiter>> {
let result = self.list.pop_back();
if result.is_none() {
self.is_empty = true;
}
result
}
}
impl<T> Shared<T> {
fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) {
let guard = Waiter::new();
pin!(guard);
let mut list = WaitersList::new(std::mem::take(&mut tail.waiters), guard.as_ref(), self);
let mut wakers = WakeList::new();
'outer: loop {
while wakers.can_push() {
match list.pop_back_locked(&mut tail) {
Some(mut waiter) => {
let waiter = unsafe { waiter.as_mut() };
assert!(waiter.queued);
waiter.queued = false;
if let Some(waker) = waiter.waker.take() {
wakers.push(waker);
}
}
None => {
break 'outer;
}
}
}
drop(tail);
wakers.wake_all();
tail = self.tail.lock();
}
drop(tail);
wakers.wake_all();
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Sender<T> {
let shared = self.shared.clone();
shared.num_tx.fetch_add(1, SeqCst);
Sender { shared }
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) {
self.close_channel();
}
}
}
impl<T> Receiver<T> {
pub fn len(&self) -> usize {
let next_send_pos = self.shared.tail.lock().pos;
(next_send_pos - self.next) as usize
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn same_channel(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.shared, &other.shared)
}
fn recv_ref(
&mut self,
waiter: Option<(&UnsafeCell<Waiter>, &Waker)>,
) -> Result<RecvGuard<'_, T>, TryRecvError> {
let idx = (self.next & self.shared.mask as u64) as usize;
let mut slot = self.shared.buffer[idx].read().unwrap();
if slot.pos != self.next {
drop(slot);
let mut old_waker = None;
let mut tail = self.shared.tail.lock();
slot = self.shared.buffer[idx].read().unwrap();
if slot.pos != self.next {
let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);
if next_pos == self.next {
if tail.closed {
return Err(TryRecvError::Closed);
}
if let Some((waiter, waker)) = waiter {
unsafe {
waiter.with_mut(|ptr| {
match (*ptr).waker {
Some(ref w) if w.will_wake(waker) => {}
_ => {
old_waker = std::mem::replace(
&mut (*ptr).waker,
Some(waker.clone()),
);
}
}
if !(*ptr).queued {
(*ptr).queued = true;
tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr));
}
});
}
}
drop(slot);
drop(tail);
drop(old_waker);
return Err(TryRecvError::Empty);
}
let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64);
let missed = next.wrapping_sub(self.next);
drop(tail);
if missed == 0 {
self.next = self.next.wrapping_add(1);
return Ok(RecvGuard { slot });
}
self.next = next;
return Err(TryRecvError::Lagged(missed));
}
}
self.next = self.next.wrapping_add(1);
Ok(RecvGuard { slot })
}
}
impl<T: Clone> Receiver<T> {
pub fn resubscribe(&self) -> Self {
let shared = self.shared.clone();
new_receiver(shared)
}
pub async fn recv(&mut self) -> Result<T, RecvError> {
let fut = Recv::new(self);
fut.await
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
let guard = self.recv_ref(None)?;
guard.clone_value().ok_or(TryRecvError::Closed)
}
pub fn blocking_recv(&mut self) -> Result<T, RecvError> {
crate::future::block_on(self.recv())
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let mut tail = self.shared.tail.lock();
tail.rx_cnt -= 1;
let until = tail.pos;
drop(tail);
while self.next < until {
match self.recv_ref(None) {
Ok(_) => {}
Err(TryRecvError::Closed) => break,
Err(TryRecvError::Lagged(..)) => {}
Err(TryRecvError::Empty) => panic!("unexpected empty broadcast channel"),
}
}
}
}
impl<'a, T> Recv<'a, T> {
fn new(receiver: &'a mut Receiver<T>) -> Recv<'a, T> {
Recv {
receiver,
waiter: UnsafeCell::new(Waiter {
queued: false,
waker: None,
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
}),
}
}
fn project(self: Pin<&mut Self>) -> (&mut Receiver<T>, &UnsafeCell<Waiter>) {
unsafe {
is_unpin::<&mut Receiver<T>>();
let me = self.get_unchecked_mut();
(me.receiver, &me.waiter)
}
}
}
impl<'a, T> Future for Recv<'a, T>
where
T: Clone,
{
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
ready!(crate::trace::trace_leaf(cx));
let (receiver, waiter) = self.project();
let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) {
Ok(value) => value,
Err(TryRecvError::Empty) => return Poll::Pending,
Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))),
Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)),
};
Poll::Ready(guard.clone_value().ok_or(RecvError::Closed))
}
}
impl<'a, T> Drop for Recv<'a, T> {
fn drop(&mut self) {
let mut tail = self.receiver.shared.tail.lock();
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });
if queued {
unsafe {
self.waiter.with_mut(|ptr| {
tail.waiters.remove((&mut *ptr).into());
});
}
}
}
}
unsafe impl linked_list::Link for Waiter {
type Handle = NonNull<Waiter>;
type Target = Waiter;
fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
*handle
}
unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
ptr
}
unsafe fn pointers(target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
Waiter::addr_of_pointers(target)
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "broadcast::Sender")
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "broadcast::Receiver")
}
}
impl<'a, T> RecvGuard<'a, T> {
fn clone_value(&self) -> Option<T>
where
T: Clone,
{
self.slot.val.with(|ptr| unsafe { (*ptr).clone() })
}
}
impl<'a, T> Drop for RecvGuard<'a, T> {
fn drop(&mut self) {
if 1 == self.slot.rem.fetch_sub(1, SeqCst) {
self.slot.val.with_mut(|ptr| unsafe { *ptr = None });
}
}
}
fn is_unpin<T: Unpin>() {}
#[cfg(not(loom))]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn receiver_count_on_sender_constructor() {
let sender = Sender::<i32>::new(16);
assert_eq!(sender.receiver_count(), 0);
let rx_1 = sender.subscribe();
assert_eq!(sender.receiver_count(), 1);
let rx_2 = rx_1.resubscribe();
assert_eq!(sender.receiver_count(), 2);
let rx_3 = sender.subscribe();
assert_eq!(sender.receiver_count(), 3);
drop(rx_3);
drop(rx_1);
assert_eq!(sender.receiver_count(), 1);
drop(rx_2);
assert_eq!(sender.receiver_count(), 0);
}
#[cfg(not(loom))]
#[test]
fn receiver_count_on_channel_constructor() {
let (sender, rx) = channel::<i32>(16);
assert_eq!(sender.receiver_count(), 1);
let _rx_2 = rx.resubscribe();
assert_eq!(sender.receiver_count(), 2);
}
}