use async_io::Async;
use async_lock::{Mutex, MutexGuard, RwLock};
use once_cell::sync::OnceCell;
use std::{
io::{self, ErrorKind},
os::unix::{
io::{AsRawFd, RawFd},
net::UnixStream,
},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures_core::stream;
use futures_util::{sink::SinkExt, stream::TryStreamExt};
use crate::{
azync::Authenticated,
raw::{Connection as RawConnection, Socket},
Error, Guid, Message, MessageType, Result, DEFAULT_MAX_QUEUED,
};
#[derive(Debug)]
struct ConnectionInner<S> {
server_guid: Guid,
cap_unix_fd: bool,
bus_conn: bool,
unique_name: OnceCell<String>,
raw_in_conn: Mutex<RawConnection<Async<S>>>,
raw_out_conn: Mutex<RawConnection<Async<S>>>,
serial: Mutex<u32>,
incoming_queue: Mutex<Vec<Message>>,
max_queued: RwLock<usize>,
}
#[derive(Clone, Debug)]
pub struct Connection(Arc<ConnectionInner<Box<dyn Socket>>>);
impl Connection {
pub async fn new_unix_client(stream: UnixStream, bus_connection: bool) -> Result<Self> {
let auth = Authenticated::client(Async::new(Box::new(stream) as Box<dyn Socket>)?).await?;
Self::new(auth, bus_connection).await
}
pub async fn new_unix_server(stream: UnixStream, guid: &Guid) -> Result<Self> {
use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
let creds = getsockopt(stream.as_raw_fd(), PeerCredentials)
.map_err(|e| Error::Handshake(format!("Failed to get peer credentials: {}", e)))?;
let auth = Authenticated::server(
Async::new(Box::new(stream) as Box<dyn Socket>)?,
guid.clone(),
creds.uid(),
)
.await?;
Self::new(auth, false).await
}
pub async fn stream(&self) -> Stream<'_> {
let raw_conn = self.0.raw_in_conn.lock().await;
let incoming_queue = Some(self.0.incoming_queue.lock().await);
Stream {
raw_conn,
incoming_queue,
}
}
pub async fn sink(&self) -> Sink<'_> {
Sink {
raw_conn: self.0.raw_out_conn.lock().await,
cap_unix_fd: self.0.cap_unix_fd,
}
}
pub async fn receive_specific<P>(&self, predicate: P) -> Result<Message>
where
P: Fn(&Message) -> Result<bool>,
{
loop {
let mut queue = self.0.incoming_queue.lock().await;
for (i, msg) in queue.iter().enumerate() {
if predicate(msg)? {
return Ok(queue.remove(i));
}
}
let mut stream = Stream {
raw_conn: self.0.raw_in_conn.lock().await,
incoming_queue: None,
};
let msg = match stream.try_next().await? {
Some(msg) => msg,
None => {
return Err(Error::Io(io::Error::new(
ErrorKind::BrokenPipe,
"socket closed",
)));
}
};
if predicate(&msg)? {
return Ok(msg);
} else if queue.len() < *self.0.max_queued.read().await {
queue.push(msg);
}
}
}
pub async fn send_message(&self, mut msg: Message) -> Result<u32> {
let serial = self.assign_serial_num(&mut msg).await?;
self.sink().await.send(msg).await?;
Ok(serial)
}
pub async fn call_method<B>(
&self,
destination: Option<&str>,
path: &str,
iface: Option<&str>,
method_name: &str,
body: &B,
) -> Result<Message>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::method(
self.unique_name(),
destination,
path,
iface,
method_name,
body,
)?;
let serial = self.send_message(m).await?;
loop {
match self
.receive_specific(|m| {
let h = m.header()?;
Ok(h.reply_serial()? == Some(serial))
})
.await
{
Ok(m) => match m.header()?.message_type()? {
MessageType::Error => return Err(m.into()),
MessageType::MethodReturn => return Ok(m),
_ => continue,
},
Err(e) => return Err(e),
};
}
}
pub async fn emit_signal<B>(
&self,
destination: Option<&str>,
path: &str,
iface: &str,
signal_name: &str,
body: &B,
) -> Result<()>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::signal(
self.unique_name(),
destination,
path,
iface,
signal_name,
body,
)?;
self.send_message(m).await.map(|_| ())
}
pub async fn reply<B>(&self, call: &Message, body: &B) -> Result<u32>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::method_reply(self.unique_name(), call, body)?;
self.send_message(m).await
}
pub async fn reply_error<B>(&self, call: &Message, error_name: &str, body: &B) -> Result<u32>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::method_error(self.unique_name(), call, error_name, body)?;
self.send_message(m).await
}
pub fn is_bus(&self) -> bool {
self.0.bus_conn
}
pub async fn assign_serial_num(&self, msg: &mut Message) -> Result<u32> {
let serial = self.next_serial().await;
msg.modify_primary_header(|primary| {
primary.set_serial_num(serial);
Ok(())
})?;
Ok(serial)
}
pub fn unique_name(&self) -> Option<&str> {
self.0.unique_name.get().map(|s| s.as_str())
}
pub async fn max_queued(&self) -> usize {
*self.0.max_queued.read().await
}
pub async fn set_max_queued(self, max: usize) -> Self {
*self.0.max_queued.write().await = max;
self
}
pub fn server_guid(&self) -> &str {
self.0.server_guid.as_str()
}
pub async fn as_raw_fd(&self) -> RawFd {
(self.0.raw_in_conn.lock().await.socket()).as_raw_fd()
}
async fn hello_bus(self) -> Result<Self> {
let name: String = self
.call_method(
Some("org.freedesktop.DBus"),
"/org/freedesktop/DBus",
Some("org.freedesktop.DBus"),
"Hello",
&(),
)
.await?
.body()?;
self.0
.unique_name
.set(name)
.expect("Attempted to set unique_name twice");
Ok(self)
}
async fn new(
auth: Authenticated<Async<Box<dyn Socket>>>,
bus_connection: bool,
) -> Result<Self> {
let auth = auth.into_inner();
let out_socket = auth.conn.socket().get_ref().try_clone()?;
let out_conn = RawConnection::wrap(Async::new(out_socket)?);
let connection = Self(Arc::new(ConnectionInner {
raw_in_conn: Mutex::new(auth.conn),
raw_out_conn: Mutex::new(out_conn),
server_guid: auth.server_guid,
cap_unix_fd: auth.cap_unix_fd,
bus_conn: bus_connection,
serial: Mutex::new(1),
unique_name: OnceCell::new(),
incoming_queue: Mutex::new(vec![]),
max_queued: RwLock::new(DEFAULT_MAX_QUEUED),
}));
if !bus_connection {
return Ok(connection);
}
connection.hello_bus().await
}
async fn next_serial(&self) -> u32 {
let mut serial = self.0.serial.lock().await;
let current = *serial;
*serial = current + 1;
current
}
pub async fn new_session() -> Result<Self> {
Self::new(Authenticated::session().await?, true).await
}
pub async fn new_system() -> Result<Self> {
Self::new(Authenticated::system().await?, true).await
}
pub async fn new_for_address(address: &str, bus_connection: bool) -> Result<Self> {
Self::new(Authenticated::for_address(address).await?, bus_connection).await
}
}
pub struct Sink<'s> {
raw_conn: MutexGuard<'s, RawConnection<Async<Box<dyn Socket>>>>,
cap_unix_fd: bool,
}
impl Sink<'_> {
fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop {
match self.raw_conn.try_flush() {
Ok(()) => return Poll::Ready(Ok(())),
Err(e) => {
if e.kind() == ErrorKind::WouldBlock {
let poll = self.raw_conn.socket().poll_writable(cx);
match poll {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(_)) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
}
} else {
return Poll::Ready(Err(Error::Io(e)));
}
}
}
}
}
}
impl futures_sink::Sink<Message> for Sink<'_> {
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, msg: Message) -> Result<()> {
if !msg.fds().is_empty() && !self.cap_unix_fd {
return Err(Error::Unsupported);
}
self.get_mut().raw_conn.enqueue_message(msg);
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.get_mut().flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let sink = self.get_mut();
match sink.flush(cx) {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
Poll::Ready((sink.raw_conn).close())
}
}
pub struct Stream<'s> {
raw_conn: MutexGuard<'s, RawConnection<Async<Box<dyn Socket>>>>,
incoming_queue: Option<MutexGuard<'s, Vec<Message>>>,
}
impl<'s> stream::Stream for Stream<'s> {
type Item = Result<Message>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let stream = self.get_mut();
if let Some(queue) = &mut stream.incoming_queue {
if let Some(msg) = queue.pop() {
return Poll::Ready(Some(Ok(msg)));
}
}
loop {
match stream.raw_conn.try_receive_message() {
Ok(m) => return Poll::Ready(Some(Ok(m))),
Err(Error::Io(e)) if e.kind() == ErrorKind::WouldBlock => {
let poll = stream.raw_conn.socket().poll_readable(cx);
match poll {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(_)) => continue,
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
}
}
Err(Error::Io(e)) if e.kind() == ErrorKind::BrokenPipe => return Poll::Ready(None),
Err(e) => return Poll::Ready(Some(Err(e))),
}
}
}
}
#[cfg(test)]
mod tests {
use std::os::unix::net::UnixStream;
use super::*;
#[test]
fn unix_p2p() {
pollster::block_on(test_unix_p2p()).unwrap();
}
async fn test_unix_p2p() -> Result<()> {
let guid = Guid::generate();
let (p0, p1) = UnixStream::pair().unwrap();
let server = Connection::new_unix_server(p0, &guid);
let client = Connection::new_unix_client(p1, false);
let (client_conn, server_conn) = futures_util::try_join!(client, server)?;
let server_future = async {
let mut method: Option<Message> = None;
while let Some(m) = server_conn.stream().await.try_next().await? {
if m.to_string() == "Method call Test" {
method.replace(m);
break;
}
}
let method = method.unwrap();
server_conn
.emit_signal(None, "/", "org.zbus.p2p", "ASignalForYou", &())
.await?;
server_conn.reply(&method, &("yay")).await
};
let client_future = async {
let reply = client_conn
.call_method(None, "/", Some("org.zbus.p2p"), "Test", &())
.await?;
assert_eq!(reply.to_string(), "Method return");
let m = client_conn.stream().await.try_next().await?.unwrap();
assert_eq!(m.to_string(), "Signal ASignalForYou");
reply.body::<String>().map_err(|e| e.into())
};
let (val, _) = futures_util::try_join!(client_future, server_future)?;
assert_eq!(val, "yay");
Ok(())
}
#[test]
fn serial_monotonically_increases() {
pollster::block_on(test_serial_monotonically_increases());
}
async fn test_serial_monotonically_increases() {
let c = Connection::new_session().await.unwrap();
let serial = c.next_serial().await + 1;
for next in serial..serial + 10 {
assert_eq!(next, c.next_serial().await);
}
}
}