1use core::{
2 cmp::Ordering,
3 iter::{Product, Sum},
4 num::FpCategory,
5 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign},
6};
7#[cfg(not(target_arch = "spirv"))]
8use core::{
9 fmt::{
10 Binary,
11 Debug,
12 Display,
13 Error,
14 Formatter,
15 LowerExp,
16 LowerHex,
17 Octal,
18 UpperExp,
19 UpperHex,
20 },
21 num::ParseFloatError,
22 str::FromStr,
23};
24
25use crate::error::TryFromFloatError;
26use crate::try_from::try_from_lossless;
27
28pub(crate) mod convert;
29
30#[repr(C)]
40#[allow(non_camel_case_types)]
41#[derive(Clone, Copy, Default)]
42#[cfg_attr(kani, derive(kani::Arbitrary))]
43pub struct bf16(u16);
44
45impl bf16 {
46 #[inline]
48 #[must_use]
49 pub const fn from_bits(bits: u16) -> bf16 {
50 bf16(bits)
51 }
52
53 #[inline]
60 #[must_use]
61 pub fn from_f32(value: f32) -> bf16 {
62 Self::from_f32_const(value)
63 }
64
65 #[inline]
77 #[must_use]
78 pub const fn from_f32_const(value: f32) -> bf16 {
79 bf16(convert::f32_to_bf16(value))
80 }
81
82 #[inline]
110 pub const fn from_f32_lossless(value: f32) -> Option<bf16> {
111 const BF16_MANT_BITS: u32 = bf16::MANTISSA_DIGITS - 1;
116 const F32_MANT_BITS: u32 = f32::MANTISSA_DIGITS - 1;
117 const EXP_MASK: u32 = (f32::MAX_EXP as u32 * 2 - 1) << F32_MANT_BITS;
118 const TRUNCATED: u32 = F32_MANT_BITS - BF16_MANT_BITS;
119 const TRUNC_MASK: u32 = (1 << TRUNCATED) - 1;
120
121 let bits: u32 = unsafe { core::mem::transmute(value) };
123
124 let exp = bits & EXP_MASK;
129 let is_special = exp == EXP_MASK;
130 if is_special || bits & TRUNC_MASK == 0 {
131 Some(Self::from_f32_const(value))
132 } else {
133 None
134 }
135 }
136
137 #[inline]
145 #[must_use]
146 pub fn from_f64(value: f64) -> bf16 {
147 Self::from_f64_const(value)
148 }
149
150 #[inline]
163 #[must_use]
164 pub const fn from_f64_const(value: f64) -> bf16 {
165 bf16(convert::f64_to_bf16(value))
166 }
167
168 #[inline]
192 pub const fn from_f64_lossless(value: f64) -> Option<bf16> {
193 try_from_lossless!(
194 value => value,
195 half => bf16,
196 full => f64,
197 half_bits => u16,
198 full_bits => u64,
199 to_half => from_f64
200 )
201 }
202
203 #[inline]
205 #[must_use]
206 pub const fn to_bits(self) -> u16 {
207 self.0
208 }
209
210 #[inline]
221 #[must_use]
222 pub const fn to_le_bytes(self) -> [u8; 2] {
223 self.0.to_le_bytes()
224 }
225
226 #[inline]
237 #[must_use]
238 pub const fn to_be_bytes(self) -> [u8; 2] {
239 self.0.to_be_bytes()
240 }
241
242 #[inline]
261 #[must_use]
262 pub const fn to_ne_bytes(self) -> [u8; 2] {
263 self.0.to_ne_bytes()
264 }
265
266 #[inline]
277 #[must_use]
278 pub const fn from_le_bytes(bytes: [u8; 2]) -> bf16 {
279 bf16::from_bits(u16::from_le_bytes(bytes))
280 }
281
282 #[inline]
293 #[must_use]
294 pub const fn from_be_bytes(bytes: [u8; 2]) -> bf16 {
295 bf16::from_bits(u16::from_be_bytes(bytes))
296 }
297
298 #[inline]
317 #[must_use]
318 pub const fn from_ne_bytes(bytes: [u8; 2]) -> bf16 {
319 bf16::from_bits(u16::from_ne_bytes(bytes))
320 }
321
322 #[inline]
327 #[must_use]
328 pub fn to_f32(self) -> f32 {
329 self.to_f32_const()
330 }
331
332 #[inline]
342 #[must_use]
343 pub const fn to_f32_const(self) -> f32 {
344 convert::bf16_to_f32(self.0)
345 }
346
347 #[inline(always)]
349 pub fn as_f32(self) -> f32 {
350 self.to_f32_const()
351 }
352
353 #[inline(always)]
355 pub const fn as_f32_const(self) -> f32 {
356 self.to_f32_const()
357 }
358
359 #[inline]
364 #[must_use]
365 pub fn to_f64(self) -> f64 {
366 self.to_f64_const()
367 }
368
369 #[inline]
379 #[must_use]
380 pub const fn to_f64_const(self) -> f64 {
381 convert::bf16_to_f64(self.0)
382 }
383
384 #[inline(always)]
386 pub fn as_f64(self) -> f64 {
387 self.to_f64_const()
388 }
389
390 #[inline(always)]
392 pub const fn as_f64_const(self) -> f64 {
393 self.to_f64_const()
394 }
395
396 #[inline]
410 #[must_use]
411 pub const fn is_nan(self) -> bool {
412 self.0 & Self::NOT_SIGN > Self::EXP_MASK
413 }
414
415 #[must_use]
417 #[inline(always)]
418 pub const fn abs(self) -> Self {
419 Self(self.0 & !Self::SIGN_MASK)
420 }
421
422 #[inline]
441 #[must_use]
442 pub const fn is_infinite(self) -> bool {
443 self.0 & Self::NOT_SIGN == Self::EXP_MASK
444 }
445
446 #[inline]
465 #[must_use]
466 pub const fn is_finite(self) -> bool {
467 self.0 & Self::EXP_MASK != Self::EXP_MASK
468 }
469
470 #[must_use]
474 #[inline(always)]
475 pub const fn is_subnormal(self) -> bool {
476 matches!(self.classify(), FpCategory::Subnormal)
477 }
478
479 #[inline]
502 #[must_use]
503 pub const fn is_normal(self) -> bool {
504 let exp = self.0 & Self::EXP_MASK;
505 exp != Self::EXP_MASK && exp != 0
506 }
507
508 #[inline]
526 #[must_use]
527 pub const fn classify(self) -> FpCategory {
528 let exp = self.0 & Self::EXP_MASK;
529 let man = self.0 & Self::MAN_MASK;
530 match (exp, man) {
531 (0, 0) => FpCategory::Zero,
532 (0, _) => FpCategory::Subnormal,
533 (Self::EXP_MASK, 0) => FpCategory::Infinite,
534 (Self::EXP_MASK, _) => FpCategory::Nan,
535 _ => FpCategory::Normal,
536 }
537 }
538
539 #[inline]
559 #[must_use]
560 pub const fn signum(self) -> bf16 {
561 if self.is_nan() {
562 self
563 } else if self.0 & Self::SIGN_MASK != 0 {
564 Self::NEG_ONE
565 } else {
566 Self::ONE
567 }
568 }
569
570 #[inline]
588 #[must_use]
589 pub const fn is_sign_positive(self) -> bool {
590 self.0 & Self::SIGN_MASK == 0
591 }
592
593 #[inline]
611 #[must_use]
612 pub const fn is_sign_negative(self) -> bool {
613 self.0 & Self::SIGN_MASK != 0
614 }
615
616 #[inline]
637 #[must_use]
638 pub const fn copysign(self, sign: bf16) -> bf16 {
639 bf16((sign.0 & Self::SIGN_MASK) | (self.0 & Self::NOT_SIGN))
640 }
641
642 #[must_use]
644 #[inline(always)]
645 pub fn recip(self) -> Self {
646 Self::ONE / self
647 }
648
649 #[must_use]
651 #[inline(always)]
652 pub fn to_degrees(self) -> Self {
653 self * Self::from(180u8) / Self::PI
654 }
655
656 #[must_use]
658 #[inline(always)]
659 pub fn to_radians(self) -> Self {
660 self * Self::PI / Self::from(180u8)
661 }
662
663 #[inline]
677 #[must_use]
678 pub const fn max(self, other: bf16) -> bf16 {
679 if self.is_nan() || gt(other, self) {
680 other
681 } else {
682 self
683 }
684 }
685
686 #[inline]
700 #[must_use]
701 pub const fn min(self, other: bf16) -> bf16 {
702 if self.is_nan() || lt(other, self) {
703 other
704 } else {
705 self
706 }
707 }
708
709 #[inline]
730 #[must_use]
731 pub const fn clamp(self, min: bf16, max: bf16) -> bf16 {
732 assert!(le(min, max));
733 let mut x = self;
734 if lt(x, min) {
735 x = min;
736 }
737 if gt(x, max) {
738 x = max;
739 }
740 x
741 }
742
743 #[inline]
811 #[must_use]
812 pub fn total_cmp(&self, other: &Self) -> Ordering {
813 let mut left = self.to_bits() as i16;
814 let mut right = other.to_bits() as i16;
815 left ^= (((left >> 15) as u16) >> 1) as i16;
816 right ^= (((right >> 15) as u16) >> 1) as i16;
817 left.cmp(&right)
818 }
819
820 pub const DIGITS: u32 = 2;
822 pub const EPSILON: bf16 = bf16(0x3C00u16);
828 pub const INFINITY: bf16 = bf16(0x7F80u16);
830 pub const MANTISSA_DIGITS: u32 = 8;
832 pub const MAX: bf16 = bf16(0x7F7F);
834 pub const MAX_10_EXP: i32 = 38;
836 pub const MAX_EXP: i32 = 128;
838 pub const MIN: bf16 = bf16(0xFF7F);
840 pub const MIN_10_EXP: i32 = -37;
842 pub const MIN_EXP: i32 = -125;
845 pub const MIN_POSITIVE: bf16 = bf16(0x0080u16);
847 pub const NAN: bf16 = bf16(0x7FC0u16);
849 pub const NEG_INFINITY: bf16 = bf16(0xFF80u16);
851 pub const RADIX: u32 = 2;
853
854 pub const MIN_POSITIVE_SUBNORMAL: bf16 = bf16(0x0001u16);
856 pub const MAX_SUBNORMAL: bf16 = bf16(0x007Fu16);
858
859 pub const ONE: bf16 = bf16(0x3F80u16);
861 pub const ZERO: bf16 = bf16(0x0000u16);
863 pub const NEG_ZERO: bf16 = bf16(0x8000u16);
865 pub const NEG_ONE: bf16 = bf16(0xBF80u16);
867
868 pub const E: bf16 = bf16(0x402Eu16);
870 pub const PI: bf16 = bf16(0x4049u16);
872 pub const FRAC_1_PI: bf16 = bf16(0x3EA3u16);
874 pub const FRAC_1_SQRT_2: bf16 = bf16(0x3F35u16);
876 pub const FRAC_2_PI: bf16 = bf16(0x3F23u16);
878 pub const FRAC_2_SQRT_PI: bf16 = bf16(0x3F90u16);
880 pub const FRAC_PI_2: bf16 = bf16(0x3FC9u16);
882 pub const FRAC_PI_3: bf16 = bf16(0x3F86u16);
884 pub const FRAC_PI_4: bf16 = bf16(0x3F49u16);
886 pub const FRAC_PI_6: bf16 = bf16(0x3F06u16);
888 pub const FRAC_PI_8: bf16 = bf16(0x3EC9u16);
890 pub const LN_10: bf16 = bf16(0x4013u16);
892 pub const LN_2: bf16 = bf16(0x3F31u16);
894 pub const LOG10_E: bf16 = bf16(0x3EDEu16);
896 pub const LOG10_2: bf16 = bf16(0x3E9Au16);
898 pub const LOG2_E: bf16 = bf16(0x3FB9u16);
900 pub const LOG2_10: bf16 = bf16(0x4055u16);
902 pub const SQRT_2: bf16 = bf16(0x3FB5u16);
904
905 pub const SIGN_MASK: u16 = 0x8000;
907 const NOT_SIGN: u16 = !Self::SIGN_MASK;
909
910 pub const EXP_MASK: u16 = 0x7F80;
912
913 pub const HIDDEN_BIT_MASK: u16 = 0x0080;
915
916 pub const MAN_MASK: u16 = 0x007F;
918
919 pub const TINY_BITS: u16 = 0x1;
921
922 pub const NEG_TINY_BITS: u16 = Self::TINY_BITS | Self::SIGN_MASK;
924}
925
926macro_rules! from_int_impl {
927 ($t:ty, $func:ident) => {
928 #[inline(always)]
930 pub const fn $func(value: $t) -> Self {
931 Self::from_f32_const(value as f32)
932 }
933 };
934}
935
936impl bf16 {
937 from_int_impl!(u8, from_u8);
938 from_int_impl!(u16, from_u16);
939 from_int_impl!(u32, from_u32);
940 from_int_impl!(u64, from_u64);
941 from_int_impl!(u128, from_u128);
942 from_int_impl!(i8, from_i8);
943 from_int_impl!(i16, from_i16);
944 from_int_impl!(i32, from_i32);
945 from_int_impl!(i64, from_i64);
946 from_int_impl!(i128, from_i128);
947}
948
949impl From<bf16> for f32 {
950 #[inline]
951 fn from(x: bf16) -> f32 {
952 x.to_f32()
953 }
954}
955
956impl From<bf16> for f64 {
957 #[inline]
958 fn from(x: bf16) -> f64 {
959 x.to_f64()
960 }
961}
962
963impl From<i8> for bf16 {
964 #[inline]
965 fn from(x: i8) -> bf16 {
966 bf16::from_f32(f32::from(x))
968 }
969}
970
971impl From<u8> for bf16 {
972 #[inline]
973 fn from(x: u8) -> bf16 {
974 bf16::from_f32(f32::from(x))
976 }
977}
978
979impl TryFrom<f32> for bf16 {
980 type Error = TryFromFloatError;
981
982 #[inline]
983 fn try_from(x: f32) -> Result<Self, Self::Error> {
984 Self::from_f32_lossless(x).ok_or(TryFromFloatError(()))
985 }
986}
987
988impl TryFrom<f64> for bf16 {
989 type Error = TryFromFloatError;
990
991 #[inline]
992 fn try_from(x: f64) -> Result<Self, Self::Error> {
993 Self::from_f64_lossless(x).ok_or(TryFromFloatError(()))
994 }
995}
996
997impl PartialEq for bf16 {
998 fn eq(&self, other: &bf16) -> bool {
999 eq(*self, *other)
1000 }
1001}
1002
1003impl PartialOrd for bf16 {
1004 fn partial_cmp(&self, other: &bf16) -> Option<Ordering> {
1005 if self.is_nan() || other.is_nan() {
1006 None
1007 } else {
1008 let neg = self.0 & Self::SIGN_MASK != 0;
1009 let other_neg = other.0 & Self::SIGN_MASK != 0;
1010 match (neg, other_neg) {
1011 (false, false) => Some(self.0.cmp(&other.0)),
1012 (false, true) => {
1013 if (self.0 | other.0) & Self::NOT_SIGN == 0 {
1014 Some(Ordering::Equal)
1015 } else {
1016 Some(Ordering::Greater)
1017 }
1018 },
1019 (true, false) => {
1020 if (self.0 | other.0) & Self::NOT_SIGN == 0 {
1021 Some(Ordering::Equal)
1022 } else {
1023 Some(Ordering::Less)
1024 }
1025 },
1026 (true, true) => Some(other.0.cmp(&self.0)),
1027 }
1028 }
1029 }
1030
1031 fn lt(&self, other: &bf16) -> bool {
1032 lt(*self, *other)
1033 }
1034
1035 fn le(&self, other: &bf16) -> bool {
1036 le(*self, *other)
1037 }
1038
1039 fn gt(&self, other: &bf16) -> bool {
1040 gt(*self, *other)
1041 }
1042
1043 fn ge(&self, other: &bf16) -> bool {
1044 ge(*self, *other)
1045 }
1046}
1047
1048#[cfg(not(target_arch = "spirv"))]
1049impl FromStr for bf16 {
1050 type Err = ParseFloatError;
1051
1052 #[inline]
1053 fn from_str(src: &str) -> Result<bf16, ParseFloatError> {
1054 f32::from_str(src).map(bf16::from_f32)
1055 }
1056}
1057
1058#[cfg(not(target_arch = "spirv"))]
1059impl Debug for bf16 {
1060 #[inline]
1061 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1062 Debug::fmt(&self.to_f32(), f)
1063 }
1064}
1065
1066#[cfg(not(target_arch = "spirv"))]
1067impl Display for bf16 {
1068 #[inline]
1069 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1070 Display::fmt(&self.to_f32(), f)
1071 }
1072}
1073
1074#[cfg(not(target_arch = "spirv"))]
1075impl LowerExp for bf16 {
1076 #[inline]
1077 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1078 write!(f, "{:e}", self.to_f32())
1079 }
1080}
1081
1082#[cfg(not(target_arch = "spirv"))]
1083impl UpperExp for bf16 {
1084 #[inline]
1085 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1086 write!(f, "{:E}", self.to_f32())
1087 }
1088}
1089
1090#[cfg(not(target_arch = "spirv"))]
1091impl Binary for bf16 {
1092 #[inline]
1093 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1094 write!(f, "{:b}", self.0)
1095 }
1096}
1097
1098#[cfg(not(target_arch = "spirv"))]
1099impl Octal for bf16 {
1100 #[inline]
1101 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1102 write!(f, "{:o}", self.0)
1103 }
1104}
1105
1106#[cfg(not(target_arch = "spirv"))]
1107impl LowerHex for bf16 {
1108 #[inline]
1109 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1110 write!(f, "{:x}", self.0)
1111 }
1112}
1113
1114#[cfg(not(target_arch = "spirv"))]
1115impl UpperHex for bf16 {
1116 #[inline]
1117 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
1118 write!(f, "{:X}", self.0)
1119 }
1120}
1121
1122impl Neg for bf16 {
1123 type Output = Self;
1124
1125 #[inline]
1126 fn neg(self) -> Self::Output {
1127 Self(self.0 ^ Self::SIGN_MASK)
1128 }
1129}
1130
1131impl Neg for &bf16 {
1132 type Output = <bf16 as Neg>::Output;
1133
1134 #[inline]
1135 fn neg(self) -> Self::Output {
1136 Neg::neg(*self)
1137 }
1138}
1139
1140impl Add for bf16 {
1141 type Output = Self;
1142
1143 #[inline]
1144 fn add(self, rhs: Self) -> Self::Output {
1145 Self::from_f32(Self::to_f32(self) + Self::to_f32(rhs))
1146 }
1147}
1148
1149impl Add<&bf16> for bf16 {
1150 type Output = <bf16 as Add<bf16>>::Output;
1151
1152 #[inline]
1153 fn add(self, rhs: &bf16) -> Self::Output {
1154 self.add(*rhs)
1155 }
1156}
1157
1158impl Add<&bf16> for &bf16 {
1159 type Output = <bf16 as Add<bf16>>::Output;
1160
1161 #[inline]
1162 fn add(self, rhs: &bf16) -> Self::Output {
1163 (*self).add(*rhs)
1164 }
1165}
1166
1167impl Add<bf16> for &bf16 {
1168 type Output = <bf16 as Add<bf16>>::Output;
1169
1170 #[inline]
1171 fn add(self, rhs: bf16) -> Self::Output {
1172 (*self).add(rhs)
1173 }
1174}
1175
1176impl AddAssign for bf16 {
1177 #[inline]
1178 fn add_assign(&mut self, rhs: Self) {
1179 *self = (*self).add(rhs);
1180 }
1181}
1182
1183impl AddAssign<&bf16> for bf16 {
1184 #[inline]
1185 fn add_assign(&mut self, rhs: &bf16) {
1186 *self = (*self).add(rhs);
1187 }
1188}
1189
1190impl Sub for bf16 {
1191 type Output = Self;
1192
1193 #[inline]
1194 fn sub(self, rhs: Self) -> Self::Output {
1195 Self::from_f32(Self::to_f32(self) - Self::to_f32(rhs))
1196 }
1197}
1198
1199impl Sub<&bf16> for bf16 {
1200 type Output = <bf16 as Sub<bf16>>::Output;
1201
1202 #[inline]
1203 fn sub(self, rhs: &bf16) -> Self::Output {
1204 self.sub(*rhs)
1205 }
1206}
1207
1208impl Sub<&bf16> for &bf16 {
1209 type Output = <bf16 as Sub<bf16>>::Output;
1210
1211 #[inline]
1212 fn sub(self, rhs: &bf16) -> Self::Output {
1213 (*self).sub(*rhs)
1214 }
1215}
1216
1217impl Sub<bf16> for &bf16 {
1218 type Output = <bf16 as Sub<bf16>>::Output;
1219
1220 #[inline]
1221 fn sub(self, rhs: bf16) -> Self::Output {
1222 (*self).sub(rhs)
1223 }
1224}
1225
1226impl SubAssign for bf16 {
1227 #[inline]
1228 fn sub_assign(&mut self, rhs: Self) {
1229 *self = (*self).sub(rhs);
1230 }
1231}
1232
1233impl SubAssign<&bf16> for bf16 {
1234 #[inline]
1235 fn sub_assign(&mut self, rhs: &bf16) {
1236 *self = (*self).sub(rhs);
1237 }
1238}
1239
1240impl Mul for bf16 {
1241 type Output = Self;
1242
1243 #[inline]
1244 fn mul(self, rhs: Self) -> Self::Output {
1245 Self::from_f32(Self::to_f32(self) * Self::to_f32(rhs))
1246 }
1247}
1248
1249impl Mul<&bf16> for bf16 {
1250 type Output = <bf16 as Mul<bf16>>::Output;
1251
1252 #[inline]
1253 fn mul(self, rhs: &bf16) -> Self::Output {
1254 self.mul(*rhs)
1255 }
1256}
1257
1258impl Mul<&bf16> for &bf16 {
1259 type Output = <bf16 as Mul<bf16>>::Output;
1260
1261 #[inline]
1262 fn mul(self, rhs: &bf16) -> Self::Output {
1263 (*self).mul(*rhs)
1264 }
1265}
1266
1267impl Mul<bf16> for &bf16 {
1268 type Output = <bf16 as Mul<bf16>>::Output;
1269
1270 #[inline]
1271 fn mul(self, rhs: bf16) -> Self::Output {
1272 (*self).mul(rhs)
1273 }
1274}
1275
1276impl MulAssign for bf16 {
1277 #[inline]
1278 fn mul_assign(&mut self, rhs: Self) {
1279 *self = (*self).mul(rhs);
1280 }
1281}
1282
1283impl MulAssign<&bf16> for bf16 {
1284 #[inline]
1285 fn mul_assign(&mut self, rhs: &bf16) {
1286 *self = (*self).mul(rhs);
1287 }
1288}
1289
1290impl Div for bf16 {
1291 type Output = Self;
1292
1293 #[inline]
1294 fn div(self, rhs: Self) -> Self::Output {
1295 Self::from_f32(Self::to_f32(self) / Self::to_f32(rhs))
1296 }
1297}
1298
1299impl Div<&bf16> for bf16 {
1300 type Output = <bf16 as Div<bf16>>::Output;
1301
1302 #[inline]
1303 fn div(self, rhs: &bf16) -> Self::Output {
1304 self.div(*rhs)
1305 }
1306}
1307
1308impl Div<&bf16> for &bf16 {
1309 type Output = <bf16 as Div<bf16>>::Output;
1310
1311 #[inline]
1312 fn div(self, rhs: &bf16) -> Self::Output {
1313 (*self).div(*rhs)
1314 }
1315}
1316
1317impl Div<bf16> for &bf16 {
1318 type Output = <bf16 as Div<bf16>>::Output;
1319
1320 #[inline]
1321 fn div(self, rhs: bf16) -> Self::Output {
1322 (*self).div(rhs)
1323 }
1324}
1325
1326impl DivAssign for bf16 {
1327 #[inline]
1328 fn div_assign(&mut self, rhs: Self) {
1329 *self = (*self).div(rhs);
1330 }
1331}
1332
1333impl DivAssign<&bf16> for bf16 {
1334 #[inline]
1335 fn div_assign(&mut self, rhs: &bf16) {
1336 *self = (*self).div(rhs);
1337 }
1338}
1339
1340impl Rem for bf16 {
1341 type Output = Self;
1342
1343 fn rem(self, rhs: Self) -> Self::Output {
1344 Self::from_f32(Self::to_f32(self) % Self::to_f32(rhs))
1345 }
1346}
1347
1348impl Rem<&bf16> for bf16 {
1349 type Output = <bf16 as Rem<bf16>>::Output;
1350
1351 #[inline]
1352 fn rem(self, rhs: &bf16) -> Self::Output {
1353 self.rem(*rhs)
1354 }
1355}
1356
1357impl Rem<&bf16> for &bf16 {
1358 type Output = <bf16 as Rem<bf16>>::Output;
1359
1360 #[inline]
1361 fn rem(self, rhs: &bf16) -> Self::Output {
1362 (*self).rem(*rhs)
1363 }
1364}
1365
1366impl Rem<bf16> for &bf16 {
1367 type Output = <bf16 as Rem<bf16>>::Output;
1368
1369 #[inline]
1370 fn rem(self, rhs: bf16) -> Self::Output {
1371 (*self).rem(rhs)
1372 }
1373}
1374
1375impl RemAssign for bf16 {
1376 #[inline]
1377 fn rem_assign(&mut self, rhs: Self) {
1378 *self = (*self).rem(rhs);
1379 }
1380}
1381
1382impl RemAssign<&bf16> for bf16 {
1383 #[inline]
1384 fn rem_assign(&mut self, rhs: &bf16) {
1385 *self = (*self).rem(rhs);
1386 }
1387}
1388
1389impl Product for bf16 {
1390 #[inline]
1391 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
1392 bf16::from_f32(iter.map(|f| f.to_f32()).product())
1393 }
1394}
1395
1396impl<'a> Product<&'a bf16> for bf16 {
1397 #[inline]
1398 fn product<I: Iterator<Item = &'a bf16>>(iter: I) -> Self {
1399 bf16::from_f32(iter.map(|f| f.to_f32()).product())
1400 }
1401}
1402
1403impl Sum for bf16 {
1404 #[inline]
1405 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
1406 bf16::from_f32(iter.map(|f| f.to_f32()).sum())
1407 }
1408}
1409
1410impl<'a> Sum<&'a bf16> for bf16 {
1411 #[inline]
1412 fn sum<I: Iterator<Item = &'a bf16>>(iter: I) -> Self {
1413 bf16::from_f32(iter.map(|f| f.to_f32()).sum())
1414 }
1415}
1416
1417#[inline]
1418const fn eq(lhs: bf16, rhs: bf16) -> bool {
1419 if lhs.is_nan() || rhs.is_nan() {
1420 false
1421 } else {
1422 (lhs.0 == rhs.0) || ((lhs.0 | rhs.0) & bf16::NOT_SIGN == 0)
1423 }
1424}
1425
1426#[inline]
1427const fn lt(lhs: bf16, rhs: bf16) -> bool {
1428 if lhs.is_nan() || rhs.is_nan() {
1429 false
1430 } else {
1431 let neg = lhs.0 & bf16::SIGN_MASK != 0;
1432 let rhs_neg = rhs.0 & bf16::SIGN_MASK != 0;
1433 match (neg, rhs_neg) {
1434 (false, false) => lhs.0 < rhs.0,
1435 (false, true) => false,
1436 (true, false) => (lhs.0 | rhs.0) & bf16::NOT_SIGN != 0,
1437 (true, true) => lhs.0 > rhs.0,
1438 }
1439 }
1440}
1441
1442#[inline]
1443const fn le(lhs: bf16, rhs: bf16) -> bool {
1444 if lhs.is_nan() || rhs.is_nan() {
1445 false
1446 } else {
1447 let neg = lhs.0 & bf16::SIGN_MASK != 0;
1448 let rhs_neg = rhs.0 & bf16::SIGN_MASK != 0;
1449 match (neg, rhs_neg) {
1450 (false, false) => lhs.0 <= rhs.0,
1451 (false, true) => (lhs.0 | rhs.0) & bf16::NOT_SIGN == 0,
1452 (true, false) => true,
1453 (true, true) => lhs.0 >= rhs.0,
1454 }
1455 }
1456}
1457
1458#[inline]
1459const fn gt(lhs: bf16, rhs: bf16) -> bool {
1460 if lhs.is_nan() || rhs.is_nan() {
1461 false
1462 } else {
1463 let neg = lhs.0 & bf16::SIGN_MASK != 0;
1464 let rhs_neg = rhs.0 & bf16::SIGN_MASK != 0;
1465 match (neg, rhs_neg) {
1466 (false, false) => lhs.0 > rhs.0,
1467 (false, true) => (lhs.0 | rhs.0) & bf16::NOT_SIGN != 0,
1468 (true, false) => false,
1469 (true, true) => lhs.0 < rhs.0,
1470 }
1471 }
1472}
1473
1474#[inline]
1475const fn ge(lhs: bf16, rhs: bf16) -> bool {
1476 if lhs.is_nan() || rhs.is_nan() {
1477 false
1478 } else {
1479 let neg = lhs.0 & bf16::SIGN_MASK != 0;
1480 let rhs_neg = rhs.0 & bf16::SIGN_MASK != 0;
1481 match (neg, rhs_neg) {
1482 (false, false) => lhs.0 >= rhs.0,
1483 (false, true) => true,
1484 (true, false) => (lhs.0 | rhs.0) & bf16::NOT_SIGN == 0,
1485 (true, true) => lhs.0 <= rhs.0,
1486 }
1487 }
1488}
1489
1490#[allow(clippy::cognitive_complexity, clippy::float_cmp, clippy::neg_cmp_op_on_partial_ord)]
1491#[cfg(test)]
1492mod test {
1493 use core::cmp::Ordering;
1494
1495 use super::*;
1496
1497 #[test]
1498 fn test_bf16_consts_from_f32() {
1499 let class="number">1.0);
1500 let zero = bf16::from_f32(0.0);
1501 let neg_zero = bf16::from_f32(-0.0);
1502 let neg_one = bf16::from_f32(-1.0);
1503 let inf = bf16::from_f32(core::f32::INFINITY);
1504 let neg_inf = bf16::from_f32(core::f32::NEG_INFINITY);
1505 let nan = bf16::from_f32(core::f32::NAN);
1506
1507 assert_eq!(bf16::ONE, one);
1508 assert_eq!(bf16::ZERO, zero);
1509 assert!(zero.is_sign_positive());
1510 assert_eq!(bf16::NEG_ZERO, neg_zero);
1511 assert!(neg_zero.is_sign_negative());
1512 assert_eq!(bf16::NEG_ONE, neg_one);
1513 assert!(neg_one.is_sign_negative());
1514 assert_eq!(bf16::INFINITY, inf);
1515 assert_eq!(bf16::NEG_INFINITY, neg_inf);
1516 assert!(nan.is_nan());
1517 assert!(bf16::NAN.is_nan());
1518
1519 let e = bf16::from_f32(core::f32::consts::E);
1520 let pi = bf16::from_f32(core::f32::consts::PI);
1521 let frac_1_pi = bf16::from_f32(core::f32::consts::FRAC_1_PI);
1522 let frac_1_sqrt_2 = bf16::from_f32(core::f32::consts::FRAC_1_SQRT_2);
1523 let frac_2_pi = bf16::from_f32(core::f32::consts::FRAC_2_PI);
1524 let frac_2_sqrt_pi = bf16::from_f32(core::f32::consts::FRAC_2_SQRT_PI);
1525 let frac_pi_2 = bf16::from_f32(core::f32::consts::FRAC_PI_2);
1526 let frac_pi_3 = bf16::from_f32(core::f32::consts::FRAC_PI_3);
1527 let frac_pi_4 = bf16::from_f32(core::f32::consts::FRAC_PI_4);
1528 let frac_pi_6 = bf16::from_f32(core::f32::consts::FRAC_PI_6);
1529 let frac_pi_8 = bf16::from_f32(core::f32::consts::FRAC_PI_8);
1530 let ln_10 = bf16::from_f32(core::f32::consts::LN_10);
1531 let ln_2 = bf16::from_f32(core::f32::consts::LN_2);
1532 let log10_e = bf16::from_f32(core::f32::consts::LOG10_E);
1533 let log10_2 = bf16::from_f32(2f32.log10());
1535 let log2_e = bf16::from_f32(core::f32::consts::LOG2_E);
1536 let log2_10 = bf16::from_f32(10f32.log2());
1538 let sqrt_2 = bf16::from_f32(core::f32::consts::SQRT_2);
1539
1540 assert_eq!(bf16::E, e);
1541 assert_eq!(bf16::PI, pi);
1542 assert_eq!(bf16::FRAC_1_PI, frac_1_pi);
1543 assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2);
1544 assert_eq!(bf16::FRAC_2_PI, frac_2_pi);
1545 assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi);
1546 assert_eq!(bf16::FRAC_PI_2, frac_pi_2);
1547 assert_eq!(bf16::FRAC_PI_3, frac_pi_3);
1548 assert_eq!(bf16::FRAC_PI_4, frac_pi_4);
1549 assert_eq!(bf16::FRAC_PI_6, frac_pi_6);
1550 assert_eq!(bf16::FRAC_PI_8, frac_pi_8);
1551 assert_eq!(bf16::LN_10, ln_10);
1552 assert_eq!(bf16::LN_2, ln_2);
1553 assert_eq!(bf16::LOG10_E, log10_e);
1554 assert_eq!(bf16::LOG10_2, log10_2);
1555 assert_eq!(bf16::LOG2_E, log2_e);
1556 assert_eq!(bf16::LOG2_10, log2_10);
1557 assert_eq!(bf16::SQRT_2, sqrt_2);
1558 }
1559
1560 #[test]
1561 fn test_bf16_consts_from_f64() {
1562 let class="number">1.0);
1563 let zero = bf16::from_f64(0.0);
1564 let neg_zero = bf16::from_f64(-0.0);
1565 let inf = bf16::from_f64(core::f64::INFINITY);
1566 let neg_inf = bf16::from_f64(core::f64::NEG_INFINITY);
1567 let nan = bf16::from_f64(core::f64::NAN);
1568
1569 assert_eq!(bf16::ONE, one);
1570 assert_eq!(bf16::ZERO, zero);
1571 assert_eq!(bf16::NEG_ZERO, neg_zero);
1572 assert_eq!(bf16::INFINITY, inf);
1573 assert_eq!(bf16::NEG_INFINITY, neg_inf);
1574 assert!(nan.is_nan());
1575 assert!(bf16::NAN.is_nan());
1576
1577 let e = bf16::from_f64(core::f64::consts::E);
1578 let pi = bf16::from_f64(core::f64::consts::PI);
1579 let frac_1_pi = bf16::from_f64(core::f64::consts::FRAC_1_PI);
1580 let frac_1_sqrt_2 = bf16::from_f64(core::f64::consts::FRAC_1_SQRT_2);
1581 let frac_2_pi = bf16::from_f64(core::f64::consts::FRAC_2_PI);
1582 let frac_2_sqrt_pi = bf16::from_f64(core::f64::consts::FRAC_2_SQRT_PI);
1583 let frac_pi_2 = bf16::from_f64(core::f64::consts::FRAC_PI_2);
1584 let frac_pi_3 = bf16::from_f64(core::f64::consts::FRAC_PI_3);
1585 let frac_pi_4 = bf16::from_f64(core::f64::consts::FRAC_PI_4);
1586 let frac_pi_6 = bf16::from_f64(core::f64::consts::FRAC_PI_6);
1587 let frac_pi_8 = bf16::from_f64(core::f64::consts::FRAC_PI_8);
1588 let ln_10 = bf16::from_f64(core::f64::consts::LN_10);
1589 let ln_2 = bf16::from_f64(core::f64::consts::LN_2);
1590 let log10_e = bf16::from_f64(core::f64::consts::LOG10_E);
1591 let log10_2 = bf16::from_f64(2f64.log10());
1593 let log2_e = bf16::from_f64(core::f64::consts::LOG2_E);
1594 let log2_10 = bf16::from_f64(10f64.log2());
1596 let sqrt_2 = bf16::from_f64(core::f64::consts::SQRT_2);
1597
1598 assert_eq!(bf16::E, e);
1599 assert_eq!(bf16::PI, pi);
1600 assert_eq!(bf16::FRAC_1_PI, frac_1_pi);
1601 assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2);
1602 assert_eq!(bf16::FRAC_2_PI, frac_2_pi);
1603 assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi);
1604 assert_eq!(bf16::FRAC_PI_2, frac_pi_2);
1605 assert_eq!(bf16::FRAC_PI_3, frac_pi_3);
1606 assert_eq!(bf16::FRAC_PI_4, frac_pi_4);
1607 assert_eq!(bf16::FRAC_PI_6, frac_pi_6);
1608 assert_eq!(bf16::FRAC_PI_8, frac_pi_8);
1609 assert_eq!(bf16::LN_10, ln_10);
1610 assert_eq!(bf16::LN_2, ln_2);
1611 assert_eq!(bf16::LOG10_E, log10_e);
1612 assert_eq!(bf16::LOG10_2, log10_2);
1613 assert_eq!(bf16::LOG2_E, log2_e);
1614 assert_eq!(bf16::LOG2_10, log2_10);
1615 assert_eq!(bf16::SQRT_2, sqrt_2);
1616 }
1617
1618 #[test]
1619 fn test_nan_conversion_to_smaller() {
1620 let nan64 = f64::from_bits(0x7FF0_0000_0000_0001u64);
1621 let neg_nan64 = f64::from_bits(0xFFF0_0000_0000_0001u64);
1622 let nan32 = f32::from_bits(0x7F80_0001u32);
1623 let neg_nan32 = f32::from_bits(0xFF80_0001u32);
1624 let nan32_from_64 = nan64 as f32;
1625 let neg_nan32_from_64 = neg_nan64 as f32;
1626 let nan16_from_64 = bf16::from_f64(nan64);
1627 let neg_nan16_from_64 = bf16::from_f64(neg_nan64);
1628 let nan16_from_32 = bf16::from_f32(nan32);
1629 let neg_nan16_from_32 = bf16::from_f32(neg_nan32);
1630
1631 assert!(nan64.is_nan() && nan64.is_sign_positive());
1632 assert!(neg_nan64.is_nan() && neg_nan64.is_sign_negative());
1633 assert!(nan32.is_nan() && nan32.is_sign_positive());
1634 assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative());
1635
1636 assert!(neg_nan32_from_64.is_nan());
1638 assert!(nan32_from_64.is_nan());
1639 assert!(nan16_from_64.is_nan());
1640 assert!(neg_nan16_from_64.is_nan());
1641 assert!(nan16_from_32.is_nan());
1642 assert!(neg_nan16_from_32.is_nan());
1643 }
1644
1645 #[test]
1646 fn test_nan_conversion_to_larger() {
1647 let nan16 = bf16::from_bits(0x7F81u16);
1648 let neg_nan16 = bf16::from_bits(0xFF81u16);
1649 let nan32 = f32::from_bits(0x7F80_0001u32);
1650 let neg_nan32 = f32::from_bits(0xFF80_0001u32);
1651 let nan32_from_16 = f32::from(nan16);
1652 let neg_nan32_from_16 = f32::from(neg_nan16);
1653 let nan64_from_16 = f64::from(nan16);
1654 let neg_nan64_from_16 = f64::from(neg_nan16);
1655 let nan64_from_32 = f64::from(nan32);
1656 let neg_nan64_from_32 = f64::from(neg_nan32);
1657
1658 assert!(nan16.is_nan() && nan16.is_sign_positive());
1659 assert!(neg_nan16.is_nan() && neg_nan16.is_sign_negative());
1660 assert!(nan32.is_nan() && nan32.is_sign_positive());
1661 assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative());
1662
1663 assert!(nan32_from_16.is_nan());
1665 assert!(neg_nan32_from_16.is_nan());
1666 assert!(nan64_from_16.is_nan());
1667 assert!(neg_nan64_from_16.is_nan());
1668 assert!(nan64_from_32.is_nan());
1669 assert!(neg_nan64_from_32.is_nan());
1670 }
1671
1672 #[test]
1673 fn test_bf16_to_f32() {
1674 let f = bf16::from_f32(7.0);
1675 assert_eq!(f.to_f32(), 7.0f32);
1676
1677 let f = bf16::from_f32(7.1);
1679 let diff = (f.to_f32() - 7.1f32).abs();
1680 assert!(diff <= 4.0 * bf16::EPSILON.to_f32());
1682
1683 let tiny32 = f32::from_bits(0x0001_0000u32);
1684 assert_eq!(bf16::from_bits(0x0001).to_f32(), tiny32);
1685 assert_eq!(bf16::from_bits(0x0005).to_f32(), 5.0 * tiny32);
1686
1687 assert_eq!(bf16::from_bits(0x0001), bf16::from_f32(tiny32));
1688 assert_eq!(bf16::from_bits(0x0005), bf16::from_f32(5.0 * tiny32));
1689 }
1690
1691 #[test]
1692 #[cfg_attr(miri, ignore)]
1693 fn test_bf16_to_f64() {
1694 let f = bf16::from_f64(7.0);
1695 assert_eq!(f.to_f64(), 7.0f64);
1696
1697 let f = bf16::from_f64(7.1);
1699 let diff = (f.to_f64() - 7.1f64).abs();
1700 assert!(diff <= 4.0 * bf16::EPSILON.to_f64());
1702
1703 let tiny64 = 2.0f64.powi(-133);
1704 assert_eq!(bf16::from_bits(0x0001).to_f64(), tiny64);
1705 assert_eq!(bf16::from_bits(0x0005).to_f64(), 5.0 * tiny64);
1706
1707 assert_eq!(bf16::from_bits(0x0001), bf16::from_f64(tiny64));
1708 assert_eq!(bf16::from_bits(0x0005), bf16::from_f64(5.0 * tiny64));
1709 }
1710
1711 #[test]
1712 fn test_comparisons() {
1713 let zero = bf16::from_f64(0.0);
1714 let class="number">1.0);
1715 let neg_zero = bf16::from_f64(-0.0);
1716 let neg_one = bf16::from_f64(-1.0);
1717
1718 assert_eq!(zero.partial_cmp(&neg_zero), Some(Ordering::Equal));
1719 assert_eq!(neg_zero.partial_cmp(&zero), Some(Ordering::Equal));
1720 assert!(zero == neg_zero);
1721 assert!(neg_zero == zero);
1722 assert!(!(zero != neg_zero));
1723 assert!(!(neg_zero != zero));
1724 assert!(!(zero < neg_zero));
1725 assert!(!(neg_zero < zero));
1726 assert!(zero <= neg_zero);
1727 assert!(neg_zero <= zero);
1728 assert!(!(zero > neg_zero));
1729 assert!(!(neg_zero > zero));
1730 assert!(zero >= neg_zero);
1731 assert!(neg_zero >= zero);
1732
1733 assert_eq!(one.partial_cmp(&neg_zero), Some(Ordering::Greater));
1734 assert_eq!(neg_zero.partial_cmp(&one), Some(Ordering::Less));
1735 assert!(!( neg_zero));
1736 assert!(!(neg_zero == one));
1737 assert!(one != neg_zero);
1738 assert!(neg_zero != one);
1739 assert!(!(one < neg_zero));
1740 assert!(neg_zero < one);
1741 assert!(!(one <= neg_zero));
1742 assert!(neg_zero <= one);
1743 assert!(one > neg_zero);
1744 assert!(!(neg_zero > one));
1745 assert!(one >= neg_zero);
1746 assert!(!(neg_zero >= one));
1747
1748 assert_eq!(one.partial_cmp(&neg_one), Some(Ordering::Greater));
1749 assert_eq!(neg_one.partial_cmp(&one), Some(Ordering::Less));
1750 assert!(!( neg_one));
1751 assert!(!(neg_one == one));
1752 assert!(one != neg_one);
1753 assert!(neg_one != one);
1754 assert!(!(one < neg_one));
1755 assert!(neg_one < one);
1756 assert!(!(one <= neg_one));
1757 assert!(neg_one <= one);
1758 assert!(one > neg_one);
1759 assert!(!(neg_one > one));
1760 assert!(one >= neg_one);
1761 assert!(!(neg_one >= one));
1762 }
1763
1764 #[test]
1765 #[allow(clippy::erasing_op, clippy::identity_op)]
1766 #[cfg_attr(miri, ignore)]
1767 fn round_to_even_f32() {
1768 let min_sub = bf16::from_bits(1);
1770 let min_sub_f = (-133f32).exp2();
1771 assert_eq!(bf16::from_f32(min_sub_f).to_bits(), min_sub.to_bits());
1772 assert_eq!(f32::from(min_sub).to_bits(), min_sub_f.to_bits());
1773
1774 assert_eq!(bf16::from_f32(min_sub_f * 0.49).to_bits(), min_sub.to_bits() * 0);
1778 assert_eq!(bf16::from_f32(min_sub_f * 0.50).to_bits(), min_sub.to_bits() * 0);
1779 assert_eq!(bf16::from_f32(min_sub_f * 0.51).to_bits(), min_sub.to_bits() * 1);
1780
1781 assert_eq!(bf16::from_f32(min_sub_f * 1.49).to_bits(), min_sub.to_bits() * 1);
1785 assert_eq!(bf16::from_f32(min_sub_f * 1.50).to_bits(), min_sub.to_bits() * 2);
1786 assert_eq!(bf16::from_f32(min_sub_f * 1.51).to_bits(), min_sub.to_bits() * 2);
1787
1788 assert_eq!(bf16::from_f32(min_sub_f * 2.49).to_bits(), min_sub.to_bits() * 2);
1792 assert_eq!(bf16::from_f32(min_sub_f * 2.50).to_bits(), min_sub.to_bits() * 2);
1793 assert_eq!(bf16::from_f32(min_sub_f * 2.51).to_bits(), min_sub.to_bits() * 3);
1794
1795 assert_eq!(bf16::from_f32(250.49f32).to_bits(), bf16::from_f32(250.0).to_bits());
1796 assert_eq!(bf16::from_f32(250.50f32).to_bits(), bf16::from_f32(250.0).to_bits());
1797 assert_eq!(bf16::from_f32(250.51f32).to_bits(), bf16::from_f32(251.0).to_bits());
1798 assert_eq!(bf16::from_f32(251.49f32).to_bits(), bf16::from_f32(251.0).to_bits());
1799 assert_eq!(bf16::from_f32(251.50f32).to_bits(), bf16::from_f32(252.0).to_bits());
1800 assert_eq!(bf16::from_f32(251.51f32).to_bits(), bf16::from_f32(252.0).to_bits());
1801 assert_eq!(bf16::from_f32(252.49f32).to_bits(), bf16::from_f32(252.0).to_bits());
1802 assert_eq!(bf16::from_f32(252.50f32).to_bits(), bf16::from_f32(252.0).to_bits());
1803 assert_eq!(bf16::from_f32(252.51f32).to_bits(), bf16::from_f32(253.0).to_bits());
1804 }
1805
1806 #[test]
1807 #[allow(clippy::erasing_op, clippy::identity_op)]
1808 #[cfg_attr(miri, ignore)]
1809 fn round_to_even_f64() {
1810 let min_sub = bf16::from_bits(1);
1812 let min_sub_f = (-133f64).exp2();
1813 assert_eq!(bf16::from_f64(min_sub_f).to_bits(), min_sub.to_bits());
1814 assert_eq!(f64::from(min_sub).to_bits(), min_sub_f.to_bits());
1815
1816 assert_eq!(bf16::from_f64(min_sub_f * 0.49).to_bits(), min_sub.to_bits() * 0);
1820 assert_eq!(bf16::from_f64(min_sub_f * 0.50).to_bits(), min_sub.to_bits() * 0);
1821 assert_eq!(bf16::from_f64(min_sub_f * 0.51).to_bits(), min_sub.to_bits() * 1);
1822
1823 assert_eq!(bf16::from_f64(min_sub_f * 1.49).to_bits(), min_sub.to_bits() * 1);
1827 assert_eq!(bf16::from_f64(min_sub_f * 1.50).to_bits(), min_sub.to_bits() * 2);
1828 assert_eq!(bf16::from_f64(min_sub_f * 1.51).to_bits(), min_sub.to_bits() * 2);
1829
1830 assert_eq!(bf16::from_f64(min_sub_f * 2.49).to_bits(), min_sub.to_bits() * 2);
1834 assert_eq!(bf16::from_f64(min_sub_f * 2.50).to_bits(), min_sub.to_bits() * 2);
1835 assert_eq!(bf16::from_f64(min_sub_f * 2.51).to_bits(), min_sub.to_bits() * 3);
1836
1837 assert_eq!(bf16::from_f64(250.49f64).to_bits(), bf16::from_f64(250.0).to_bits());
1838 assert_eq!(bf16::from_f64(250.50f64).to_bits(), bf16::from_f64(250.0).to_bits());
1839 assert_eq!(bf16::from_f64(250.51f64).to_bits(), bf16::from_f64(251.0).to_bits());
1840 assert_eq!(bf16::from_f64(251.49f64).to_bits(), bf16::from_f64(251.0).to_bits());
1841 assert_eq!(bf16::from_f64(251.50f64).to_bits(), bf16::from_f64(252.0).to_bits());
1842 assert_eq!(bf16::from_f64(251.51f64).to_bits(), bf16::from_f64(252.0).to_bits());
1843 assert_eq!(bf16::from_f64(252.49f64).to_bits(), bf16::from_f64(252.0).to_bits());
1844 assert_eq!(bf16::from_f64(252.50f64).to_bits(), bf16::from_f64(252.0).to_bits());
1845 assert_eq!(bf16::from_f64(252.51f64).to_bits(), bf16::from_f64(253.0).to_bits());
1846 }
1847
1848 #[test]
1849 fn from_f32_lossless() {
1850 let from_f32 = |v: f32| bf16::from_f32_lossless(v);
1851 let roundtrip = |v: f32, expected: Option<bf16>| {
1852 let half = from_f32(v);
1853 assert_eq!(half, expected);
1854 if !expected.is_none() {
1855 let as_f32 = expected.unwrap().to_f32_const();
1856 assert_eq!(v, as_f32);
1857 }
1858 };
1859
1860 assert_eq!(from_f32(f32::NAN).map(bf16::is_nan), Some(true));
1861 roundtrip(f32::INFINITY, Some(bf16::INFINITY));
1862 roundtrip(f32::NEG_INFINITY, Some(bf16::NEG_INFINITY));
1863 roundtrip(f32::from_bits(0b0_00000000_00000000000000000000000), Some(bf16(0)));
1864 roundtrip(
1865 f32::from_bits(0b1_00000000_00000000000000000000000),
1866 Some(bf16(bf16::SIGN_MASK)),
1867 );
1868 roundtrip(f32::from_bits(1), None);
1869 roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
1870 roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
1871 roundtrip(f32::from_bits(0b0_00001010_10101011000000000000000), None);
1872 roundtrip(
1873 f32::from_bits(0b0_00001010_10101010000000000000000),
1874 Some(bf16(0b0_00001010_1010101)),
1875 );
1876 roundtrip(f32::from_bits(0b0_00000000_10000000000000000000000), Some(bf16(0x40)));
1877 roundtrip(f32::from_bits(0b0_00000000_00000001000000000000000), None);
1879 roundtrip(f32::from_bits(0b0_00000000_00000010000000000000000), Some(bf16(1)));
1880 roundtrip(f32::from_bits(0b0_00000000_00000100000000000000000), Some(bf16(2)));
1881 roundtrip(f32::from_bits(0b0_00000000_00000110000000000000000), Some(bf16(3)));
1882 roundtrip(f32::from_bits(0b0_00000000_00000111000000000000000), None);
1883 roundtrip(f32::from_bits(0b0_00001011_10100111101101101001001), None);
1884 roundtrip(f32::from_bits(0b0_00001011_10100111100000000000000), None);
1886 roundtrip(f32::from_bits(0b0_00001011_10100111000000000000000), None);
1888 roundtrip(f32::from_bits(0b0_00001011_10100110000000000000000), Some(bf16(0x05d3)));
1890 }
1891
1892 #[test]
1893 fn from_f64_lossless() {
1894 let from_f64 = |v: f64| bf16::from_f64_lossless(v);
1895 let roundtrip = |v: f64, expected: Option<bf16>| {
1896 let half = from_f64(v);
1897 assert_eq!(half, expected);
1898 if !expected.is_none() {
1899 let as_f64 = expected.unwrap().to_f64_const();
1900 assert_eq!(v, as_f64);
1901 }
1902 };
1903
1904 assert_eq!(from_f64(f64::NAN).map(bf16::is_nan), Some(true));
1905 roundtrip(f64::INFINITY, Some(bf16::INFINITY));
1906 roundtrip(f64::NEG_INFINITY, Some(bf16::NEG_INFINITY));
1907 roundtrip(
1908 f64::from_bits(0b0_00000000000_0000000000000000000000000000000000000000000000000000),
1909 Some(bf16(0)),
1910 );
1911 roundtrip(
1912 f64::from_bits(0b1_00000000000_0000000000000000000000000000000000000000000000000000),
1913 Some(bf16(bf16::SIGN_MASK)),
1914 );
1915 roundtrip(
1916 f64::from_bits(0b0_01110001010_1010100101011010010110110111111110000111101000001111),
1917 None,
1918 );
1919 roundtrip(
1921 f64::from_bits(0b0_01110001010_1010100100000000000000000000000000000000000000000000),
1922 None,
1923 );
1924 roundtrip(
1925 f64::from_bits(0b0_01110001010_1010100000000000000000000000000000000000000000000000),
1926 Some(bf16(0x0554)),
1927 );
1928 roundtrip(
1929 f64::from_bits(0b0_01110001010_1010101000000000000000000000000000000000000000000000),
1930 Some(bf16(0x0555)),
1931 );
1932 roundtrip(
1933 f64::from_bits(0b0_01110001010_1010110000000000000000000000000000000000000000000000),
1934 Some(bf16(0x0556)),
1935 );
1936 roundtrip(
1937 f64::from_bits(0b0_01110001010_1010111000000000000000000000000000000000000000000000),
1938 Some(bf16(0x0557)),
1939 );
1940 roundtrip(
1941 f64::from_bits(0b0_01110001010_1010101100000000000000000000000000000000000000000000),
1942 None,
1943 );
1944 roundtrip(
1945 f64::from_bits(0b0_01110001010_1010100110000000000000000000000000000000000000000000),
1946 None,
1947 );
1948 roundtrip(
1949 f64::from_bits(0b1_01110001010_1010100000000000000000000000000000000000000000000000),
1950 Some(bf16(0x8554)),
1951 );
1952 roundtrip(
1953 f64::from_bits(0b1_01110001010_1010101000000000000000000000000000000000000000000000),
1954 Some(bf16(0x8555)),
1955 );
1956 roundtrip(
1958 f64::from_bits(0b1_11110001010_1010101000000000000000000000000000000000000000000000),
1959 None,
1960 );
1961 roundtrip(
1963 f64::from_bits(0b0_01101111010_0000000000000000000000000000000000000000000000000000),
1964 Some(bf16(1)),
1965 );
1966 roundtrip(
1967 f64::from_bits(0b0_01101111011_1000000000000000000000000000000000000000000000000000),
1968 Some(bf16(3)),
1969 );
1970 roundtrip(
1971 f64::from_bits(0b0_01101111011_1100000000000000000000000000000000000000000000000000),
1972 None,
1973 );
1974 roundtrip(
1976 f64::from_bits(0b0_01101111010_0001000000000000000000000000000000000000000000000000),
1977 None,
1978 );
1979 roundtrip(
1980 f64::from_bits(0b0_01101111010_1000000000000000000000000000000000000000000000000000),
1981 None,
1982 );
1983 }
1984
1985 #[test]
1986 fn test_max() {
1987 let a = bf16::from_f32(0.0);
1988 let b = bf16::from_f32(42.0);
1989 assert_eq!(a.max(b), b);
1990
1991 let a = bf16::from_f32(42.0);
1992 let b = bf16::from_f32(0.0);
1993 assert_eq!(a.max(b), a);
1994
1995 let a = bf16::NAN;
1996 let b = bf16::from_f32(42.0);
1997 assert_eq!(a.max(b), b);
1998
1999 let a = bf16::from_f32(42.0);
2000 let b = bf16::NAN;
2001 assert_eq!(a.max(b), a);
2002
2003 let a = bf16::NAN;
2004 let b = bf16::NAN;
2005 assert!(a.max(b).is_nan());
2006 }
2007
2008 #[test]
2009 fn test_min() {
2010 let a = bf16::from_f32(0.0);
2011 let b = bf16::from_f32(42.0);
2012 assert_eq!(a.min(b), a);
2013
2014 let a = bf16::from_f32(42.0);
2015 let b = bf16::from_f32(0.0);
2016 assert_eq!(a.min(b), b);
2017
2018 let a = bf16::NAN;
2019 let b = bf16::from_f32(42.0);
2020 assert_eq!(a.min(b), b);
2021
2022 let a = bf16::from_f32(42.0);
2023 let b = bf16::NAN;
2024 assert_eq!(a.min(b), a);
2025
2026 let a = bf16::NAN;
2027 let b = bf16::NAN;
2028 assert!(a.min(b).is_nan());
2029 }
2030}