use crate::cast::cast;
use crate::Integer;
use gmp_mpfr_sys::gmp::{self, randstate_t};
use std::marker::PhantomData;
use std::mem;
use std::os::raw::{c_int, c_ulong, c_void};
#[cfg(not(ffi_panic_aborts))]
use std::panic::{self, AssertUnwindSafe};
use std::process;
use std::ptr;
#[repr(transparent)]
pub struct RandState<'a> {
inner: randstate_t,
phantom: PhantomData<&'a dyn RandGen>,
}
impl Default for RandState<'_> {
#[inline]
fn default() -> RandState<'static> {
RandState::new()
}
}
impl Clone for RandState<'_> {
#[inline]
fn clone(&self) -> RandState<'static> {
unsafe {
let mut inner = mem::zeroed();
gmp::randinit_set(&mut inner, self.as_raw());
let ptr = cast_ptr!(&inner, MpRandState);
if (*ptr).seed.d.is_null() {
panic!("`RandGen::boxed_clone` returned `None`");
}
RandState { inner, phantom: PhantomData }
}
}
}
impl Drop for RandState<'_> {
#[inline]
fn drop(&mut self) {
unsafe {
gmp::randclear(self.as_raw_mut());
}
}
}
unsafe impl Send for RandState<'_> {}
unsafe impl Sync for RandState<'_> {}
impl RandState<'_> {
#[inline]
pub fn new() -> RandState<'static> {
unsafe {
let mut inner = mem::zeroed();
gmp::randinit_default(&mut inner);
RandState { inner, phantom: PhantomData }
}
}
pub fn new_mersenne_twister() -> RandState<'static> {
unsafe {
let mut inner = mem::zeroed();
gmp::randinit_mt(&mut inner);
RandState { inner, phantom: PhantomData }
}
}
pub fn new_linear_congruential(
a: &Integer,
c: u32,
bits: u32,
) -> RandState<'static> {
unsafe {
let mut inner = mem::zeroed();
gmp::randinit_lc_2exp(
&mut inner,
a.as_raw(),
c.into(),
bits.into(),
);
RandState { inner, phantom: PhantomData }
}
}
pub fn new_linear_congruential_size(
size: u32,
) -> Option<RandState<'static>> {
unsafe {
let mut inner = mem::zeroed();
if gmp::randinit_lc_2exp_size(&mut inner, size.into()) != 0 {
Some(RandState { inner, phantom: PhantomData })
} else {
None
}
}
}
pub fn new_custom(custom: &mut dyn RandGen) -> RandState<'_> {
let b: Box<&mut dyn RandGen> = Box::new(custom);
let r_ptr: *mut &mut dyn RandGen = Box::into_raw(b);
let inner = MpRandState {
seed: gmp::mpz_t {
alloc: 0,
size: 0,
d: r_ptr as *mut gmp::limb_t,
},
alg: 0,
algdata: &CUSTOM_FUNCS as *const Funcs as *mut c_void,
};
RandState {
inner: unsafe { mem::transmute(inner) },
phantom: PhantomData,
}
}
pub fn new_custom_boxed(custom: Box<dyn RandGen>) -> RandState<'static> {
let b: Box<Box<dyn RandGen>> = Box::new(custom);
let r_ptr: *mut Box<dyn RandGen> = Box::into_raw(b);
let inner = MpRandState {
seed: gmp::mpz_t {
alloc: 0,
size: 0,
d: r_ptr as *mut gmp::limb_t,
},
alg: 0,
algdata: &CUSTOM_BOXED_FUNCS as *const Funcs as *mut c_void,
};
RandState {
inner: unsafe { mem::transmute(inner) },
phantom: PhantomData,
}
}
#[inline]
pub unsafe fn from_raw(raw: randstate_t) -> RandState<'static> {
RandState { inner: raw, phantom: PhantomData }
}
#[inline]
pub fn into_raw(self) -> randstate_t {
let ret = self.inner;
mem::forget(self);
ret
}
#[inline]
pub fn as_raw(&self) -> *const randstate_t {
&self.inner
}
#[inline]
pub fn as_raw_mut(&mut self) -> *mut randstate_t {
&mut self.inner
}
#[inline]
pub fn seed(&mut self, seed: &Integer) {
unsafe {
gmp::randseed(self.as_raw_mut(), seed.as_raw());
}
}
#[inline]
pub fn bits(&mut self, bits: u32) -> u32 {
assert!(bits <= 32, "bits out of range");
unsafe { gmp::urandomb_ui(self.as_raw_mut(), bits.into()) as u32 }
}
#[inline]
pub fn below(&mut self, bound: u32) -> u32 {
assert_ne!(bound, 0, "cannot be below zero");
unsafe { gmp::urandomm_ui(self.as_raw_mut(), bound.into()) as u32 }
}
}
pub trait RandGen: Send + Sync {
fn gen(&mut self) -> u32;
fn gen_bits(&mut self, bits: u32) -> u32 {
let gen = self.gen();
match bits {
0 => 0,
1..=32 => gen >> (32 - bits),
_ => gen,
}
}
#[inline]
fn seed(&mut self, seed: &Integer) {
let _ = seed;
}
#[inline]
fn boxed_clone(&self) -> Option<Box<dyn RandGen>> {
None
}
}
#[repr(C)]
struct MpRandState {
seed: gmp::mpz_t,
alg: c_int,
algdata: *mut c_void,
}
fn _static_assertions() {
static_assert_size!(RandState<'_>, randstate_t);
static_assert_size!(MpRandState, randstate_t);
}
#[repr(C)]
struct Funcs {
seed: Option<unsafe extern "C" fn(*mut randstate_t, *const gmp::mpz_t)>,
get: Option<
unsafe extern "C" fn(*mut randstate_t, *mut gmp::limb_t, c_ulong),
>,
clear: Option<unsafe extern "C" fn(*mut randstate_t)>,
iset: Option<unsafe extern "C" fn(*mut randstate_t, *const randstate_t)>,
}
#[cfg(not(ffi_panic_aborts))]
macro_rules! c_callback {
($(fn $func:ident($($param:tt)*) $body:block)*) => { $(
unsafe extern "C" fn $func($($param)*) {
panic::catch_unwind(AssertUnwindSafe(|| $body))
.unwrap_or_else(|_| process::abort())
}
)* };
}
#[cfg(ffi_panic_aborts)]
macro_rules! c_callback {
($(fn $func:ident($($param:tt)*) $body:block)*) => { $(
unsafe extern "C" fn $func($($param)*) $body
)* };
}
c_callback! {
fn abort_seed(_: *mut randstate_t, _: *const gmp::mpz_t) {
process::abort();
}
fn abort_get(_: *mut randstate_t, _: *mut gmp::limb_t, _: c_ulong) {
process::abort();
}
fn abort_clear(_: *mut randstate_t) {
process::abort();
}
fn abort_iset(_: *mut randstate_t, _: *const randstate_t) {
process::abort();
}
fn custom_seed(s: *mut randstate_t, seed: *const gmp::mpz_t) {
let s_ptr = cast_ptr_mut!(s, MpRandState);
let r_ptr = (*s_ptr).seed.d as *mut &mut dyn RandGen;
(*r_ptr).seed(&*cast_ptr!(seed, Integer));
}
fn custom_get(s: *mut randstate_t, limb: *mut gmp::limb_t, bits: c_ulong) {
let s_ptr = cast_ptr_mut!(s, MpRandState);
let r_ptr = (*s_ptr).seed.d as *mut &mut dyn RandGen;
gen_bits(*r_ptr, limb, bits);
}
fn custom_clear(s: *mut randstate_t) {
let s_ptr = cast_ptr_mut!(s, MpRandState);
let r_ptr = (*s_ptr).seed.d as *mut &mut dyn RandGen;
drop(Box::from_raw(r_ptr));
}
fn custom_iset(dst: *mut randstate_t, src: *const randstate_t) {
let src_ptr = cast_ptr!(src, MpRandState);
let r_ptr = (*src_ptr).seed.d as *const &mut dyn RandGen;
gen_copy(*r_ptr, dst);
}
fn custom_boxed_seed(s: *mut randstate_t, seed: *const gmp::mpz_t) {
let s_ptr = cast_ptr_mut!(s, MpRandState);
let r_ptr = (*s_ptr).seed.d as *mut Box<dyn RandGen>;
(*r_ptr).seed(&*cast_ptr!(seed, Integer));
}
fn custom_boxed_get(
s: *mut randstate_t,
limb: *mut gmp::limb_t,
bits: c_ulong,
) {
let s_ptr = cast_ptr_mut!(s, MpRandState);
let r_ptr = (*s_ptr).seed.d as *mut Box<dyn RandGen>;
gen_bits(&mut **r_ptr, limb, bits);
}
fn custom_boxed_clear(s: *mut randstate_t) {
let s_ptr = cast_ptr_mut!(s, MpRandState);
let r_ptr = (*s_ptr).seed.d as *mut Box<dyn RandGen>;
drop(Box::from_raw(r_ptr));
}
fn custom_boxed_iset(dst: *mut randstate_t, src: *const randstate_t) {
let src_ptr = cast_ptr!(src, MpRandState);
let r_ptr = (*src_ptr).seed.d as *const Box<dyn RandGen>;
gen_copy(&**r_ptr, dst);
}
}
#[cfg(gmp_limb_bits_64)]
unsafe fn gen_bits(
gen: &mut dyn RandGen,
limb: *mut gmp::limb_t,
bits: c_ulong,
) {
let (limbs, rest) = (bits / 64, bits % 64);
let limbs: isize = cast(limbs);
for i in 0..limbs {
let n = u64::from(gen.gen()) | u64::from(gen.gen()) << 32;
*limb.offset(i) = cast(n);
}
if rest >= 32 {
let mut n = u64::from(gen.gen());
if rest > 32 {
let mask = !(!0 << (rest - 32));
n |= u64::from(gen.gen_bits(cast(rest - 32)) & mask) << 32;
}
*limb.offset(limbs) = cast(n);
} else if rest > 0 {
let mask = !(!0 << rest);
let n = u64::from(gen.gen_bits(cast(rest)) & mask);
*limb.offset(limbs) = cast(n);
}
}
#[cfg(gmp_limb_bits_32)]
unsafe fn gen_bits(gen: &mut RandGen, limb: *mut gmp::limb_t, bits: c_ulong) {
let (limbs, rest) = (bits / 32, bits % 32);
let limbs: isize = cast(limbs);
for i in 0..limbs {
*limb.offset(i) = cast(gen.gen());
}
if rest > 0 {
let mask = !(!0 << rest);
*limb.offset(limbs) = cast(gen.gen_bits(cast(rest)) & mask);
}
}
unsafe fn gen_copy(gen: &dyn RandGen, dst: *mut randstate_t) {
let other = gen.boxed_clone();
let (dst_r_ptr, funcs) = if let Some(other) = other {
let b: Box<Box<dyn RandGen>> = Box::new(other);
let dst_r_ptr: *mut Box<dyn RandGen> = Box::into_raw(b);
let funcs = &CUSTOM_BOXED_FUNCS as *const Funcs as *mut c_void;
(dst_r_ptr, funcs)
} else {
(ptr::null_mut(), &ABORT_FUNCS as *const Funcs as *mut c_void)
};
let dst_ptr = cast_ptr_mut!(dst, MpRandState);
*dst_ptr = MpRandState {
seed: gmp::mpz_t {
alloc: 0,
size: 0,
d: dst_r_ptr as *mut gmp::limb_t,
},
alg: 0,
algdata: funcs,
};
}
const ABORT_FUNCS: Funcs = Funcs {
seed: Some(abort_seed),
get: Some(abort_get),
clear: Some(abort_clear),
iset: Some(abort_iset),
};
const CUSTOM_FUNCS: Funcs = Funcs {
seed: Some(custom_seed),
get: Some(custom_get),
clear: Some(custom_clear),
iset: Some(custom_iset),
};
const CUSTOM_BOXED_FUNCS: Funcs = Funcs {
seed: Some(custom_boxed_seed),
get: Some(custom_boxed_get),
clear: Some(custom_boxed_clear),
iset: Some(custom_boxed_iset),
};
#[cfg(test)]
mod tests {
use crate::rand::{RandGen, RandState};
struct SimpleGenerator {
seed: u64,
}
impl RandGen for SimpleGenerator {
fn gen(&mut self) -> u32 {
self.seed = self
.seed
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
(self.seed >> 32) as u32
}
fn boxed_clone(&self) -> Option<Box<dyn RandGen>> {
let other = SimpleGenerator { seed: self.seed };
let boxed = Box::new(other);
Some(boxed)
}
}
#[test]
fn check_custom_clone() {
let mut gen = SimpleGenerator { seed: 1 };
let third2;
{
let mut rand1 = RandState::new_custom(&mut gen);
let mut rand2 = rand1.clone();
let first1 = rand1.bits(32);
let first2 = rand2.bits(32);
assert_eq!(first1, first2);
let second1 = rand1.bits(32);
let second2 = rand2.bits(32);
assert_eq!(second1, second2);
assert_ne!(first1, second1);
third2 = rand2.bits(32);
assert_ne!(second2, third2);
}
let mut rand3 = RandState::new_custom_boxed(Box::new(gen));
let mut rand4 = rand3.clone();
let third3 = rand3.bits(32);
let third4 = rand4.bits(32);
assert_eq!(third2, third3);
assert_eq!(third2, third4);
}
struct NoCloneGenerator;
impl RandGen for NoCloneGenerator {
fn gen(&mut self) -> u32 {
0
}
}
#[test]
#[should_panic(expected = "`RandGen::boxed_clone` returned `None`")]
fn check_custom_no_clone() {
let mut gen = NoCloneGenerator;
let rand1 = RandState::new_custom(&mut gen);
let _ = rand1.clone();
}
}