[go: up one dir, main page]

com 0.2.0

Utilities for implementing COM Client and Servers
Documentation
//! Helpers for registering COM servers

use crate::interfaces::IUnknown;
use crate::sys::{
    GetModuleFileNameA, GetModuleHandleA, RegCloseKey, RegCreateKeyExA, RegDeleteKeyA,
    RegSetValueExA, CLSID, ERROR_SUCCESS, FAILED, GUID, HKEY, HRESULT, IID, LSTATUS,
    SELFREG_E_CLASS, S_OK,
};

use std::convert::TryInto;
use std::ffi::c_void;
use std::ffi::CString;

#[doc(hidden)]
pub struct RegistryKeyInfo {
    key_path: CString,
    key_value_name: CString,
    key_value_data: CString,
}

#[doc(hidden)]
impl RegistryKeyInfo {
    pub fn new(key_path: &str, key_value_name: &str, key_value_data: &str) -> RegistryKeyInfo {
        RegistryKeyInfo {
            key_path: CString::new(key_path).unwrap(),
            key_value_name: CString::new(key_value_name).unwrap(),
            key_value_data: CString::new(key_value_data).unwrap(),
        }
    }
}

#[doc(hidden)]
pub fn register_keys(registry_keys_to_add: &Vec<RegistryKeyInfo>) -> HRESULT {
    for key_info in registry_keys_to_add.iter() {
        let result = add_class_key(&key_info);
        if result as u32 != ERROR_SUCCESS {
            return SELFREG_E_CLASS;
        }
    }

    S_OK
}

#[doc(hidden)]
pub fn unregister_keys(registry_keys_to_remove: &Vec<RegistryKeyInfo>) -> HRESULT {
    let mut hr = S_OK;
    for key_info in registry_keys_to_remove.iter() {
        let result = remove_class_key(&key_info);
        if result as u32 != ERROR_SUCCESS {
            hr = SELFREG_E_CLASS;
        }
    }

    hr
}

const HKEY_CLASSES_ROOT: HKEY = 0x8000_0000 as HKEY;
const KEY_ALL_ACCESS: u32 = 0x000F_003F;
const REG_OPTION_NON_VOLATILE: u32 = 0x00000000;
fn create_class_key(key_info: &RegistryKeyInfo) -> Result<HKEY, LSTATUS> {
    let mut hk_result = std::ptr::null_mut::<c_void>();
    let lp_class = std::ptr::null_mut::<u8>();
    let lp_security_attributes = std::ptr::null_mut::<c_void>();
    let lpdw_disposition = std::ptr::null_mut::<u32>();
    let result = unsafe {
        RegCreateKeyExA(
            HKEY_CLASSES_ROOT,
            key_info.key_path.as_ptr(),
            0,
            lp_class,
            REG_OPTION_NON_VOLATILE,
            KEY_ALL_ACCESS,
            lp_security_attributes,
            &mut hk_result as *mut _,
            lpdw_disposition,
        )
    };
    if result as u32 != ERROR_SUCCESS {
        return Err(result);
    }

    Ok(hk_result)
}

const REG_SZ: u32 = 1;
fn set_class_key(key_handle: HKEY, key_info: &RegistryKeyInfo) -> Result<HKEY, LSTATUS> {
    let result = unsafe {
        RegSetValueExA(
            key_handle,
            key_info.key_value_name.as_ptr(),
            0,
            REG_SZ,
            key_info.key_value_data.as_ptr() as *const u8,
            key_info
                .key_value_data
                .to_bytes_with_nul()
                .len()
                .try_into()
                .unwrap(),
        )
    };
    if result as u32 != ERROR_SUCCESS {
        return Err(result);
    }

    Ok(key_handle)
}

fn add_class_key(key_info: &RegistryKeyInfo) -> LSTATUS {
    let key_handle = match create_class_key(key_info) {
        Ok(key_handle) => key_handle,
        Err(e) => {
            return e;
        }
    };

    let key_handle = match set_class_key(key_handle, key_info) {
        Ok(key_handle) => key_handle,
        Err(e) => {
            return e;
        }
    };

    unsafe { RegCloseKey(key_handle) }
}

fn remove_class_key(key_info: &RegistryKeyInfo) -> LSTATUS {
    unsafe { RegDeleteKeyA(HKEY_CLASSES_ROOT, key_info.key_path.as_ptr()) }
}

#[doc(hidden)]
pub fn get_dll_file_path() -> String {
    unsafe {
        const MAX_FILE_PATH_LENGTH: usize = 260;
        let h_module = GetModuleHandleA(CString::new("server.dll").unwrap().as_ptr());
        let raw_ptr = CString::new(Vec::with_capacity(MAX_FILE_PATH_LENGTH))
            .expect("Failed to create empty string!")
            .into_raw();

        GetModuleFileNameA(h_module, raw_ptr, MAX_FILE_PATH_LENGTH.try_into().unwrap());

        CString::from_raw(raw_ptr).into_string().unwrap()
    }
}

#[doc(hidden)]
pub fn class_key_path(clsid: CLSID) -> String {
    format!("CLSID\\{}", guid_to_string(&clsid))
}

#[doc(hidden)]
pub fn class_inproc_key_path(clsid: CLSID) -> String {
    format!("CLSID\\{}\\InprocServer32", guid_to_string(&clsid))
}

fn guid_to_string(guid: &GUID) -> String {
    format!(
        "{{{:04X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
        guid.data1,
        guid.data2,
        guid.data3,
        guid.data4[0],
        guid.data4[1],
        guid.data4[2],
        guid.data4[3],
        guid.data4[4],
        guid.data4[5],
        guid.data4[6],
        guid.data4[7],
    )
}

#[doc(hidden)]
#[inline]
pub fn initialize_class_object<T: IUnknown>(
    instance: Box<T>,
    riid: *const IID,
    result: *mut *mut c_void,
) -> HRESULT {
    let hr = unsafe {
        instance.add_ref();
        let hr = instance.query_interface(riid, result);
        instance.release();
        hr
    };
    Box::into_raw(instance);

    hr
}

/// Register the supplied keys with the registry
#[doc(hidden)]
#[inline]
pub fn dll_register_server(relevant_keys: &mut Vec<RegistryKeyInfo>) -> HRESULT {
    let hr = register_keys(relevant_keys);
    if FAILED(hr) {
        dll_unregister_server(relevant_keys);
    }

    hr
}

/// Unregister the supplied keys with the registry
#[doc(hidden)]
#[inline]
pub fn dll_unregister_server(relevant_keys: &mut Vec<RegistryKeyInfo>) -> HRESULT {
    relevant_keys.reverse();
    unregister_keys(relevant_keys)
}

/// A macro for declaring a COM server to the COM runtime
///
/// This implements the `DllGetClassObject`, `DllRegisterServer`, and `DllUnregisterServer`
/// functions on behalf of the user.
#[macro_export]
macro_rules! inproc_dll_module {
    (($class_id_one:ident, $class_type_one:ty), $(($class_id:ident, $class_type:ty)),*) => {
        #[no_mangle]
        extern "stdcall" fn DllGetClassObject(class_id: *const com::sys::CLSID, iid: *const com::sys::IID, result: *mut *mut std::ffi::c_void) -> com::sys::HRESULT {
            use com::interfaces::IUnknown;
            use com::registration::initialize_class_object;
            assert!(!class_id.is_null(), "class id passed to DllGetClassObject should never be null");

            let class_id = unsafe { &*class_id };
            if class_id == &$class_id_one {
                let mut instance = <$class_type_one>::get_class_object();
                initialize_class_object(instance, iid, result)
            } $(else if class_id == &$class_id {
                let mut instance = <$class_type>::get_class_object();
                initialize_class_object(instance, iid, result)
            })* else {
                com::sys::CLASS_E_CLASSNOTAVAILABLE
            }
        }

        #[no_mangle]
        extern "stdcall" fn DllRegisterServer() -> com::sys::HRESULT {
            com::registration::dll_register_server(&mut get_relevant_registry_keys())
        }

        #[no_mangle]
        extern "stdcall" fn DllUnregisterServer() -> com::sys::HRESULT {
            com::registration::dll_unregister_server(&mut get_relevant_registry_keys())
        }

        fn get_relevant_registry_keys() -> Vec<com::registration::RegistryKeyInfo> {
            use com::registration::RegistryKeyInfo;
            let file_path = com::registration::get_dll_file_path();
            vec![
                RegistryKeyInfo::new(
                    &com::registration::class_key_path($class_id_one),
                    "",
                    stringify!($class_type_one),
                ),
                RegistryKeyInfo::new(
                    &com::registration::class_inproc_key_path($class_id_one),
                    "",
                    &file_path,
                ),
                $(RegistryKeyInfo::new(
                    &com::registration::class_key_path($class_id),
                    "",
                    stringify!($class_type),
                ),
                RegistryKeyInfo::new(
                    &com::registration::class_inproc_key_path($class_id),
                    "",
                    &file_path,
                )),*
            ]
        }
    };
}