use std::sync::Arc; use arc_swap::ArcSwap; use async_trait::async_trait; use tokio::sync::{Mutex, mpsc}; use crate::asynchronous::traits::{ AsyncChannelReceiver, AsyncChannelSender, AsyncChannelType, ChannelError, }; use crate::types; type TokioSucker = crate::AsyncSucker, TokioReceiver>>; type TokioSourcer = crate::AsyncSourcer, TokioSender>>; pub struct TokioSender(mpsc::UnboundedSender); pub struct TokioReceiver(Mutex>); #[async_trait] impl AsyncChannelSender for TokioSender { async fn send(&self, msg: T) -> Result<(), ChannelError> { self.0 .send(msg) .map_err(|_| ChannelError::ProducerDisconnected) } } #[async_trait] impl AsyncChannelReceiver for TokioReceiver { async fn recv(&self) -> Result { let mut receiver = self.0.lock().await; receiver .recv() .await .ok_or(ChannelError::ProducerDisconnected) } } pub struct TokioChannel; impl AsyncChannelType for TokioChannel { type Sender = TokioSender; type Receiver = TokioReceiver; fn create_request_channel() -> (Self::Sender, Self::Receiver) { let (tx, rx) = mpsc::unbounded_channel(); (TokioSender(tx), TokioReceiver(Mutex::new(rx))) } fn create_response_channel() -> ( Self::Sender>, Self::Receiver>, ) { let (tx, rx) = mpsc::unbounded_channel(); (TokioSender(tx), TokioReceiver(Mutex::new(rx))) } } pub struct TokioSuck { _phantom: std::marker::PhantomData, } impl TokioSuck { pub fn pair() -> (TokioSucker, TokioSourcer) where T: Clone + Send + 'static, { let (request_tx, request_rx) = TokioChannel::create_request_channel(); let (response_tx, response_rx) = TokioChannel::create_response_channel::(); let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); let sucker = crate::AsyncSucker::new(request_tx, response_rx); let sourcer = crate::AsyncSourcer::new(request_rx, response_tx, state); (sucker, sourcer) } } #[cfg(test)] mod tests { use super::*; use crate::Error; #[tokio::test] async fn test_pre_computed_value() { let (sucker, sourcer) = TokioSuck::::pair(); let producer = tokio::spawn(async move { sourcer.set_static(42).unwrap(); sourcer.run().await.unwrap(); }); let value = sucker.get().await.unwrap(); assert_eq!(value, 42); sucker.close().await.unwrap(); producer.await.unwrap(); } #[tokio::test] async fn test_no_source_error() { let (sucker, sourcer) = TokioSuck::::pair(); let producer = tokio::spawn(async move { sourcer.run().await.unwrap(); }); let result = sucker.get().await; assert!(matches!(result, Err(Error::NoSource))); sucker.close().await.unwrap(); producer.await.unwrap(); } }