[go: up one dir, main page]

uv_auth/
store.rs

1use std::ops::Deref;
2use std::path::{Path, PathBuf};
3
4use fs_err as fs;
5use rustc_hash::FxHashMap;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8use uv_fs::{LockedFile, LockedFileError, LockedFileMode, with_added_extension};
9use uv_preview::{Preview, PreviewFeatures};
10use uv_redacted::DisplaySafeUrl;
11
12use uv_state::{StateBucket, StateStore};
13use uv_static::EnvVars;
14
15use crate::credentials::{Password, Token, Username};
16use crate::realm::Realm;
17use crate::service::Service;
18use crate::{Credentials, KeyringProvider};
19
20/// The storage backend to use in `uv auth` commands.
21#[derive(Debug)]
22pub enum AuthBackend {
23    // TODO(zanieb): Right now, we're using a keyring provider for the system store but that's just
24    // where the native implementation is living at the moment. We should consider refactoring these
25    // into a shared API in the future.
26    System(KeyringProvider),
27    TextStore(TextCredentialStore, LockedFile),
28}
29
30impl AuthBackend {
31    pub async fn from_settings(preview: Preview) -> Result<Self, TomlCredentialError> {
32        // If preview is enabled, we'll use the system-native store
33        if preview.is_enabled(PreviewFeatures::NATIVE_AUTH) {
34            return Ok(Self::System(KeyringProvider::native()));
35        }
36
37        // Otherwise, we'll use the plaintext credential store
38        let path = TextCredentialStore::default_file()?;
39        match TextCredentialStore::read(&path).await {
40            Ok((store, lock)) => Ok(Self::TextStore(store, lock)),
41            Err(err)
42                if err
43                    .as_io_error()
44                    .is_some_and(|err| err.kind() == std::io::ErrorKind::NotFound) =>
45            {
46                Ok(Self::TextStore(
47                    TextCredentialStore::default(),
48                    TextCredentialStore::lock(&path).await?,
49                ))
50            }
51            Err(err) => Err(err),
52        }
53    }
54}
55
56/// Authentication scheme to use.
57#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
58#[serde(rename_all = "lowercase")]
59pub enum AuthScheme {
60    /// HTTP Basic Authentication
61    ///
62    /// Uses a username and password.
63    #[default]
64    Basic,
65    /// Bearer token authentication.
66    ///
67    /// Uses a token provided as `Bearer <token>` in the `Authorization` header.
68    Bearer,
69}
70
71/// Errors that can occur when working with TOML credential storage.
72#[derive(Debug, Error)]
73pub enum TomlCredentialError {
74    #[error(transparent)]
75    Io(#[from] std::io::Error),
76    #[error(transparent)]
77    LockedFile(#[from] LockedFileError),
78    #[error("Failed to parse TOML credential file: {0}")]
79    ParseError(#[from] toml::de::Error),
80    #[error("Failed to serialize credentials to TOML")]
81    SerializeError(#[from] toml::ser::Error),
82    #[error(transparent)]
83    BasicAuthError(#[from] BasicAuthError),
84    #[error(transparent)]
85    BearerAuthError(#[from] BearerAuthError),
86    #[error("Failed to determine credentials directory")]
87    CredentialsDirError,
88    #[error("Token is not valid unicode")]
89    TokenNotUnicode(#[from] std::string::FromUtf8Error),
90}
91
92impl TomlCredentialError {
93    pub fn as_io_error(&self) -> Option<&std::io::Error> {
94        match self {
95            Self::Io(err) => Some(err),
96            Self::LockedFile(err) => err.as_io_error(),
97            Self::ParseError(_)
98            | Self::SerializeError(_)
99            | Self::BasicAuthError(_)
100            | Self::BearerAuthError(_)
101            | Self::CredentialsDirError
102            | Self::TokenNotUnicode(_) => None,
103        }
104    }
105}
106
107#[derive(Debug, Error)]
108pub enum BasicAuthError {
109    #[error("`username` is required with `scheme = basic`")]
110    MissingUsername,
111    #[error("`token` cannot be provided with `scheme = basic`")]
112    UnexpectedToken,
113}
114
115#[derive(Debug, Error)]
116pub enum BearerAuthError {
117    #[error("`token` is required with `scheme = bearer`")]
118    MissingToken,
119    #[error("`username` cannot be provided with `scheme = bearer`")]
120    UnexpectedUsername,
121    #[error("`password` cannot be provided with `scheme = bearer`")]
122    UnexpectedPassword,
123}
124
125/// A single credential entry in a TOML credentials file.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(try_from = "TomlCredentialWire", into = "TomlCredentialWire")]
128struct TomlCredential {
129    /// The service URL for this credential.
130    service: Service,
131    /// The credentials for this entry.
132    credentials: Credentials,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136struct TomlCredentialWire {
137    /// The service URL for this credential.
138    service: Service,
139    /// The username to use. Only allowed with [`AuthScheme::Basic`].
140    username: Username,
141    /// The authentication scheme.
142    #[serde(default)]
143    scheme: AuthScheme,
144    /// The password to use. Only allowed with [`AuthScheme::Basic`].
145    password: Option<Password>,
146    /// The token to use. Only allowed with [`AuthScheme::Bearer`].
147    token: Option<String>,
148}
149
150impl From<TomlCredential> for TomlCredentialWire {
151    fn from(value: TomlCredential) -> Self {
152        match value.credentials {
153            Credentials::Basic { username, password } => Self {
154                service: value.service,
155                username,
156                scheme: AuthScheme::Basic,
157                password,
158                token: None,
159            },
160            Credentials::Bearer { token } => Self {
161                service: value.service,
162                username: Username::new(None),
163                scheme: AuthScheme::Bearer,
164                password: None,
165                token: Some(String::from_utf8(token.into_bytes()).expect("Token is valid UTF-8")),
166            },
167        }
168    }
169}
170
171impl TryFrom<TomlCredentialWire> for TomlCredential {
172    type Error = TomlCredentialError;
173
174    fn try_from(value: TomlCredentialWire) -> Result<Self, Self::Error> {
175        match value.scheme {
176            AuthScheme::Basic => {
177                if value.username.as_deref().is_none() {
178                    return Err(TomlCredentialError::BasicAuthError(
179                        BasicAuthError::MissingUsername,
180                    ));
181                }
182                if value.token.is_some() {
183                    return Err(TomlCredentialError::BasicAuthError(
184                        BasicAuthError::UnexpectedToken,
185                    ));
186                }
187                let credentials = Credentials::Basic {
188                    username: value.username,
189                    password: value.password,
190                };
191                Ok(Self {
192                    service: value.service,
193                    credentials,
194                })
195            }
196            AuthScheme::Bearer => {
197                if value.username.is_some() {
198                    return Err(TomlCredentialError::BearerAuthError(
199                        BearerAuthError::UnexpectedUsername,
200                    ));
201                }
202                if value.password.is_some() {
203                    return Err(TomlCredentialError::BearerAuthError(
204                        BearerAuthError::UnexpectedPassword,
205                    ));
206                }
207                if value.token.is_none() {
208                    return Err(TomlCredentialError::BearerAuthError(
209                        BearerAuthError::MissingToken,
210                    ));
211                }
212                let credentials = Credentials::Bearer {
213                    token: Token::new(value.token.unwrap().into_bytes()),
214                };
215                Ok(Self {
216                    service: value.service,
217                    credentials,
218                })
219            }
220        }
221    }
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, Default)]
225struct TomlCredentials {
226    /// Array of credential entries.
227    #[serde(rename = "credential")]
228    credentials: Vec<TomlCredential>,
229}
230
231/// A credential store with a plain text storage backend.
232#[derive(Debug, Default)]
233pub struct TextCredentialStore {
234    credentials: FxHashMap<(Service, Username), Credentials>,
235}
236
237impl TextCredentialStore {
238    /// Return the directory for storing credentials.
239    pub fn directory_path() -> Result<PathBuf, TomlCredentialError> {
240        if let Some(dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR)
241            .filter(|s| !s.is_empty())
242            .map(PathBuf::from)
243        {
244            return Ok(dir);
245        }
246
247        Ok(StateStore::from_settings(None)?.bucket(StateBucket::Credentials))
248    }
249
250    /// Return the standard file path for storing credentials.
251    pub fn default_file() -> Result<PathBuf, TomlCredentialError> {
252        let dir = Self::directory_path()?;
253        Ok(dir.join("credentials.toml"))
254    }
255
256    /// Acquire a lock on the credentials file at the given path.
257    pub async fn lock(path: &Path) -> Result<LockedFile, TomlCredentialError> {
258        if let Some(parent) = path.parent() {
259            fs::create_dir_all(parent)?;
260        }
261        let lock = with_added_extension(path, ".lock");
262        Ok(LockedFile::acquire(lock, LockedFileMode::Exclusive, "credentials store").await?)
263    }
264
265    /// Read credentials from a file.
266    fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, TomlCredentialError> {
267        let content = fs::read_to_string(path)?;
268        let credentials: TomlCredentials = toml::from_str(&content)?;
269
270        let credentials: FxHashMap<(Service, Username), Credentials> = credentials
271            .credentials
272            .into_iter()
273            .map(|credential| {
274                let username = match &credential.credentials {
275                    Credentials::Basic { username, .. } => username.clone(),
276                    Credentials::Bearer { .. } => Username::none(),
277                };
278                (
279                    (credential.service.clone(), username),
280                    credential.credentials,
281                )
282            })
283            .collect();
284
285        Ok(Self { credentials })
286    }
287
288    /// Read credentials from a file.
289    ///
290    /// Returns [`TextCredentialStore`] and a [`LockedFile`] to hold if mutating the store.
291    ///
292    /// If the store will not be written to following the read, the lock can be dropped.
293    pub async fn read<P: AsRef<Path>>(path: P) -> Result<(Self, LockedFile), TomlCredentialError> {
294        let lock = Self::lock(path.as_ref()).await?;
295        let store = Self::from_file(path)?;
296        Ok((store, lock))
297    }
298
299    /// Persist credentials to a file.
300    ///
301    /// Requires a [`LockedFile`] from [`TextCredentialStore::lock`] or
302    /// [`TextCredentialStore::read`] to ensure exclusive access.
303    pub fn write<P: AsRef<Path>>(
304        self,
305        path: P,
306        _lock: LockedFile,
307    ) -> Result<(), TomlCredentialError> {
308        let credentials = self
309            .credentials
310            .into_iter()
311            .map(|((service, _username), credentials)| TomlCredential {
312                service,
313                credentials,
314            })
315            .collect::<Vec<_>>();
316
317        let toml_creds = TomlCredentials { credentials };
318        let content = toml::to_string_pretty(&toml_creds)?;
319        fs::create_dir_all(
320            path.as_ref()
321                .parent()
322                .ok_or(TomlCredentialError::CredentialsDirError)?,
323        )?;
324
325        // TODO(zanieb): We should use an atomic write here
326        fs::write(path, content)?;
327        Ok(())
328    }
329
330    /// Get credentials for a given URL and username.
331    ///
332    /// The most specific URL prefix match in the same [`Realm`] is returned, if any.
333    pub fn get_credentials(
334        &self,
335        url: &DisplaySafeUrl,
336        username: Option<&str>,
337    ) -> Option<&Credentials> {
338        let request_realm = Realm::from(url);
339
340        // Perform an exact lookup first
341        // TODO(zanieb): Consider adding `DisplaySafeUrlRef` so we can avoid this clone
342        // TODO(zanieb): We could also return early here if we can't normalize to a `Service`
343        if let Ok(url_service) = Service::try_from(url.clone()) {
344            if let Some(credential) = self.credentials.get(&(
345                url_service.clone(),
346                Username::from(username.map(str::to_string)),
347            )) {
348                return Some(credential);
349            }
350        }
351
352        // If that fails, iterate through to find a prefix match
353        let mut best: Option<(usize, &Service, &Credentials)> = None;
354
355        for ((service, stored_username), credential) in &self.credentials {
356            let service_realm = Realm::from(service.url().deref());
357
358            // Only consider services in the same realm
359            if service_realm != request_realm {
360                continue;
361            }
362
363            // Service path must be a prefix of request path
364            if !url.path().starts_with(service.url().path()) {
365                continue;
366            }
367
368            // If a username is provided, it must match
369            if let Some(request_username) = username {
370                if Some(request_username) != stored_username.as_deref() {
371                    continue;
372                }
373            }
374
375            // Update our best matching credential based on prefix length
376            let specificity = service.url().path().len();
377            if best.is_none_or(|(best_specificity, _, _)| specificity > best_specificity) {
378                best = Some((specificity, service, credential));
379            }
380        }
381
382        // Return the most specific match
383        if let Some((_, _, credential)) = best {
384            return Some(credential);
385        }
386
387        None
388    }
389
390    /// Store credentials for a given service.
391    pub fn insert(&mut self, service: Service, credentials: Credentials) -> Option<Credentials> {
392        let username = match &credentials {
393            Credentials::Basic { username, .. } => username.clone(),
394            Credentials::Bearer { .. } => Username::none(),
395        };
396        self.credentials.insert((service, username), credentials)
397    }
398
399    /// Remove credentials for a given service.
400    pub fn remove(&mut self, service: &Service, username: Username) -> Option<Credentials> {
401        // Remove the specific credential for this service and username
402        self.credentials.remove(&(service.clone(), username))
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use std::io::Write;
409    use std::str::FromStr;
410
411    use tempfile::NamedTempFile;
412
413    use super::*;
414
415    #[test]
416    fn test_toml_serialization() {
417        let credentials = TomlCredentials {
418            credentials: vec![
419                TomlCredential {
420                    service: Service::from_str("https://example.com").unwrap(),
421                    credentials: Credentials::Basic {
422                        username: Username::new(Some("user1".to_string())),
423                        password: Some(Password::new("pass1".to_string())),
424                    },
425                },
426                TomlCredential {
427                    service: Service::from_str("https://test.org").unwrap(),
428                    credentials: Credentials::Basic {
429                        username: Username::new(Some("user2".to_string())),
430                        password: Some(Password::new("pass2".to_string())),
431                    },
432                },
433            ],
434        };
435
436        let toml_str = toml::to_string_pretty(&credentials).unwrap();
437        let parsed: TomlCredentials = toml::from_str(&toml_str).unwrap();
438
439        assert_eq!(parsed.credentials.len(), 2);
440        assert_eq!(
441            parsed.credentials[0].service.to_string(),
442            "https://example.com/"
443        );
444        assert_eq!(
445            parsed.credentials[1].service.to_string(),
446            "https://test.org/"
447        );
448    }
449
450    #[test]
451    fn test_credential_store_operations() {
452        let mut store = TextCredentialStore::default();
453        let credentials = Credentials::basic(Some("user".to_string()), Some("pass".to_string()));
454
455        let service = Service::from_str("https://example.com").unwrap();
456        store.insert(service.clone(), credentials.clone());
457        let url = DisplaySafeUrl::parse("https://example.com/").unwrap();
458        assert!(store.get_credentials(&url, None).is_some());
459
460        let url = DisplaySafeUrl::parse("https://example.com/path").unwrap();
461        let retrieved = store.get_credentials(&url, None).unwrap();
462        assert_eq!(retrieved.username(), Some("user"));
463        assert_eq!(retrieved.password(), Some("pass"));
464
465        assert!(
466            store
467                .remove(&service, Username::from(Some("user".to_string())))
468                .is_some()
469        );
470        let url = DisplaySafeUrl::parse("https://example.com/").unwrap();
471        assert!(store.get_credentials(&url, None).is_none());
472    }
473
474    #[tokio::test]
475    async fn test_file_operations() {
476        let mut temp_file = NamedTempFile::new().unwrap();
477        writeln!(
478            temp_file,
479            r#"
480[[credential]]
481service = "https://example.com"
482username = "testuser"
483scheme = "basic"
484password = "testpass"
485
486[[credential]]
487service = "https://test.org"
488username = "user2"
489password = "pass2"
490"#
491        )
492        .unwrap();
493
494        let store = TextCredentialStore::from_file(temp_file.path()).unwrap();
495
496        let url = DisplaySafeUrl::parse("https://example.com/").unwrap();
497        assert!(store.get_credentials(&url, None).is_some());
498        let url = DisplaySafeUrl::parse("https://test.org/").unwrap();
499        assert!(store.get_credentials(&url, None).is_some());
500
501        let url = DisplaySafeUrl::parse("https://example.com").unwrap();
502        let cred = store.get_credentials(&url, None).unwrap();
503        assert_eq!(cred.username(), Some("testuser"));
504        assert_eq!(cred.password(), Some("testpass"));
505
506        // Test saving
507        let temp_output = NamedTempFile::new().unwrap();
508        store
509            .write(
510                temp_output.path(),
511                TextCredentialStore::lock(temp_file.path()).await.unwrap(),
512            )
513            .unwrap();
514
515        let content = fs::read_to_string(temp_output.path()).unwrap();
516        assert!(content.contains("example.com"));
517        assert!(content.contains("testuser"));
518    }
519
520    #[test]
521    fn test_prefix_matching() {
522        let mut store = TextCredentialStore::default();
523        let credentials = Credentials::basic(Some("user".to_string()), Some("pass".to_string()));
524
525        // Store credentials for a specific path prefix
526        let service = Service::from_str("https://example.com/api").unwrap();
527        store.insert(service.clone(), credentials.clone());
528
529        // Should match URLs that are prefixes of the stored service
530        let matching_urls = [
531            "https://example.com/api",
532            "https://example.com/api/v1",
533            "https://example.com/api/v1/users",
534        ];
535
536        for url_str in matching_urls {
537            let url = DisplaySafeUrl::parse(url_str).unwrap();
538            let cred = store.get_credentials(&url, None);
539            assert!(cred.is_some(), "Failed to match URL with prefix: {url_str}");
540        }
541
542        // Should NOT match URLs that are not prefixes
543        let non_matching_urls = [
544            "https://example.com/different",
545            "https://example.com/ap", // Not a complete path segment match
546            "https://example.com",    // Shorter than the stored prefix
547        ];
548
549        for url_str in non_matching_urls {
550            let url = DisplaySafeUrl::parse(url_str).unwrap();
551            let cred = store.get_credentials(&url, None);
552            assert!(cred.is_none(), "Should not match non-prefix URL: {url_str}");
553        }
554    }
555
556    #[test]
557    fn test_realm_based_matching() {
558        let mut store = TextCredentialStore::default();
559        let credentials = Credentials::basic(Some("user".to_string()), Some("pass".to_string()));
560
561        // Store by full URL (realm)
562        let service = Service::from_str("https://example.com").unwrap();
563        store.insert(service.clone(), credentials.clone());
564
565        // Should match URLs in the same realm
566        let matching_urls = [
567            "https://example.com",
568            "https://example.com/path",
569            "https://example.com/different/path",
570            "https://example.com:443/path", // Default HTTPS port
571        ];
572
573        for url_str in matching_urls {
574            let url = DisplaySafeUrl::parse(url_str).unwrap();
575            let cred = store.get_credentials(&url, None);
576            assert!(
577                cred.is_some(),
578                "Failed to match URL in same realm: {url_str}"
579            );
580        }
581
582        // Should NOT match URLs in different realms
583        let non_matching_urls = [
584            "http://example.com",       // Different scheme
585            "https://different.com",    // Different host
586            "https://example.com:8080", // Different port
587        ];
588
589        for url_str in non_matching_urls {
590            let url = DisplaySafeUrl::parse(url_str).unwrap();
591            let cred = store.get_credentials(&url, None);
592            assert!(
593                cred.is_none(),
594                "Should not match URL in different realm: {url_str}"
595            );
596        }
597    }
598
599    #[test]
600    fn test_most_specific_prefix_matching() {
601        let mut store = TextCredentialStore::default();
602        let general_cred =
603            Credentials::basic(Some("general".to_string()), Some("pass1".to_string()));
604        let specific_cred =
605            Credentials::basic(Some("specific".to_string()), Some("pass2".to_string()));
606
607        // Store credentials with different prefix lengths
608        let general_service = Service::from_str("https://example.com/api").unwrap();
609        let specific_service = Service::from_str("https://example.com/api/v1").unwrap();
610        store.insert(general_service.clone(), general_cred);
611        store.insert(specific_service.clone(), specific_cred);
612
613        // Should match the most specific prefix
614        let url = DisplaySafeUrl::parse("https://example.com/api/v1/users").unwrap();
615        let cred = store.get_credentials(&url, None).unwrap();
616        assert_eq!(cred.username(), Some("specific"));
617
618        // Should match the general prefix for non-specific paths
619        let url = DisplaySafeUrl::parse("https://example.com/api/v2").unwrap();
620        let cred = store.get_credentials(&url, None).unwrap();
621        assert_eq!(cred.username(), Some("general"));
622    }
623
624    #[test]
625    fn test_username_exact_url_match() {
626        let mut store = TextCredentialStore::default();
627        let url = DisplaySafeUrl::parse("https://example.com").unwrap();
628        let service = Service::from_str("https://example.com").unwrap();
629        let user1_creds = Credentials::basic(Some("user1".to_string()), Some("pass1".to_string()));
630        store.insert(service.clone(), user1_creds.clone());
631
632        // Should return credentials when username matches
633        let result = store.get_credentials(&url, Some("user1"));
634        assert!(result.is_some());
635        assert_eq!(result.unwrap().username(), Some("user1"));
636        assert_eq!(result.unwrap().password(), Some("pass1"));
637
638        // Should not return credentials when username doesn't match
639        let result = store.get_credentials(&url, Some("user2"));
640        assert!(result.is_none());
641
642        // Should return credentials when no username is specified
643        let result = store.get_credentials(&url, None);
644        assert!(result.is_some());
645        assert_eq!(result.unwrap().username(), Some("user1"));
646    }
647
648    #[test]
649    fn test_username_prefix_url_match() {
650        let mut store = TextCredentialStore::default();
651
652        // Add credentials with different usernames for overlapping URL prefixes
653        let general_service = Service::from_str("https://example.com/api").unwrap();
654        let specific_service = Service::from_str("https://example.com/api/v1").unwrap();
655
656        let general_creds = Credentials::basic(
657            Some("general_user".to_string()),
658            Some("general_pass".to_string()),
659        );
660        let specific_creds = Credentials::basic(
661            Some("specific_user".to_string()),
662            Some("specific_pass".to_string()),
663        );
664
665        store.insert(general_service, general_creds);
666        store.insert(specific_service, specific_creds);
667
668        let url = DisplaySafeUrl::parse("https://example.com/api/v1/users").unwrap();
669
670        // Should match specific credentials when username matches
671        let result = store.get_credentials(&url, Some("specific_user"));
672        assert!(result.is_some());
673        assert_eq!(result.unwrap().username(), Some("specific_user"));
674
675        // Should match the general credentials when requesting general_user (falls back to less specific prefix)
676        let result = store.get_credentials(&url, Some("general_user"));
677        assert!(
678            result.is_some(),
679            "Should match general_user from less specific prefix"
680        );
681        assert_eq!(result.unwrap().username(), Some("general_user"));
682
683        // Should match most specific when no username specified
684        let result = store.get_credentials(&url, None);
685        assert!(result.is_some());
686        assert_eq!(result.unwrap().username(), Some("specific_user"));
687    }
688}