use crate::command::Command;
use crate::mock::InnerMock;
use crate::request::Request;
use crate::response::{Body as ResponseBody, Chunked as ResponseChunked};
use crate::{Error, ErrorKind, Matcher, Mock};
use futures::stream::{self, StreamExt};
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{Body, Request as HyperRequest, Response, StatusCode};
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::thread;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::Mutex;
#[derive(Clone, Debug)]
pub(crate) struct RemoteMock {
pub(crate) inner: InnerMock,
}
impl RemoteMock {
pub(crate) fn new(inner: InnerMock) -> Self {
RemoteMock { inner }
}
async fn matches(&self, other: &mut Request) -> bool {
self.method_matches(other)
&& self.path_matches(other)
&& self.headers_match(other)
&& self.body_matches(other).await
}
fn method_matches(&self, request: &Request) -> bool {
self.inner.method.as_str() == request.method()
}
fn path_matches(&self, request: &Request) -> bool {
self.inner.path.matches_value(request.path_and_query())
}
fn headers_match(&self, request: &Request) -> bool {
self.inner
.headers
.iter()
.all(|&(ref field, ref expected)| expected.matches_values(&request.header(field)))
}
async fn body_matches(&self, request: &mut Request) -> bool {
let body = request.read_body().await;
let safe_body = &String::from_utf8_lossy(body);
self.inner.body.matches_value(safe_body) || self.inner.body.matches_binary_value(body)
}
#[allow(clippy::missing_const_for_fn)]
fn is_missing_hits(&self) -> bool {
match (
self.inner.expected_hits_at_least,
self.inner.expected_hits_at_most,
) {
(Some(_at_least), Some(at_most)) => self.inner.hits < at_most,
(Some(at_least), None) => self.inner.hits < at_least,
(None, Some(at_most)) => self.inner.hits < at_most,
(None, None) => self.inner.hits < 1,
}
}
}
#[derive(Debug)]
pub(crate) struct State {
pub(crate) mocks: Vec<RemoteMock>,
pub(crate) unmatched_requests: Vec<Request>,
}
impl State {
fn new() -> Self {
State {
mocks: vec![],
unmatched_requests: vec![],
}
}
}
#[derive(Debug)]
pub struct Server {
address: String,
state: Arc<Mutex<State>>,
sender: Sender<Command>,
busy: bool,
}
impl Server {
#[allow(clippy::new_ret_no_self)]
#[track_caller]
pub fn new() -> ServerGuard {
Server::try_new().unwrap()
}
pub async fn new_async() -> ServerGuard {
let server = Server::new_with_port_async(0).await;
ServerGuard::new(server)
}
pub(crate) fn try_new() -> Result<ServerGuard, Error> {
crate::RUNTIME.block_on(async { Server::try_new_async().await })
}
pub(crate) async fn try_new_async() -> Result<ServerGuard, Error> {
let server = Server::try_new_with_port_async(0)
.await
.map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
Ok(ServerGuard::new(server))
}
#[track_caller]
pub fn new_with_port(port: u16) -> Server {
Server::try_new_with_port(port).unwrap()
}
pub async fn new_with_port_async(port: u16) -> Server {
Server::try_new_with_port_async(port).await.unwrap()
}
pub(crate) fn try_new_with_port(port: u16) -> Result<Server, Error> {
crate::RUNTIME.block_on(async { Server::try_new_with_port_async(port).await })
}
pub(crate) async fn try_new_with_port_async(port: u16) -> Result<Server, Error> {
let state = Arc::new(Mutex::new(State::new()));
let address = SocketAddr::from(([127, 0, 0, 1], port));
let listener = tokio::net::TcpListener::bind(address)
.await
.map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
let address = listener
.local_addr()
.map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;
let mutex = state.clone();
let server = async move {
while let Ok((stream, _)) = listener.accept().await {
let mutex = mutex.clone();
tokio::spawn(async move {
Http::new()
.serve_connection(
stream,
service_fn(move |request: HyperRequest<Body>| {
handle_request(request, mutex.clone())
}),
)
.await
.unwrap();
});
}
};
thread::spawn(move || crate::RUNTIME.block_on(server));
let (sender, receiver) = mpsc::channel(32);
let mut server = Server {
address: address.to_string(),
state,
sender,
busy: true,
};
server.accept_commands(receiver).await;
Ok(server)
}
pub fn mock<P: Into<Matcher>>(&mut self, method: &str, path: P) -> Mock {
Mock::new(self.sender.clone(), method, path)
}
pub fn url(&self) -> String {
format!("http://{}", self.address)
}
pub fn host_with_port(&self) -> String {
self.address.clone()
}
pub fn reset(&mut self) {
crate::RUNTIME.block_on(async { self.reset_async().await });
}
pub async fn reset_async(&mut self) {
let state = self.state.clone();
let mut state = state.lock().await;
state.mocks.clear();
state.unmatched_requests.clear();
}
#[allow(dead_code)]
pub(crate) fn busy(&self) -> bool {
let state = self.state.clone();
let locked = state.try_lock().is_err();
let sender_busy = self.sender.try_send(Command::Noop).is_err();
self.busy || locked || sender_busy
}
pub(crate) fn set_busy(&mut self, busy: bool) {
self.busy = busy;
}
async fn accept_commands(&mut self, mut receiver: Receiver<Command>) {
let state = self.state.clone();
tokio::spawn(async move {
while let Some(cmd) = receiver.recv().await {
let state = state.lock().await;
Command::handle(cmd, state).await;
}
});
log::debug!("Server is accepting commands");
}
}
type GuardType = Server;
pub struct ServerGuard {
server: GuardType,
}
impl ServerGuard {
pub(crate) fn new(mut server: GuardType) -> ServerGuard {
server.set_busy(true);
ServerGuard { server }
}
}
impl Deref for ServerGuard {
type Target = Server;
fn deref(&self) -> &Self::Target {
&self.server
}
}
impl DerefMut for ServerGuard {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.server
}
}
impl Drop for ServerGuard {
fn drop(&mut self) {
self.server.set_busy(false);
}
}
async fn handle_request(
hyper_request: HyperRequest<Body>,
state: Arc<Mutex<State>>,
) -> Result<Response<Body>, Error> {
let mut request = Request::new(hyper_request);
log::debug!("Request received: {}", request.to_string().await);
let mutex = state.clone();
let mut state = mutex.lock().await;
let mut mocks_stream = stream::iter(&mut state.mocks);
let mut matching_mocks: Vec<&mut RemoteMock> = vec![];
while let Some(mock) = mocks_stream.next().await {
if mock.matches(&mut request).await {
matching_mocks.push(mock);
}
}
let maybe_missing_hits = matching_mocks.iter_mut().find(|m| m.is_missing_hits());
let mock = match maybe_missing_hits {
Some(m) => Some(m),
None => matching_mocks.last_mut(),
};
if let Some(mock) = mock {
log::debug!("Mock found");
mock.inner.hits += 1;
respond_with_mock(request, mock).await
} else {
log::debug!("Mock not found");
state.unmatched_requests.push(request);
respond_with_mock_not_found()
}
}
async fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Body>, Error> {
let status: StatusCode = mock.inner.response.status;
let mut response = Response::builder().status(status);
for (name, value) in mock.inner.response.headers.iter() {
response = response.header(name, value);
}
let body = if request.method() != "HEAD" {
match &mock.inner.response.body {
ResponseBody::Bytes(bytes) => {
if !request.has_header("content-length") {
response = response.header("content-length", bytes.len());
}
Body::from(bytes.clone())
}
ResponseBody::Fn(body_fn) => {
let mut chunked = ResponseChunked::new();
body_fn(&mut chunked)
.map_err(|_| Error::new(ErrorKind::ResponseBodyFailure))
.unwrap();
chunked.finish();
Body::wrap_stream(chunked)
}
}
} else {
Body::empty()
};
let response: Response<Body> = response
.body(body)
.map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;
Ok(response)
}
fn respond_with_mock_not_found() -> Result<Response<Body>, Error> {
let response: Response<Body> = Response::builder()
.status(StatusCode::NOT_IMPLEMENTED)
.body(Body::empty())
.map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;
Ok(response)
}