use spinwait::SpinWait;
use std::cell::Cell;
use std::mem;
use std::ptr;
use std::sync::atomic::{fence, AtomicUsize, Ordering};
use thread_parker::ThreadParker;
struct ThreadData {
parker: ThreadParker,
queue_tail: Cell<*const ThreadData>,
prev: Cell<*const ThreadData>,
next: Cell<*const ThreadData>,
}
impl ThreadData {
fn new() -> ThreadData {
ThreadData {
parker: ThreadParker::new(),
queue_tail: Cell::new(ptr::null()),
prev: Cell::new(ptr::null()),
next: Cell::new(ptr::null()),
}
}
}
unsafe fn get_thread_data(local: &mut Option<ThreadData>) -> &ThreadData {
if !cfg!(windows) && !cfg!(all(feature = "nightly", target_os = "linux")) {
thread_local!(static THREAD_DATA: ThreadData = ThreadData::new());
if let Ok(tls) = THREAD_DATA.try_with(|x| x as *const ThreadData) {
return &*tls;
}
}
*local = Some(ThreadData::new());
local.as_ref().unwrap()
}
const LOCKED_BIT: usize = 1;
const QUEUE_LOCKED_BIT: usize = 2;
const QUEUE_MASK: usize = !3;
pub struct WordLock {
state: AtomicUsize,
}
impl WordLock {
#[inline]
pub fn new() -> WordLock {
WordLock {
state: AtomicUsize::new(0),
}
}
#[inline]
pub unsafe fn lock(&self) {
if self
.state
.compare_exchange_weak(0, LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return;
}
self.lock_slow();
}
#[inline]
pub unsafe fn unlock(&self) {
let state = self.state.fetch_sub(LOCKED_BIT, Ordering::Release);
if state & QUEUE_LOCKED_BIT != 0 || state & QUEUE_MASK == 0 {
return;
}
self.unlock_slow();
}
#[cold]
#[inline(never)]
unsafe fn lock_slow(&self) {
let mut spinwait = SpinWait::new();
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state & LOCKED_BIT == 0 {
match self.state.compare_exchange_weak(
state,
state | LOCKED_BIT,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(x) => state = x,
}
continue;
}
if state & QUEUE_MASK == 0 && spinwait.spin() {
state = self.state.load(Ordering::Relaxed);
continue;
}
let mut thread_data = None;
let thread_data = get_thread_data(&mut thread_data);
assert!(mem::align_of_val(thread_data) > !QUEUE_MASK);
thread_data.parker.prepare_park();
let queue_head = (state & QUEUE_MASK) as *const ThreadData;
if queue_head.is_null() {
thread_data.queue_tail.set(thread_data);
thread_data.prev.set(ptr::null());
} else {
thread_data.queue_tail.set(ptr::null());
thread_data.prev.set(ptr::null());
thread_data.next.set(queue_head);
}
if let Err(x) = self.state.compare_exchange_weak(
state,
(state & !QUEUE_MASK) | thread_data as *const _ as usize,
Ordering::Release,
Ordering::Relaxed,
) {
state = x;
continue;
}
thread_data.parker.park();
spinwait.reset();
self.state.load(Ordering::Relaxed);
}
}
#[cold]
#[inline(never)]
unsafe fn unlock_slow(&self) {
let mut state = self.state.load(Ordering::Relaxed);
loop {
if state & QUEUE_LOCKED_BIT != 0 || state & QUEUE_MASK == 0 {
return;
}
match self.state.compare_exchange_weak(
state,
state | QUEUE_LOCKED_BIT,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => state = x,
}
}
'outer: loop {
let queue_head = (state & QUEUE_MASK) as *const ThreadData;
let mut queue_tail;
let mut current = queue_head;
loop {
queue_tail = (*current).queue_tail.get();
if !queue_tail.is_null() {
break;
}
let next = (*current).next.get();
(*next).prev.set(current);
current = next;
}
(*queue_head).queue_tail.set(queue_tail);
if state & LOCKED_BIT != 0 {
match self.state.compare_exchange_weak(
state,
state & !QUEUE_LOCKED_BIT,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(x) => state = x,
}
fence(Ordering::Acquire);
continue;
}
let new_tail = (*queue_tail).prev.get();
if new_tail.is_null() {
loop {
match self.state.compare_exchange_weak(
state,
state & LOCKED_BIT,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => state = x,
}
if state & QUEUE_MASK == 0 {
continue;
} else {
fence(Ordering::Acquire);
continue 'outer;
}
}
} else {
(*queue_head).queue_tail.set(new_tail);
self.state.fetch_and(!QUEUE_LOCKED_BIT, Ordering::Release);
}
(*queue_tail).parker.unpark_lock().unpark();
break;
}
}
}