use bytes::{BufMut, BytesMut};
use futures_core::ready;
use http::{
header::{self, HeaderName, HeaderValue},
request::Parts,
Method, Request, Response, StatusCode,
};
use pin_project_lite::pin_project;
use std::{
fmt,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tower_layer::Layer;
use tower_service::Service;
#[derive(Debug, Clone)]
pub struct CorsLayer {
allow_credentials: Option<HeaderValue>,
allow_headers: Option<HeaderValue>,
allow_methods: Option<HeaderValue>,
allow_origin: Option<AnyOr<Origin>>,
expose_headers: Option<HeaderValue>,
max_age: Option<HeaderValue>,
}
#[allow(clippy::declare_interior_mutable_const)]
const WILDCARD: HeaderValue = HeaderValue::from_static("*");
impl CorsLayer {
pub fn new() -> Self {
Self {
allow_credentials: None,
allow_headers: None,
allow_methods: None,
allow_origin: None,
expose_headers: None,
max_age: None,
}
}
pub fn permissive() -> Self {
Self::new()
.allow_credentials(true)
.allow_headers(Any)
.allow_methods(Any)
.allow_origin(Any)
.expose_headers(Any)
.max_age(Duration::from_secs(60 * 60))
}
pub fn allow_credentials(mut self, allow_credentials: bool) -> Self {
self.allow_credentials = allow_credentials.then(|| HeaderValue::from_static("true"));
self
}
pub fn allow_headers<I>(mut self, headers: I) -> Self
where
I: Into<AnyOr<Vec<HeaderName>>>,
{
self.allow_headers = match headers.into().0 {
AnyOrInner::Any => Some(WILDCARD),
AnyOrInner::Value(headers) => separated_by_commas(headers.into_iter().map(Into::into)),
};
self
}
pub fn max_age(mut self, max_age: Duration) -> Self {
self.max_age = Some(max_age.as_secs().into());
self
}
pub fn allow_methods<T>(mut self, methods: T) -> Self
where
T: Into<AnyOr<Vec<Method>>>,
{
self.allow_methods = match methods.into().0 {
AnyOrInner::Any => Some(WILDCARD),
AnyOrInner::Value(methods) => separated_by_commas(
methods
.into_iter()
.map(|m| HeaderValue::from_str(m.as_str()).unwrap()),
),
};
self
}
pub fn allow_origin<T>(mut self, origin: T) -> Self
where
T: Into<AnyOr<Origin>>,
{
self.allow_origin = Some(origin.into());
self
}
pub fn expose_headers<I>(mut self, headers: I) -> Self
where
I: Into<AnyOr<Vec<HeaderName>>>,
{
self.expose_headers = match headers.into().0 {
AnyOrInner::Any => Some(WILDCARD),
AnyOrInner::Value(headers) => separated_by_commas(headers.into_iter().map(Into::into)),
};
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct Any;
#[deprecated = "Use Any as a unit struct literal instead"]
pub fn any() -> Any {
Any
}
#[derive(Debug, Clone, Copy)]
pub struct AnyOr<T>(AnyOrInner<T>);
#[derive(Debug, Clone, Copy)]
enum AnyOrInner<T> {
Any,
Value(T),
}
impl From<Origin> for AnyOr<Origin> {
fn from(origin: Origin) -> Self {
AnyOr(AnyOrInner::Value(origin))
}
}
impl<T> From<Any> for AnyOr<T> {
fn from(_: Any) -> Self {
AnyOr(AnyOrInner::Any)
}
}
impl<I> From<I> for AnyOr<Vec<Method>>
where
I: IntoIterator<Item = Method>,
{
fn from(methods: I) -> Self {
AnyOr(AnyOrInner::Value(methods.into_iter().collect()))
}
}
impl<I> From<I> for AnyOr<Vec<HeaderName>>
where
I: IntoIterator<Item = HeaderName>,
{
fn from(headers: I) -> Self {
AnyOr(AnyOrInner::Value(headers.into_iter().collect()))
}
}
fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
where
I: Iterator<Item = HeaderValue>,
{
match iter.next() {
Some(fst) => {
let mut result = BytesMut::from(fst.as_bytes());
for val in iter {
result.reserve(val.len() + 1);
result.put_u8(b',');
result.extend_from_slice(val.as_bytes());
}
Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
}
None => None,
}
}
impl Default for CorsLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for CorsLayer {
type Service = Cors<S>;
fn layer(&self, inner: S) -> Self::Service {
Cors {
inner,
layer: self.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct Cors<S> {
inner: S,
layer: CorsLayer,
}
impl<S> Cors<S> {
pub fn new(inner: S) -> Self {
Self {
inner,
layer: CorsLayer::new(),
}
}
pub fn permissive(inner: S) -> Self {
Self {
inner,
layer: CorsLayer::permissive(),
}
}
define_inner_service_accessors!();
pub fn layer() -> CorsLayer {
CorsLayer::new()
}
pub fn allow_credentials(self, allow_credentials: bool) -> Self {
self.map_layer(|layer| layer.allow_credentials(allow_credentials))
}
pub fn allow_headers<I>(self, headers: I) -> Self
where
I: Into<AnyOr<Vec<HeaderName>>>,
{
self.map_layer(|layer| layer.allow_headers(headers))
}
pub fn max_age(self, max_age: Duration) -> Self {
self.map_layer(|layer| layer.max_age(max_age))
}
pub fn allow_methods<T>(self, methods: T) -> Self
where
T: Into<AnyOr<Vec<Method>>>,
{
self.map_layer(|layer| layer.allow_methods(methods))
}
pub fn allow_origin<T>(self, origin: T) -> Self
where
T: Into<AnyOr<Origin>>,
{
self.map_layer(|layer| layer.allow_origin(origin))
}
pub fn expose_headers<I>(self, headers: I) -> Self
where
I: Into<AnyOr<Vec<HeaderName>>>,
{
self.map_layer(|layer| layer.expose_headers(headers))
}
fn map_layer<F>(mut self, f: F) -> Self
where
F: FnOnce(CorsLayer) -> CorsLayer,
{
self.layer = f(self.layer);
self
}
fn is_valid_origin(&self, origin: &HeaderValue, parts: &Parts) -> bool {
if let Some(allow_origin) = &self.layer.allow_origin {
match &allow_origin.0 {
AnyOrInner::Any => true,
AnyOrInner::Value(allow_origin) => match &allow_origin.0 {
OriginInner::Exact(s) => s == origin,
OriginInner::List(list) => list.contains(origin),
OriginInner::Closure(f) => f(origin, parts),
},
}
} else {
false
}
}
fn is_valid_request_method(&self, method: &HeaderValue) -> bool {
if let Some(allow_methods) = &self.layer.allow_methods {
#[allow(clippy::borrow_interior_mutable_const)]
if allow_methods == WILDCARD {
return true;
}
allow_methods
.as_bytes()
.split(|&byte| byte == b',')
.any(|bytes| bytes == method.as_bytes())
} else {
false
}
}
fn build_preflight_response<B>(&self, origin: HeaderValue) -> Response<B>
where
B: Default,
{
let mut response = Response::new(B::default());
response
.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
if let Some(allow_methods) = &self.layer.allow_methods {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods.clone());
}
if let Some(allow_headers) = &self.layer.allow_headers {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, allow_headers.clone());
}
if let Some(max_age) = self.layer.max_age.clone() {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_MAX_AGE, max_age);
}
if let Some(allow_credentials) = self.layer.allow_credentials.clone() {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, allow_credentials);
}
if let Some(expose_headers) = self.layer.expose_headers.clone() {
response
.headers_mut()
.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers);
}
response
}
}
#[derive(Clone)]
pub struct Origin(OriginInner);
impl Origin {
pub fn exact(origin: HeaderValue) -> Self {
Self(OriginInner::Exact(origin))
}
pub fn list<I>(origins: I) -> Self
where
I: IntoIterator<Item = HeaderValue>,
{
let origins = origins.into_iter().collect::<Vec<_>>().into();
Self(OriginInner::List(origins))
}
pub fn predicate<F>(f: F) -> Self
where
F: Fn(&HeaderValue, &Parts) -> bool + Send + Sync + 'static,
{
Self(OriginInner::Closure(Arc::new(f)))
}
}
impl fmt::Debug for Origin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
OriginInner::Exact(inner) => f.debug_tuple("Exact").field(inner).finish(),
OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
OriginInner::Closure(_) => f.debug_tuple("Closure").finish(),
}
}
}
#[derive(Clone)]
enum OriginInner {
Exact(HeaderValue),
List(Arc<[HeaderValue]>),
Closure(Arc<dyn for<'a> Fn(&'a HeaderValue, &'a Parts) -> bool + Send + Sync + 'static>),
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Default,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future, ResBody>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let origin = req.headers().get(&header::ORIGIN).cloned();
let origin = if let Some(origin) = origin {
origin
} else {
return ResponseFuture {
inner: Kind::NonCorsCall {
future: self.inner.call(req),
},
};
};
let (parts, body) = req.into_parts();
let origin = if self.is_valid_origin(&origin, &parts) {
origin
} else {
return ResponseFuture {
inner: Kind::Error {
response: Some(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(ResBody::default())
.unwrap(),
),
},
};
};
let req = Request::from_parts(parts, body);
if req.method() == Method::OPTIONS {
match req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) {
Some(request_method) if self.is_valid_request_method(request_method) => {}
_ => {
return ResponseFuture {
inner: Kind::Error {
response: Some(
Response::builder()
.status(StatusCode::OK)
.body(ResBody::default())
.unwrap(),
),
},
};
}
}
return ResponseFuture {
inner: Kind::PreflightCall {
response: Some(self.build_preflight_response(origin)),
},
};
}
ResponseFuture {
inner: Kind::CorsCall {
future: self.inner.call(req),
allow_origin: self.layer.allow_origin.clone(),
origin,
allow_credentials: self.layer.allow_credentials.clone(),
expose_headers: self.layer.expose_headers.clone(),
},
}
}
}
pin_project! {
pub struct ResponseFuture<F, B> {
#[pin]
inner: Kind<F, B>,
}
}
pin_project! {
#[project = KindProj]
enum Kind<F, B> {
NonCorsCall {
#[pin]
future: F,
},
CorsCall {
#[pin]
future: F,
allow_origin: Option<AnyOr<Origin>>,
origin: HeaderValue,
allow_credentials: Option<HeaderValue>,
expose_headers: Option<HeaderValue>,
},
PreflightCall {
response: Option<Response<B>>,
},
Error {
response: Option<Response<B>>,
},
}
}
impl<F, B, E> Future for ResponseFuture<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = Result<Response<B>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
KindProj::CorsCall {
future,
allow_origin,
origin,
allow_credentials,
expose_headers,
} => {
let mut response: Response<B> = ready!(future.poll(cx))?;
let headers = response.headers_mut();
headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
response_origin(allow_origin.take().unwrap(), origin),
);
if let Some(allow_credentials) = allow_credentials {
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
allow_credentials.clone(),
);
}
if let Some(expose_headers) = expose_headers {
headers.insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
expose_headers.clone(),
);
}
apply_vary_headers(headers);
Poll::Ready(Ok(response))
}
KindProj::NonCorsCall { future } => future.poll(cx),
KindProj::PreflightCall { response } => {
let mut response = response.take().unwrap();
apply_vary_headers(response.headers_mut());
Poll::Ready(Ok(response))
}
KindProj::Error { response } => Poll::Ready(Ok(response.take().unwrap())),
}
}
}
fn apply_vary_headers(headers: &mut http::HeaderMap) {
const VARY_HEADERS: [HeaderName; 3] = [
header::ORIGIN,
header::ACCESS_CONTROL_REQUEST_METHOD,
header::ACCESS_CONTROL_REQUEST_HEADERS,
];
for h in &VARY_HEADERS {
headers.append(header::VARY, HeaderValue::from_static(h.as_str()));
}
}
fn response_origin(allow_origin: AnyOr<Origin>, origin: &HeaderValue) -> HeaderValue {
if let AnyOrInner::Any = allow_origin.0 {
WILDCARD
} else {
origin.clone()
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_is_valid_request_method() {
let cors = Cors::new(()).allow_methods(vec![Method::GET, Method::POST]);
assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET")));
assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST")));
let cors = Cors::new(());
assert!(!cors.is_valid_request_method(&HeaderValue::from_static("GET")));
assert!(!cors.is_valid_request_method(&HeaderValue::from_static("POST")));
assert!(!cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS")));
let cors = Cors::new(()).allow_methods(Any);
assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET")));
assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST")));
assert!(cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS")));
let cors = Cors::new(()).allow_methods(Any);
assert!(cors.is_valid_request_method(&HeaderValue::from_static("GET")));
assert!(cors.is_valid_request_method(&HeaderValue::from_static("POST")));
assert!(cors.is_valid_request_method(&HeaderValue::from_static("OPTIONS")));
}
}