[go: up one dir, main page]

sett 0.3.0

Rust port of sett (data compression, encryption and transfer tool).
Documentation
//! Decrypt workflow

use std::{
    collections::BTreeMap,
    io::{self, Write as _},
    path::{Path, PathBuf},
};

use anyhow::Context as _;
use flate2::read::GzDecoder;
use sequoia_openpgp::{
    parse::{stream::DecryptorBuilder, Parse},
    policy::StandardPolicy,
};
use tracing::{debug, info, instrument};
use walkdir::WalkDir;

use crate::{
    filesystem::get_combined_file_size,
    package::{CompressionAlgorithm, Package, CHECKSUM_FILE, DATA_FILE},
    task::{Mode, Status},
    utils::{Progress, ProgressReader},
};

const HEAP_BUFFER_SIZE: usize = 1 << 22;

/// Options required by the decrypt workflow
pub struct DecryptOpts<T, F> {
    /// Input file for decryption.
    pub package: Package,
    /// Private OpenPGP key store (used for decrypting data).
    pub key_store: crate::openpgp::keystore::KeyStore,
    /// Public OpenPGP certificate store (used for verifying signatures).
    pub cert_store: crate::openpgp::certstore::CertStore<'static>,
    /// Password for decrypting recipients' keys.
    pub password: F,
    /// Output path for the decrypted data.
    pub output: Option<PathBuf>,
    /// Decrypt data without unpacking it.
    pub decrypt_only: bool,
    /// Run the workflow or only perform a check.
    pub mode: Mode,
    /// Report decryption progress using this callback.
    pub progress: Option<T>,
}

impl<T, F> std::fmt::Debug for DecryptOpts<T, F> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("DecryptOpts")
            .field("package", &self.package)
            .field("output", &self.output.as_ref().map(|p| p.display()))
            .field("decrypt_only", &self.decrypt_only)
            .field("mode", &self.mode)
            .finish()
    }
}

/// Decrypts and (optionally) decompresses a data package.
///
/// While decrypting/decompressing signatures and checksums are verified.
#[instrument(err(Debug, level=tracing::Level::ERROR))]
pub async fn decrypt<T, F>(mut opts: DecryptOpts<T, F>) -> anyhow::Result<Status>
where
    T: Progress + Send + 'static,
    F: Fn(crate::openpgp::types::PasswordHint) -> crate::openpgp::types::Password + Send + 'static,
{
    if let Some(path) = opts.output {
        opts.output = Some(path.canonicalize()?);
    }
    let output = get_output_path(opts.output, opts.package.name())?;
    let policy = StandardPolicy::new();
    let package = opts.package.verify(&opts.cert_store).await?;
    let metadata = package.metadata().await?;
    let (data_reader, data_size) = package.data().await?;
    let mut data_reader = tokio_util::io::SyncIoBridge::new(data_reader);
    let status = tokio::task::spawn_blocking(move || -> anyhow::Result<_> {
        let mut decryptor = DecryptorBuilder::from_reader(&mut data_reader)?.with_policy(
            &policy,
            None,
            crate::openpgp::crypto::DecryptionHelper {
                cert_store: &opts.cert_store,
                key_store: &mut opts.key_store,
                password: opts.password,
            },
        )?;
        let status = if let Mode::Check = opts.mode {
            Status::Checked {
                destination: output.to_string_lossy().to_string(),
                source_size: data_size,
            }
        } else {
            std::fs::create_dir_all(&output)?;
            if let Some(mut pg) = opts.progress {
                pg.set_length(data_size);
                let mut p = ProgressReader::new(decryptor, |len| {
                    pg.inc(len.try_into()?);
                    Ok(())
                });
                if opts.decrypt_only {
                    write_to_file(&mut p, &output)?;
                } else {
                    unpack(&mut p, &output, metadata.compression_algorithm)?;
                }
                pg.finish();
            } else if opts.decrypt_only {
                write_to_file(&mut decryptor, &output)?;
            } else {
                unpack(&mut decryptor, &output, metadata.compression_algorithm)?;
            }

            let output_files = WalkDir::new(&output)
                .into_iter()
                .flatten()
                .filter(|entry| entry.file_type().is_file())
                .map(|entry| entry.into_path());

            Status::Completed {
                source_size: data_size,
                destination_size: get_combined_file_size(output_files)?,
                destination: output.to_string_lossy().to_string(),
                metadata,
            }
        };
        Ok(status)
    })
    .await??;
    match &status {
        Status::Checked {
            destination,
            source_size,
        } => {
            debug!(destination, source_size, "Checked decryption task input");
        }
        Status::Completed {
            destination,
            source_size,
            destination_size,
            metadata,
        } => {
            info!(
                destination,
                source_size,
                destination_size,
                metadata = metadata.to_json_or_debug(),
                "Successfully decrypted data package"
            )
        }
    }
    Ok(status)
}

/// Returns output path based on the provided or default path and the package name.
fn get_output_path(output: Option<PathBuf>, pkg_file_name: &str) -> anyhow::Result<PathBuf> {
    let base = if let Some(p) = output {
        p
    } else {
        std::env::current_dir()?
    };
    let pkg_base_name = pkg_file_name
        .split('.')
        .next()
        .context("Package file has no extension")?;
    let mut output = base.join(pkg_base_name);
    let mut i = 1;
    while output.exists() {
        output = base.join(format!("{pkg_base_name}_{i}"));
        i += 1;
    }
    Ok(output)
}

/// Decompresses source while writing to destination.
#[instrument(skip(source))]
fn unpack<R: io::Read + Send>(
    source: &mut R,
    output: &Path,
    compression_algorithm: CompressionAlgorithm,
) -> anyhow::Result<()> {
    match compression_algorithm {
        CompressionAlgorithm::Stored => unpack_tar(&mut tar::Archive::new(source), output),
        CompressionAlgorithm::Gzip(_) => {
            unpack_tar(&mut tar::Archive::new(GzDecoder::new(source)), output)
        }
        CompressionAlgorithm::Zstandard(_) => unpack_tar(
            &mut tar::Archive::new(zstd::stream::read::Decoder::new(source)?),
            output,
        ),
    }?;
    Ok(())
}

/// Returns the destination path for a file extracted from a tar archive.
///
/// It sanitizes the file path to prevent tar bombs and resolves symbolic links.
///
/// Note: the sanitization implementation is taken from the `tar` crate.
fn sanitize_path(dest: &Path, path: &Path) -> anyhow::Result<PathBuf> {
    use std::path::Component;
    let mut sanitized = PathBuf::new();

    for part in path.components() {
        match part {
            // Leading '/' characters, root paths, and '.'
            // components are just ignored and treated as "empty
            // components"
            Component::Prefix(_) | Component::RootDir | Component::CurDir => continue,

            // If any part of the filename is '..', then skip over
            // unpacking the file to prevent directory traversal
            // security issues.  See, e.g.: CVE-2001-1267,
            // CVE-2002-0399, CVE-2005-1918, CVE-2007-4131
            Component::ParentDir => anyhow::bail!("file path contains a relative part"),

            Component::Normal(part) => sanitized.push(part),
        }
    }
    anyhow::ensure!(sanitized.parent().is_some(), "empty file path");
    Ok(dest.join(&sanitized))
}

enum Message {
    Init(PathBuf),
    Payload(bytes::Bytes),
    Finalize,
}

fn unpack_tar(archive: &mut tar::Archive<impl io::Read>, dest: &Path) -> anyhow::Result<()> {
    let (tx_checksum, rx_checksum) = std::sync::mpsc::sync_channel(8);
    let (tx_write, rx_write) = std::sync::mpsc::sync_channel(8);

    let checksum_handle = std::thread::spawn(move || -> anyhow::Result<_> {
        use sequoia_openpgp::{crypto::hash::Digest as _, types::HashAlgorithm::SHA256};
        let mut hasher = SHA256.context()?;
        let mut path = None;
        let mut checksums = BTreeMap::new();
        while let Ok(message) = rx_checksum.recv() {
            match message {
                Message::Init(p) => {
                    path = Some(p);
                }
                Message::Payload(buf) => hasher.update(&buf),
                Message::Finalize => {
                    checksums.insert(
                        std::mem::take(&mut path).expect("path is initialized"),
                        crate::utils::to_hex_string(
                            &std::mem::replace(&mut hasher, SHA256.context()?).into_digest()?,
                        ),
                    );
                }
            }
        }
        Ok(checksums)
    });

    let write_handle = std::thread::spawn(move || -> io::Result<()> {
        let mut writer = None;
        while let Ok(message) = rx_write.recv() {
            match message {
                Message::Init(p) => {
                    if let Some(parent) = p.parent() {
                        if !parent.exists() {
                            std::fs::create_dir_all(parent)?;
                        }
                    }
                    writer = Some(io::BufWriter::with_capacity(
                        HEAP_BUFFER_SIZE,
                        std::fs::File::create(&p)?,
                    ));
                }
                Message::Payload(buf) => writer
                    .as_mut()
                    .expect("writer is initialized")
                    .write_all(&buf)?,
                Message::Finalize => {
                    writer = None;
                }
            }
        }
        Ok(())
    });

    for entry in archive.entries()? {
        let mut entry = entry?;
        let archive_path = entry.path()?.into_owned();
        let output_path = match sanitize_path(dest, &archive_path) {
            Ok(p) => p,
            Err(e) => {
                tracing::warn!("{:?}: {}", archive_path, e);
                continue;
            }
        };
        tx_checksum.send(Message::Init(archive_path))?;
        tx_write.send(Message::Init(output_path))?;
        copy_to_channels(&mut entry, [&tx_checksum, &tx_write])?;
        tx_checksum.send(Message::Finalize)?;
        tx_write.send(Message::Finalize)?;
    }
    drop(tx_checksum);
    drop(tx_write);

    anyhow::ensure!(
        write_handle.join().is_ok(),
        "write thread in decrypt panicked"
    );
    let mut checksums = checksum_handle
        .join()
        .map_err(|_| anyhow::anyhow!("checksum thread in decrypt panicked"))??;
    checksums.remove(Path::new(CHECKSUM_FILE));
    verify_checksums(&checksums, &read_checksum_file(dest.join(CHECKSUM_FILE))?)?;
    Ok(())
}

fn copy_to_channels<const N: usize>(
    reader: &mut impl io::Read,
    tx: [&std::sync::mpsc::SyncSender<Message>; N],
) -> anyhow::Result<()> {
    let mut buf = [0; 8192];
    let mut bigbuf = bytes::BytesMut::with_capacity(HEAP_BUFFER_SIZE);
    macro_rules! exchange {
        ($buffer:expr) => {{
            let b = std::mem::replace(&mut bigbuf, $buffer).freeze();
            for tx in tx {
                tx.send(Message::Payload(b.clone()))?;
            }
        }};
    }
    loop {
        let n = reader.read(&mut buf)?;
        if n == 0 {
            if !bigbuf.is_empty() {
                exchange!(bytes::BytesMut::new());
            }
            break;
        }
        if bigbuf.len() + n > bigbuf.capacity() {
            exchange!(bytes::BytesMut::with_capacity(HEAP_BUFFER_SIZE));
        }
        bigbuf.extend_from_slice(&buf[..n]);
    }
    Ok(())
}

/// Writes source to a file.
fn write_to_file<R: io::Read, P: AsRef<Path>>(source: &mut R, output: P) -> anyhow::Result<()> {
    let mut f = std::fs::File::create(output.as_ref().join(DATA_FILE))?;
    io::copy(source, &mut f)?;
    Ok(())
}

/// Returns the content of the data package checksum file.
///
/// The checksum file has the following structure:
///
/// ```text
/// <checksum1> <file 1 path inside the data package>
/// <checksum2> <file 2 path inside the data package>
/// ...
/// ```
fn read_checksum_file(path: impl AsRef<Path>) -> anyhow::Result<BTreeMap<PathBuf, String>> {
    use std::io::BufRead as _;

    let mut reader = io::BufReader::new(std::fs::File::open(path)?);
    let mut parsed = BTreeMap::new();
    let mut buf = String::new();
    while reader.read_line(&mut buf)? > 0 {
        let (checksum, path) = buf
            .trim()
            .split_once(char::is_whitespace)
            .context("Unable to parse the checksum file")?;
        parsed.insert(PathBuf::from(path), checksum.to_string());
        buf.clear();
    }
    Ok(parsed)
}

fn verify_checksums(
    source: &BTreeMap<PathBuf, String>,
    reference: &BTreeMap<PathBuf, String>,
) -> anyhow::Result<()> {
    for (path, checksum) in source {
        let expected = reference
            .get(path)
            .with_context(|| format!("unable to find checksum for file: {path:?}"))?;
        if !checksum.eq_ignore_ascii_case(expected) {
            anyhow::bail!("wrong checksum for {path:?} (expected {expected}, computed {checksum})");
        }
    }
    Ok(())
}