#![allow(clippy::enum_variant_names)]
use bytes::{BufMut, BytesMut};
use futures_core::ready;
use http::{
header::{self, HeaderName},
HeaderMap, HeaderValue, Method, Request, Response,
};
use pin_project_lite::pin_project;
use std::{
array,
future::Future,
mem,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
mod allow_credentials;
mod allow_headers;
mod allow_methods;
mod allow_origin;
mod expose_headers;
mod max_age;
mod vary;
pub use self::{
allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
};
#[derive(Debug, Clone)]
#[must_use]
pub struct CorsLayer {
allow_credentials: AllowCredentials,
allow_headers: AllowHeaders,
allow_methods: AllowMethods,
allow_origin: AllowOrigin,
expose_headers: ExposeHeaders,
max_age: MaxAge,
vary: Vary,
}
#[allow(clippy::declare_interior_mutable_const)]
const WILDCARD: HeaderValue = HeaderValue::from_static("*");
impl CorsLayer {
pub fn new() -> Self {
Self {
allow_credentials: Default::default(),
allow_headers: Default::default(),
allow_methods: Default::default(),
allow_origin: Default::default(),
expose_headers: Default::default(),
max_age: Default::default(),
vary: Default::default(),
}
}
pub fn permissive() -> Self {
Self::new()
.allow_headers(Any)
.allow_methods(Any)
.allow_origin(Any)
.expose_headers(Any)
}
pub fn very_permissive() -> Self {
Self::new()
.allow_headers(AllowHeaders::mirror_request())
.allow_methods(AllowMethods::mirror_request())
.allow_origin(AllowOrigin::mirror_request())
}
pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self
where
T: Into<AllowCredentials>,
{
self.allow_credentials = allow_credentials.into();
self
}
pub fn allow_headers<T>(mut self, headers: T) -> Self
where
T: Into<AllowHeaders>,
{
self.allow_headers = headers.into();
self
}
pub fn max_age<T>(mut self, max_age: T) -> Self
where
T: Into<MaxAge>,
{
self.max_age = max_age.into();
self
}
pub fn allow_methods<T>(mut self, methods: T) -> Self
where
T: Into<AllowMethods>,
{
self.allow_methods = methods.into();
self
}
pub fn allow_origin<T>(mut self, origin: T) -> Self
where
T: Into<AllowOrigin>,
{
self.allow_origin = origin.into();
self
}
pub fn expose_headers<T>(mut self, headers: T) -> Self
where
T: Into<ExposeHeaders>,
{
self.expose_headers = headers.into();
self
}
pub fn vary<T>(mut self, headers: T) -> Self
where
T: Into<Vary>,
{
self.vary = headers.into();
self
}
}
#[derive(Debug, Clone, Copy)]
#[must_use]
pub struct Any;
#[deprecated = "Use Any as a unit struct literal instead"]
pub fn any() -> Any {
Any
}
fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
where
I: Iterator<Item = HeaderValue>,
{
match iter.next() {
Some(fst) => {
let mut result = BytesMut::from(fst.as_bytes());
for val in iter {
result.reserve(val.len() + 1);
result.put_u8(b',');
result.extend_from_slice(val.as_bytes());
}
Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
}
None => None,
}
}
impl Default for CorsLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for CorsLayer {
type Service = Cors<S>;
fn layer(&self, inner: S) -> Self::Service {
ensure_usable_cors_rules(self);
Cors {
inner,
layer: self.clone(),
}
}
}
#[derive(Debug, Clone)]
#[must_use]
pub struct Cors<S> {
inner: S,
layer: CorsLayer,
}
impl<S> Cors<S> {
pub fn new(inner: S) -> Self {
Self {
inner,
layer: CorsLayer::new(),
}
}
pub fn permissive(inner: S) -> Self {
Self {
inner,
layer: CorsLayer::permissive(),
}
}
pub fn very_permissive(inner: S) -> Self {
Self {
inner,
layer: CorsLayer::very_permissive(),
}
}
define_inner_service_accessors!();
pub fn layer() -> CorsLayer {
CorsLayer::new()
}
pub fn allow_credentials<T>(self, allow_credentials: T) -> Self
where
T: Into<AllowCredentials>,
{
self.map_layer(|layer| layer.allow_credentials(allow_credentials))
}
pub fn allow_headers<T>(self, headers: T) -> Self
where
T: Into<AllowHeaders>,
{
self.map_layer(|layer| layer.allow_headers(headers))
}
pub fn max_age<T>(self, max_age: T) -> Self
where
T: Into<MaxAge>,
{
self.map_layer(|layer| layer.max_age(max_age))
}
pub fn allow_methods<T>(self, methods: T) -> Self
where
T: Into<AllowMethods>,
{
self.map_layer(|layer| layer.allow_methods(methods))
}
pub fn allow_origin<T>(self, origin: T) -> Self
where
T: Into<AllowOrigin>,
{
self.map_layer(|layer| layer.allow_origin(origin))
}
pub fn expose_headers<T>(self, headers: T) -> Self
where
T: Into<ExposeHeaders>,
{
self.map_layer(|layer| layer.expose_headers(headers))
}
fn map_layer<F>(mut self, f: F) -> Self
where
F: FnOnce(CorsLayer) -> CorsLayer,
{
self.layer = f(self.layer);
self
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Default,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ensure_usable_cors_rules(&self.layer);
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let (parts, body) = req.into_parts();
let origin = parts.headers.get(&header::ORIGIN);
let mut headers = HeaderMap::new();
headers.extend(self.layer.allow_origin.to_header(origin, &parts));
headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
let mut vary_headers = self.layer.vary.values();
if let Some(first) = vary_headers.next() {
let mut header = match headers.entry(header::VARY) {
header::Entry::Occupied(_) => {
unreachable!("no vary header inserted up to this point")
}
header::Entry::Vacant(v) => v.insert_entry(first),
};
for val in vary_headers {
header.append(val);
}
}
if parts.method == Method::OPTIONS {
headers.extend(self.layer.allow_methods.to_header(&parts));
headers.extend(self.layer.allow_headers.to_header(&parts));
headers.extend(self.layer.max_age.to_header(origin, &parts));
ResponseFuture {
inner: Kind::PreflightCall { headers },
}
} else {
headers.extend(self.layer.expose_headers.to_header(&parts));
let req = Request::from_parts(parts, body);
ResponseFuture {
inner: Kind::CorsCall {
future: self.inner.call(req),
headers,
},
}
}
}
}
pin_project! {
pub struct ResponseFuture<F> {
#[pin]
inner: Kind<F>,
}
}
pin_project! {
#[project = KindProj]
enum Kind<F> {
CorsCall {
#[pin]
future: F,
headers: HeaderMap,
},
PreflightCall {
headers: HeaderMap,
},
}
}
impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
B: Default,
{
type Output = Result<Response<B>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
KindProj::CorsCall { future, headers } => {
let mut response: Response<B> = ready!(future.poll(cx))?;
response.headers_mut().extend(headers.drain());
Poll::Ready(Ok(response))
}
KindProj::PreflightCall { headers } => {
let mut response = Response::new(B::default());
mem::swap(response.headers_mut(), headers);
Poll::Ready(Ok(response))
}
}
}
}
fn ensure_usable_cors_rules(layer: &CorsLayer) {
if layer.allow_credentials.is_true() {
assert!(
!layer.allow_headers.is_wildcard(),
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
with `Access-Control-Allow-Headers: *`"
);
assert!(
!layer.allow_methods.is_wildcard(),
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
with `Access-Control-Allow-Methods: *`"
);
assert!(
!layer.allow_origin.is_wildcard(),
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
with `Access-Control-Allow-Origin: *`"
);
assert!(
!layer.expose_headers.is_wildcard(),
"Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
with `Access-Control-Expose-Headers: *`"
);
}
}
pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
#[allow(deprecated)] array::IntoIter::new([
header::ORIGIN,
header::ACCESS_CONTROL_REQUEST_METHOD,
header::ACCESS_CONTROL_REQUEST_HEADERS,
])
}