1use crate::{bf16, f16};
2
3use rand::{distr::Distribution, Rng};
4use rand_distr::uniform::UniformFloat;
5
6macro_rules! impl_distribution_via_f32 {
7 ($Ty:ty, $Distr:ty) => {
8 impl Distribution<$Ty> for $Distr {
9 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $Ty {
10 <$Ty>::from_f32(<Self as Distribution<f32>>::sample(self, rng))
11 }
12 }
13 };
14}
15
16impl_distribution_via_f32!(f16, rand_distr::StandardUniform);
17impl_distribution_via_f32!(f16, rand_distr::StandardNormal);
18impl_distribution_via_f32!(f16, rand_distr::Exp1);
19impl_distribution_via_f32!(f16, rand_distr::Open01);
20impl_distribution_via_f32!(f16, rand_distr::OpenClosed01);
21
22impl_distribution_via_f32!(bf16, rand_distr::StandardUniform);
23impl_distribution_via_f32!(bf16, rand_distr::StandardNormal);
24impl_distribution_via_f32!(bf16, rand_distr::Exp1);
25impl_distribution_via_f32!(bf16, rand_distr::Open01);
26impl_distribution_via_f32!(bf16, rand_distr::OpenClosed01);
27
28impl rand::distr::weighted::Weight for f16 {
29 const ZERO: Self = Self::ZERO;
30
31 fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
32 *self += v;
34 Ok(())
35 }
36}
37
38impl rand::distr::weighted::Weight for bf16 {
39 const ZERO: Self = Self::ZERO;
40
41 fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
42 *self += v;
44 Ok(())
45 }
46}
47
48#[derive(Debug, Clone, Copy)]
49pub struct Float16Sampler(UniformFloat<f32>);
50
51impl rand_distr::uniform::SampleUniform for f16 {
52 type Sampler = Float16Sampler;
53}
54
55impl rand_distr::uniform::UniformSampler for Float16Sampler {
56 type X = f16;
57 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
58 where
59 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
60 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
61 {
62 Ok(Self(UniformFloat::new(
63 low.borrow().to_f32(),
64 high.borrow().to_f32(),
65 )?))
66 }
67 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
68 where
69 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
70 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
71 {
72 Ok(Self(UniformFloat::new_inclusive(
73 low.borrow().to_f32(),
74 high.borrow().to_f32(),
75 )?))
76 }
77 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
78 f16::from_f32(self.0.sample(rng))
79 }
80}
81
82#[derive(Debug, Clone, Copy)]
83pub struct BFloat16Sampler(UniformFloat<f32>);
84
85impl rand_distr::uniform::SampleUniform for bf16 {
86 type Sampler = BFloat16Sampler;
87}
88
89impl rand_distr::uniform::UniformSampler for BFloat16Sampler {
90 type X = bf16;
91 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
92 where
93 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
94 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
95 {
96 Ok(Self(UniformFloat::new(
97 low.borrow().to_f32(),
98 high.borrow().to_f32(),
99 )?))
100 }
101 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_distr::uniform::Error>
102 where
103 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
104 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
105 {
106 Ok(Self(UniformFloat::new_inclusive(
107 low.borrow().to_f32(),
108 high.borrow().to_f32(),
109 )?))
110 }
111 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
112 bf16::from_f32(self.0.sample(rng))
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[allow(unused_imports)]
121 use rand::{rng, Rng};
122 use rand_distr::{StandardNormal, StandardUniform, Uniform};
123
124 #[test]
125 fn test_sample_f16() {
126 let mut rng = rng();
127 let _: f16 = rng.sample(StandardUniform);
128 let _: f16 = rng.sample(StandardNormal);
129 let _: f16 = rng.sample(Uniform::new(f16::from_f32(0.0), f16::from_f32(1.0)).unwrap());
130 #[cfg(feature = "num-traits")]
131 let _: f16 =
132 rng.sample(rand_distr::Normal::new(f16::from_f32(0.0), f16::from_f32(1.0)).unwrap());
133 }
134
135 #[test]
136 fn test_sample_bf16() {
137 let mut rng = rng();
138 let _: bf16 = rng.sample(StandardUniform);
139 let _: bf16 = rng.sample(StandardNormal);
140 let _: bf16 = rng.sample(Uniform::new(bf16::from_f32(0.0), bf16::from_f32(1.0)).unwrap());
141 #[cfg(feature = "num-traits")]
142 let _: bf16 =
143 rng.sample(rand_distr::Normal::new(bf16::from_f32(0.0), bf16::from_f32(1.0)).unwrap());
144 }
145}