[go: up one dir, main page]

bastion/
message.rs

1//!
2//! Dynamic dispatch oriented messaging system
3//!
4//! This system allows:
5//! * Generic communication between mailboxes.
6//! * All message communication relies on at-most-once delivery guarantee.
7//! * Messages are not guaranteed to be ordered, all message's order is causal.
8//!
9use crate::callbacks::CallbackType;
10use crate::children::Children;
11use crate::context::{BastionId, ContextState};
12use crate::envelope::{RefAddr, SignedMessage};
13use crate::supervisor::{SupervisionStrategy, Supervisor};
14
15use futures::channel::oneshot::{self, Receiver};
16use std::any::{type_name, Any};
17use std::fmt::Debug;
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use tracing::{debug, trace};
23
24/// A trait that any message sent needs to implement (it is
25/// already automatically implemented but forces message to
26/// implement the following traits: [`Any`], [`Send`],
27/// [`Sync`] and [`Debug`]).
28///
29/// [`Any`]: std::any::Any
30/// [`Send`]: std::marker::Send
31/// [`Sync`]: std::marker::Sync
32/// [`Debug`]: std::fmt::Debug
33pub trait Message: Any + Send + Sync + Debug {}
34impl<T> Message for T where T: Any + Send + Sync + Debug {}
35
36/// Allows to respond to questions.
37///
38/// This type features the [`respond`] method, that allows to respond to a
39/// question.
40///
41/// [`respond`]: #method.respond
42#[derive(Debug)]
43pub struct AnswerSender(oneshot::Sender<SignedMessage>, RefAddr);
44
45#[derive(Debug)]
46/// A [`Future`] returned when successfully "asking" a
47/// message using [`ChildRef::ask_anonymously`] and which resolves to
48/// a `Result<Msg, ()>` where the [`Msg`] is the message
49/// answered by the child (see the [`msg!`] macro for more
50/// information).
51///
52/// # Example
53///
54/// ```rust
55/// # use bastion::prelude::*;
56/// #
57/// # #[cfg(feature = "tokio-runtime")]
58/// # #[tokio::main]
59/// # async fn main() {
60/// #    run();    
61/// # }
62/// #
63/// # #[cfg(not(feature = "tokio-runtime"))]
64/// # fn main() {
65/// #    run();    
66/// # }
67/// #
68/// # fn run() {
69///     # Bastion::init();
70/// // The message that will be "asked"...
71/// const ASK_MSG: &'static str = "A message containing data (ask).";
72/// // The message the will be "answered"...
73/// const ANSWER_MSG: &'static str = "A message containing data (answer).";
74///
75///     # let children_ref =
76/// // Create a new child...
77/// Bastion::children(|children| {
78///     children.with_exec(|ctx: BastionContext| {
79///         async move {
80///             // ...which will receive the message asked...
81///             msg! { ctx.recv().await?,
82///                 msg: &'static str =!> {
83///                     assert_eq!(msg, ASK_MSG);
84///                     // Handle the message...
85///
86///                     // ...and eventually answer to it...
87///                     answer!(ctx, ANSWER_MSG);
88///                 };
89///                 // This won't happen because this example
90///                 // only "asks" a `&'static str`...
91///                 _: _ => ();
92///             }
93///
94///             Ok(())
95///         }
96///     })
97/// }).expect("Couldn't create the children group.");
98///
99///     # Bastion::children(|children| {
100///         # children.with_exec(move |ctx: BastionContext| {
101///             # let child_ref = children_ref.elems()[0].clone();
102///             # async move {
103/// // Later, the message is "asked" to the child...
104/// let answer: Answer = ctx.ask(&child_ref.addr(), ASK_MSG).expect("Couldn't send the message.");
105///
106/// // ...and the child's answer is received...
107/// msg! { answer.await.expect("Couldn't receive the answer."),
108///     msg: &'static str => {
109///         assert_eq!(msg, ANSWER_MSG);
110///         // Handle the answer...
111///     };
112///     // This won't happen because this example
113///     // only answers a `&'static str`...
114///     _: _ => ();
115/// }
116///                 #
117///                 # Ok(())
118///             # }
119///         # })
120///     # }).unwrap();
121///     #
122///     # Bastion::start();
123///     # Bastion::stop();
124///     # Bastion::block_until_stopped();
125/// # }
126/// ```
127///
128/// [`Future`]: std::future::Future
129/// [`ChildRef::ask_anonymously`]: crate::child_ref::ChildRef::ask_anonymously
130pub struct Answer(Receiver<SignedMessage>);
131
132#[derive(Debug)]
133/// A message returned by [`BastionContext::recv`] or
134/// [`BastionContext::try_recv`] that should be passed to the
135/// [`msg!`] macro to try to match what its real type is.
136///
137/// # Example
138///
139/// ```rust
140/// # use bastion::prelude::*;
141/// #
142/// # #[cfg(feature = "tokio-runtime")]
143/// # #[tokio::main]
144/// # async fn main() {
145/// #    run();    
146/// # }
147/// #
148/// # #[cfg(not(feature = "tokio-runtime"))]
149/// # fn main() {
150/// #    run();    
151/// # }
152/// #
153/// # fn run() {
154///     # Bastion::init();
155/// Bastion::children(|children| {
156///     children.with_exec(|ctx: BastionContext| {
157///         async move {
158///             loop {
159///                 let msg: SignedMessage = ctx.recv().await?;
160///                 msg! { msg,
161///                     // We match a broadcasted `&'static str`s...
162///                     ref msg: &'static str => {
163///                         // Note that `msg` will actually be a `&&'static str`.
164///                         assert_eq!(msg, &"A message containing data.");
165///
166///                         // Handle the message...
167///                     };
168///                     // We match a `&'static str`s "told" to this child...
169///                     msg: &'static str => {
170///                         assert_eq!(msg, "A message containing data.");
171///                         // Handle the message...
172///
173///                         // get message signature
174///                         let sign = signature!();
175///                         ctx.tell(&sign, "A message containing reply").unwrap();
176///                     };
177///                     // We match a `&'static str`s "asked" to this child...
178///                     msg: &'static str =!> {
179///                         assert_eq!(msg, "A message containing data.");
180///                         // Handle the message...
181///
182///                         // ...and eventually answer to it...
183///                         answer!(ctx, "An answer message containing data.");
184///                     };
185///                     // We match a message that wasn't previously matched...
186///                     _: _ => ();
187///                 }
188///             }
189///         }
190///     })
191/// }).expect("Couldn't start the children group.");
192///     #
193///     # Bastion::start();
194///     # Bastion::stop();
195///     # Bastion::block_until_stopped();
196/// # }
197/// ```
198///
199/// [`BastionContext::recv`]: crate::context::BastionContext::recv
200/// [`BastionContext::try_recv`]: crate::context::BastionContext::try_recv
201pub struct Msg(MsgInner);
202
203#[derive(Debug)]
204enum MsgInner {
205    Broadcast(Arc<dyn Any + Send + Sync + 'static>),
206    Tell(Box<dyn Any + Send + Sync + 'static>),
207    Ask {
208        msg: Box<dyn Any + Send + Sync + 'static>,
209        sender: Option<AnswerSender>,
210    },
211}
212
213#[derive(Debug)]
214pub(crate) enum BastionMessage {
215    Start,
216    Stop,
217    Kill,
218    Deploy(Box<Deployment>),
219    Prune {
220        id: BastionId,
221    },
222    SuperviseWith(SupervisionStrategy),
223    ApplyCallback(CallbackType),
224    InstantiatedChild {
225        parent_id: BastionId,
226        child_id: BastionId,
227        state: Arc<Pin<Box<ContextState>>>,
228    },
229    Message(Msg),
230    RestartRequired {
231        id: BastionId,
232        parent_id: BastionId,
233    },
234    FinishedChild {
235        id: BastionId,
236        parent_id: BastionId,
237    },
238    RestartSubtree,
239    RestoreChild {
240        id: BastionId,
241        state: Arc<Pin<Box<ContextState>>>,
242    },
243    DropChild {
244        id: BastionId,
245    },
246    SetState {
247        state: Arc<Pin<Box<ContextState>>>,
248    },
249    Stopped {
250        id: BastionId,
251    },
252    Faulted {
253        id: BastionId,
254    },
255    Heartbeat,
256}
257
258#[derive(Debug)]
259pub(crate) enum Deployment {
260    Supervisor(Supervisor),
261    Children(Children),
262}
263
264impl AnswerSender {
265    /// Sends data back to the original sender.
266    ///
267    /// Returns  `Ok` if the data was sent successfully, otherwise returns the
268    /// original data.
269    pub fn reply<M: Message>(self, msg: M) -> Result<(), M> {
270        debug!("{:?}: Sending answer: {:?}", self, msg);
271        let msg = Msg::tell(msg);
272        trace!("{:?}: Sending message: {:?}", self, msg);
273
274        let AnswerSender(sender, sign) = self;
275        sender
276            .send(SignedMessage::new(msg, sign))
277            .map_err(|smsg| smsg.msg.try_unwrap().unwrap())
278    }
279}
280
281impl Msg {
282    pub(crate) fn broadcast<M: Message>(msg: M) -> Self {
283        let inner = MsgInner::Broadcast(Arc::new(msg));
284        Msg(inner)
285    }
286
287    pub(crate) fn tell<M: Message>(msg: M) -> Self {
288        let inner = MsgInner::Tell(Box::new(msg));
289        Msg(inner)
290    }
291
292    pub(crate) fn ask<M: Message>(msg: M, sign: RefAddr) -> (Self, Answer) {
293        let msg = Box::new(msg);
294        let (sender, recver) = oneshot::channel();
295        let sender = AnswerSender(sender, sign);
296        let answer = Answer(recver);
297
298        let sender = Some(sender);
299        let inner = MsgInner::Ask { msg, sender };
300
301        (Msg(inner), answer)
302    }
303
304    #[doc(hidden)]
305    pub fn is_broadcast(&self) -> bool {
306        matches!(self.0, MsgInner::Broadcast(_))
307    }
308
309    #[doc(hidden)]
310    pub fn is_tell(&self) -> bool {
311        matches!(self.0, MsgInner::Tell(_))
312    }
313
314    #[doc(hidden)]
315    pub fn is_ask(&self) -> bool {
316        matches!(self.0, MsgInner::Ask { .. })
317    }
318
319    #[doc(hidden)]
320    pub fn take_sender(&mut self) -> Option<AnswerSender> {
321        debug!("{:?}: Taking sender.", self);
322        if let MsgInner::Ask { sender, .. } = &mut self.0 {
323            sender.take()
324        } else {
325            None
326        }
327    }
328
329    #[doc(hidden)]
330    pub fn is<M: Message>(&self) -> bool {
331        match &self.0 {
332            MsgInner::Tell(msg) => msg.is::<M>(),
333            MsgInner::Ask { msg, .. } => msg.is::<M>(),
334            MsgInner::Broadcast(msg) => msg.is::<M>(),
335        }
336    }
337
338    #[doc(hidden)]
339    pub fn downcast<M: Message>(self) -> Result<M, Self> {
340        trace!("{:?}: Downcasting to {}.", self, type_name::<M>());
341        match self.0 {
342            MsgInner::Tell(msg) => {
343                if msg.is::<M>() {
344                    let msg: Box<dyn Any + 'static> = msg;
345                    Ok(*msg.downcast().unwrap())
346                } else {
347                    let inner = MsgInner::Tell(msg);
348                    Err(Msg(inner))
349                }
350            }
351            MsgInner::Ask { msg, sender } => {
352                if msg.is::<M>() {
353                    let msg: Box<dyn Any + 'static> = msg;
354                    Ok(*msg.downcast().unwrap())
355                } else {
356                    let inner = MsgInner::Ask { msg, sender };
357                    Err(Msg(inner))
358                }
359            }
360            _ => Err(self),
361        }
362    }
363
364    #[doc(hidden)]
365    pub fn downcast_ref<M: Message>(&self) -> Option<Arc<M>> {
366        trace!("{:?}: Downcasting to ref of {}.", self, type_name::<M>());
367        if let MsgInner::Broadcast(msg) = &self.0 {
368            if msg.is::<M>() {
369                return Some(msg.clone().downcast::<M>().unwrap());
370            }
371        }
372
373        None
374    }
375
376    pub(crate) fn try_clone(&self) -> Option<Self> {
377        trace!("{:?}: Trying to clone.", self);
378        if let MsgInner::Broadcast(msg) = &self.0 {
379            let inner = MsgInner::Broadcast(msg.clone());
380            Some(Msg(inner))
381        } else {
382            None
383        }
384    }
385
386    pub(crate) fn try_unwrap<M: Message>(self) -> Result<M, Self> {
387        debug!("{:?}: Trying to unwrap.", self);
388        if let MsgInner::Broadcast(msg) = self.0 {
389            match msg.downcast() {
390                Ok(msg) => match Arc::try_unwrap(msg) {
391                    Ok(msg) => Ok(msg),
392                    Err(msg) => {
393                        let inner = MsgInner::Broadcast(msg);
394                        Err(Msg(inner))
395                    }
396                },
397                Err(msg) => {
398                    let inner = MsgInner::Broadcast(msg);
399                    Err(Msg(inner))
400                }
401            }
402        } else {
403            self.downcast()
404        }
405    }
406}
407
408impl AsRef<dyn Any> for Msg {
409    fn as_ref(&self) -> &dyn Any {
410        match &self.0 {
411            MsgInner::Broadcast(msg) => msg.as_ref(),
412            MsgInner::Tell(msg) => msg.as_ref(),
413            MsgInner::Ask { msg, .. } => msg.as_ref(),
414        }
415    }
416}
417
418impl BastionMessage {
419    pub(crate) fn start() -> Self {
420        BastionMessage::Start
421    }
422
423    pub(crate) fn stop() -> Self {
424        BastionMessage::Stop
425    }
426
427    pub(crate) fn kill() -> Self {
428        BastionMessage::Kill
429    }
430
431    pub(crate) fn deploy_supervisor(supervisor: Supervisor) -> Self {
432        let deployment = Deployment::Supervisor(supervisor);
433
434        BastionMessage::Deploy(deployment.into())
435    }
436
437    pub(crate) fn deploy_children(children: Children) -> Self {
438        let deployment = Deployment::Children(children);
439
440        BastionMessage::Deploy(deployment.into())
441    }
442
443    pub(crate) fn prune(id: BastionId) -> Self {
444        BastionMessage::Prune { id }
445    }
446
447    pub(crate) fn supervise_with(strategy: SupervisionStrategy) -> Self {
448        BastionMessage::SuperviseWith(strategy)
449    }
450
451    pub(crate) fn apply_callback(callback_type: CallbackType) -> Self {
452        BastionMessage::ApplyCallback(callback_type)
453    }
454
455    pub(crate) fn instantiated_child(
456        parent_id: BastionId,
457        child_id: BastionId,
458        state: Arc<Pin<Box<ContextState>>>,
459    ) -> Self {
460        BastionMessage::InstantiatedChild {
461            parent_id,
462            child_id,
463            state,
464        }
465    }
466
467    pub(crate) fn broadcast<M: Message>(msg: M) -> Self {
468        let msg = Msg::broadcast(msg);
469        BastionMessage::Message(msg)
470    }
471
472    pub(crate) fn tell<M: Message>(msg: M) -> Self {
473        let msg = Msg::tell(msg);
474        BastionMessage::Message(msg)
475    }
476
477    pub(crate) fn ask<M: Message>(msg: M, sign: RefAddr) -> (Self, Answer) {
478        let (msg, answer) = Msg::ask(msg, sign);
479        (BastionMessage::Message(msg), answer)
480    }
481
482    pub(crate) fn restart_required(id: BastionId, parent_id: BastionId) -> Self {
483        BastionMessage::RestartRequired { id, parent_id }
484    }
485
486    pub(crate) fn finished_child(id: BastionId, parent_id: BastionId) -> Self {
487        BastionMessage::FinishedChild { id, parent_id }
488    }
489
490    pub(crate) fn restart_subtree() -> Self {
491        BastionMessage::RestartSubtree
492    }
493
494    pub(crate) fn restore_child(id: BastionId, state: Arc<Pin<Box<ContextState>>>) -> Self {
495        BastionMessage::RestoreChild { id, state }
496    }
497
498    pub(crate) fn drop_child(id: BastionId) -> Self {
499        BastionMessage::DropChild { id }
500    }
501
502    pub(crate) fn set_state(state: Arc<Pin<Box<ContextState>>>) -> Self {
503        BastionMessage::SetState { state }
504    }
505
506    pub(crate) fn stopped(id: BastionId) -> Self {
507        BastionMessage::Stopped { id }
508    }
509
510    pub(crate) fn faulted(id: BastionId) -> Self {
511        BastionMessage::Faulted { id }
512    }
513
514    pub(crate) fn heartbeat() -> Self {
515        BastionMessage::Heartbeat
516    }
517
518    pub(crate) fn try_clone(&self) -> Option<Self> {
519        trace!("{:?}: Trying to clone.", self);
520        let clone = match self {
521            BastionMessage::Start => BastionMessage::start(),
522            BastionMessage::Stop => BastionMessage::stop(),
523            BastionMessage::Kill => BastionMessage::kill(),
524            // FIXME
525            BastionMessage::Deploy(_) => unimplemented!(),
526            BastionMessage::Prune { id } => BastionMessage::prune(id.clone()),
527            BastionMessage::SuperviseWith(strategy) => {
528                BastionMessage::supervise_with(strategy.clone())
529            }
530            BastionMessage::ApplyCallback(callback_type) => {
531                BastionMessage::apply_callback(callback_type.clone())
532            }
533            BastionMessage::InstantiatedChild {
534                parent_id,
535                child_id,
536                state,
537            } => BastionMessage::instantiated_child(
538                parent_id.clone(),
539                child_id.clone(),
540                state.clone(),
541            ),
542            BastionMessage::Message(msg) => BastionMessage::Message(msg.try_clone()?),
543            BastionMessage::RestartRequired { id, parent_id } => {
544                BastionMessage::restart_required(id.clone(), parent_id.clone())
545            }
546            BastionMessage::FinishedChild { id, parent_id } => {
547                BastionMessage::finished_child(id.clone(), parent_id.clone())
548            }
549            BastionMessage::RestartSubtree => BastionMessage::restart_subtree(),
550            BastionMessage::RestoreChild { id, state } => {
551                BastionMessage::restore_child(id.clone(), state.clone())
552            }
553            BastionMessage::DropChild { id } => BastionMessage::drop_child(id.clone()),
554            BastionMessage::SetState { state } => BastionMessage::set_state(state.clone()),
555            BastionMessage::Stopped { id } => BastionMessage::stopped(id.clone()),
556            BastionMessage::Faulted { id } => BastionMessage::faulted(id.clone()),
557            BastionMessage::Heartbeat => BastionMessage::heartbeat(),
558        };
559
560        Some(clone)
561    }
562
563    pub(crate) fn into_msg<M: Message>(self) -> Option<M> {
564        if let BastionMessage::Message(msg) = self {
565            msg.try_unwrap().ok()
566        } else {
567            None
568        }
569    }
570}
571
572impl Future for Answer {
573    type Output = Result<SignedMessage, ()>;
574
575    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
576        debug!("{:?}: Polling.", self);
577        Pin::new(&mut self.get_mut().0).poll(ctx).map_err(|_| ())
578    }
579}
580
581#[macro_export]
582/// Matches a [`Msg`] (as returned by [`BastionContext::recv`]
583/// or [`BastionContext::try_recv`]) with different types.
584///
585/// Each case is defined as:
586/// - an optional `ref` which will make the case only match
587///   if the message was broadcasted
588/// - a variable name for the message if it matched this case
589/// - a colon
590/// - a type that the message must be of to match this case
591///   (note that if the message was broadcasted, the actual
592///   type of the variable will be a reference to this type)
593/// - an arrow (`=>`) with an optional bang (`!`) between
594///   the equal and greater-than signs which will make the
595///   case only match if the message can be answered
596/// - code that will be executed if the case matches
597///
598/// If the message can be answered (when using `=!>` instead
599/// of `=>` as said above), an answer can be sent by passing
600/// it to the `answer!` macro that will be generated for this
601/// use.
602///
603/// A default case is required, which is defined in the same
604/// way as any other case but with its type set as `_` (note
605/// that it doesn't has the optional `ref` or `=!>`).
606///
607/// # Example
608///
609/// ```rust
610/// # use bastion::prelude::*;
611/// #
612/// # #[cfg(feature = "tokio-runtime")]
613/// # #[tokio::main]
614/// # async fn main() {
615/// #    run();    
616/// # }
617/// #
618/// # #[cfg(not(feature = "tokio-runtime"))]
619/// # fn main() {
620/// #    run();    
621/// # }
622/// #
623/// # fn run() {
624///     # Bastion::init();
625/// // The message that will be broadcasted...
626/// const BCAST_MSG: &'static str = "A message containing data (broadcast).";
627/// // The message that will be "told" to the child...
628/// const TELL_MSG: &'static str = "A message containing data (tell).";
629/// // The message that will be "asked" to the child...
630/// const ASK_MSG: &'static str = "A message containing data (ask).";
631///
632/// Bastion::children(|children| {
633///     children.with_exec(|ctx: BastionContext| {
634///         async move {
635///             # ctx.tell(&ctx.current().addr(), TELL_MSG).unwrap();
636///             # ctx.ask(&ctx.current().addr(), ASK_MSG).unwrap();
637///             #
638///             loop {
639///                 msg! { ctx.recv().await?,
640///                     // We match broadcasted `&'static str`s...
641///                     ref msg: &'static str => {
642///                         // Note that `msg` will actually be a `&&'static str`.
643///                         assert_eq!(msg, &BCAST_MSG);
644///                         // Handle the message...
645///                     };
646///                     // We match `&'static str`s "told" to this child...
647///                     msg: &'static str => {
648///                         assert_eq!(msg, TELL_MSG);
649///                         // Handle the message...
650///                     };
651///                     // We match `&'static str`'s "asked" to this child...
652///                     msg: &'static str =!> {
653///                         assert_eq!(msg, ASK_MSG);
654///                         // Handle the message...
655///
656///                         // ...and eventually answer to it...
657///                         answer!(ctx, "An answer to the message.");
658///                     };
659///                     // We are only broadcasting, "telling" and "asking" a
660///                     // `&'static str` in this example, so we know that this won't
661///                     // happen...
662///                     _: _ => ();
663///                 }
664///             }
665///         }
666///     })
667/// }).expect("Couldn't start the children group.");
668///     #
669///     # Bastion::start();
670///     # Bastion::broadcast(BCAST_MSG).unwrap();
671///     # Bastion::stop();
672///     # Bastion::block_until_stopped();
673/// # }
674/// ```
675///
676/// [`BastionContext::recv`]: crate::context::BastionContext::recv
677/// [`BastionContext::try_recv`]: crate::context::BastionContext::try_recv
678macro_rules! msg {
679    ($msg:expr, $($tokens:tt)+) => {
680        msg!(@internal $msg, (), (), (), $($tokens)+)
681    };
682
683    (@internal
684        $msg:expr,
685        ($($bvar:ident, $bty:ty, $bhandle:expr,)*),
686        ($($tvar:ident, $tty:ty, $thandle:expr,)*),
687        ($($avar:ident, $aty:ty, $ahandle:expr,)*),
688        ref $var:ident: $ty:ty => $handle:expr;
689        $($rest:tt)+
690    ) => {
691        msg!(@internal $msg,
692            ($($bvar, $bty, $bhandle,)* $var, $ty, $handle,),
693            ($($tvar, $tty, $thandle,)*),
694            ($($avar, $aty, $ahandle,)*),
695            $($rest)+
696        )
697    };
698
699    (@internal
700        $msg:expr,
701        ($($bvar:ident, $bty:ty, $bhandle:expr,)*),
702        ($($tvar:ident, $tty:ty, $thandle:expr,)*),
703        ($($avar:ident, $aty:ty, $ahandle:expr,)*),
704        $var:ident: $ty:ty => $handle:expr;
705        $($rest:tt)+
706    ) => {
707        msg!(@internal $msg,
708            ($($bvar, $bty, $bhandle,)*),
709            ($($tvar, $tty, $thandle,)* $var, $ty, $handle,),
710            ($($avar, $aty, $ahandle,)*),
711            $($rest)+
712        )
713    };
714
715    (@internal
716        $msg:expr,
717        ($($bvar:ident, $bty:ty, $bhandle:expr,)*),
718        ($($tvar:ident, $tty:ty, $thandle:expr,)*),
719        ($($avar:ident, $aty:ty, $ahandle:expr,)*),
720        $var:ident: $ty:ty =!> $handle:expr;
721        $($rest:tt)+
722    ) => {
723        msg!(@internal $msg,
724            ($($bvar, $bty, $bhandle,)*),
725            ($($tvar, $tty, $thandle,)*),
726            ($($avar, $aty, $ahandle,)* $var, $ty, $handle,),
727            $($rest)+
728        )
729    };
730
731    (@internal
732        $msg:expr,
733        ($($bvar:ident, $bty:ty, $bhandle:expr,)*),
734        ($($tvar:ident, $tty:ty, $thandle:expr,)*),
735        ($($avar:ident, $aty:ty, $ahandle:expr,)*),
736        _: _ => $handle:expr;
737    ) => {
738        msg!(@internal $msg,
739            ($($bvar, $bty, $bhandle,)*),
740            ($($tvar, $tty, $thandle,)*),
741            ($($avar, $aty, $ahandle,)*),
742            msg: _ => $handle;
743        )
744    };
745
746    (@internal
747        $msg:expr,
748        ($($bvar:ident, $bty:ty, $bhandle:expr,)*),
749        ($($tvar:ident, $tty:ty, $thandle:expr,)*),
750        ($($avar:ident, $aty:ty, $ahandle:expr,)*),
751        $var:ident: _ => $handle:expr;
752    ) => { {
753        let mut signed = $msg;
754
755        let (mut $var, sign) = signed.extract();
756
757        macro_rules! signature {
758            () => {
759                sign
760            };
761        }
762
763        let sender = $var.take_sender();
764        if $var.is_broadcast() {
765            if false {
766                unreachable!();
767            }
768            $(
769                else if $var.is::<$bty>() {
770                    let $bvar = &*$var.downcast_ref::<$bty>().unwrap();
771                    { $bhandle }
772                }
773            )*
774            else {
775                { $handle }
776            }
777        } else if sender.is_some() {
778            let sender = sender.unwrap();
779
780            macro_rules! answer {
781                ($ctx:expr, $answer:expr) => {
782                    {
783                        let sign = $ctx.signature();
784                        sender.reply($answer)
785                    }
786                };
787            }
788
789            if false {
790                unreachable!();
791            }
792            $(
793                else if $var.is::<$aty>() {
794                    let $avar = $var.downcast::<$aty>().unwrap();
795                    { $ahandle }
796                }
797            )*
798            else {
799                { $handle }
800            }
801        } else {
802            if false {
803                unreachable!();
804            }
805            $(
806                else if $var.is::<$tty>() {
807                    let $tvar = $var.downcast::<$tty>().unwrap();
808                    { $thandle }
809                }
810            )*
811            else {
812                { $handle }
813            }
814        }
815    } };
816}
817
818#[macro_export]
819/// Answers to a given message, with the given answer.
820///
821/// # Example
822///
823/// ```rust
824/// # use bastion::prelude::*;
825/// #
826/// # #[cfg(feature = "tokio-runtime")]
827/// # #[tokio::main]
828/// # async fn main() {
829/// #    run();    
830/// # }
831/// #
832/// # #[cfg(not(feature = "tokio-runtime"))]
833/// # fn main() {
834/// #    run();    
835/// # }
836/// #
837/// # fn run() {
838///     # Bastion::init();
839///     # let children_ref =
840/// // Create a new child...
841/// Bastion::children(|children| {
842///     children.with_exec(|ctx: BastionContext| {
843///         async move {
844///             let msg = ctx.recv().await?;
845///             answer!(msg, "goodbye").unwrap();
846///             Ok(())
847///         }
848///     })
849/// }).expect("Couldn't create the children group.");
850///
851///     # Bastion::children(|children| {
852///         # children.with_exec(move |ctx: BastionContext| {
853///             # let child_ref = children_ref.elems()[0].clone();
854///             # async move {
855/// // now you can ask the child, from another children
856/// let answer: Answer = ctx.ask(&child_ref.addr(), "hello").expect("Couldn't send the message.");
857///
858/// msg! { answer.await.expect("Couldn't receive the answer."),
859///     msg: &'static str => {
860///         assert_eq!(msg, "goodbye");
861///     };
862///     _: _ => ();
863/// }
864///                 #
865///                 # Ok(())
866///             # }
867///         # })
868///     # }).unwrap();
869///     #
870///     # Bastion::start();
871///     # Bastion::stop();
872///     # Bastion::block_until_stopped();
873/// # }
874/// ```
875macro_rules! answer {
876    ($msg:expr, $answer:expr) => {{
877        let (mut msg, sign) = $msg.extract();
878        let sender = msg.take_sender().expect("failed to take render");
879        sender.reply($answer)
880    }};
881}
882
883#[derive(Debug)]
884enum MessageHandlerState<O> {
885    Matched(O),
886    Unmatched(SignedMessage),
887}
888
889impl<O> MessageHandlerState<O> {
890    fn take_message(self) -> Result<SignedMessage, O> {
891        match self {
892            MessageHandlerState::Unmatched(msg) => Ok(msg),
893            MessageHandlerState::Matched(output) => Err(output),
894        }
895    }
896
897    fn output_or_else(self, f: impl FnOnce(SignedMessage) -> O) -> O {
898        match self {
899            MessageHandlerState::Matched(output) => output,
900            MessageHandlerState::Unmatched(msg) => f(msg),
901        }
902    }
903}
904
905/// Matches a [`Msg`] (as returned by [`BastionContext::recv`]
906/// or [`BastionContext::try_recv`]) with different types.
907///
908/// This type may replace the [`msg!`] macro in the future.
909///
910/// The [`new`] function creates a new [`MessageHandler`], which is then
911/// matched on with the `on_*` functions.
912///
913/// There are different kind of messages:
914///   - messages that are broadcasted, which can be matched with the
915///     [`on_broadcast`] method,
916///   - messages that can be responded to, which are matched with the
917///     [`on_question`] method,
918///   - messages that can not be responded to, which are matched with
919///     [`on_tell`],
920///   - fallback case, which matches everything, entitled [`on_fallback`].
921///
922/// The closure passed to the functions described previously must return the
923/// same type. This value is retrieved when [`on_fallback`] is invoked.
924///
925/// Questions can be responded to by calling [`reply`] on the provided
926/// sender.
927///
928/// # Example
929///
930/// ```rust
931/// # use bastion::prelude::*;
932/// # use bastion::message::MessageHandler;
933/// #
934/// # #[cfg(feature = "tokio-runtime")]
935/// # #[tokio::main]
936/// # async fn main() {
937/// #    run();    
938/// # }
939/// #
940/// # #[cfg(not(feature = "tokio-runtime"))]
941/// # fn main() {
942/// #    run();    
943/// # }
944/// #
945/// # fn run() {
946///     # Bastion::init();
947/// // The message that will be broadcasted...
948/// const BCAST_MSG: &'static str = "A message containing data (broadcast).";
949/// // The message that will be "told" to the child...
950/// const TELL_MSG: &'static str = "A message containing data (tell).";
951/// // The message that will be "asked" to the child...
952/// const ASK_MSG: &'static str = "A message containing data (ask).";
953///
954/// Bastion::children(|children| {
955///     children.with_exec(|ctx: BastionContext| {
956///         async move {
957///             # ctx.tell(&ctx.current().addr(), TELL_MSG).unwrap();
958///             # ctx.ask(&ctx.current().addr(), ASK_MSG).unwrap();
959///             #
960///             loop {
961///                 MessageHandler::new(ctx.recv().await?)
962///                     // We match on broadcasts of &str
963///                     .on_broadcast(|msg: &&str, _sender_addr| {
964///                         assert_eq!(*msg, BCAST_MSG);
965///                         // Handle the message...
966///                     })
967///                     // We match on messages of &str
968///                     .on_tell(|msg: &str, _sender_addr| {
969///                         assert_eq!(msg, TELL_MSG);
970///                         // Handle the message...
971///                     })
972///                     // We match on questions of &str
973///                     .on_question(|msg: &str, sender| {
974///                         assert_eq!(msg, ASK_MSG);
975///                         // Handle the message...
976///
977///                         // ...and eventually answer to it...
978///                         sender.reply("An answer to the message.");
979///                     })
980///                     // We are only broadcasting, "telling" and "asking" a
981///                     // `&str` in this example, so we know that this won't
982///                     // happen...
983///                     .on_fallback(|msg, _sender_addr| ());
984///             }
985///         }
986///     })
987/// }).expect("Couldn't start the children group.");
988///     #
989///     # Bastion::start();
990///     # Bastion::broadcast(BCAST_MSG).unwrap();
991///     # Bastion::stop();
992///     # Bastion::block_until_stopped();
993/// # }
994/// ```
995///
996/// [`BastionContext::recv`]: crate::context::BastionContext::recv
997/// [`BastionContext::try_recv`]: crate::context::BastionContext::try_recv
998/// [`new`]: Self::new
999/// [`on_broadcast`]: Self::on_broadcast
1000/// [`on_question`]: Self::on_question
1001/// [`on_tell`]: Self::on_tell
1002/// [`on_fallback`]: Self::on_fallback
1003/// [`reply`]: AnswerSender::reply
1004#[derive(Debug)]
1005pub struct MessageHandler<O> {
1006    state: MessageHandlerState<O>,
1007}
1008
1009impl<O> MessageHandler<O> {
1010    /// Creates a new [`MessageHandler`] with an incoming message.
1011    pub fn new(msg: SignedMessage) -> MessageHandler<O> {
1012        let state = MessageHandlerState::Unmatched(msg);
1013        MessageHandler { state }
1014    }
1015
1016    /// Matches on a question of a specific type.
1017    ///
1018    /// This will consume the inner data and call `f` if the contained message
1019    /// can be replied to.
1020    pub fn on_question<T, F>(self, f: F) -> MessageHandler<O>
1021    where
1022        T: 'static,
1023        F: FnOnce(T, AnswerSender) -> O,
1024    {
1025        match self.try_into_question::<T>() {
1026            Ok((arg, sender)) => {
1027                let val = f(arg, sender);
1028                MessageHandler::matched(val)
1029            }
1030            Err(this) => this,
1031        }
1032    }
1033
1034    /// Calls a fallback function if the message has still not matched yet.
1035    ///
1036    /// This consumes the [`MessageHandler`], so that no matching can be
1037    /// performed anymore.
1038    pub fn on_fallback<F>(self, f: F) -> O
1039    where
1040        F: FnOnce(&dyn Any, RefAddr) -> O,
1041    {
1042        self.state
1043            .output_or_else(|SignedMessage { msg, sign }| f(msg.as_ref(), sign))
1044    }
1045
1046    /// Calls a function if the incoming message is a broadcast and has a
1047    /// specific type.
1048    pub fn on_broadcast<T, F>(self, f: F) -> MessageHandler<O>
1049    where
1050        T: 'static + Send + Sync,
1051        F: FnOnce(&T, RefAddr) -> O,
1052    {
1053        match self.try_into_broadcast::<T>() {
1054            Ok((arg, addr)) => {
1055                let val = f(arg.as_ref(), addr);
1056                MessageHandler::matched(val)
1057            }
1058            Err(this) => this,
1059        }
1060    }
1061
1062    /// Calls a function if the incoming message can't be replied to and has a
1063    /// specific type.
1064    pub fn on_tell<T, F>(self, f: F) -> MessageHandler<O>
1065    where
1066        T: Debug + 'static,
1067        F: FnOnce(T, RefAddr) -> O,
1068    {
1069        match self.try_into_tell::<T>() {
1070            Ok((msg, addr)) => {
1071                let val = f(msg, addr);
1072                MessageHandler::matched(val)
1073            }
1074            Err(this) => this,
1075        }
1076    }
1077
1078    fn matched(output: O) -> MessageHandler<O> {
1079        let state = MessageHandlerState::Matched(output);
1080        MessageHandler { state }
1081    }
1082
1083    fn try_into_question<T: 'static>(self) -> Result<(T, AnswerSender), MessageHandler<O>> {
1084        debug!("try_into_question with type {}", std::any::type_name::<T>());
1085        match self.state.take_message() {
1086            Ok(SignedMessage {
1087                msg:
1088                    Msg(MsgInner::Ask {
1089                        msg,
1090                        sender: Some(sender),
1091                    }),
1092                ..
1093            }) if msg.is::<T>() => {
1094                let msg: Box<dyn Any> = msg;
1095                Ok((*msg.downcast::<T>().unwrap(), sender))
1096            }
1097
1098            Ok(anything) => Err(MessageHandler::new(anything)),
1099            Err(output) => Err(MessageHandler::matched(output)),
1100        }
1101    }
1102
1103    fn try_into_broadcast<T: Send + Sync + 'static>(
1104        self,
1105    ) -> Result<(Arc<T>, RefAddr), MessageHandler<O>> {
1106        debug!(
1107            "try_into_broadcast with type {}",
1108            std::any::type_name::<T>()
1109        );
1110        match self.state.take_message() {
1111            Ok(SignedMessage {
1112                msg: Msg(MsgInner::Broadcast(msg)),
1113                sign,
1114            }) if msg.is::<T>() => {
1115                let msg: Arc<dyn Any + Send + Sync + 'static> = msg;
1116                Ok((msg.downcast::<T>().unwrap(), sign))
1117            }
1118
1119            Ok(anything) => Err(MessageHandler::new(anything)),
1120            Err(output) => Err(MessageHandler::matched(output)),
1121        }
1122    }
1123
1124    fn try_into_tell<T: Debug + 'static>(self) -> Result<(T, RefAddr), MessageHandler<O>> {
1125        debug!("try_into_tell with type {}", std::any::type_name::<T>());
1126        match self.state.take_message() {
1127            Ok(SignedMessage {
1128                msg: Msg(MsgInner::Tell(msg)),
1129                sign,
1130            }) if msg.is::<T>() => {
1131                let msg: Box<dyn Any> = msg;
1132                Ok((*msg.downcast::<T>().unwrap(), sign))
1133            }
1134            Ok(anything) => Err(MessageHandler::new(anything)),
1135            Err(output) => Err(MessageHandler::matched(output)),
1136        }
1137    }
1138}