[go: up one dir, main page]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
use crate::{Error, Headers, Method, Result};

/// Cors struct, holding cors configuration
#[derive(Debug, Clone)]
pub struct Cors {
    credentials: bool,
    max_age: Option<u32>,
    origins: Vec<String>,
    methods: Vec<Method>,
    allowed_headers: Vec<String>,
    exposed_headers: Vec<String>,
}

/// Creates a default cors configuration, which will do nothing.
impl Default for Cors {
    fn default() -> Self {
        Self {
            credentials: false,
            max_age: None,
            origins: vec![],
            methods: vec![],
            allowed_headers: vec![],
            exposed_headers: vec![],
        }
    }
}

impl Cors {
    /// `new` constructor for convenience; does the same as `Self::default()`.
    pub fn new() -> Self {
        Self::default()
    }

    /// Configures whether cors is allowed to share credentials or not.
    pub fn with_credentials(mut self, credentials: bool) -> Self {
        self.credentials = credentials;
        self
    }

    /// Configures how long cors is allowed to cache a preflight-response.
    pub fn with_max_age(mut self, max_age: u32) -> Self {
        self.max_age = Some(max_age);
        self
    }

    /// Configures which origins are allowed for cors.
    pub fn with_origins<S: Into<String>, V: IntoIterator<Item = S>>(mut self, origins: V) -> Self {
        self.origins = origins
            .into_iter()
            .map(|item| item.into())
            .collect::<Vec<String>>();
        self
    }

    /// Configures which methods are allowed for cors.
    pub fn with_methods<V: IntoIterator<Item = Method>>(mut self, methods: V) -> Self {
        self.methods = methods.into_iter().collect();
        self
    }

    /// Configures which headers are allowed for cors.
    pub fn with_allowed_headers<S: Into<String>, V: IntoIterator<Item = S>>(
        mut self,
        origins: V,
    ) -> Self {
        self.allowed_headers = origins
            .into_iter()
            .map(|item| item.into())
            .collect::<Vec<String>>();
        self
    }

    /// Configures which headers the client is allowed to access.
    pub fn with_exposed_headers<S: Into<String>, V: IntoIterator<Item = S>>(
        mut self,
        origins: V,
    ) -> Self {
        self.exposed_headers = origins
            .into_iter()
            .map(|item| item.into())
            .collect::<Vec<String>>();
        self
    }

    /// Applies the cors configuration to response headers.
    pub fn apply_headers(&self, headers: &mut Headers) -> Result<()> {
        if self.credentials {
            headers.set("Access-Control-Allow-Credentials", "true")?;
        }
        if let Some(ref max_age) = self.max_age {
            headers.set("Access-Control-Max-Age", format!("{max_age}").as_str())?;
        }
        if !self.origins.is_empty() {
            headers.set(
                "Access-Control-Allow-Origin",
                concat_vec_to_string(self.origins.as_slice())?.as_str(),
            )?;
        }
        if !self.methods.is_empty() {
            headers.set(
                "Access-Control-Allow-Methods",
                concat_vec_to_string(self.methods.as_slice())?.as_str(),
            )?;
        }
        if !self.allowed_headers.is_empty() {
            headers.set(
                "Access-Control-Allow-Headers",
                concat_vec_to_string(self.allowed_headers.as_slice())?.as_str(),
            )?;
        }
        if !self.exposed_headers.is_empty() {
            headers.set(
                "Access-Control-Expose-headers",
                concat_vec_to_string(self.exposed_headers.as_slice())?.as_str(),
            )?;
        }
        Ok(())
    }
}

fn concat_vec_to_string<S: AsRef<str>>(vec: &[S]) -> Result<String> {
    let str = vec.iter().fold("".to_owned(), |mut init, item| {
        init.push(',');
        init.push_str(item.as_ref());
        init
    });
    if !str.is_empty() {
        Ok(str[1..].to_string())
    } else {
        Err(Error::RustError(
            "Tried to concat header values without values.".to_string(),
        ))
    }
}