diff --git a/CHANGELOG.md b/CHANGELOG.md index 038967e..90f7b45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,16 +6,32 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] -## [0.0.3] - 2026-01-11 +## [0.0.3] - 2026-03-04 ### ๐Ÿš€ Features - Remove closed flag from ChannelState - Add internal constructor for `Sucker`/`Sourcer` +- Implement asynchronous channel support with tokio integration ### ๐Ÿ› Bug Fixes - Correct toolchain in flake + +### ๐Ÿšœ Refactor + +- Move traits to sync module and update imports +- Reorganize channel modules and implement async/sync structures + +### ๐Ÿงช Testing + +- Set_mut tests +- Increase code coverage of failure paths + +### โš™๏ธ Miscellaneous Tasks + +- Remove unused traits module +- Reorganize module exports for async and sync features ## [0.0.2] - 2025-09-04 ### ๐Ÿš€ Features diff --git a/Cargo.toml b/Cargo.toml index db5ff2d..49df714 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,20 +18,24 @@ thiserror = "2.0" flume = { version = "0.12", optional = true } crossbeam-channel = { version = "0.5", optional = true } arc-swap = "1.7.1" +tokio = { version = "1.48", features = ["sync", "macros", "rt-multi-thread"], optional = true } +async-trait = { version = "0.1", optional = true } [features] default = ["all"] sync = [] -async = [] +async = ["dep:async-trait"] sync-std = ["sync"] sync-flume = ["sync", "dep:flume"] sync-crossbeam = ["sync", "dep:crossbeam-channel"] +async-tokio = ["async", "dep:tokio"] all-sync = ["sync-std", "sync-flume", "sync-crossbeam"] +all-async = ["async-tokio"] -all = ["all-sync"] +all = ["all-sync", "all-async"] [lib] diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index 02cb8fc..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,3 +0,0 @@ -[toolchain] -channel = "stable" -profile = "default" diff --git a/src/asynchronous/channel.rs b/src/asynchronous/channel.rs new file mode 100644 index 0000000..532cac7 --- /dev/null +++ b/src/asynchronous/channel.rs @@ -0,0 +1,196 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use crate::asynchronous::traits::{AsyncChannelReceiver, AsyncChannelSender}; +use crate::error::Error; +use crate::types::{ChannelState, Request, Response, ValueSource}; + +/// The consumer side of the channel that requests values asynchronously. +pub struct AsyncSucker +where + ST: AsyncChannelSender, + SR: AsyncChannelReceiver>, +{ + request_tx: ST, + response_rx: SR, + closed: AtomicBool, + _phantom: std::marker::PhantomData, +} + +impl AsyncSucker +where + ST: AsyncChannelSender, + SR: AsyncChannelReceiver>, +{ + pub(crate) fn new(request_tx: ST, response_rx: SR) -> Self { + Self { + request_tx, + response_rx, + closed: AtomicBool::new(false), + _phantom: std::marker::PhantomData, + } + } +} + +/// The producer side of the channel that provides values asynchronously. +pub struct AsyncSourcer +where + SR: AsyncChannelReceiver, + ST: AsyncChannelSender>, +{ + request_rx: SR, + response_tx: ST, + state: ChannelState, + _phantom: std::marker::PhantomData, +} + +impl AsyncSourcer +where + SR: AsyncChannelReceiver, + ST: AsyncChannelSender>, +{ + pub(crate) fn new(request_rx: SR, response_tx: ST, state: ChannelState) -> Self { + Self { + request_rx, + response_tx, + state, + _phantom: std::marker::PhantomData, + } + } +} + +impl AsyncSourcer +where + T: Send + 'static, + SR: AsyncChannelReceiver, + ST: AsyncChannelSender>, +{ + pub fn set_static(&self, val: T) -> Result<(), Error> + where + T: Clone, + { + self.state.swap(Arc::new(ValueSource::Static { + val, + clone: T::clone, + })); + Ok(()) + } + + pub fn set(&self, closure: F) -> Result<(), Error> + where + F: Fn() -> T + Send + Sync + 'static, + { + self.state + .swap(Arc::new(ValueSource::Dynamic(Box::new(closure)))); + Ok(()) + } + + pub fn set_mut(&self, closure: F) -> Result<(), Error> + where + F: FnMut() -> T + Send + Sync + 'static, + { + self.state + .swap(Arc::new(ValueSource::DynamicMut(Mutex::new(Box::new( + closure, + ))))); + Ok(()) + } + + pub fn close(&self) -> Result<(), Error> { + self.state.swap(Arc::new(ValueSource::Cleared)); + Ok(()) + } + + pub async fn run(self) -> Result<(), Error> { + loop { + match self.request_rx.recv().await { + Ok(Request::GetValue) => { + let response = self.handle_get_value()?; + if self.response_tx.send(response).await.is_err() { + break; + } + } + Ok(Request::Close) => { + self.close()?; + break; + } + Err(_) => break, + } + } + Ok(()) + } + + fn handle_get_value(&self) -> Result, Error> { + let state = self.state.load(); + + match &**state { + ValueSource::Static { val, clone } => { + let value = self.execute_closure_safely(&mut || clone(val)); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), + } + } + ValueSource::Dynamic(closure) => { + let value = self.execute_closure_safely(&mut || closure()); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), + } + } + ValueSource::DynamicMut(closure) => { + let mut closure = closure.lock().unwrap(); + let value = self.execute_closure_safely(&mut *closure); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), + } + } + ValueSource::None => Ok(Response::NoSource), + ValueSource::Cleared => Ok(Response::Closed), + } + } + + fn execute_closure_safely( + &self, + closure: &mut dyn FnMut() -> T, + ) -> Result> { + std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) + } +} + +impl AsyncSucker +where + ST: AsyncChannelSender, + SR: AsyncChannelReceiver>, +{ + pub async fn get(&self) -> Result { + if self.closed.load(Ordering::Acquire) { + return Err(Error::ChannelClosed); + } + + self.request_tx + .send(Request::GetValue) + .await + .map_err(|_| Error::ProducerDisconnected)?; + + match self.response_rx.recv().await { + Ok(Response::Value(value)) => Ok(value), + Ok(Response::NoSource) => Err(Error::NoSource), + Ok(Response::Closed) => Err(Error::ChannelClosed), + Err(_) => Err(Error::ProducerDisconnected), + } + } + + pub async fn is_closed(&self) -> bool { + self.request_tx.send(Request::GetValue).await.is_err() + } + + pub async fn close(&self) -> Result<(), Error> { + self.closed.store(true, Ordering::Release); + self.request_tx + .send(Request::Close) + .await + .map_err(|_| Error::ProducerDisconnected) + } +} diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs new file mode 100644 index 0000000..6145b07 --- /dev/null +++ b/src/asynchronous/mod.rs @@ -0,0 +1,8 @@ +pub mod traits; +pub mod channel; + +#[cfg(feature = "async-tokio")] +pub mod tokio; + +#[cfg(feature = "async-tokio")] +pub use tokio::TokioSuck; diff --git a/src/asynchronous/tokio.rs b/src/asynchronous/tokio.rs new file mode 100644 index 0000000..677b676 --- /dev/null +++ b/src/asynchronous/tokio.rs @@ -0,0 +1,121 @@ +use std::sync::Arc; + +use arc_swap::ArcSwap; +use async_trait::async_trait; +use tokio::sync::{Mutex, mpsc}; + +use crate::asynchronous::traits::{ + AsyncChannelReceiver, AsyncChannelSender, AsyncChannelType, ChannelError, +}; +use crate::types; + +type TokioSucker = + crate::asynchronous::channel::AsyncSucker< + T, + TokioSender, + TokioReceiver>, + >; +type TokioSourcer = + crate::asynchronous::channel::AsyncSourcer< + T, + TokioReceiver, + TokioSender>, + >; + +pub struct TokioSender(mpsc::UnboundedSender); +pub struct TokioReceiver(Mutex>); + +#[async_trait] +impl AsyncChannelSender for TokioSender { + async fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +#[async_trait] +impl AsyncChannelReceiver for TokioReceiver { + async fn recv(&self) -> Result { + let mut receiver = self.0.lock().await; + receiver + .recv() + .await + .ok_or(ChannelError::ProducerDisconnected) + } +} + +pub struct TokioChannel; + +impl AsyncChannelType for TokioChannel { + type Sender = TokioSender; + type Receiver = TokioReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = mpsc::unbounded_channel(); + (TokioSender(tx), TokioReceiver(Mutex::new(rx))) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = mpsc::unbounded_channel(); + (TokioSender(tx), TokioReceiver(Mutex::new(rx))) + } +} + +pub struct TokioSuck { + _phantom: std::marker::PhantomData, +} + +impl TokioSuck { + pub fn pair() -> (TokioSucker, TokioSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = TokioChannel::create_request_channel(); + let (response_tx, response_rx) = TokioChannel::create_response_channel::(); + let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); + + let sucker = crate::asynchronous::channel::AsyncSucker::new(request_tx, response_rx); + let sourcer = crate::asynchronous::channel::AsyncSourcer::new(request_rx, response_tx, state); + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + + #[tokio::test] + async fn test_pre_computed_value() { + let (sucker, sourcer) = TokioSuck::::pair(); + + let producer = tokio::spawn(async move { + sourcer.set_static(42).unwrap(); + sourcer.run().await.unwrap(); + }); + + let value = sucker.get().await.unwrap(); + assert_eq!(value, 42); + sucker.close().await.unwrap(); + producer.await.unwrap(); + } + + #[tokio::test] + async fn test_no_source_error() { + let (sucker, sourcer) = TokioSuck::::pair(); + + let producer = tokio::spawn(async move { + sourcer.run().await.unwrap(); + }); + + let result = sucker.get().await; + assert!(matches!(result, Err(Error::NoSource))); + sucker.close().await.unwrap(); + producer.await.unwrap(); + } +} diff --git a/src/asynchronous/traits.rs b/src/asynchronous/traits.rs new file mode 100644 index 0000000..67b6ec3 --- /dev/null +++ b/src/asynchronous/traits.rs @@ -0,0 +1,27 @@ +use async_trait::async_trait; + +pub use crate::error::Error as ChannelError; + +#[async_trait] +pub trait AsyncChannelSender: Send + Sync { + async fn send(&self, msg: T) -> Result<(), ChannelError>; +} + +#[async_trait] +pub trait AsyncChannelReceiver: Send + Sync { + async fn recv(&self) -> Result; +} + +pub trait AsyncChannelType { + type Sender: AsyncChannelSender; + type Receiver: AsyncChannelReceiver; + + fn create_request_channel() -> ( + Self::Sender, + Self::Receiver, + ); + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ); +} diff --git a/src/lib.rs b/src/lib.rs index 867d924..120cbd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,17 @@ #![doc = include_str!("../README.md")] #![cfg_attr(docsrs, feature(doc_auto_cfg))] -pub mod channel; pub mod error; +#[cfg(feature = "async")] +pub mod asynchronous; #[cfg(feature = "sync")] pub mod sync; +#[cfg(any(feature = "sync", feature = "async"))] pub mod types; -pub use channel::{Sourcer, Sucker}; +#[cfg(feature = "async")] +pub use asynchronous::channel::{AsyncSourcer, AsyncSucker}; +#[cfg(feature = "sync")] +pub use sync::channel::{Sourcer, Sucker}; pub use error::Error; diff --git a/src/channel.rs b/src/sync/channel.rs similarity index 100% rename from src/channel.rs rename to src/sync/channel.rs diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs index 4e72624..e5784b0 100644 --- a/src/sync/crossbeam.rs +++ b/src/sync/crossbeam.rs @@ -7,9 +7,17 @@ use arc_swap::ArcSwap; use crossbeam_channel; type CrossbeamSucker = - crate::Sucker, CrossbeamReceiver>>; + crate::sync::channel::Sucker< + T, + CrossbeamSender, + CrossbeamReceiver>, + >; type CrossbeamSourcer = - crate::Sourcer, CrossbeamSender>>; + crate::sync::channel::Sourcer< + T, + CrossbeamReceiver, + CrossbeamSender>, + >; /// Internal sender type for crossbeam backend pub struct CrossbeamSender(crossbeam_channel::Sender); @@ -67,8 +75,8 @@ impl CrossbeamSuck { let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); - let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + let sucker = crate::sync::channel::Sucker::new(request_tx, response_rx); + let sourcer = crate::sync::channel::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } @@ -132,6 +140,38 @@ mod tests { producer_handle.join().unwrap(); } + #[test] + fn test_mut_closure_value() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let mut count = 0; + sourcer + .set_mut(move || { + count += 1; + count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets incrementing values from the mutable closure + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + let value3 = sucker.get().unwrap(); + assert_eq!(value3, 3); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + #[test] fn test_no_source_error() { let (sucker, sourcer) = CrossbeamSuck::::pair(); diff --git a/src/sync/flume.rs b/src/sync/flume.rs index 78ec2a5..816054b 100644 --- a/src/sync/flume.rs +++ b/src/sync/flume.rs @@ -7,9 +7,9 @@ use arc_swap::ArcSwap; use flume; type FlumeSucker = - crate::Sucker, FlumeReceiver>>; + crate::sync::channel::Sucker, FlumeReceiver>>; type FlumeSourcer = - crate::Sourcer, FlumeSender>>; + crate::sync::channel::Sourcer, FlumeSender>>; /// Internal sender type for flume backend pub struct FlumeSender(flume::Sender); @@ -68,8 +68,8 @@ impl FlumeSuck { let state = Arc::new(crate::types::ValueSource::None); let state = ArcSwap::new(state); - let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + let sucker = crate::sync::channel::Sucker::new(request_tx, response_rx); + let sourcer = crate::sync::channel::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } @@ -133,6 +133,38 @@ mod tests { producer_handle.join().unwrap(); } + #[test] + fn test_mut_closure_value() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let mut count = 0; + sourcer + .set_mut(move || { + count += 1; + count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets incrementing values from the mutable closure + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + let value3 = sucker.get().unwrap(); + assert_eq!(value3, 3); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + #[test] fn test_no_source_error() { let (sucker, sourcer) = FlumeSuck::::pair(); diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 8feab0d..dd12595 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -1,4 +1,5 @@ pub mod traits; +pub mod channel; #[cfg(feature = "sync-crossbeam")] pub mod crossbeam; diff --git a/src/sync/std.rs b/src/sync/std.rs index 8184a2b..cf2b06a 100644 --- a/src/sync/std.rs +++ b/src/sync/std.rs @@ -6,8 +6,10 @@ use std::sync::Arc; #[cfg(feature = "sync-std")] use std::sync::mpsc; -type StdSucker = crate::Sucker, StdReceiver>>; -type StdSourcer = crate::Sourcer, StdSender>>; +type StdSucker = + crate::sync::channel::Sucker, StdReceiver>>; +type StdSourcer = + crate::sync::channel::Sourcer, StdSender>>; /// Internal sender type for std backend pub struct StdSender(mpsc::Sender); @@ -66,8 +68,8 @@ impl StdSuck { let state = Arc::new(crate::types::ValueSource::None); let state = ArcSwap::new(state); - let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + let sucker = crate::sync::channel::Sucker::new(request_tx, response_rx); + let sourcer = crate::sync::channel::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } @@ -77,8 +79,18 @@ impl StdSuck { mod tests { use super::*; use crate::Error; + use crate::sync::traits::ChannelType; use std::thread; + #[derive(Debug)] + struct PanicOnClone; + + impl Clone for PanicOnClone { + fn clone(&self) -> Self { + panic!("intentional panic from Clone"); + } + } + #[test] fn test_pre_computed_value() { let (sucker, sourcer) = StdSuck::::pair(); @@ -131,6 +143,38 @@ mod tests { producer_handle.join().unwrap(); } + #[test] + fn test_mut_closure_value() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let mut count = 0; + sourcer + .set_mut(move || { + count += 1; + count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets incrementing values from the mutable closure + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + let value3 = sucker.get().unwrap(); + assert_eq!(value3, 3); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + #[test] fn test_no_source_error() { let (sucker, sourcer) = StdSuck::::pair(); @@ -186,6 +230,110 @@ mod tests { let _ = producer_handle.join(); } + #[test] + fn test_run_breaks_when_response_receiver_is_dropped() { + let (request_tx, request_rx) = StdChannel::create_request_channel(); + let (response_tx, response_rx) = StdChannel::create_response_channel::(); + drop(response_rx); + + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); + let sourcer = crate::sync::channel::Sourcer::new(request_rx, response_tx, state); + sourcer.set_static(42).unwrap(); + + let producer_handle = thread::spawn(move || sourcer.run().unwrap()); + + request_tx.send(crate::types::Request::GetValue).unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_run_breaks_when_request_sender_is_dropped() { + let (request_tx, request_rx) = StdChannel::create_request_channel(); + let (response_tx, _response_rx) = StdChannel::create_response_channel::(); + drop(request_tx); + + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); + let sourcer = crate::sync::channel::Sourcer::new(request_rx, response_tx, state); + + sourcer.run().unwrap(); + } + + #[test] + fn test_static_source_panic_returns_no_source() { + let (sucker, sourcer) = StdSuck::::pair(); + + let producer_handle = thread::spawn(move || { + sourcer.set_static(PanicOnClone).unwrap(); + sourcer.run().unwrap(); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + sucker.close().unwrap(); + producer_handle.join().unwrap(); + } + + #[test] + fn test_dynamic_source_panic_returns_no_source() { + let (sucker, sourcer) = StdSuck::::pair(); + + let producer_handle = thread::spawn(move || { + sourcer + .set(|| -> i32 { + panic!("intentional panic from Fn source"); + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + sucker.close().unwrap(); + producer_handle.join().unwrap(); + } + + #[test] + fn test_dynamic_mut_source_panic_returns_no_source() { + let (sucker, sourcer) = StdSuck::::pair(); + + let producer_handle = thread::spawn(move || { + sourcer + .set_mut(|| -> i32 { + panic!("intentional panic from FnMut source"); + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + sucker.close().unwrap(); + producer_handle.join().unwrap(); + } + + #[test] + fn test_cleared_source_returns_channel_closed() { + let (sucker, sourcer) = StdSuck::::pair(); + + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.close().unwrap(); + sourcer.run().unwrap(); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ChannelClosed))); + + sucker.close().unwrap(); + producer_handle.join().unwrap(); + } + #[test] fn test_is_closed() { let (sucker, sourcer) = StdSuck::::pair();