[go: up one dir, main page]

half/
rand_distr.rs

1use crate::{bf16, f16};
2
3use rand::{distr::Distribution, Rng};
4use rand_distr::uniform::UniformFloat;
5
6macro_rules! impl_distribution_via_f32 {
7    ($Ty:ty, $Distr:ty) => {
8        impl Distribution<$Ty> for $Distr {
9            fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $Ty {
10                <$Ty>::from_f32(<Self as Distribution<f32>>::sample(self, rng))
11            }
12        }
13    };
14}
15
16impl_distribution_via_f32!(f16, rand_distr::StandardUniform);
17impl_distribution_via_f32!(f16, rand_distr::StandardNormal);
18impl_distribution_via_f32!(f16, rand_distr::Exp1);
19impl_distribution_via_f32!(f16, rand_distr::Open01);
20impl_distribution_via_f32!(f16, rand_distr::OpenClosed01);
21
22impl_distribution_via_f32!(bf16, rand_distr::StandardUniform);
23impl_distribution_via_f32!(bf16, rand_distr::StandardNormal);
24impl_distribution_via_f32!(bf16, rand_distr::Exp1);
25impl_distribution_via_f32!(bf16, rand_distr::Open01);
26impl_distribution_via_f32!(bf16, rand_distr::OpenClosed01);
27
28impl rand::distr::weighted::Weight for f16 {
29    const ZERO: Self = Self::ZERO;
30
31    fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
32        // Floats have an explicit representation for overflow
33        *self += v;
34        Ok(())
35    }
36}
37
38impl rand::distr::weighted::Weight for bf16 {
39    const ZERO: Self = Self::ZERO;
40
41    fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
42        // Floats have an explicit representation for overflow
43        *self += v;
44        Ok(())
45    }
46}
47
48#[derive(Debug, Clone, Copy)]
49pub struct Float16Sampler(UniformFloat<f32>);
50
51impl rand_distr::uniform::SampleUniform for f16 {
52    type Sampler = Float16Sampler;
53}
54
55impl rand_distr::uniform::UniformSampler for Float16Sampler {
56    type X = f16;
57    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
58    where
59        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
60        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
61    {
62        Ok(Self(UniformFloat::new(
63            low.borrow().to_f32(),
64            high.borrow().to_f32(),
65        )?))
66    }
67    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
68    where
69        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
70        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
71    {
72        Ok(Self(UniformFloat::new_inclusive(
73            low.borrow().to_f32(),
74            high.borrow().to_f32(),
75        )?))
76    }
77    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
78        f16::from_f32(self.0.sample(rng))
79    }
80}
81
82#[derive(Debug, Clone, Copy)]
83pub struct BFloat16Sampler(UniformFloat<f32>);
84
85impl rand_distr::uniform::SampleUniform for bf16 {
86    type Sampler = BFloat16Sampler;
87}
88
89impl rand_distr::uniform::UniformSampler for BFloat16Sampler {
90    type X = bf16;
91    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
92    where
93        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
94        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
95    {
96        Ok(Self(UniformFloat::new(
97            low.borrow().to_f32(),
98            high.borrow().to_f32(),
99        )?))
100    }
101    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
102    where
103        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
104        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
105    {
106        Ok(Self(UniformFloat::new_inclusive(
107            low.borrow().to_f32(),
108            high.borrow().to_f32(),
109        )?))
110    }
111    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
112        bf16::from_f32(self.0.sample(rng))
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[allow(unused_imports)]
121    use rand::{rng, Rng};
122    use rand_distr::{StandardNormal, StandardUniform, Uniform};
123
124    #[test]
125    fn test_sample_f16() {
126        let mut rng = rng();
127        let _: f16 = rng.sample(StandardUniform);
128        let _: f16 = rng.sample(StandardNormal);
129        let _: f16 = rng.sample(Uniform::new(f16::from_f32(0.0), f16::from_f32(1.0)).unwrap());
130        #[cfg(feature = "num-traits")]
131        let _: f16 =
132            rng.sample(rand_distr::Normal::new(f16::from_f32(0.0), f16::from_f32(1.0)).unwrap());
133    }
134
135    #[test]
136    fn test_sample_bf16() {
137        let mut rng = rng();
138        let _: bf16 = rng.sample(StandardUniform);
139        let _: bf16 = rng.sample(StandardNormal);
140        let _: bf16 = rng.sample(Uniform::new(bf16::from_f32(0.0), bf16::from_f32(1.0)).unwrap());
141        #[cfg(feature = "num-traits")]
142        let _: bf16 =
143            rng.sample(rand_distr::Normal::new(bf16::from_f32(0.0), bf16::from_f32(1.0)).unwrap());
144    }
145}