use std::cell::{Cell, RefCell};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::rc::Rc;
use nix::poll::PollFlags;
use once_cell::unsync::OnceCell;
use crate::handshake::{Authenticated, ClientHandshake, ServerHandshake};
use crate::raw::Connection as RawConnection;
use crate::utils::wait_on;
use crate::{fdo, Error, Guid, Message, MessageType, Result};
type MessageHandlerFn = Box<dyn FnMut(Message) -> Option<Message>>;
const DEFAULT_MAX_QUEUED: usize = 32;
#[derive(derivative::Derivative)]
#[derivative(Debug)]
struct ConnectionInner {
server_guid: Guid,
cap_unix_fd: bool,
unique_name: OnceCell<String>,
raw_conn: RefCell<RawConnection<UnixStream>>,
serial: Cell<u32>,
incoming_queue: RefCell<Vec<Message>>,
max_queued: Cell<usize>,
#[derivative(Debug = "ignore")]
default_msg_handler: RefCell<Option<MessageHandlerFn>>,
}
#[derive(Debug, Clone)]
pub struct Connection(Rc<ConnectionInner>);
impl AsRawFd for Connection {
fn as_raw_fd(&self) -> RawFd {
self.0.raw_conn.borrow().socket().as_raw_fd()
}
}
impl Connection {
pub fn new_unix_client(stream: UnixStream, bus_connection: bool) -> Result<Self> {
let auth = ClientHandshake::new(stream).blocking_finish()?;
if bus_connection {
Connection::new_authenticated_unix_bus(auth)
} else {
Ok(Connection::new_authenticated_unix(auth))
}
}
pub fn new_session() -> Result<Self> {
ClientHandshake::new_session()?
.blocking_finish()
.and_then(Self::new_authenticated_unix_bus)
}
pub fn new_system() -> Result<Self> {
ClientHandshake::new_system()?
.blocking_finish()
.and_then(Self::new_authenticated_unix_bus)
}
pub fn new_for_address(address: &str, bus_connection: bool) -> Result<Self> {
let auth = ClientHandshake::new_for_address(address)?.blocking_finish()?;
if bus_connection {
Connection::new_authenticated_unix_bus(auth)
} else {
Ok(Connection::new_authenticated_unix(auth))
}
}
pub fn new_unix_server(stream: UnixStream, guid: &Guid) -> Result<Self> {
use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
let creds = getsockopt(stream.as_raw_fd(), PeerCredentials)
.map_err(|e| Error::Handshake(format!("Failed to get peer credentials: {}", e)))?;
let handshake = ServerHandshake::new(stream, guid.clone(), creds.uid());
handshake
.blocking_finish()
.map(Connection::new_authenticated_unix)
}
pub fn max_queued(&self) -> usize {
self.0.max_queued.get()
}
pub fn set_max_queued(self, max: usize) -> Self {
self.0.max_queued.replace(max);
self
}
pub fn server_guid(&self) -> &str {
self.0.server_guid.as_str()
}
pub fn unique_name(&self) -> Option<&str> {
self.0.unique_name.get().map(|s| s.as_str())
}
pub fn receive_message(&self) -> Result<Message> {
let mut queue = self.0.incoming_queue.borrow_mut();
if let Some(msg) = queue.pop() {
return Ok(msg);
}
loop {
let incoming = self.0.raw_conn.borrow_mut().try_receive_message()?;
if let Some(ref mut handler) = &mut *self.0.default_msg_handler.borrow_mut() {
match handler(incoming) {
Some(m) => return Ok(m),
None => continue,
}
}
return Ok(incoming);
}
}
pub fn send_message(&self, mut msg: Message) -> Result<u32> {
if !msg.fds().is_empty() && !self.0.cap_unix_fd {
return Err(Error::Unsupported);
}
let serial = self.next_serial();
msg.modify_primary_header(|primary| {
primary.set_serial_num(serial);
Ok(())
})?;
let mut conn = self.0.raw_conn.borrow_mut();
conn.enqueue_message(msg);
if let Err(e) = conn.try_flush() {
if e.kind() != std::io::ErrorKind::WouldBlock {
return Err(e.into());
}
}
Ok(serial)
}
pub fn flush(&self) -> Result<()> {
self.0.raw_conn.borrow_mut().try_flush()?;
Ok(())
}
pub fn call_method<B>(
&self,
destination: Option<&str>,
path: &str,
iface: Option<&str>,
method_name: &str,
body: &B,
) -> Result<Message>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::method(
self.unique_name(),
destination,
path,
iface,
method_name,
body,
)?;
let serial = self.send_message(m)?;
loop {
match self.flush() {
Ok(()) => break,
Err(Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
wait_on(self.as_raw_fd(), PollFlags::POLLOUT)?;
}
Err(e) => return Err(e),
}
}
let mut tmp_queue = vec![];
loop {
let m = loop {
match self.receive_message() {
Ok(m) => break m,
Err(Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
wait_on(self.as_raw_fd(), PollFlags::POLLIN)?;
}
Err(e) => return Err(e),
}
};
let h = m.header()?;
if h.reply_serial()? != Some(serial) {
let queue = self.0.incoming_queue.borrow();
if queue.len() + tmp_queue.len() < self.0.max_queued.get() {
tmp_queue.push(m);
}
continue;
} else {
self.0.incoming_queue.borrow_mut().append(&mut tmp_queue);
}
match h.message_type()? {
MessageType::Error => return Err(m.into()),
MessageType::MethodReturn => return Ok(m),
_ => (),
}
}
}
pub fn emit_signal<B>(
&self,
destination: Option<&str>,
path: &str,
iface: &str,
signal_name: &str,
body: &B,
) -> Result<()>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::signal(
self.unique_name(),
destination,
path,
iface,
signal_name,
body,
)?;
self.send_message(m)?;
Ok(())
}
pub fn reply<B>(&self, call: &Message, body: &B) -> Result<u32>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::method_reply(self.unique_name(), call, body)?;
self.send_message(m)
}
pub fn reply_error<B>(&self, call: &Message, error_name: &str, body: &B) -> Result<u32>
where
B: serde::ser::Serialize + zvariant::Type,
{
let m = Message::method_error(self.unique_name(), call, error_name, body)?;
self.send_message(m)
}
pub fn set_default_message_handler(&mut self, handler: MessageHandlerFn) {
self.0.default_msg_handler.borrow_mut().replace(handler);
}
pub fn reset_default_message_handler(&mut self) {
self.0.default_msg_handler.borrow_mut().take();
}
pub fn new_authenticated_unix(auth: Authenticated<UnixStream>) -> Self {
Self(Rc::new(ConnectionInner {
raw_conn: RefCell::new(auth.conn),
server_guid: auth.server_guid,
cap_unix_fd: auth.cap_unix_fd,
serial: Cell::new(1),
unique_name: OnceCell::new(),
incoming_queue: RefCell::new(vec![]),
max_queued: Cell::new(DEFAULT_MAX_QUEUED),
default_msg_handler: RefCell::new(None),
}))
}
pub fn set_unique_name(&self, name: String) -> std::result::Result<(), String> {
self.0.unique_name.set(name)
}
fn new_authenticated_unix_bus(auth: Authenticated<UnixStream>) -> Result<Self> {
let connection = Connection::new_authenticated_unix(auth);
let name = fdo::DBusProxy::new(&connection)?
.hello()
.map_err(|e| Error::Handshake(format!("Hello failed: {}", e)))?;
connection
.0
.unique_name
.set(name)
.expect("Attempted to set unique_name twice");
Ok(connection)
}
fn next_serial(&self) -> u32 {
let next = self.0.serial.get() + 1;
self.0.serial.replace(next)
}
}
#[cfg(test)]
mod tests {
use std::os::unix::net::UnixStream;
use std::thread;
use crate::{Connection, Guid};
#[test]
fn unix_p2p() {
let guid = Guid::generate();
let (p0, p1) = UnixStream::pair().unwrap();
let server_thread = thread::spawn(move || {
let c = Connection::new_unix_server(p0, &guid).unwrap();
let reply = c
.call_method(None, "/", Some("org.zbus.p2p"), "Test", &())
.unwrap();
assert_eq!(reply.to_string(), "Method return");
let val: String = reply.body().unwrap();
val
});
let c = Connection::new_unix_client(p1, false).unwrap();
let m = c.receive_message().unwrap();
assert_eq!(m.to_string(), "Method call Test");
c.reply(&m, &("yay")).unwrap();
let val = server_thread.join().expect("failed to join server thread");
assert_eq!(val, "yay");
}
#[test]
fn serial_monotonically_increases() {
let c = Connection::new_session().unwrap();
let serial = c.next_serial() + 1;
for next in serial..serial + 10 {
assert_eq!(next, c.next_serial());
}
}
}