use async_lock::Mutex;
use futures_core::{future::BoxFuture, stream};
use futures_util::{
future::FutureExt,
stream::{unfold, StreamExt},
};
use once_cell::sync::OnceCell;
use slotmap::{new_key_type, SlotMap};
use std::{
borrow::Cow,
convert::{TryFrom, TryInto},
future::ready,
pin::Pin,
task::{Context, Poll},
};
use async_io::block_on;
use zvariant::{ObjectPath, OwnedValue, Value};
use crate::{azync::Connection, Error, Message, Result};
use crate::fdo::{self, AsyncIntrospectableProxy, AsyncPropertiesProxy};
type SignalHandler = Box<dyn for<'msg> FnMut(&'msg Message) -> BoxFuture<'msg, Result<()>> + Send>;
new_key_type! {
pub struct SignalHandlerId;
}
struct SignalHandlerInfo {
signal_name: &'static str,
handler: SignalHandler,
}
#[derive(derivative::Derivative)]
#[derivative(Debug)]
pub struct Proxy<'a> {
core: ProxyCore<'a>,
#[derivative(Debug = "ignore")]
sig_handlers: Mutex<SlotMap<SignalHandlerId, SignalHandlerInfo>>,
}
#[derive(Clone, Debug)]
struct ProxyCore<'a> {
conn: Connection,
destination: Cow<'a, str>,
path: ObjectPath<'a>,
interface: Cow<'a, str>,
dest_unique_name: OnceCell<String>,
}
impl<'a> Proxy<'a> {
pub fn new<E>(
conn: &Connection,
destination: &'a str,
path: impl TryInto<ObjectPath<'a>, Error = E>,
interface: &'a str,
) -> Result<Self>
where
Error: From<E>,
{
Ok(Self {
core: ProxyCore {
conn: conn.clone(),
destination: Cow::from(destination),
path: path.try_into()?,
interface: Cow::from(interface),
dest_unique_name: OnceCell::new(),
},
sig_handlers: Mutex::new(SlotMap::with_key()),
})
}
pub fn new_owned<E>(
conn: Connection,
destination: String,
path: impl TryInto<ObjectPath<'static>, Error = E>,
interface: String,
) -> Result<Self>
where
Error: From<E>,
{
Ok(Self {
core: ProxyCore {
conn,
destination: Cow::from(destination),
path: path.try_into()?,
interface: Cow::from(interface),
dest_unique_name: OnceCell::new(),
},
sig_handlers: Mutex::new(SlotMap::with_key()),
})
}
pub fn connection(&self) -> &Connection {
&self.core.conn
}
pub fn destination(&self) -> &str {
&self.core.destination
}
pub fn path(&self) -> &ObjectPath<'_> {
&self.core.path
}
pub fn interface(&self) -> &str {
&self.core.interface
}
pub async fn introspect(&self) -> fdo::Result<String> {
AsyncIntrospectableProxy::new_for(&self.core.conn, &self.core.destination, &self.core.path)?
.introspect()
.await
}
pub async fn get_property<T>(&self, property_name: &str) -> fdo::Result<T>
where
T: TryFrom<OwnedValue>,
{
AsyncPropertiesProxy::new_for(&self.core.conn, &self.core.destination, &self.core.path)?
.get(&self.core.interface, property_name)
.await?
.try_into()
.map_err(|_| Error::InvalidReply.into())
}
pub async fn set_property<'t, T: 't>(&self, property_name: &str, value: T) -> fdo::Result<()>
where
T: Into<Value<'t>>,
{
AsyncPropertiesProxy::new_for(&self.core.conn, &self.core.destination, &self.core.path)?
.set(&self.core.interface, property_name, &value.into())
.await
}
pub async fn call_method<B>(&self, method_name: &str, body: &B) -> Result<Message>
where
B: serde::ser::Serialize + zvariant::Type,
{
let reply = self
.core
.conn
.call_method(
Some(&self.core.destination),
self.core.path.as_str(),
Some(&self.core.interface),
method_name,
body,
)
.await;
match reply {
Ok(mut reply) => {
reply.disown_fds();
Ok(reply)
}
Err(e) => Err(e),
}
}
pub async fn call<B, R>(&self, method_name: &str, body: &B) -> Result<R>
where
B: serde::ser::Serialize + zvariant::Type,
R: serde::de::DeserializeOwned + zvariant::Type,
{
Ok(self.call_method(method_name, body).await?.body()?)
}
pub async fn receive_signal(&self, signal_name: &'static str) -> Result<SignalStream<'a>> {
let subscription_id = if self.core.conn.is_bus() {
let id = self
.core
.conn
.subscribe_signal(
self.destination(),
self.path().clone(),
self.interface(),
signal_name,
)
.await?;
Some(id)
} else {
None
};
self.resolve_name().await?;
let proxy = self.core.clone();
let stream = unfold((proxy, signal_name), |(proxy, signal_name)| async move {
proxy
.conn
.receive_specific(|msg| {
let hdr = match msg.header() {
Ok(hdr) => hdr,
Err(_) => return ready(Ok(false)).boxed(),
};
let expected_sender = proxy.dest_unique_name.get().map(|s| s.as_str());
ready(Ok(hdr.primary().msg_type() == crate::MessageType::Signal
&& hdr.interface() == Ok(Some(&proxy.interface))
&& hdr.sender() == Ok(expected_sender)
&& hdr.path() == Ok(Some(&proxy.path))
&& hdr.member() == Ok(Some(signal_name))))
.boxed()
})
.await
.ok()
.map(|msg| (msg, (proxy, signal_name)))
});
Ok(SignalStream {
stream: stream.boxed(),
conn: self.core.conn.clone(),
subscription_id,
})
}
pub async fn connect_signal<H>(
&self,
signal_name: &'static str,
handler: H,
) -> fdo::Result<SignalHandlerId>
where
for<'msg> H: FnMut(&'msg Message) -> BoxFuture<'msg, Result<()>> + Send + 'static,
{
let id = self.sig_handlers.lock().await.insert(SignalHandlerInfo {
signal_name,
handler: Box::new(handler),
});
if self.core.conn.is_bus() {
let _ = self
.core
.conn
.subscribe_signal(
self.destination(),
self.path().clone(),
self.interface(),
signal_name,
)
.await?;
}
Ok(id)
}
pub async fn disconnect_signal(&self, handler_id: SignalHandlerId) -> fdo::Result<bool> {
match self.sig_handlers.lock().await.remove(handler_id) {
Some(handler_info) => {
if self.core.conn.is_bus() {
let _ = self
.core
.conn
.unsubscribe_signal(
self.destination(),
self.path().clone(),
self.interface(),
handler_info.signal_name,
)
.await?;
}
Ok(true)
}
None => Ok(false),
}
}
pub async fn next_signal(&self) -> Result<Option<Message>> {
let msg = {
let handlers = self.sig_handlers.lock().await;
let signals: Vec<&str> = handlers.values().map(|info| info.signal_name).collect();
self.resolve_name().await?;
self.core
.conn
.receive_specific(move |msg| {
let ret = match msg.header() {
Err(_) => false,
Ok(hdr) => match hdr.member() {
Ok(None) | Err(_) => false,
Ok(Some(member)) => {
let expected_sender = self.destination_unique_name();
hdr.interface() == Ok(Some(self.interface()))
&& hdr.path() == Ok(Some(self.path()))
&& hdr.sender() == Ok(expected_sender)
&& hdr.message_type() == Ok(crate::MessageType::Signal)
&& signals.contains(&member)
}
},
};
ready(Ok(ret)).boxed()
})
.await?
};
if self.handle_signal(&msg).await? {
Ok(None)
} else {
Ok(Some(msg))
}
}
pub async fn handle_signal(&self, msg: &Message) -> Result<bool> {
let mut handlers = self.sig_handlers.lock().await;
if handlers.is_empty() {
return Ok(false);
}
let hdr = msg.header()?;
if let Some(name) = hdr.member()? {
let mut handled = false;
for info in handlers
.values_mut()
.filter(|info| info.signal_name == name)
{
(*info.handler)(&msg).await?;
if !handled {
handled = true;
}
}
return Ok(handled);
}
Ok(false)
}
pub(crate) async fn has_signal_handler(&self, signal_name: &str) -> bool {
self.sig_handlers
.lock()
.await
.values()
.any(|info| info.signal_name == signal_name)
}
pub(crate) async fn resolve_name(&self) -> Result<()> {
if self.core.dest_unique_name.get().is_some() {
return Ok(());
}
let destination = &self.core.destination;
let unique_name = if destination.starts_with(':') || destination == "org.freedesktop.DBus" {
destination.to_string()
} else {
fdo::AsyncDBusProxy::new(&self.core.conn)?
.get_name_owner(destination)
.await?
};
self.core
.dest_unique_name
.set(unique_name)
.expect("Attempted to set dest_unique_name twice");
Ok(())
}
pub(crate) fn destination_unique_name(&self) -> Option<&str> {
self.core.dest_unique_name.get().map(|s| s.as_str())
}
}
#[derive(derivative::Derivative)]
#[derivative(Debug)]
pub struct SignalStream<'s> {
#[derivative(Debug = "ignore")]
stream: stream::BoxStream<'s, Message>,
conn: Connection,
subscription_id: Option<u64>,
}
impl SignalStream<'_> {
pub async fn close(mut self) -> Result<()> {
self.close_().await
}
async fn close_(&mut self) -> Result<()> {
if let Some(id) = self.subscription_id.take() {
let _ = self.conn.unsubscribe_signal_by_id(id).await?;
}
Ok(())
}
}
impl stream::Stream for SignalStream<'_> {
type Item = Message;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
stream::Stream::poll_next(self.get_mut().stream.as_mut(), cx)
}
}
impl std::ops::Drop for SignalStream<'_> {
fn drop(&mut self) {
if self.subscription_id.is_some() {
let _ = block_on(self.close_());
}
}
}
impl<'azync, 'sync: 'azync> From<crate::Proxy<'sync>> for Proxy<'azync> {
fn from(proxy: crate::Proxy<'sync>) -> Self {
proxy.into_inner()
}
}
#[cfg(test)]
mod tests {
use super::*;
use enumflags2::BitFlags;
use futures_util::future::FutureExt;
use std::{future::ready, sync::Arc};
#[test]
fn signal_stream() {
block_on(test_signal_stream()).unwrap();
}
async fn test_signal_stream() -> Result<()> {
let conn = Connection::new_session().await?;
let unique_name = conn.unique_name().unwrap();
let proxy = Proxy::new(
&conn,
"org.freedesktop.DBus",
"/org/freedesktop/DBus",
"org.freedesktop.DBus",
)
.unwrap();
let well_known = "org.freedesktop.zbus.async.ProxySignalStreamTest";
let owner_changed_stream = proxy
.receive_signal("NameOwnerChanged")
.await?
.filter(|msg| {
if let Ok((name, _, new_owner)) = msg.body::<(&str, &str, &str)>() {
return ready(new_owner == unique_name && name == well_known);
}
ready(false)
});
let name_acquired_stream = proxy.receive_signal("NameAcquired").await?.filter(|msg| {
if let Ok(name) = msg.body::<&str>() {
return ready(name == well_known);
}
ready(false)
});
let reply = proxy
.call_method(
"RequestName",
&(
well_known,
BitFlags::from(fdo::RequestNameFlags::ReplaceExisting),
),
)
.await
.unwrap();
let reply: fdo::RequestNameReply = reply.body().unwrap();
assert_eq!(reply, fdo::RequestNameReply::PrimaryOwner);
let (changed_signal, acquired_signal) = futures_util::join!(
owner_changed_stream.into_future(),
name_acquired_stream.into_future()
);
let changed_signal = changed_signal.0.unwrap();
let (acquired_name, _, new_owner) = changed_signal.body::<(&str, &str, &str)>().unwrap();
assert_eq!(acquired_name, well_known);
assert_eq!(new_owner, unique_name);
let acquired_signal = acquired_signal.0.unwrap();
assert_eq!(acquired_signal.body::<&str>().unwrap(), well_known);
Ok(())
}
#[test]
fn signal_connect() {
block_on(test_signal_connect()).unwrap();
}
async fn test_signal_connect() -> Result<()> {
let conn = Connection::new_session().await?;
let owner_change_signaled = Arc::new(Mutex::new(false));
let name_acquired_signaled = Arc::new(Mutex::new(false));
let name_acquired_signaled2 = Arc::new(Mutex::new(false));
let proxy = Proxy::new(
&conn,
"org.freedesktop.DBus",
"/org/freedesktop/DBus",
"org.freedesktop.DBus",
)?;
let well_known = "org.freedesktop.zbus.async.ProxySignalConnectTest";
let unique_name = conn.unique_name().unwrap().to_string();
let name_owner_changed_id = {
let well_known = well_known.clone();
let signaled = owner_change_signaled.clone();
proxy
.connect_signal("NameOwnerChanged", move |m| {
let signaled = signaled.clone();
let unique_name = unique_name.clone();
async move {
let (name, _, new_owner) = m.body::<(&str, &str, &str)>()?;
if name != well_known {
return Ok(());
}
assert_eq!(new_owner, unique_name);
*signaled.lock().await = true;
Ok(())
}
.boxed()
})
.await?
};
let name_acquired_id = {
let signaled = name_acquired_signaled.clone();
proxy
.connect_signal("NameAcquired", move |m| {
let signaled = signaled.clone();
async move {
if m.body::<&str>()? == well_known {
*signaled.lock().await = true;
}
Ok(())
}
.boxed()
})
.await?
};
let name_acquired_id2 = {
let signaled = name_acquired_signaled2.clone();
proxy
.connect_signal("NameAcquired", move |m| {
let signaled = signaled.clone();
async move {
if m.body::<&str>()? == well_known {
*signaled.lock().await = true;
}
Ok(())
}
.boxed()
})
.await?
};
fdo::DBusProxy::new(&crate::Connection::from(conn))
.unwrap()
.request_name(&well_known, fdo::RequestNameFlags::ReplaceExisting.into())
.unwrap();
loop {
proxy.next_signal().await?;
if *owner_change_signaled.lock().await
&& *name_acquired_signaled.lock().await
&& *name_acquired_signaled2.lock().await
{
break;
}
}
assert_eq!(proxy.disconnect_signal(name_owner_changed_id).await?, true);
assert_eq!(proxy.disconnect_signal(name_owner_changed_id).await?, false);
assert_eq!(proxy.disconnect_signal(name_acquired_id).await?, true);
assert_eq!(proxy.disconnect_signal(name_acquired_id2).await?, true);
Ok(())
}
}