use serde::Serialize;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter, Write};
use std::future::Future;
use std::pin::Pin;
use std::str::FromStr;
use std::task::{Context, Poll};
use std::time::Duration;
use futures::{future, Stream, TryStream, TryStreamExt};
use http::header::{HeaderValue, CACHE_CONTROL, CONTENT_TYPE};
use hyper::Body;
use pin_project::pin_project;
use serde_json::{self, Error};
use tokio::time::{self, Sleep};
use self::sealed::SseError;
use super::header;
use crate::filter::One;
use crate::reply::Response;
use crate::{Filter, Rejection, Reply};
#[derive(Debug)]
enum DataType {
Text(String),
Json(String),
}
#[derive(Default, Debug)]
pub struct Event {
name: Option<String>,
id: Option<String>,
data: Option<DataType>,
event: Option<String>,
comment: Option<String>,
retry: Option<Duration>,
}
impl Event {
pub fn data<T: Into<String>>(mut self, data: T) -> Event {
self.data = Some(DataType::Text(data.into()));
self
}
pub fn json_data<T: Serialize>(mut self, data: T) -> Result<Event, Error> {
self.data = Some(DataType::Json(serde_json::to_string(&data)?));
Ok(self)
}
pub fn comment<T: Into<String>>(mut self, comment: T) -> Event {
self.comment = Some(comment.into());
self
}
pub fn event<T: Into<String>>(mut self, event: T) -> Event {
self.event = Some(event.into());
self
}
pub fn retry(mut self, duration: Duration) -> Event {
self.retry = Some(duration.into());
self
}
pub fn id<T: Into<String>>(mut self, id: T) -> Event {
self.id = Some(id.into());
self
}
}
impl Display for Event {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if let Some(ref comment) = &self.comment {
":".fmt(f)?;
comment.fmt(f)?;
f.write_char('\n')?;
}
if let Some(ref event) = &self.event {
"event:".fmt(f)?;
event.fmt(f)?;
f.write_char('\n')?;
}
match self.data {
Some(DataType::Text(ref data)) => {
for line in data.split('\n') {
"data:".fmt(f)?;
line.fmt(f)?;
f.write_char('\n')?;
}
}
Some(DataType::Json(ref data)) => {
"data:".fmt(f)?;
data.fmt(f)?;
f.write_char('\n')?;
}
None => {}
}
if let Some(ref id) = &self.id {
"id:".fmt(f)?;
id.fmt(f)?;
f.write_char('\n')?;
}
if let Some(ref duration) = &self.retry {
"retry:".fmt(f)?;
let secs = duration.as_secs();
let millis = duration.subsec_millis();
if secs > 0 {
secs.fmt(f)?;
if millis < 10 {
f.write_str("00")?;
} else if millis < 100 {
f.write_char('0')?;
}
}
millis.fmt(f)?;
f.write_char('\n')?;
}
f.write_char('\n')?;
Ok(())
}
}
pub fn last_event_id<T>() -> impl Filter<Extract = One<Option<T>>, Error = Rejection> + Copy
where
T: FromStr + Send + Sync + 'static,
{
header::optional("last-event-id")
}
pub fn reply<S>(event_stream: S) -> impl Reply
where
S: TryStream<Ok = Event> + Send + 'static,
S::Error: StdError + Send + Sync + 'static,
{
SseReply { event_stream }
}
#[allow(missing_debug_implementations)]
struct SseReply<S> {
event_stream: S,
}
impl<S> Reply for SseReply<S>
where
S: TryStream<Ok = Event> + Send + 'static,
S::Error: StdError + Send + Sync + 'static,
{
#[inline]
fn into_response(self) -> Response {
let body_stream = self
.event_stream
.map_err(|error| {
log::error!("sse stream error: {}", error);
SseError
})
.into_stream()
.and_then(|event| future::ready(Ok(event.to_string())));
let mut res = Response::new(Body::wrap_stream(body_stream));
res.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
res.headers_mut()
.insert(CACHE_CONTROL, HeaderValue::from_static("no-cache"));
res
}
}
#[derive(Debug)]
pub struct KeepAlive {
comment_text: Cow<'static, str>,
max_interval: Duration,
}
impl KeepAlive {
pub fn interval(mut self, time: Duration) -> Self {
self.max_interval = time;
self
}
pub fn text(mut self, text: impl Into<Cow<'static, str>>) -> Self {
self.comment_text = text.into();
self
}
pub fn stream<S>(
self,
event_stream: S,
) -> impl TryStream<Ok = Event, Error = impl StdError + Send + Sync + 'static> + Send + 'static
where
S: TryStream<Ok = Event> + Send + 'static,
S::Error: StdError + Send + Sync + 'static,
{
let alive_timer = time::sleep(self.max_interval);
SseKeepAlive {
event_stream,
comment_text: self.comment_text,
max_interval: self.max_interval,
alive_timer,
}
}
}
#[allow(missing_debug_implementations)]
#[pin_project]
struct SseKeepAlive<S> {
#[pin]
event_stream: S,
comment_text: Cow<'static, str>,
max_interval: Duration,
#[pin]
alive_timer: Sleep,
}
pub fn keep_alive() -> KeepAlive {
KeepAlive {
comment_text: Cow::Borrowed(""),
max_interval: Duration::from_secs(15),
}
}
impl<S> Stream for SseKeepAlive<S>
where
S: TryStream<Ok = Event> + Send + 'static,
S::Error: StdError + Send + Sync + 'static,
{
type Item = Result<Event, SseError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut pin = self.project();
match pin.event_stream.try_poll_next(cx) {
Poll::Pending => match Pin::new(&mut pin.alive_timer).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => {
pin.alive_timer
.reset(tokio::time::Instant::now() + *pin.max_interval);
let comment_str = pin.comment_text.clone();
let event = Event::default().comment(comment_str);
Poll::Ready(Some(Ok(event)))
}
},
Poll::Ready(Some(Ok(event))) => {
pin.alive_timer
.reset(tokio::time::Instant::now() + *pin.max_interval);
Poll::Ready(Some(Ok(event)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(error))) => {
log::error!("sse::keep error: {}", error);
Poll::Ready(Some(Err(SseError)))
}
}
}
}
mod sealed {
use super::*;
#[derive(Debug)]
pub struct SseError;
impl Display for SseError {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
write!(f, "sse error")
}
}
impl StdError for SseError {}
}