From 863ca61215df372313fa1330d61686988341cbdd Mon Sep 17 00:00:00 2001 From: Callum Leslie Date: Wed, 4 Mar 2026 20:51:20 +0000 Subject: [PATCH] feat: implement asynchronous channel support with tokio integration --- Cargo.toml | 8 +- rust-toolchain.toml | 3 - src/async_channel.rs | 196 +++++++++++++++++++++++++++++++++++++ src/asynchronous/mod.rs | 7 ++ src/asynchronous/tokio.rs | 113 +++++++++++++++++++++ src/asynchronous/traits.rs | 27 +++++ src/channel.rs | 2 +- src/lib.rs | 7 ++ src/sync/crossbeam.rs | 2 +- src/sync/flume.rs | 2 +- src/sync/std.rs | 4 +- src/sync/traits.rs | 24 +---- src/traits.rs | 23 +++++ 13 files changed, 385 insertions(+), 33 deletions(-) delete mode 100644 rust-toolchain.toml create mode 100644 src/async_channel.rs create mode 100644 src/asynchronous/mod.rs create mode 100644 src/asynchronous/tokio.rs create mode 100644 src/asynchronous/traits.rs create mode 100644 src/traits.rs diff --git a/Cargo.toml b/Cargo.toml index a4277f3..aec3b50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,20 +18,24 @@ thiserror = "2.0" flume = { version = "0.12", optional = true } crossbeam-channel = { version = "0.5", optional = true } arc-swap = "1.7.1" +tokio = { version = "1.48", features = ["sync", "macros", "rt-multi-thread"], optional = true } +async-trait = { version = "0.1", optional = true } [features] default = ["all"] sync = [] -async = [] +async = ["dep:async-trait"] sync-std = ["sync"] sync-flume = ["sync", "dep:flume"] sync-crossbeam = ["sync", "dep:crossbeam-channel"] +async-tokio = ["async", "dep:tokio"] all-sync = ["sync-std", "sync-flume", "sync-crossbeam"] +all-async = ["async-tokio"] -all = ["all-sync"] +all = ["all-sync", "all-async"] [lib] diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index 02cb8fc..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,3 +0,0 @@ -[toolchain] -channel = "stable" -profile = "default" diff --git a/src/async_channel.rs b/src/async_channel.rs new file mode 100644 index 0000000..532cac7 --- /dev/null +++ b/src/async_channel.rs @@ -0,0 +1,196 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use crate::asynchronous::traits::{AsyncChannelReceiver, AsyncChannelSender}; +use crate::error::Error; +use crate::types::{ChannelState, Request, Response, ValueSource}; + +/// The consumer side of the channel that requests values asynchronously. +pub struct AsyncSucker +where + ST: AsyncChannelSender, + SR: AsyncChannelReceiver>, +{ + request_tx: ST, + response_rx: SR, + closed: AtomicBool, + _phantom: std::marker::PhantomData, +} + +impl AsyncSucker +where + ST: AsyncChannelSender, + SR: AsyncChannelReceiver>, +{ + pub(crate) fn new(request_tx: ST, response_rx: SR) -> Self { + Self { + request_tx, + response_rx, + closed: AtomicBool::new(false), + _phantom: std::marker::PhantomData, + } + } +} + +/// The producer side of the channel that provides values asynchronously. +pub struct AsyncSourcer +where + SR: AsyncChannelReceiver, + ST: AsyncChannelSender>, +{ + request_rx: SR, + response_tx: ST, + state: ChannelState, + _phantom: std::marker::PhantomData, +} + +impl AsyncSourcer +where + SR: AsyncChannelReceiver, + ST: AsyncChannelSender>, +{ + pub(crate) fn new(request_rx: SR, response_tx: ST, state: ChannelState) -> Self { + Self { + request_rx, + response_tx, + state, + _phantom: std::marker::PhantomData, + } + } +} + +impl AsyncSourcer +where + T: Send + 'static, + SR: AsyncChannelReceiver, + ST: AsyncChannelSender>, +{ + pub fn set_static(&self, val: T) -> Result<(), Error> + where + T: Clone, + { + self.state.swap(Arc::new(ValueSource::Static { + val, + clone: T::clone, + })); + Ok(()) + } + + pub fn set(&self, closure: F) -> Result<(), Error> + where + F: Fn() -> T + Send + Sync + 'static, + { + self.state + .swap(Arc::new(ValueSource::Dynamic(Box::new(closure)))); + Ok(()) + } + + pub fn set_mut(&self, closure: F) -> Result<(), Error> + where + F: FnMut() -> T + Send + Sync + 'static, + { + self.state + .swap(Arc::new(ValueSource::DynamicMut(Mutex::new(Box::new( + closure, + ))))); + Ok(()) + } + + pub fn close(&self) -> Result<(), Error> { + self.state.swap(Arc::new(ValueSource::Cleared)); + Ok(()) + } + + pub async fn run(self) -> Result<(), Error> { + loop { + match self.request_rx.recv().await { + Ok(Request::GetValue) => { + let response = self.handle_get_value()?; + if self.response_tx.send(response).await.is_err() { + break; + } + } + Ok(Request::Close) => { + self.close()?; + break; + } + Err(_) => break, + } + } + Ok(()) + } + + fn handle_get_value(&self) -> Result, Error> { + let state = self.state.load(); + + match &**state { + ValueSource::Static { val, clone } => { + let value = self.execute_closure_safely(&mut || clone(val)); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), + } + } + ValueSource::Dynamic(closure) => { + let value = self.execute_closure_safely(&mut || closure()); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), + } + } + ValueSource::DynamicMut(closure) => { + let mut closure = closure.lock().unwrap(); + let value = self.execute_closure_safely(&mut *closure); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), + } + } + ValueSource::None => Ok(Response::NoSource), + ValueSource::Cleared => Ok(Response::Closed), + } + } + + fn execute_closure_safely( + &self, + closure: &mut dyn FnMut() -> T, + ) -> Result> { + std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) + } +} + +impl AsyncSucker +where + ST: AsyncChannelSender, + SR: AsyncChannelReceiver>, +{ + pub async fn get(&self) -> Result { + if self.closed.load(Ordering::Acquire) { + return Err(Error::ChannelClosed); + } + + self.request_tx + .send(Request::GetValue) + .await + .map_err(|_| Error::ProducerDisconnected)?; + + match self.response_rx.recv().await { + Ok(Response::Value(value)) => Ok(value), + Ok(Response::NoSource) => Err(Error::NoSource), + Ok(Response::Closed) => Err(Error::ChannelClosed), + Err(_) => Err(Error::ProducerDisconnected), + } + } + + pub async fn is_closed(&self) -> bool { + self.request_tx.send(Request::GetValue).await.is_err() + } + + pub async fn close(&self) -> Result<(), Error> { + self.closed.store(true, Ordering::Release); + self.request_tx + .send(Request::Close) + .await + .map_err(|_| Error::ProducerDisconnected) + } +} diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs new file mode 100644 index 0000000..fcaa130 --- /dev/null +++ b/src/asynchronous/mod.rs @@ -0,0 +1,7 @@ +pub mod traits; + +#[cfg(feature = "async-tokio")] +pub mod tokio; + +#[cfg(feature = "async-tokio")] +pub use tokio::TokioSuck; diff --git a/src/asynchronous/tokio.rs b/src/asynchronous/tokio.rs new file mode 100644 index 0000000..10d4cf9 --- /dev/null +++ b/src/asynchronous/tokio.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use async_trait::async_trait; +use tokio::sync::{Mutex, mpsc}; + +use crate::asynchronous::traits::{ + AsyncChannelReceiver, AsyncChannelSender, AsyncChannelType, ChannelError, +}; +use crate::types; + +type TokioSucker = + crate::AsyncSucker, TokioReceiver>>; +type TokioSourcer = + crate::AsyncSourcer, TokioSender>>; + +pub struct TokioSender(mpsc::UnboundedSender); +pub struct TokioReceiver(Mutex>); + +#[async_trait] +impl AsyncChannelSender for TokioSender { + async fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +#[async_trait] +impl AsyncChannelReceiver for TokioReceiver { + async fn recv(&self) -> Result { + let mut receiver = self.0.lock().await; + receiver + .recv() + .await + .ok_or(ChannelError::ProducerDisconnected) + } +} + +pub struct TokioChannel; + +impl AsyncChannelType for TokioChannel { + type Sender = TokioSender; + type Receiver = TokioReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = mpsc::unbounded_channel(); + (TokioSender(tx), TokioReceiver(Mutex::new(rx))) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = mpsc::unbounded_channel(); + (TokioSender(tx), TokioReceiver(Mutex::new(rx))) + } +} + +pub struct TokioSuck { + _phantom: std::marker::PhantomData, +} + +impl TokioSuck { + pub fn pair() -> (TokioSucker, TokioSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = TokioChannel::create_request_channel(); + let (response_tx, response_rx) = TokioChannel::create_response_channel::(); + let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); + + let sucker = crate::AsyncSucker::new(request_tx, response_rx); + let sourcer = crate::AsyncSourcer::new(request_rx, response_tx, state); + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + + #[tokio::test] + async fn test_pre_computed_value() { + let (sucker, sourcer) = TokioSuck::::pair(); + + let producer = tokio::spawn(async move { + sourcer.set_static(42).unwrap(); + sourcer.run().await.unwrap(); + }); + + let value = sucker.get().await.unwrap(); + assert_eq!(value, 42); + sucker.close().await.unwrap(); + producer.await.unwrap(); + } + + #[tokio::test] + async fn test_no_source_error() { + let (sucker, sourcer) = TokioSuck::::pair(); + + let producer = tokio::spawn(async move { + sourcer.run().await.unwrap(); + }); + + let result = sucker.get().await; + assert!(matches!(result, Err(Error::NoSource))); + sucker.close().await.unwrap(); + producer.await.unwrap(); + } +} diff --git a/src/asynchronous/traits.rs b/src/asynchronous/traits.rs new file mode 100644 index 0000000..67b6ec3 --- /dev/null +++ b/src/asynchronous/traits.rs @@ -0,0 +1,27 @@ +use async_trait::async_trait; + +pub use crate::error::Error as ChannelError; + +#[async_trait] +pub trait AsyncChannelSender: Send + Sync { + async fn send(&self, msg: T) -> Result<(), ChannelError>; +} + +#[async_trait] +pub trait AsyncChannelReceiver: Send + Sync { + async fn recv(&self) -> Result; +} + +pub trait AsyncChannelType { + type Sender: AsyncChannelSender; + type Receiver: AsyncChannelReceiver; + + fn create_request_channel() -> ( + Self::Sender, + Self::Receiver, + ); + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ); +} diff --git a/src/channel.rs b/src/channel.rs index c3c8159..391725b 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -2,7 +2,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use crate::error::Error; -use crate::sync::traits::{ChannelReceiver, ChannelSender}; +use crate::traits::{ChannelReceiver, ChannelSender}; use crate::types::{ChannelState, Request, Response, ValueSource}; /// The consumer side of the channel that requests values diff --git a/src/lib.rs b/src/lib.rs index 867d924..76c3e87 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,19 @@ #![doc = include_str!("../README.md")] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#[cfg(feature = "async")] +pub mod async_channel; pub mod channel; pub mod error; +pub mod traits; +#[cfg(feature = "async")] +pub mod asynchronous; #[cfg(feature = "sync")] pub mod sync; pub mod types; +#[cfg(feature = "async")] +pub use async_channel::{AsyncSourcer, AsyncSucker}; pub use channel::{Sourcer, Sucker}; pub use error::Error; diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs index cb8e6a9..4f2ff02 100644 --- a/src/sync/crossbeam.rs +++ b/src/sync/crossbeam.rs @@ -1,7 +1,7 @@ use std::sync::Arc; #[cfg(feature = "sync-crossbeam")] -use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; use arc_swap::ArcSwap; use crossbeam_channel; diff --git a/src/sync/flume.rs b/src/sync/flume.rs index e201d91..2fe4a22 100644 --- a/src/sync/flume.rs +++ b/src/sync/flume.rs @@ -1,7 +1,7 @@ use std::sync::Arc; #[cfg(feature = "sync-flume")] -use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; use arc_swap::ArcSwap; use flume; diff --git a/src/sync/std.rs b/src/sync/std.rs index 809d4b7..692055e 100644 --- a/src/sync/std.rs +++ b/src/sync/std.rs @@ -1,6 +1,6 @@ use arc_swap::ArcSwap; -use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; use std::sync::Arc; #[cfg(feature = "sync-std")] @@ -77,7 +77,7 @@ impl StdSuck { mod tests { use super::*; use crate::Error; - use crate::sync::traits::ChannelType; + use crate::traits::ChannelType; use std::thread; #[derive(Debug)] diff --git a/src/sync/traits.rs b/src/sync/traits.rs index d2d1fa5..5d9bf6c 100644 --- a/src/sync/traits.rs +++ b/src/sync/traits.rs @@ -1,23 +1 @@ -pub use crate::error::Error as ChannelError; - -pub trait ChannelSender { - fn send(&self, msg: T) -> Result<(), ChannelError>; -} - -pub trait ChannelReceiver { - fn recv(&self) -> Result; -} - -pub trait ChannelType { - type Sender: ChannelSender; - type Receiver: ChannelReceiver; - - fn create_request_channel() -> ( - Self::Sender, - Self::Receiver, - ); - fn create_response_channel() -> ( - Self::Sender>, - Self::Receiver>, - ); -} +pub use crate::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; diff --git a/src/traits.rs b/src/traits.rs new file mode 100644 index 0000000..d2d1fa5 --- /dev/null +++ b/src/traits.rs @@ -0,0 +1,23 @@ +pub use crate::error::Error as ChannelError; + +pub trait ChannelSender { + fn send(&self, msg: T) -> Result<(), ChannelError>; +} + +pub trait ChannelReceiver { + fn recv(&self) -> Result; +} + +pub trait ChannelType { + type Sender: ChannelSender; + type Receiver: ChannelReceiver; + + fn create_request_channel() -> ( + Self::Sender, + Self::Receiver, + ); + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ); +}