use crate::error::{KeyringError, Result};
use byteorder::{ByteOrder, LittleEndian};
use std::ffi::OsStr;
use std::iter::once;
use std::mem::MaybeUninit;
use std::os::windows::ffi::OsStrExt;
use std::slice;
use std::str;
use winapi::shared::minwindef::FILETIME;
use winapi::shared::winerror::{ERROR_NOT_FOUND, ERROR_NO_SUCH_LOGON_SESSION};
use winapi::um::errhandlingapi::GetLastError;
use winapi::um::wincred::{
CredDeleteW, CredFree, CredReadW, CredWriteW, CREDENTIALW, CRED_PERSIST_ENTERPRISE,
CRED_TYPE_GENERIC, PCREDENTIALW, PCREDENTIAL_ATTRIBUTEW,
};
pub struct Keyring<'a> {
service: &'a str,
username: &'a str,
}
impl<'a> Keyring<'a> {
pub fn new(service: &'a str, username: &'a str) -> Keyring<'a> {
Keyring { service, username }
}
pub fn set_password(&self, password: &str) -> Result<()> {
let flags = 0;
let cred_type = CRED_TYPE_GENERIC;
let target_name: String = [self.username, self.service].join(".");
let mut target_name = to_wstr(&target_name);
let mut empty_str = to_wstr("");
let last_written = FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let blob_u16 = to_wstr_no_null(password);
let mut blob = vec![0; blob_u16.len() * 2];
LittleEndian::write_u16_into(&blob_u16, &mut blob);
let blob_len = blob.len() as u32;
let persist = CRED_PERSIST_ENTERPRISE;
let attribute_count = 0;
let attributes: PCREDENTIAL_ATTRIBUTEW = std::ptr::null_mut();
let mut username = to_wstr(self.username);
let mut credential = CREDENTIALW {
Flags: flags,
Type: cred_type,
TargetName: target_name.as_mut_ptr(),
Comment: empty_str.as_mut_ptr(),
LastWritten: last_written,
CredentialBlobSize: blob_len,
CredentialBlob: blob.as_mut_ptr(),
Persist: persist,
AttributeCount: attribute_count,
Attributes: attributes,
TargetAlias: empty_str.as_mut_ptr(),
UserName: username.as_mut_ptr(),
};
let pcredential: PCREDENTIALW = &mut credential;
match unsafe { CredWriteW(pcredential, 0) } {
0 => Err(KeyringError::WindowsVaultError),
_ => Ok(()),
}
}
pub fn get_password(&self) -> Result<String> {
let mut pcredential = MaybeUninit::uninit();
let target_name: String = [self.username, self.service].join(".");
let target_name = to_wstr(&target_name);
let cred_type = CRED_TYPE_GENERIC;
match unsafe { CredReadW(target_name.as_ptr(), cred_type, 0, pcredential.as_mut_ptr()) } {
0 => unsafe {
match GetLastError() {
ERROR_NOT_FOUND => Err(KeyringError::NoPasswordFound),
ERROR_NO_SUCH_LOGON_SESSION => Err(KeyringError::NoBackendFound),
_ => Err(KeyringError::WindowsVaultError),
}
},
_ => {
let pcredential = unsafe { pcredential.assume_init() };
let credential: CREDENTIALW = unsafe { *pcredential };
let blob_pointer: *const u8 = credential.CredentialBlob;
let blob_len: usize = credential.CredentialBlobSize as usize;
let blob: &[u8] = unsafe { slice::from_raw_parts(blob_pointer, blob_len) };
let mut blob_u16 = vec![0; blob_len / 2];
LittleEndian::read_u16_into(blob, &mut blob_u16);
let password =
String::from_utf16(&blob_u16).map_err(|_| KeyringError::WindowsVaultError);
unsafe {
CredFree(pcredential as *mut _);
}
password
}
}
}
pub fn delete_password(&self) -> Result<()> {
let target_name: String = [self.username, self.service].join(".");
let cred_type = CRED_TYPE_GENERIC;
let target_name = to_wstr(&target_name);
match unsafe { CredDeleteW(target_name.as_ptr(), cred_type, 0) } {
0 => unsafe {
match GetLastError() {
ERROR_NOT_FOUND => Err(KeyringError::NoPasswordFound),
ERROR_NO_SUCH_LOGON_SESSION => Err(KeyringError::NoBackendFound),
_ => Err(KeyringError::WindowsVaultError),
}
},
_ => Ok(()),
}
}
}
fn to_wstr(s: &str) -> Vec<u16> {
OsStr::new(s).encode_wide().chain(once(0)).collect()
}
fn to_wstr_no_null(s: &str) -> Vec<u16> {
OsStr::new(s).encode_wide().collect()
}
#[cfg(test)]
#[cfg(target_os = "windows")]
mod test {
use super::*;
#[test]
fn test_basic() {
let password_1 = "大根";
let password_2 = "0xE5A4A7E6A0B9";
let keyring = Keyring::new("testservice", "testuser");
keyring.set_password(password_1).unwrap();
let res_1 = keyring.get_password().unwrap();
assert_eq!(
res_1, password_1,
"Stored and retrieved passwords don't match"
);
keyring.set_password(password_2).unwrap();
let res_2 = keyring.get_password().unwrap();
assert_eq!(
res_2, password_2,
"Stored and retrieved passwords don't match"
);
keyring.delete_password().unwrap();
assert!(
keyring.get_password().is_err(),
"Able to read a deleted password"
)
}
#[test]
fn test_no_password() {
let keyring = Keyring::new("testservice", "test-no-password");
let result = keyring.get_password();
match result {
Ok(_) => panic!("expected KeyringError::NoPassword, got Ok"),
Err(KeyringError::NoPasswordFound) => (),
Err(e) => panic!("expected KeyringError::NoPassword, got {:}", e),
}
let result = keyring.delete_password();
match result {
Ok(_) => panic!("expected Err(KeyringError::NoPassword), got Ok()"),
Err(KeyringError::NoPasswordFound) => (),
Err(e) => panic!("expected KeyringError::NoPassword, got {:}", e),
}
}
}