use std::ascii::AsciiExt;
use std::io::Error as IoError;
use std::io::{self, Cursor, Read, Write, ErrorKind};
use std::net::SocketAddr;
use std::fmt;
use std::str::FromStr;
use {Header, HTTPVersion, Method, Response, StatusCode};
use util::EqualReader;
use chunked_transfer::Decoder;
pub struct Request {
data_reader: Option<Box<Read + Send + 'static>>,
response_writer: Option<Box<Write + Send + 'static>>,
remote_addr: SocketAddr,
secure: bool,
method: Method,
path: String,
http_version: HTTPVersion,
headers: Vec<Header>,
body_length: Option<usize>,
must_send_continue: bool,
}
pub enum RequestCreationError {
ExpectationFailed,
CreationIoError(IoError),
}
impl From<IoError> for RequestCreationError {
fn from(err: IoError) -> RequestCreationError {
RequestCreationError::CreationIoError(err)
}
}
pub fn new_request<R, W>(secure: bool, method: Method, path: String,
version: HTTPVersion, headers: Vec<Header>,
remote_addr: SocketAddr, mut source_data: R, writer: W)
-> Result<Request, RequestCreationError>
where R: Read + Send + 'static, W: Write + Send + 'static
{
let transfer_encoding = headers.iter()
.find(|h: &&Header| h.field.equiv(&"Transfer-Encoding"))
.map(|h| h.value.clone());
let content_length = if transfer_encoding.is_some() {
None
} else {
headers.iter()
.find(|h: &&Header| h.field.equiv(&"Content-Length"))
.and_then(|h| FromStr::from_str(h.value.as_str()).ok())
};
let expects_continue = {
match headers.iter().find(|h: &&Header| h.field.equiv(&"Expect")).map(|h| AsRef::<str>::as_ref(h.value.as_ref())) {
None => false,
Some(v) if v.eq_ignore_ascii_case("100-continue")
=> true,
_ => return Err(RequestCreationError::ExpectationFailed)
}
};
let connection_upgrade = {
match headers.iter().find(|h: &&Header| h.field.equiv(&"Connection")).map(|h| AsRef::<str>::as_ref(h.value.as_ref())) {
None => false,
Some(v) if v.eq_ignore_ascii_case("upgrade")
=> true,
_ => false
}
};
let reader =
if connection_upgrade {
Box::new(source_data) as Box<Read + Send + 'static>
} else if let Some(content_length) = content_length {
if content_length == 0 {
Box::new(io::empty()) as Box<Read + Send + 'static>
} else if content_length <= 1024 && !expects_continue {
let mut buffer = vec![0; content_length];
let mut offset = 0;
while offset != content_length {
let read = try!(source_data.read(&mut buffer[offset..]));
if read == 0 {
let info = "Connection has been closed before we received enough data";
let err = IoError::new(ErrorKind::ConnectionAborted, info);
return Err(RequestCreationError::CreationIoError(err));
}
offset += read;
}
Box::new(Cursor::new(buffer)) as Box<Read + Send + 'static>
} else {
let (data_reader, _) = EqualReader::new(source_data, content_length); Box::new(data_reader) as Box<Read + Send + 'static>
}
} else if transfer_encoding.is_some() {
Box::new(Decoder::new(source_data)) as Box<Read + Send + 'static>
} else {
Box::new(io::empty()) as Box<Read + Send + 'static>
};
Ok(Request {
data_reader: Some(reader),
response_writer: Some(Box::new(writer) as Box<Write + Send + 'static>),
remote_addr: remote_addr,
secure: secure,
method: method,
path: path,
http_version: version,
headers: headers,
body_length: content_length,
must_send_continue: expects_continue,
})
}
impl Request {
#[inline]
pub fn secure(&self) -> bool {
self.secure
}
#[inline]
pub fn method(&self) -> &Method {
&self.method
}
#[inline]
pub fn url(&self) -> &str {
&self.path
}
#[inline]
pub fn headers(&self) -> &[Header] {
&self.headers
}
#[inline]
pub fn http_version(&self) -> &HTTPVersion {
&self.http_version
}
#[inline]
pub fn body_length(&self) -> Option<usize> {
self.body_length
}
#[inline]
pub fn remote_addr(&self) -> &SocketAddr {
&self.remote_addr
}
#[inline]
pub fn as_reader(&mut self) -> &mut Read {
if self.must_send_continue {
let msg = Response::new_empty(StatusCode(100));
msg.raw_print(self.response_writer.as_mut().unwrap().by_ref(),
self.http_version.clone(), &self.headers, true, None).ok();
self.response_writer.as_mut().unwrap().flush().ok();
self.must_send_continue = false;
}
self.data_reader.as_mut().unwrap()
}
#[inline]
pub fn into_writer(mut self) -> Box<Write + Send + 'static> {
self.into_writer_impl()
}
fn into_writer_impl(&mut self) -> Box<Write + Send + 'static> {
use std::mem;
assert!(self.response_writer.is_some());
let mut writer = None;
mem::swap(&mut self.response_writer, &mut writer);
writer.unwrap()
}
#[inline]
pub fn respond<R>(mut self, response: Response<R>) where R: Read {
self.respond_impl(response)
}
fn respond_impl<R>(&mut self, response: Response<R>) where R: Read {
let mut writer = self.into_writer_impl();
let do_not_send_body = self.method == Method::Head;
match response.raw_print(writer.by_ref(),
self.http_version.clone(), &self.headers,
do_not_send_body, None)
{
Ok(_) => (),
Err(ref err) if err.kind() == ErrorKind::BrokenPipe => (),
Err(ref err) if err.kind() == ErrorKind::ConnectionAborted => (),
Err(ref err) if err.kind() == ErrorKind::ConnectionRefused => (),
Err(ref err) if err.kind() == ErrorKind::ConnectionReset => (),
Err(ref err) =>
println!("error while sending answer: {}", err) };
writer.flush().ok();
}
}
impl fmt::Debug for Request {
fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(formatter, "Request({} {} from {})", self.method, self.path, self.remote_addr)
}
}
impl Drop for Request {
fn drop(&mut self) {
if self.response_writer.is_some() {
let response = Response::empty(500);
self.respond_impl(response);
}
}
}
#[cfg(test)]
mod tests {
use super::Request;
#[test]
fn must_be_send() {
#![allow(dead_code)]
fn f<T: Send>(_: &T) {}
fn bar(rq: &Request) { f(rq); }
}
}