Skip to content

Commit

Permalink
Merge pull request #488 from zeenix/immutable-message
Browse files Browse the repository at this point in the history
🚸 zb: `Message` now immutable
  • Loading branch information
zeenix authored Oct 7, 2023
2 parents 3ebf075 + aec0618 commit 4bbdba4
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 203 deletions.
23 changes: 6 additions & 17 deletions zbus/src/blocking/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use enumflags2::BitFlags;
use event_listener::EventListener;
use static_assertions::assert_impl_all;
use std::{io, num::NonZeroU32, ops::Deref, sync::Arc};
use std::{io, ops::Deref, sync::Arc};
use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, OwnedUniqueName, WellKnownName};
use zvariant::ObjectPath;

Expand Down Expand Up @@ -62,12 +62,8 @@ impl Connection {
}

/// Send `msg` to the peer.
///
/// The connection sets a unique serial number on the message before sending it off.
///
/// On successfully sending off `msg`, the assigned serial number is returned.
pub fn send_message(&self, msg: Message) -> Result<NonZeroU32> {
block_on(self.inner.send_message(msg))
pub fn send(&self, msg: &Message) -> Result<()> {
block_on(self.inner.send(msg))
}

/// Send a method call.
Expand Down Expand Up @@ -138,9 +134,7 @@ impl Connection {
///
/// Given an existing message (likely a method call), send a reply back to the caller with the
/// given `body`.
///
/// Returns the message serial number.
pub fn reply<B>(&self, call: &Message, body: &B) -> Result<NonZeroU32>
pub fn reply<B>(&self, call: &Message, body: &B) -> Result<()>
where
B: serde::ser::Serialize + zvariant::DynamicType,
{
Expand All @@ -153,12 +147,7 @@ impl Connection {
/// with the given `error_name` and `body`.
///
/// Returns the message serial number.
pub fn reply_error<'e, E, B>(
&self,
call: &Message,
error_name: E,
body: &B,
) -> Result<NonZeroU32>
pub fn reply_error<'e, E, B>(&self, call: &Message, error_name: E, body: &B) -> Result<()>
where
B: serde::ser::Serialize + zvariant::DynamicType,
E: TryInto<ErrorName<'e>>,
Expand All @@ -177,7 +166,7 @@ impl Connection {
&self,
call: &zbus::message::Header<'_>,
err: impl DBusError,
) -> Result<NonZeroU32> {
) -> Result<()> {
block_on(self.inner.reply_dbus_error(call, err))
}

Expand Down
100 changes: 13 additions & 87 deletions zbus/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ use std::{
num::NonZeroU32,
ops::Deref,
pin::Pin,
sync::{
atomic::{AtomicU32, Ordering::SeqCst},
Arc, Weak,
},
sync::{Arc, Weak},
task::{Context, Poll},
};
use tracing::{debug, info_span, instrument, trace, trace_span, warn, Instrument};
Expand Down Expand Up @@ -62,9 +59,6 @@ pub(crate) struct ConnectionInner {
activity_event: Arc<Event>,
socket_write: Mutex<Box<dyn socket::WriteHalf>>,

// Serial number for next outgoing message
serial: AtomicU32,

// Our executor
executor: Executor<'static>,

Expand Down Expand Up @@ -114,10 +108,7 @@ pub(crate) type MsgBroadcaster = Broadcaster<Result<Arc<Message>>>;
/// [`crate::blocking::MessageIterator`] instances are continuously polled and iterated on,
/// respectively.
///
/// For sending messages you can either use [`Connection::send_message`] or [`Connection::send`]
/// method. While the former sets the serial numbers (cookies) on the messages for you, the latter
/// does not. You can manually assign unique serial numbers to messages using the
/// [`Connection::assign_serial_num`] method when using `send` method to send them off.
/// For sending messages you can either use [`Connection::send`] method.
///
/// [method calls]: struct.Connection.html#method.call_method
/// [signals]: struct.Connection.html#method.emit_signal
Expand Down Expand Up @@ -279,34 +270,12 @@ impl OrderedFuture for PendingMethodCall {

impl Connection {
/// Send `msg` to the peer.
///
/// Unlike [`Connection::send`], this method sets a unique (to this connection) serial number on
/// the message before sending it off, for you.
///
/// On successfully sending off `msg`, the assigned serial number is returned.
pub async fn send_message(&self, mut msg: Message) -> Result<NonZeroU32> {
let serial = self.assign_serial_num(&mut msg)?;

trace!("Sending message: {:?}", msg);
self.send(&msg).await?;
trace!("Sent message with serial: {}", serial);

Ok(serial)
}

/// Send `msg` to the peer.
///
/// Same as [`Connection::send_message`] except it doesn't sets the unique serial number on the
/// message for you. It expects a serial number to be already set on the message.
pub async fn send(&self, msg: &Message) -> Result<()> {
#[cfg(unix)]
if !msg.fds().is_empty() && !self.inner.cap_unix_fd {
return Err(Error::Unsupported);
}
let serial = msg
.primary_header()
.serial_num()
.ok_or(Error::InvalidSerial)?;
let serial = msg.primary_header().serial_num();

trace!("Sending message: {:?}", msg);
self.inner.activity_event.notify(usize::MAX);
Expand Down Expand Up @@ -419,7 +388,8 @@ impl Connection {
None,
self,
));
let serial = self.send_message(msg).await?;
let serial = msg.primary_header().serial_num();
self.send(&msg).await?;
if flags.contains(Flags::NoReplyExpected) {
Ok(None)
} else {
Expand Down Expand Up @@ -458,16 +428,14 @@ impl Connection {
}
let m = b.build(body)?;

self.send_message(m).await.map(|_| ())
self.send(&m).await
}

/// Reply to a message.
///
/// Given an existing message (likely a method call), send a reply back to the caller with the
/// given `body`.
///
/// Returns the message serial number.
pub async fn reply<B>(&self, call: &Message, body: &B) -> Result<NonZeroU32>
pub async fn reply<B>(&self, call: &Message, body: &B) -> Result<()>
where
B: serde::ser::Serialize + zvariant::DynamicType,
{
Expand All @@ -476,21 +444,14 @@ impl Connection {
b = b.sender(sender)?;
}
let m = b.build(body)?;
self.send_message(m).await
self.send(&m).await
}

/// Reply an error to a message.
///
/// Given an existing message (likely a method call), send an error reply back to the caller
/// with the given `error_name` and `body`.
///
/// Returns the message serial number.
pub async fn reply_error<'e, E, B>(
&self,
call: &Message,
error_name: E,
body: &B,
) -> Result<NonZeroU32>
pub async fn reply_error<'e, E, B>(&self, call: &Message, error_name: E, body: &B) -> Result<()>
where
B: serde::ser::Serialize + zvariant::DynamicType,
E: TryInto<ErrorName<'e>>,
Expand All @@ -501,22 +462,20 @@ impl Connection {
b = b.sender(sender)?;
}
let m = b.build(body)?;
self.send_message(m).await
self.send(&m).await
}

/// Reply an error to a message.
///
/// Given an existing message (likely a method call), send an error reply back to the caller
/// using one of the standard interface reply types.
///
/// Returns the message serial number.
pub async fn reply_dbus_error(
&self,
call: &zbus::message::Header<'_>,
err: impl DBusError,
) -> Result<NonZeroU32> {
let m = err.create_reply(call);
self.send_message(m?).await
) -> Result<()> {
let m = err.create_reply(call)?;
self.send(&m).await
}

/// Register a well-known name for this connection.
Expand Down Expand Up @@ -818,19 +777,6 @@ impl Connection {
self.inner.bus_conn
}

/// Assigns a serial number to `msg` that is unique to this connection.
///
/// This method can fail if `msg` is corrupted.
pub fn assign_serial_num(&self, msg: &mut Message) -> Result<NonZeroU32> {
let serial = self
.next_serial()
.try_into()
.map_err(|_| Error::InvalidSerial)?;
msg.set_serial_num(serial)?;

Ok(serial)
}

/// The unique name of the connection, if set/applicable.
///
/// The unique name is assigned by the message bus or set manually using
Expand Down Expand Up @@ -1230,7 +1176,6 @@ impl Connection {
#[cfg(unix)]
cap_unix_fd,
bus_conn: bus_connection,
serial: AtomicU32::new(1),
unique_name: OnceCell::new(),
subscriptions,
object_server: OnceCell::new(),
Expand All @@ -1247,10 +1192,6 @@ impl Connection {
Ok(connection)
}

fn next_serial(&self) -> u32 {
self.inner.serial.fetch_add(1, SeqCst)
}

/// Create a `Connection` to the session/user message bus.
pub async fn session() -> Result<Self> {
Builder::session()?.build().await
Expand Down Expand Up @@ -1585,21 +1526,6 @@ mod tests {
)
}

#[test]
#[timeout(15000)]
fn serial_monotonically_increases() {
crate::utils::block_on(test_serial_monotonically_increases());
}

async fn test_serial_monotonically_increases() {
let c = Connection::session().await.unwrap();
let serial = c.next_serial() + 1;

for next in serial..serial + 10 {
assert_eq!(next, c.next_serial());
}
}

#[cfg(all(windows, feature = "windows-gdbus"))]
#[test]
fn connect_gdbus_session_bus() {
Expand Down
3 changes: 1 addition & 2 deletions zbus/src/fdo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,13 +1007,12 @@ mod tests {

#[test]
fn error_from_zerror() {
let mut m = Message::method("/", "foo")
let m = Message::method("/", "foo")
.unwrap()
.destination(":1.2")
.unwrap()
.build(&())
.unwrap();
m.set_serial_num(1.try_into().unwrap()).unwrap();
let m = Message::method_error(&m, "org.freedesktop.DBus.Error.TimedOut")
.unwrap()
.build(&("so long"))
Expand Down
22 changes: 6 additions & 16 deletions zbus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ mod tests {
use crate::{
blocking::{self, MessageIterator},
fdo::{RequestNameFlags, RequestNameReply},
message::{Flags, Message},
message::Message,
object_server::SignalContext,
Connection, Result,
};
Expand All @@ -231,7 +231,7 @@ mod tests {

#[test]
fn msg() {
let mut m = Message::method("/org/freedesktop/DBus", "GetMachineId")
let m = Message::method("/org/freedesktop/DBus", "GetMachineId")
.unwrap()
.destination("org.freedesktop.DBus")
.unwrap()
Expand All @@ -243,16 +243,6 @@ mod tests {
assert_eq!(hdr.path().unwrap(), "/org/freedesktop/DBus");
assert_eq!(hdr.interface().unwrap(), "org.freedesktop.DBus.Peer");
assert_eq!(hdr.member().unwrap(), "GetMachineId");
m.modify_primary_header(|primary| {
primary.set_flags(BitFlags::from(Flags::NoAutoStart));
primary.set_serial_num(11.try_into().unwrap());

Ok(())
})
.unwrap();
let primary = m.primary_header();
assert!(primary.serial_num().unwrap().get() == 11);
assert!(primary.flags() == Flags::NoAutoStart);
}

#[test]
Expand Down Expand Up @@ -333,7 +323,6 @@ mod tests {
.unwrap();

let fd: Fd = reply.body().unwrap();
let _fds = reply.take_fds();
assert!(fd.as_raw_fd() >= 0);
let f = unsafe { File::from_raw_fd(fd.as_raw_fd()) };
f.metadata().unwrap();
Expand Down Expand Up @@ -584,7 +573,8 @@ mod tests {
.unwrap()
.build(&())
.unwrap();
let serial = client_conn.send_message(msg).unwrap();
let serial = msg.primary_header().serial_num();
client_conn.send(&msg).unwrap();

crate::blocking::fdo::DBusProxy::new(&conn)
.unwrap()
Expand All @@ -594,7 +584,7 @@ mod tests {
for m in stream {
let msg = m.unwrap();

if msg.primary_header().serial_num().unwrap() == serial {
if msg.primary_header().serial_num() == serial {
break;
}
}
Expand Down Expand Up @@ -731,7 +721,7 @@ mod tests {
.unwrap()
.build(&())
.unwrap();
conn.send_message(msg).unwrap();
conn.send(&msg).unwrap();

child.join().unwrap();
}
Expand Down
9 changes: 3 additions & 6 deletions zbus/src/message/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use std::io::{Cursor, Write};
#[cfg(unix)]
use crate::message::Fds;
#[cfg(unix)]
use std::{
os::unix::io::RawFd,
sync::{Arc, RwLock},
};
use std::{os::unix::io::RawFd, sync::Arc};

use enumflags2::BitFlags;
use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
Expand Down Expand Up @@ -179,7 +176,7 @@ impl<'a> Builder<'a> {
}

fn reply_to(mut self, reply_to: &Header<'_>) -> Result<Self> {
let serial = reply_to.primary().serial_num().ok_or(Error::MissingField)?;
let serial = reply_to.primary().serial_num();
self.header.fields_mut().replace(Field::ReplySerial(serial));

if let Some(sender) = reply_to.sender() {
Expand Down Expand Up @@ -334,7 +331,7 @@ impl<'a> Builder<'a> {
bytes,
body_offset,
#[cfg(unix)]
fds: Arc::new(RwLock::new(Fds::Raw(fds))),
fds: Arc::new(Fds::Raw(fds)),
recv_seq: Sequence::default(),
})
}
Expand Down
Loading

0 comments on commit 4bbdba4

Please sign in to comment.