use std::fmt;
use std::marker::PhantomData;
use http::header;
use http::uri::Scheme;
use http::{
HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, Uri, Version,
};
use crate::body::calculate_max_input;
use crate::ext::{HeaderIterExt, MethodExt, StatusExt};
use crate::parser::try_parse_response;
use crate::util::ArrayVec;
use crate::{BodyMode, Error};
use super::holder::CallHolder;
#[doc(hidden)]
pub mod state {
pub(crate) trait Named {
fn name() -> &'static str;
}
macro_rules! flow_state {
($n:tt) => {
#[doc(hidden)]
pub struct $n(());
impl Named for $n {
fn name() -> &'static str {
stringify!($n)
}
}
};
}
flow_state!(Prepare);
flow_state!(SendRequest);
flow_state!(Await100);
flow_state!(SendBody);
flow_state!(RecvResponse);
flow_state!(RecvBody);
flow_state!(Redirect);
flow_state!(Cleanup);
}
use self::state::*;
pub struct Flow<B, State> {
inner: Inner<B>,
_ph: PhantomData<State>,
}
#[derive(Debug)]
pub(crate) struct Inner<B> {
pub call: CallHolder<B>,
pub close_reason: ArrayVec<CloseReason, 4>,
pub should_send_body: bool,
pub await_100_continue: bool,
pub status: Option<StatusCode>,
pub location: Option<HeaderValue>,
}
impl<B> Inner<B> {
fn is_redirect(&self) -> bool {
match self.status {
Some(v) => v.is_redirection() && v != StatusCode::NOT_MODIFIED,
None => false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CloseReason {
Http10,
ClientConnectionClose,
ServerConnectionClose,
Not100Continue,
CloseDelimitedBody,
}
impl CloseReason {
fn explain(&self) -> &'static str {
match self {
CloseReason::Http10 => "version is http1.0",
CloseReason::ClientConnectionClose => "client sent Connection: close",
CloseReason::ServerConnectionClose => "server sent Connection: close",
CloseReason::Not100Continue => "got non-100 response before sending body",
CloseReason::CloseDelimitedBody => "response body is close delimited",
}
}
}
impl<B, S> Flow<B, S> {
fn wrap(inner: Inner<B>) -> Flow<B, S>
where
S: Named,
{
let wrapped = Flow {
inner,
_ph: PhantomData,
};
debug!("{:?}", wrapped);
wrapped
}
fn call(&self) -> &CallHolder<B> {
&self.inner.call
}
fn call_mut(&mut self) -> &mut CallHolder<B> {
&mut self.inner.call
}
#[cfg(test)]
pub(crate) fn inner(&self) -> &Inner<B> {
&self.inner
}
}
impl<B> Flow<B, Prepare> {
pub fn new(request: Request<B>) -> Result<Self, Error> {
let mut close_reason = ArrayVec::from_fn(|_| CloseReason::Http10);
if request.version() == Version::HTTP_10 {
close_reason.push(CloseReason::Http10)
}
if request.headers().iter().has(header::CONNECTION, "close") {
close_reason.push(CloseReason::ClientConnectionClose);
}
let should_send_body = request.method().need_request_body();
let await_100_continue = request.headers().iter().has_expect_100();
let call = CallHolder::new(request)?;
let inner = Inner {
call,
close_reason,
should_send_body,
await_100_continue,
status: None,
location: None,
};
Ok(Flow::wrap(inner))
}
pub fn method(&self) -> &Method {
self.call().request().method()
}
pub fn uri(&self) -> &Uri {
self.call().request().uri()
}
pub fn version(&self) -> Version {
self.call().request().version()
}
pub fn headers(&self) -> &HeaderMap {
self.call().request().original_request_headers()
}
pub fn allow_non_standard_methods(&mut self, v: bool) {
self.call_mut().allow_non_standard_methods(v);
}
pub fn header<K, V>(&mut self, key: K, value: V) -> Result<(), Error>
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.call_mut().request_mut().set_header(key, value)
}
pub fn send_body_despite_method(&mut self) {
self.inner.should_send_body = true;
self.inner.call.convert_to_send_body();
}
pub fn proceed(self) -> Flow<B, SendRequest> {
Flow::wrap(self.inner)
}
}
impl<B> Flow<B, SendRequest> {
pub fn write(&mut self, output: &mut [u8]) -> Result<usize, Error> {
match &mut self.inner.call {
CallHolder::WithoutBody(v) => v.write(output),
CallHolder::WithBody(v) => v.write(&[], output).map(|r| r.1),
_ => unreachable!(),
}
}
pub fn method(&self) -> &Method {
self.call().request().method()
}
pub fn uri(&self) -> &Uri {
self.call().request().uri()
}
pub fn version(&self) -> Version {
self.call().request().version()
}
pub fn headers_map(&mut self) -> Result<HeaderMap, Error> {
self.call_mut().analyze_request()?;
let mut map = HeaderMap::new();
for (k, v) in self.call().request().headers() {
map.insert(k, v.clone());
}
Ok(map)
}
pub fn can_proceed(&self) -> bool {
match &self.inner.call {
CallHolder::WithoutBody(v) => v.is_finished(),
CallHolder::WithBody(v) => v.is_body(),
_ => unreachable!(),
}
}
pub fn proceed(mut self) -> Result<Option<SendRequestResult<B>>, Error> {
if !self.can_proceed() {
return Ok(None);
}
if self.inner.should_send_body {
if self.inner.await_100_continue {
Ok(Some(SendRequestResult::Await100(Flow::wrap(self.inner))))
} else {
let mut flow = Flow::wrap(self.inner);
flow.inner.call.analyze_request()?;
Ok(Some(SendRequestResult::SendBody(flow)))
}
} else {
let call = match self.inner.call {
CallHolder::WithoutBody(v) => v,
_ => unreachable!(),
};
let call_recv = call.into_receive().unwrap();
let call = CallHolder::RecvResponse(call_recv);
self.inner.call = call;
let flow = Flow::wrap(self.inner);
Ok(Some(SendRequestResult::RecvResponse(flow)))
}
}
}
pub enum SendRequestResult<B> {
Await100(Flow<B, Await100>),
SendBody(Flow<B, SendBody>),
RecvResponse(Flow<B, RecvResponse>),
}
impl<B> Flow<B, Await100> {
pub fn try_read_100(&mut self, input: &[u8]) -> Result<usize, Error> {
match try_parse_response::<0>(input) {
Ok(v) => match v {
Some((input_used, response)) => {
self.inner.await_100_continue = false;
if response.status() == StatusCode::CONTINUE {
assert!(self.inner.should_send_body);
Ok(input_used)
} else {
self.inner.close_reason.push(CloseReason::Not100Continue);
self.inner.should_send_body = false;
Ok(0)
}
}
None => Ok(0),
},
Err(e) => {
self.inner.await_100_continue = false;
if e == Error::HttpParseTooManyHeaders {
self.inner.close_reason.push(CloseReason::Not100Continue);
self.inner.should_send_body = false;
Ok(0)
} else {
Err(e)
}
}
}
}
pub fn can_keep_await_100(&self) -> bool {
self.inner.await_100_continue
}
pub fn proceed(self) -> Result<Await100Result<B>, Error> {
if self.inner.should_send_body {
let mut flow = Flow::wrap(self.inner);
flow.inner.call.analyze_request()?;
Ok(Await100Result::SendBody(flow))
} else {
Ok(Await100Result::RecvResponse(Flow::wrap(self.inner)))
}
}
}
pub enum Await100Result<B> {
SendBody(Flow<B, SendBody>),
RecvResponse(Flow<B, RecvResponse>),
}
impl<B> Flow<B, SendBody> {
pub fn write(&mut self, input: &[u8], output: &mut [u8]) -> Result<(usize, usize), Error> {
self.inner.call.as_with_body_mut().write(input, output)
}
pub fn consume_direct_write(&mut self, amount: usize) -> Result<(), Error> {
self.inner
.call
.as_with_body_mut()
.consume_direct_write(amount)
}
pub fn calculate_max_input(&mut self, output_len: usize) -> usize {
let call = self.inner.call.as_with_body_mut();
if !call.is_chunked() {
return output_len;
}
calculate_max_input(output_len)
}
pub fn is_chunked(&mut self) -> bool {
let call = self.inner.call.as_with_body_mut();
call.is_chunked()
}
pub fn can_proceed(&self) -> bool {
self.inner.call.as_with_body().is_finished()
}
pub fn proceed(mut self) -> Option<Flow<B, RecvResponse>> {
if !self.can_proceed() {
return None;
}
let call_body = match self.inner.call {
CallHolder::WithBody(v) => v,
_ => unreachable!(),
};
let call_recv = call_body.into_receive().unwrap();
let call = CallHolder::RecvResponse(call_recv);
self.inner.call = call;
Some(Flow::wrap(self.inner))
}
}
impl<B> Flow<B, RecvResponse> {
pub fn try_response(
&mut self,
input: &[u8],
allow_partial_redirect: bool,
) -> Result<(usize, Option<Response<()>>), Error> {
let maybe_response = self
.inner
.call
.as_recv_response_mut()
.try_response(input, allow_partial_redirect)?;
let (input_used, response) = match maybe_response {
Some(v) => v,
None => return Ok((0, None)),
};
if response.status() == StatusCode::CONTINUE && self.inner.await_100_continue {
self.inner.await_100_continue = false;
return Ok((input_used, None));
}
self.inner.status = Some(response.status());
self.inner.location = response
.headers()
.get_all(header::LOCATION)
.into_iter()
.last()
.cloned();
if response.headers().iter().has(header::CONNECTION, "close") {
self.inner
.close_reason
.push(CloseReason::ServerConnectionClose);
}
Ok((input_used, Some(response)))
}
pub fn can_proceed(&self) -> bool {
self.inner.call.as_recv_response().is_finished()
}
pub fn proceed(mut self) -> Option<RecvResponseResult<B>> {
if !self.can_proceed() {
return None;
}
let call_body = match self.inner.call {
CallHolder::RecvResponse(v) => v,
_ => unreachable!(),
};
let has_response_body = call_body.need_response_body();
let call_body = call_body.do_into_body();
if has_response_body {
if call_body.is_close_delimited() {
self.inner
.close_reason
.push(CloseReason::CloseDelimitedBody);
}
self.inner.call = CallHolder::RecvBody(call_body);
Some(RecvResponseResult::RecvBody(Flow::wrap(self.inner)))
} else {
self.inner.call = CallHolder::RecvBody(call_body);
Some(if self.inner.is_redirect() {
RecvResponseResult::Redirect(Flow::wrap(self.inner))
} else {
RecvResponseResult::Cleanup(Flow::wrap(self.inner))
})
}
}
}
pub enum RecvResponseResult<B> {
RecvBody(Flow<B, RecvBody>),
Redirect(Flow<B, Redirect>),
Cleanup(Flow<B, Cleanup>),
}
impl<B> Flow<B, RecvBody> {
pub fn read(&mut self, input: &[u8], output: &mut [u8]) -> Result<(usize, usize), Error> {
self.inner.call.as_recv_body_mut().read(input, output)
}
pub fn stop_on_chunk_boundary(&mut self, enabled: bool) {
self.inner
.call
.as_recv_body_mut()
.stop_on_chunk_boundary(enabled);
}
pub fn is_on_chunk_boundary(&self) -> bool {
self.inner.call.as_recv_body().is_on_chunk_boundary()
}
pub fn body_mode(&self) -> BodyMode {
self.call().body_mode()
}
pub fn can_proceed(&self) -> bool {
let call = self.inner.call.as_recv_body();
call.is_ended() || call.is_close_delimited()
}
pub fn proceed(self) -> Option<RecvBodyResult<B>> {
if !self.can_proceed() {
return None;
}
Some(if self.inner.is_redirect() {
RecvBodyResult::Redirect(Flow::wrap(self.inner))
} else {
RecvBodyResult::Cleanup(Flow::wrap(self.inner))
})
}
}
pub enum RecvBodyResult<B> {
Redirect(Flow<B, Redirect>),
Cleanup(Flow<B, Cleanup>),
}
impl<B> Flow<B, Redirect> {
pub fn as_new_flow(
&mut self,
redirect_auth_headers: RedirectAuthHeaders,
) -> Result<Option<Flow<B, Prepare>>, Error> {
let header = match &self.inner.location {
Some(v) => v,
None => return Err(Error::NoLocationHeader),
};
let location = match header.to_str() {
Ok(v) => v,
Err(_) => {
return Err(Error::BadLocationHeader(
String::from_utf8_lossy(header.as_bytes()).to_string(),
))
}
};
let previous = self.inner.call.request_mut();
let status = self.inner.status.unwrap();
let method = previous.method();
let uri = previous.new_uri_from_location(location)?;
let new_method = if status.is_redirect_retaining_status() {
if method.need_request_body() {
return Ok(None);
} else if method == Method::DELETE {
return Ok(None);
} else {
method.clone()
}
} else {
if matches!(*method, Method::GET | Method::HEAD) {
method.clone()
} else {
Method::GET
}
};
let mut request = previous.take_request();
*request.method_mut() = new_method;
let mut next = Flow::new(request)?;
let request = next.inner.call.request_mut();
let keep_auth_header = match redirect_auth_headers {
RedirectAuthHeaders::Never => false,
RedirectAuthHeaders::SameHost => can_redirect_auth_header(request.uri(), &uri),
};
request.set_uri(uri);
if !keep_auth_header {
request.unset_header(header::AUTHORIZATION)?;
}
request.unset_header(header::COOKIE)?;
request.unset_header(header::CONTENT_LENGTH)?;
Ok(Some(next))
}
pub fn status(&self) -> StatusCode {
self.inner.status.unwrap()
}
pub fn must_close_connection(&self) -> bool {
self.close_reason().is_some()
}
pub fn close_reason(&self) -> Option<&'static str> {
self.inner.close_reason.first().map(|s| s.explain())
}
pub fn proceed(self) -> Flow<B, Cleanup> {
Flow::wrap(self.inner)
}
}
fn can_redirect_auth_header(prev: &Uri, next: &Uri) -> bool {
let host_prev = prev.authority().map(|a| a.host());
let host_next = next.authority().map(|a| a.host());
let scheme_prev = prev.scheme();
let scheme_next = next.scheme();
host_prev == host_next && (scheme_prev == scheme_next || scheme_next == Some(&Scheme::HTTPS))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum RedirectAuthHeaders {
Never,
SameHost,
}
impl<B> Flow<B, Cleanup> {
pub fn must_close_connection(&self) -> bool {
self.close_reason().is_some()
}
pub fn close_reason(&self) -> Option<&'static str> {
self.inner.close_reason.first().map(|s| s.explain())
}
}
impl<B, State: Named> fmt::Debug for Flow<B, State> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Flow<{}>", State::name())
}
}