use std::{
collections::BTreeMap,
io::{self, Write as _},
path::{Path, PathBuf},
};
use anyhow::Context as _;
use flate2::read::GzDecoder;
use sequoia_openpgp::{
parse::{stream::DecryptorBuilder, Parse},
policy::StandardPolicy,
};
use tracing::{debug, info, instrument};
use walkdir::WalkDir;
use crate::{
filesystem::get_combined_file_size,
package::{CompressionAlgorithm, Package, CHECKSUM_FILE, DATA_FILE},
task::{Mode, Status},
utils::{Progress, ProgressReader},
};
const HEAP_BUFFER_SIZE: usize = 1 << 22;
pub struct DecryptOpts<T, F> {
pub package: Package,
pub key_store: crate::openpgp::keystore::KeyStore,
pub cert_store: crate::openpgp::certstore::CertStore<'static>,
pub password: F,
pub output: Option<PathBuf>,
pub decrypt_only: bool,
pub mode: Mode,
pub progress: Option<T>,
}
impl<T, F> std::fmt::Debug for DecryptOpts<T, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DecryptOpts")
.field("package", &self.package)
.field("output", &self.output.as_ref().map(|p| p.display()))
.field("decrypt_only", &self.decrypt_only)
.field("mode", &self.mode)
.finish()
}
}
#[instrument(err(Debug, level=tracing::Level::ERROR))]
pub async fn decrypt<T, F>(mut opts: DecryptOpts<T, F>) -> anyhow::Result<Status>
where
T: Progress + Send + 'static,
F: Fn(crate::openpgp::types::PasswordHint) -> crate::openpgp::types::Password + Send + 'static,
{
if let Some(path) = opts.output {
opts.output = Some(path.canonicalize()?);
}
let output = get_output_path(opts.output, opts.package.name())?;
let policy = StandardPolicy::new();
let package = opts.package.verify(&opts.cert_store).await?;
let metadata = package.metadata().await?;
let (data_reader, data_size) = package.data().await?;
let mut data_reader = tokio_util::io::SyncIoBridge::new(data_reader);
let status = tokio::task::spawn_blocking(move || -> anyhow::Result<_> {
let mut decryptor = DecryptorBuilder::from_reader(&mut data_reader)?.with_policy(
&policy,
None,
crate::openpgp::crypto::DecryptionHelper {
cert_store: &opts.cert_store,
key_store: &mut opts.key_store,
password: opts.password,
},
)?;
let status = if let Mode::Check = opts.mode {
Status::Checked {
destination: output.to_string_lossy().to_string(),
source_size: data_size,
}
} else {
std::fs::create_dir_all(&output)?;
if let Some(mut pg) = opts.progress {
pg.set_length(data_size);
let mut p = ProgressReader::new(decryptor, |len| {
pg.inc(len.try_into()?);
Ok(())
});
if opts.decrypt_only {
write_to_file(&mut p, &output)?;
} else {
unpack(&mut p, &output, metadata.compression_algorithm)?;
}
pg.finish();
} else if opts.decrypt_only {
write_to_file(&mut decryptor, &output)?;
} else {
unpack(&mut decryptor, &output, metadata.compression_algorithm)?;
}
let output_files = WalkDir::new(&output)
.into_iter()
.flatten()
.filter(|entry| entry.file_type().is_file())
.map(|entry| entry.into_path());
Status::Completed {
source_size: data_size,
destination_size: get_combined_file_size(output_files)?,
destination: output.to_string_lossy().to_string(),
metadata,
}
};
Ok(status)
})
.await??;
match &status {
Status::Checked {
destination,
source_size,
} => {
debug!(destination, source_size, "Checked decryption task input");
}
Status::Completed {
destination,
source_size,
destination_size,
metadata,
} => {
info!(
destination,
source_size,
destination_size,
metadata = metadata.to_json_or_debug(),
"Successfully decrypted data package"
)
}
}
Ok(status)
}
fn get_output_path(output: Option<PathBuf>, pkg_file_name: &str) -> anyhow::Result<PathBuf> {
let base = if let Some(p) = output {
p
} else {
std::env::current_dir()?
};
let pkg_base_name = pkg_file_name
.split('.')
.next()
.context("Package file has no extension")?;
let mut output = base.join(pkg_base_name);
let mut i = 1;
while output.exists() {
output = base.join(format!("{pkg_base_name}_{i}"));
i += 1;
}
Ok(output)
}
#[instrument(skip(source))]
fn unpack<R: io::Read + Send>(
source: &mut R,
output: &Path,
compression_algorithm: CompressionAlgorithm,
) -> anyhow::Result<()> {
match compression_algorithm {
CompressionAlgorithm::Stored => unpack_tar(&mut tar::Archive::new(source), output),
CompressionAlgorithm::Gzip(_) => {
unpack_tar(&mut tar::Archive::new(GzDecoder::new(source)), output)
}
CompressionAlgorithm::Zstandard(_) => unpack_tar(
&mut tar::Archive::new(zstd::stream::read::Decoder::new(source)?),
output,
),
}?;
Ok(())
}
fn sanitize_path(dest: &Path, path: &Path) -> anyhow::Result<PathBuf> {
use std::path::Component;
let mut sanitized = PathBuf::new();
for part in path.components() {
match part {
Component::Prefix(_) | Component::RootDir | Component::CurDir => continue,
Component::ParentDir => anyhow::bail!("file path contains a relative part"),
Component::Normal(part) => sanitized.push(part),
}
}
anyhow::ensure!(sanitized.parent().is_some(), "empty file path");
Ok(dest.join(&sanitized))
}
enum Message {
Init(PathBuf),
Payload(bytes::Bytes),
Finalize,
}
fn unpack_tar(archive: &mut tar::Archive<impl io::Read>, dest: &Path) -> anyhow::Result<()> {
let (tx_checksum, rx_checksum) = std::sync::mpsc::sync_channel(8);
let (tx_write, rx_write) = std::sync::mpsc::sync_channel(8);
let checksum_handle = std::thread::spawn(move || -> anyhow::Result<_> {
use sequoia_openpgp::{crypto::hash::Digest as _, types::HashAlgorithm::SHA256};
let mut hasher = SHA256.context()?;
let mut path = None;
let mut checksums = BTreeMap::new();
while let Ok(message) = rx_checksum.recv() {
match message {
Message::Init(p) => {
path = Some(p);
}
Message::Payload(buf) => hasher.update(&buf),
Message::Finalize => {
checksums.insert(
std::mem::take(&mut path).expect("path is initialized"),
crate::utils::to_hex_string(
&std::mem::replace(&mut hasher, SHA256.context()?).into_digest()?,
),
);
}
}
}
Ok(checksums)
});
let write_handle = std::thread::spawn(move || -> io::Result<()> {
let mut writer = None;
while let Ok(message) = rx_write.recv() {
match message {
Message::Init(p) => {
if let Some(parent) = p.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent)?;
}
}
writer = Some(io::BufWriter::with_capacity(
HEAP_BUFFER_SIZE,
std::fs::File::create(&p)?,
));
}
Message::Payload(buf) => writer
.as_mut()
.expect("writer is initialized")
.write_all(&buf)?,
Message::Finalize => {
writer = None;
}
}
}
Ok(())
});
for entry in archive.entries()? {
let mut entry = entry?;
let archive_path = entry.path()?.into_owned();
let output_path = match sanitize_path(dest, &archive_path) {
Ok(p) => p,
Err(e) => {
tracing::warn!("{:?}: {}", archive_path, e);
continue;
}
};
tx_checksum.send(Message::Init(archive_path))?;
tx_write.send(Message::Init(output_path))?;
copy_to_channels(&mut entry, [&tx_checksum, &tx_write])?;
tx_checksum.send(Message::Finalize)?;
tx_write.send(Message::Finalize)?;
}
drop(tx_checksum);
drop(tx_write);
anyhow::ensure!(
write_handle.join().is_ok(),
"write thread in decrypt panicked"
);
let mut checksums = checksum_handle
.join()
.map_err(|_| anyhow::anyhow!("checksum thread in decrypt panicked"))??;
checksums.remove(Path::new(CHECKSUM_FILE));
verify_checksums(&checksums, &read_checksum_file(dest.join(CHECKSUM_FILE))?)?;
Ok(())
}
fn copy_to_channels<const N: usize>(
reader: &mut impl io::Read,
tx: [&std::sync::mpsc::SyncSender<Message>; N],
) -> anyhow::Result<()> {
let mut buf = [0; 8192];
let mut bigbuf = bytes::BytesMut::with_capacity(HEAP_BUFFER_SIZE);
macro_rules! exchange {
($buffer:expr) => {{
let b = std::mem::replace(&mut bigbuf, $buffer).freeze();
for tx in tx {
tx.send(Message::Payload(b.clone()))?;
}
}};
}
loop {
let n = reader.read(&mut buf)?;
if n == 0 {
if !bigbuf.is_empty() {
exchange!(bytes::BytesMut::new());
}
break;
}
if bigbuf.len() + n > bigbuf.capacity() {
exchange!(bytes::BytesMut::with_capacity(HEAP_BUFFER_SIZE));
}
bigbuf.extend_from_slice(&buf[..n]);
}
Ok(())
}
fn write_to_file<R: io::Read, P: AsRef<Path>>(source: &mut R, output: P) -> anyhow::Result<()> {
let mut f = std::fs::File::create(output.as_ref().join(DATA_FILE))?;
io::copy(source, &mut f)?;
Ok(())
}
fn read_checksum_file(path: impl AsRef<Path>) -> anyhow::Result<BTreeMap<PathBuf, String>> {
use std::io::BufRead as _;
let mut reader = io::BufReader::new(std::fs::File::open(path)?);
let mut parsed = BTreeMap::new();
let mut buf = String::new();
while reader.read_line(&mut buf)? > 0 {
let (checksum, path) = buf
.trim()
.split_once(char::is_whitespace)
.context("Unable to parse the checksum file")?;
parsed.insert(PathBuf::from(path), checksum.to_string());
buf.clear();
}
Ok(parsed)
}
fn verify_checksums(
source: &BTreeMap<PathBuf, String>,
reference: &BTreeMap<PathBuf, String>,
) -> anyhow::Result<()> {
for (path, checksum) in source {
let expected = reference
.get(path)
.with_context(|| format!("unable to find checksum for file: {path:?}"))?;
if !checksum.eq_ignore_ascii_case(expected) {
anyhow::bail!("wrong checksum for {path:?} (expected {expected}, computed {checksum})");
}
}
Ok(())
}