use std::marker;
use {Rng, Rand};
pub use self::range::Range;
pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
pub use self::normal::{Normal, LogNormal};
pub use self::exponential::Exp;
pub mod range;
pub mod gamma;
pub mod normal;
pub mod exponential;
pub trait Sample<Support> {
fn sample<R: Rng>(&mut self, rng: &mut R) -> Support;
}
pub trait IndependentSample<Support>: Sample<Support> {
fn ind_sample<R: Rng>(&self, &mut R) -> Support;
}
pub struct RandSample<Sup> {
_marker: marker::PhantomData<fn() -> Sup>,
}
impl<Sup> Copy for RandSample<Sup> {}
impl<Sup> Clone for RandSample<Sup> {
fn clone(&self) -> Self { *self }
}
impl<Sup: Rand> Sample<Sup> for RandSample<Sup> {
fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup { self.ind_sample(rng) }
}
impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
rng.gen()
}
}
impl<Sup> RandSample<Sup> {
pub fn new() -> RandSample<Sup> {
RandSample { _marker: marker::PhantomData }
}
}
#[derive(Copy)]
#[derive(Clone)]
pub struct Weighted<T> {
pub weight: u32,
pub item: T,
}
pub struct WeightedChoice<'a, T:'a> {
items: &'a mut [Weighted<T>],
weight_range: Range<u32>
}
impl<'a, T: Clone> WeightedChoice<'a, T> {
pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
assert!(!items.is_empty(), "WeightedChoice::new called with no items");
let mut running_total: u32 = 0;
for item in items.iter_mut() {
running_total = match running_total.checked_add(item.weight) {
Some(n) => n,
None => panic!("WeightedChoice::new called with a total weight \
larger than a u32 can contain")
};
item.weight = running_total;
}
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
WeightedChoice {
items: items,
weight_range: Range::new(0, running_total)
}
}
}
impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) }
}
impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
let sample_weight = self.weight_range.ind_sample(rng);
if sample_weight < self.items[0].weight {
return self.items[0].item.clone();
}
let mut idx = 0;
let mut modifier = self.items.len();
while modifier > 1 {
let i = idx + modifier / 2;
if self.items[i].weight <= sample_weight {
idx = i;
modifier += 1;
} else {
}
modifier /= 2;
}
return self.items[idx + 1].item.clone();
}
}
mod ziggurat_tables;
#[inline(always)]
fn ziggurat<R: Rng, P, Z>(
rng: &mut R,
symmetric: bool,
x_tab: ziggurat_tables::ZigTable,
f_tab: ziggurat_tables::ZigTable,
mut pdf: P,
mut zero_case: Z)
-> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64 {
const SCALE: f64 = (1u64 << 53) as f64;
loop {
let bits: u64 = rng.gen();
let i = (bits & 0xff) as usize;
let f = (bits >> 11) as f64 / SCALE;
let u = if symmetric {2.0 * f - 1.0} else {f};
let x = u * x_tab[i];
let test_x = if symmetric { x.abs() } else {x};
if test_x < x_tab[i + 1] {
return x;
}
if i == 0 {
return zero_case(rng, u);
}
if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
return x;
}
}
}
#[cfg(test)]
mod tests {
use {Rng, Rand};
use super::{RandSample, WeightedChoice, Weighted, Sample, IndependentSample};
#[derive(PartialEq, Debug)]
struct ConstRand(usize);
impl Rand for ConstRand {
fn rand<R: Rng>(_: &mut R) -> ConstRand {
ConstRand(0)
}
}
struct CountingRng { i: u32 }
impl Rng for CountingRng {
fn next_u32(&mut self) -> u32 {
self.i += 1;
self.i - 1
}
fn next_u64(&mut self) -> u64 {
self.next_u32() as u64
}
}
#[test]
fn test_rand_sample() {
let mut rand_sample = RandSample::<ConstRand>::new();
assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0));
assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0));
}
#[test]
fn test_weighted_choice() {
macro_rules! t {
($items:expr, $expected:expr) => {{
let mut items = $items;
let wc = WeightedChoice::new(&mut items);
let expected = $expected;
let mut rng = CountingRng { i: 0 };
for &val in expected.iter() {
assert_eq!(wc.ind_sample(&mut rng), val)
}
}}
}
t!(vec!(Weighted { weight: 1, item: 10}), [10]);
t!(vec!(Weighted { weight: 0, item: 20},
Weighted { weight: 2, item: 21},
Weighted { weight: 0, item: 22},
Weighted { weight: 1, item: 23}),
[21,21, 23]);
t!(vec!(Weighted { weight: 4, item: 30},
Weighted { weight: 3, item: 31}),
[30,30,30,30, 31,31,31]);
t!(vec!(Weighted { weight: 1, item: 40},
Weighted { weight: 1, item: 41},
Weighted { weight: 1, item: 42},
Weighted { weight: 1, item: 43},
Weighted { weight: 1, item: 44}),
[40, 41, 42, 43, 44]);
t!(vec!(Weighted { weight: 1, item: 50},
Weighted { weight: 1, item: 51},
Weighted { weight: 1, item: 52},
Weighted { weight: 1, item: 53},
Weighted { weight: 1, item: 54},
Weighted { weight: 1, item: 55},
Weighted { weight: 1, item: 56}),
[50, 51, 52, 53, 54, 55, 56]);
}
#[test]
fn test_weighted_clone_initialization() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let clone = initial.clone();
assert_eq!(initial.weight, clone.weight);
assert_eq!(initial.item, clone.item);
}
#[test] #[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_weighted_clone_change_weight() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
clone.weight = 5;
assert_eq!(initial.weight, clone.weight);
}
#[test] #[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_weighted_clone_change_item() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
clone.item = 5;
assert_eq!(initial.item, clone.item);
}
#[test] #[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_weighted_choice_no_items() {
WeightedChoice::<isize>::new(&mut []);
}
#[test] #[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_weighted_choice_zero_weight() {
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
Weighted { weight: 0, item: 1}]);
}
#[test] #[should_panic]
#[cfg_attr(target_env = "msvc", ignore)]
fn test_weighted_choice_weight_overflows() {
let x = ::std::u32::MAX / 2; WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
Weighted { weight: 1, item: 1 },
Weighted { weight: x, item: 2 },
Weighted { weight: 1, item: 3 }]);
}
}