use std::io::Error as IoError;
use std::io::{self, Cursor, ErrorKind, Read, Write};
use std::fmt;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::mpsc::Sender;
use chunked_transfer::Decoder;
use util::EqualReader;
use {HTTPVersion, Header, Method, Response, StatusCode};
pub struct Request {
data_reader: Option<Box<dyn Read + Send + 'static>>,
response_writer: Option<Box<dyn Write + Send + 'static>>,
remote_addr: SocketAddr,
secure: bool,
method: Method,
path: String,
http_version: HTTPVersion,
headers: Vec<Header>,
body_length: Option<usize>,
must_send_continue: bool,
notify_when_responded: Option<Sender<()>>,
}
struct NotifyOnDrop<R> {
sender: Sender<()>,
inner: R,
}
impl<R: Read> Read for NotifyOnDrop<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<R: Write> Write for NotifyOnDrop<R> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<R> Drop for NotifyOnDrop<R> {
fn drop(&mut self) {
self.sender.send(()).unwrap();
}
}
pub enum RequestCreationError {
ExpectationFailed,
CreationIoError(IoError),
}
impl From<IoError> for RequestCreationError {
fn from(err: IoError) -> RequestCreationError {
RequestCreationError::CreationIoError(err)
}
}
pub fn new_request<R, W>(
secure: bool,
method: Method,
path: String,
version: HTTPVersion,
headers: Vec<Header>,
remote_addr: SocketAddr,
mut source_data: R,
writer: W,
) -> Result<Request, RequestCreationError>
where
R: Read + Send + 'static,
W: Write + Send + 'static,
{
let transfer_encoding = headers
.iter()
.find(|h: &&Header| h.field.equiv(&"Transfer-Encoding"))
.map(|h| h.value.clone());
let content_length = if transfer_encoding.is_some() {
None
} else {
headers
.iter()
.find(|h: &&Header| h.field.equiv(&"Content-Length"))
.and_then(|h| FromStr::from_str(h.value.as_str()).ok())
};
let expects_continue = {
match headers
.iter()
.find(|h: &&Header| h.field.equiv(&"Expect"))
.map(|h| h.value.as_str())
{
None => false,
Some(v) if v.eq_ignore_ascii_case("100-continue") => true,
_ => return Err(RequestCreationError::ExpectationFailed),
}
};
let connection_upgrade = {
match headers
.iter()
.find(|h: &&Header| h.field.equiv(&"Connection"))
.map(|h| h.value.as_str())
{
Some(v) if v.to_ascii_lowercase().contains("upgrade") => true,
_ => false,
}
};
let reader = if connection_upgrade {
Box::new(source_data) as Box<dyn Read + Send + 'static>
} else if let Some(content_length) = content_length {
if content_length == 0 {
Box::new(io::empty()) as Box<dyn Read + Send + 'static>
} else if content_length <= 1024 && !expects_continue {
let mut buffer = vec![0; content_length];
let mut offset = 0;
while offset != content_length {
let read = source_data.read(&mut buffer[offset..])?;
if read == 0 {
let info = "Connection has been closed before we received enough data";
let err = IoError::new(ErrorKind::ConnectionAborted, info);
return Err(RequestCreationError::CreationIoError(err));
}
offset += read;
}
Box::new(Cursor::new(buffer)) as Box<dyn Read + Send + 'static>
} else {
let (data_reader, _) = EqualReader::new(source_data, content_length); Box::new(data_reader) as Box<dyn Read + Send + 'static>
}
} else if transfer_encoding.is_some() {
Box::new(Decoder::new(source_data)) as Box<dyn Read + Send + 'static>
} else {
Box::new(io::empty()) as Box<dyn Read + Send + 'static>
};
Ok(Request {
data_reader: Some(reader),
response_writer: Some(Box::new(writer) as Box<dyn Write + Send + 'static>),
remote_addr,
secure,
method,
path,
http_version: version,
headers,
body_length: content_length,
must_send_continue: expects_continue,
notify_when_responded: None,
})
}
impl Request {
#[inline]
pub fn secure(&self) -> bool {
self.secure
}
#[inline]
pub fn method(&self) -> &Method {
&self.method
}
#[inline]
pub fn url(&self) -> &str {
&self.path
}
#[inline]
pub fn headers(&self) -> &[Header] {
&self.headers
}
#[inline]
pub fn http_version(&self) -> &HTTPVersion {
&self.http_version
}
#[inline]
pub fn body_length(&self) -> Option<usize> {
self.body_length
}
#[inline]
pub fn remote_addr(&self) -> &SocketAddr {
&self.remote_addr
}
pub fn upgrade<R: Read>(
mut self,
protocol: &str,
response: Response<R>,
) -> Box<dyn ReadWrite + Send> {
use util::CustomStream;
response
.raw_print(
self.response_writer.as_mut().unwrap().by_ref(),
self.http_version.clone(),
&self.headers,
false,
Some(protocol),
)
.ok();
self.response_writer.as_mut().unwrap().flush().ok();
let stream = CustomStream::new(self.into_reader_impl(), self.into_writer_impl());
if let Some(sender) = self.notify_when_responded.take() {
let stream = NotifyOnDrop {
sender,
inner: stream,
};
Box::new(stream) as Box<dyn ReadWrite + Send>
} else {
Box::new(stream) as Box<dyn ReadWrite + Send>
}
}
#[inline]
pub fn as_reader(&mut self) -> &mut dyn Read {
if self.must_send_continue {
let msg = Response::new_empty(StatusCode(100));
msg.raw_print(
self.response_writer.as_mut().unwrap().by_ref(),
self.http_version.clone(),
&self.headers,
true,
None,
)
.ok();
self.response_writer.as_mut().unwrap().flush().ok();
self.must_send_continue = false;
}
self.data_reader.as_mut().unwrap()
}
#[inline]
pub fn into_writer(mut self) -> Box<dyn Write + Send + 'static> {
let writer = self.into_writer_impl();
if let Some(sender) = self.notify_when_responded.take() {
let writer = NotifyOnDrop {
sender,
inner: writer,
};
Box::new(writer) as Box<dyn Write + Send + 'static>
} else {
writer
}
}
fn into_writer_impl(&mut self) -> Box<dyn Write + Send + 'static> {
use std::mem;
assert!(self.response_writer.is_some());
let mut writer = None;
mem::swap(&mut self.response_writer, &mut writer);
writer.unwrap()
}
fn into_reader_impl(&mut self) -> Box<dyn Read + Send + 'static> {
use std::mem;
assert!(self.data_reader.is_some());
let mut reader = None;
mem::swap(&mut self.data_reader, &mut reader);
reader.unwrap()
}
#[inline]
pub fn respond<R>(mut self, response: Response<R>) -> Result<(), IoError>
where
R: Read,
{
let res = self.respond_impl(response);
if let Some(sender) = self.notify_when_responded.take() {
sender.send(()).unwrap();
}
res
}
fn respond_impl<R>(&mut self, response: Response<R>) -> Result<(), IoError>
where
R: Read,
{
self.data_reader = None;
let mut writer = self.into_writer_impl();
let do_not_send_body = self.method == Method::Head;
Self::ignore_client_closing_errors(response.raw_print(
writer.by_ref(),
self.http_version.clone(),
&self.headers,
do_not_send_body,
None,
))?;
Self::ignore_client_closing_errors(writer.flush())
}
fn ignore_client_closing_errors(result: io::Result<()>) -> io::Result<()> {
result.or_else(|err| match err.kind() {
ErrorKind::BrokenPipe => Ok(()),
ErrorKind::ConnectionAborted => Ok(()),
ErrorKind::ConnectionRefused => Ok(()),
ErrorKind::ConnectionReset => Ok(()),
_ => Err(err),
})
}
pub(crate) fn with_notify_sender(mut self, sender: Sender<()>) -> Self {
self.notify_when_responded = Some(sender);
self
}
}
impl fmt::Debug for Request {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(
formatter,
"Request({} {} from {})",
self.method, self.path, self.remote_addr
)
}
}
impl Drop for Request {
fn drop(&mut self) {
self.data_reader = None;
if self.response_writer.is_some() {
let response = Response::empty(500);
let _ = self.respond_impl(response); if let Some(sender) = self.notify_when_responded.take() {
sender.send(()).unwrap();
}
}
}
}
pub trait ReadWrite: Read + Write {}
impl<T> ReadWrite for T where T: Read + Write {}
#[cfg(test)]
mod tests {
use super::Request;
#[test]
fn must_be_send() {
#![allow(dead_code)]
fn f<T: Send>(_: &T) {}
fn bar(rq: &Request) {
f(rq);
}
}
}