use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use std::{cmp::Ordering, fmt, str::FromStr, time};
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct Duration {
duration: time::Duration,
is_negative: bool,
}
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
#[non_exhaustive]
pub enum ParseError {
#[error("invalid unit: {}", EXPECTED_UNITS)]
InvalidUnit,
#[error("missing a unit: {}", EXPECTED_UNITS)]
NoUnit,
#[error("invalid floating-point number: {}", .0)]
NotANumber(#[from] std::num::ParseFloatError),
}
const EXPECTED_UNITS: &str = "expected one of 'ns', 'us', '\u{00b5}s', 'ms', 's', 'm', or 'h'";
impl From<time::Duration> for Duration {
fn from(duration: time::Duration) -> Self {
Self {
duration,
is_negative: false,
}
}
}
impl From<Duration> for time::Duration {
fn from(Duration { duration, .. }: Duration) -> Self {
duration
}
}
impl Duration {
#[inline]
#[must_use]
pub fn is_negative(&self) -> bool {
self.is_negative
}
}
impl fmt::Debug for Duration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::fmt::Write;
if self.is_negative {
f.write_char('-')?;
}
fmt::Debug::fmt(&self.duration, f)
}
}
impl fmt::Display for Duration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::fmt::Write;
if self.is_negative {
f.write_char('-')?;
}
fmt::Debug::fmt(&self.duration, f)
}
}
impl FromStr for Duration {
type Err = ParseError;
fn from_str(mut s: &str) -> Result<Self, Self::Err> {
const MINUTE: time::Duration = time::Duration::from_secs(60);
let is_negative = s.starts_with('-');
s = s.trim_start_matches('+').trim_start_matches('-');
let mut total = time::Duration::from_secs(0);
while !s.is_empty() && s != "0" {
let unit_start = s.find(|c: char| c.is_alphabetic()).ok_or(ParseError::NoUnit)?;
let (val, rest) = s.split_at(unit_start);
let val = val.parse::<f64>()?;
let unit = if let Some(next_numeric_start) = rest.find(|c: char| !c.is_alphabetic()) {
let (unit, rest) = rest.split_at(next_numeric_start);
s = rest;
unit
} else {
s = "";
rest
};
let base = match unit {
"ns" => time::Duration::from_nanos(1),
"us" | "\u{00b5}s" | "\u{03bc}s" => time::Duration::from_micros(1),
"ms" => time::Duration::from_millis(1),
"s" => time::Duration::from_secs(1),
"m" => MINUTE,
"h" => MINUTE * 60,
_ => return Err(ParseError::InvalidUnit),
};
total += base.mul_f64(val);
}
Ok(Duration {
duration: total,
is_negative,
})
}
}
impl Serialize for Duration {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.collect_str(self)
}
}
impl<'de> Deserialize<'de> for Duration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor;
impl de::Visitor<'_> for Visitor {
type Value = Duration;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("a string in Go `time.Duration.String()` format")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let val = value.parse::<Duration>().map_err(de::Error::custom)?;
Ok(val)
}
}
deserializer.deserialize_str(Visitor)
}
}
impl PartialEq<time::Duration> for Duration {
fn eq(&self, other: &time::Duration) -> bool {
if self.is_negative {
return false;
}
self.duration == *other
}
}
impl PartialEq<time::Duration> for &'_ Duration {
fn eq(&self, other: &time::Duration) -> bool {
if self.is_negative {
return false;
}
self.duration == *other
}
}
impl PartialEq<Duration> for time::Duration {
fn eq(&self, other: &Duration) -> bool {
if other.is_negative {
return false;
}
self == &other.duration
}
}
impl PartialEq<Duration> for &'_ time::Duration {
fn eq(&self, other: &Duration) -> bool {
if other.is_negative {
return false;
}
*self == &other.duration
}
}
impl PartialOrd for Duration {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Duration {
fn cmp(&self, other: &Self) -> Ordering {
match (self.is_negative, other.is_negative) {
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
(true, true) => self.duration.cmp(&other.duration).reverse(),
(false, false) => self.duration.cmp(&other.duration),
}
}
}
impl PartialOrd<time::Duration> for Duration {
fn partial_cmp(&self, other: &time::Duration) -> Option<Ordering> {
if self.is_negative {
return Some(Ordering::Less);
}
self.duration.partial_cmp(other)
}
}
#[cfg(feature = "schema")]
impl schemars::JsonSchema for Duration {
fn schema_name() -> String {
"Duration".to_owned()
}
fn is_referenceable() -> bool {
false
}
fn json_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
schemars::schema::SchemaObject {
instance_type: Some(schemars::schema::InstanceType::String.into()),
format: None,
..Default::default()
}
.into()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_the_same_as_go() {
const MINUTE: time::Duration = time::Duration::from_secs(60);
const HOUR: time::Duration = time::Duration::from_secs(60 * 60);
let cases: &[(&str, Duration)] = &[
("0", time::Duration::from_secs(0).into()),
("5s", time::Duration::from_secs(5).into()),
("30s", time::Duration::from_secs(30).into()),
("1478s", time::Duration::from_secs(1478).into()),
("-5s", Duration {
duration: time::Duration::from_secs(5),
is_negative: true,
}),
("+5s", time::Duration::from_secs(5).into()),
("-0", Duration {
duration: time::Duration::from_secs(0),
is_negative: true,
}),
("+0", time::Duration::from_secs(0).into()),
("5s", time::Duration::from_secs(5).into()),
(
"5.6s",
(time::Duration::from_secs(5) + time::Duration::from_millis(600)).into(),
),
("5.s", time::Duration::from_secs(5).into()),
(".5s", time::Duration::from_millis(500).into()),
("1.0s", time::Duration::from_secs(1).into()),
("1.00s", time::Duration::from_secs(1).into()),
(
"1.004s",
(time::Duration::from_secs(1) + time::Duration::from_millis(4)).into(),
),
(
"1.0040s",
(time::Duration::from_secs(1) + time::Duration::from_millis(4)).into(),
),
(
"100.00100s",
(time::Duration::from_secs(100) + time::Duration::from_millis(1)).into(),
),
("10ns", time::Duration::from_nanos(10).into()),
("11us", time::Duration::from_micros(11).into()),
("12µs", time::Duration::from_micros(12).into()),
("12μs", time::Duration::from_micros(12).into()),
("13ms", time::Duration::from_millis(13).into()),
("14s", time::Duration::from_secs(14).into()),
("15m", (15 * MINUTE).into()),
("16h", (16 * HOUR).into()),
("3h30m", (3 * HOUR + 30 * MINUTE).into()),
(
"10.5s4m",
(4 * MINUTE + time::Duration::from_secs(10) + time::Duration::from_millis(500)).into(),
),
("-2m3.4s", Duration {
duration: 2 * MINUTE + time::Duration::from_secs(3) + time::Duration::from_millis(400),
is_negative: true,
}),
(
"1h2m3s4ms5us6ns",
(1 * HOUR
+ 2 * MINUTE
+ time::Duration::from_secs(3)
+ time::Duration::from_millis(4)
+ time::Duration::from_micros(5)
+ time::Duration::from_nanos(6))
.into(),
),
(
"39h9m14.425s",
(39 * HOUR + 9 * MINUTE + time::Duration::from_secs(14) + time::Duration::from_millis(425))
.into(),
),
("52763797000ns", time::Duration::from_nanos(52763797000).into()),
("0.3333333333333333333h", (20 * MINUTE).into()),
(
"9007199254740993ns",
time::Duration::from_nanos((1 << 53) + 1).into(),
),
("0.100000000000000000000h", (6 * MINUTE).into()), ];
for (input, expected) in cases {
let parsed = dbg!(input).parse::<Duration>().unwrap();
assert_eq!(&dbg!(parsed), expected);
}
}
}