From faa5ba23c5ff6cd46f12652fc1620f41657ef6d6 Mon Sep 17 00:00:00 2001 From: Callum Leslie Date: Thu, 4 Sep 2025 09:37:55 +0100 Subject: [PATCH] feat: add multiple channel providers Each provider is enabled via a feature flag. The currently implemented providers are: - std::mpsc - flume - crossbeam_channel --- Cargo.toml | 19 ++++ README.md | 6 +- src/channel.rs | 73 ++++++-------- src/lib.rs | 141 +-------------------------- src/sync/crossbeam.rs | 221 ++++++++++++++++++++++++++++++++++++++++++ src/sync/flume.rs | 221 ++++++++++++++++++++++++++++++++++++++++++ src/sync/mod.rs | 17 ++++ src/sync/std.rs | 221 ++++++++++++++++++++++++++++++++++++++++++ src/sync/traits.rs | 23 +++++ src/types.rs | 8 +- 10 files changed, 761 insertions(+), 189 deletions(-) create mode 100644 src/sync/crossbeam.rs create mode 100644 src/sync/flume.rs create mode 100644 src/sync/mod.rs create mode 100644 src/sync/std.rs create mode 100644 src/sync/traits.rs diff --git a/Cargo.toml b/Cargo.toml index fd6d830..560215f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,5 +15,24 @@ exclude = ["flake.nix", "flake.lock", ".envrc", "cliff.toml", "release-plz.toml" [dependencies] thiserror = "2.0" +flume = { version = "0.11", optional = true } +crossbeam-channel = { version = "0.5", optional = true } + +[features] +default = ["all"] + +sync = [] +async = [] + +sync-std = ["sync"] +sync-flume = ["sync", "dep:flume"] +sync-crossbeam = ["sync", "dep:crossbeam-channel"] + +all-sync = ["sync-std", "sync-flume", "sync-crossbeam"] + +all = ["all-sync"] [lib] + +[package.metadata.docs.rs] +all-features = true diff --git a/README.md b/README.md index dafd0a0..de7abd2 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,11 @@ suck = "*" ## Quick Start ```rust -use suck::SuckPair; +use suck::sync::StdSuck; fn main() -> Result<(), Box> { - // Create a pair - let (sucker, sourcer) = SuckPair::::pair(); + // Create a pair (using default std backend) + let (sucker, sourcer) = StdSuck::::pair(); // Start producer in a thread let producer = std::thread::spawn(move || { diff --git a/src/channel.rs b/src/channel.rs index ef2c298..8da830b 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,61 +1,38 @@ -use std::sync::mpsc; use std::sync::{Arc, Mutex}; use crate::error::Error; +use crate::sync::traits::{ChannelReceiver, ChannelSender}; use crate::types::{ChannelState, Request, Response, ValueSource}; /// The consumer side of the channel that requests values -pub struct Sucker { - request_tx: mpsc::Sender, - response_rx: mpsc::Receiver>, - closed: Mutex, +pub struct Sucker +where + ST: ChannelSender, + SR: ChannelReceiver>, +{ + pub(crate) request_tx: ST, + pub(crate) response_rx: SR, + pub(crate) closed: Mutex, + pub(crate) _phantom: std::marker::PhantomData, } /// The producer side of the channel that provides values -pub struct Sourcer { - request_rx: mpsc::Receiver, - response_tx: mpsc::Sender>, - state: Arc>>, +pub struct Sourcer +where + SR: ChannelReceiver, + ST: ChannelSender>, +{ + pub(crate) request_rx: SR, + pub(crate) response_tx: ST, + pub(crate) state: Arc>>, + pub(crate) _phantom: std::marker::PhantomData, } -/// Helper type for creating Sucker and Sourcer instances -pub struct SuckPair { - _phantom: std::marker::PhantomData, -} - -impl SuckPair { - /// Create a new suck pair - pub fn pair() -> (Sucker, Sourcer) - where - T: Clone + Send + 'static, - { - let (request_tx, request_rx) = mpsc::channel(); - let (response_tx, response_rx) = mpsc::channel(); - - let state = Arc::new(Mutex::new(ChannelState { - source: ValueSource::None, - closed: false, - })); - - let sucker = Sucker { - request_tx, - response_rx, - closed: Mutex::new(false), - }; - - let sourcer = Sourcer { - request_rx, - response_tx, - state: Arc::clone(&state), - }; - - (sucker, sourcer) - } -} - -impl Sourcer +impl Sourcer where T: Clone + Send + 'static, + SR: ChannelReceiver, + ST: ChannelSender>, { /// Set a fixed value pub fn set_static(&self, value: T) -> Result<(), Error> { @@ -142,7 +119,11 @@ where } } -impl Sucker { +impl Sucker +where + ST: ChannelSender, + SR: ChannelReceiver>, +{ /// Get the current value from the producer pub fn get(&self) -> Result { // Check if locally marked as closed diff --git a/src/lib.rs b/src/lib.rs index d5e254e..17942a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,145 +1,14 @@ #![doc = include_str!("../README.md")] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] pub mod channel; pub mod error; + +#[cfg(feature = "sync")] +pub mod sync; pub mod types; -// Re-exports -pub use channel::{Sourcer, SuckPair, Sucker}; +pub use channel::{Sourcer, Sucker}; pub use error::Error; -pub use types::ValueSource; -#[cfg(test)] -mod tests { - use super::*; - use std::thread; - #[test] - fn test_pre_computed_value() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - sourcer.run().unwrap(); - }); - - // Ensure consumer gets the value - let value = sucker.get().unwrap(); - assert_eq!(value, 42); - - // Close consumer - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_closure_value() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = std::thread::spawn(move || { - let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); - let counter_clone = std::sync::Arc::clone(&counter); - sourcer - .set(move || { - let mut count = counter_clone.lock().unwrap(); - *count += 1; - *count - }) - .unwrap(); - sourcer.run().unwrap(); - }); - - // Ensure consumer gets the value - let value1 = sucker.get().unwrap(); - assert_eq!(value1, 1); - - // Ensure consumer gets the next value - let value2 = sucker.get().unwrap(); - assert_eq!(value2, 2); - - // Close consumer - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_no_source_error() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.run().unwrap(); - }); - - // Consumer should get NoSource error - let result = sucker.get(); - assert!(matches!(result, Err(Error::NoSource))); - - // Close consumer - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_channel_closed_error() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - sourcer.run().unwrap(); - }); - - // Close consumer - sucker.close().unwrap(); - - let result = sucker.get(); - assert!(matches!(result, Err(Error::ChannelClosed))); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_producer_disconnection_error() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - // Simulate crash - panic!("Producer crashed!"); - }); - - let result = sucker.get(); - assert!(matches!(result, Err(Error::ProducerDisconnected))); - - let _ = producer_handle.join(); - } - - #[test] - fn test_is_closed() { - let (sucker, sourcer) = SuckPair::::pair(); - - assert!(!sucker.is_closed()); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - sourcer.run().unwrap(); - }); - - // Get one value - let _ = sucker.get().unwrap(); - assert!(!sucker.is_closed()); - - // Close and check - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } -} diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs new file mode 100644 index 0000000..b4a820e --- /dev/null +++ b/src/sync/crossbeam.rs @@ -0,0 +1,221 @@ +#[cfg(feature = "sync-crossbeam")] +use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::types; +use crossbeam_channel; + +type CrossbeamSucker = crate::Sucker, CrossbeamReceiver>>; +type CrossbeamSourcer = crate::Sourcer, CrossbeamSender>>; + +/// Internal sender type for crossbeam backend +pub struct CrossbeamSender(crossbeam_channel::Sender); +/// Internal receiver type for crossbeam backend +pub struct CrossbeamReceiver(crossbeam_channel::Receiver); + +impl ChannelSender for CrossbeamSender { + fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +impl ChannelReceiver for CrossbeamReceiver { + fn recv(&self) -> Result { + self.0 + .recv() + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +/// Internal channel type for crossbeam backend +pub struct CrossbeamChannel; + +impl ChannelType for CrossbeamChannel { + type Sender = CrossbeamSender; + type Receiver = CrossbeamReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = crossbeam_channel::unbounded(); + (CrossbeamSender(tx), CrossbeamReceiver(rx)) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = crossbeam_channel::unbounded(); + (CrossbeamSender(tx), CrossbeamReceiver(rx)) + } +} + +pub struct CrossbeamSuck { + _phantom: std::marker::PhantomData, +} + +impl CrossbeamSuck { + pub fn pair() -> (CrossbeamSucker, CrossbeamSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = CrossbeamChannel::create_request_channel(); + let (response_tx, response_rx) = CrossbeamChannel::create_response_channel::(); + + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { + source: crate::types::ValueSource::None, + closed: false, + })); + + let sucker = crate::Sucker { + request_tx, + response_rx, + closed: std::sync::Mutex::new(false), + _phantom: std::marker::PhantomData, + }; + + let sourcer = crate::Sourcer { + request_rx, + response_tx, + state: std::sync::Arc::clone(&state), + _phantom: std::marker::PhantomData, + }; + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + use std::thread; + + #[test] + fn test_pre_computed_value() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value = sucker.get().unwrap(); + assert_eq!(value, 42); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_closure_value() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + sourcer + .set(move || { + let mut count = counter_clone.lock().unwrap(); + *count += 1; + *count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + // Ensure consumer gets the next value + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_no_source_error() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.run().unwrap(); + }); + + // Consumer should get NoSource error + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_channel_closed_error() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Close consumer + sucker.close().unwrap(); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ChannelClosed))); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_producer_disconnection_error() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + // Simulate crash + panic!("Producer crashed!"); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ProducerDisconnected))); + + let _ = producer_handle.join(); + } + + #[test] + fn test_is_closed() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + assert!(!sucker.is_closed()); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Get one value + let _ = sucker.get().unwrap(); + assert!(!sucker.is_closed()); + + // Close and check + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } +} + diff --git a/src/sync/flume.rs b/src/sync/flume.rs new file mode 100644 index 0000000..8ab4abe --- /dev/null +++ b/src/sync/flume.rs @@ -0,0 +1,221 @@ +#[cfg(feature = "sync-flume")] +use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::types; +use flume; + +type FlumeSucker = crate::Sucker, FlumeReceiver>>; +type FlumeSourcer = crate::Sourcer, FlumeSender>>; + +/// Internal sender type for flume backend +pub struct FlumeSender(flume::Sender); +/// Internal receiver type for flume backend +pub struct FlumeReceiver(flume::Receiver); + +impl ChannelSender for FlumeSender { + fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +impl ChannelReceiver for FlumeReceiver { + fn recv(&self) -> Result { + self.0 + .recv() + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +/// Internal channel type for flume backend +pub struct FlumeChannel; + +impl ChannelType for FlumeChannel { + type Sender = FlumeSender; + type Receiver = FlumeReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = flume::unbounded(); + (FlumeSender(tx), FlumeReceiver(rx)) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = flume::unbounded(); + (FlumeSender(tx), FlumeReceiver(rx)) + } +} + +pub struct FlumeSuck { + _phantom: std::marker::PhantomData, +} + +impl FlumeSuck { + pub fn pair() -> (FlumeSucker, FlumeSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = FlumeChannel::create_request_channel(); + let (response_tx, response_rx) = FlumeChannel::create_response_channel::(); + + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { + source: crate::types::ValueSource::None, + closed: false, + })); + + let sucker = crate::Sucker { + request_tx, + response_rx, + closed: std::sync::Mutex::new(false), + _phantom: std::marker::PhantomData, + }; + + let sourcer = crate::Sourcer { + request_rx, + response_tx, + state: std::sync::Arc::clone(&state), + _phantom: std::marker::PhantomData, + }; + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + use std::thread; + + #[test] + fn test_pre_computed_value() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value = sucker.get().unwrap(); + assert_eq!(value, 42); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_closure_value() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + sourcer + .set(move || { + let mut count = counter_clone.lock().unwrap(); + *count += 1; + *count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + // Ensure consumer gets the next value + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_no_source_error() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.run().unwrap(); + }); + + // Consumer should get NoSource error + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_channel_closed_error() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Close consumer + sucker.close().unwrap(); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ChannelClosed))); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_producer_disconnection_error() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + // Simulate crash + panic!("Producer crashed!"); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ProducerDisconnected))); + + let _ = producer_handle.join(); + } + + #[test] + fn test_is_closed() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + assert!(!sucker.is_closed()); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Get one value + let _ = sucker.get().unwrap(); + assert!(!sucker.is_closed()); + + // Close and check + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } +} + diff --git a/src/sync/mod.rs b/src/sync/mod.rs new file mode 100644 index 0000000..8feab0d --- /dev/null +++ b/src/sync/mod.rs @@ -0,0 +1,17 @@ +pub mod traits; + +#[cfg(feature = "sync-crossbeam")] +pub mod crossbeam; +#[cfg(feature = "sync-flume")] +pub mod flume; +#[cfg(feature = "sync-std")] +pub mod std; + +#[cfg(feature = "sync-flume")] +pub use flume::FlumeSuck; + +#[cfg(feature = "sync-crossbeam")] +pub use crossbeam::CrossbeamSuck; + +#[cfg(feature = "sync-std")] +pub use std::StdSuck; diff --git a/src/sync/std.rs b/src/sync/std.rs new file mode 100644 index 0000000..169942d --- /dev/null +++ b/src/sync/std.rs @@ -0,0 +1,221 @@ +use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::types; +#[cfg(feature = "sync-std")] +use std::sync::mpsc; + +type StdSucker = crate::Sucker, StdReceiver>>; +type StdSourcer = crate::Sourcer, StdSender>>; + +/// Internal sender type for std backend +pub struct StdSender(mpsc::Sender); +/// Internal receiver type for std backend +pub struct StdReceiver(mpsc::Receiver); + +impl ChannelSender for StdSender { + fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +impl ChannelReceiver for StdReceiver { + fn recv(&self) -> Result { + self.0 + .recv() + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +/// Internal channel type for std backend +pub struct StdChannel; + +impl ChannelType for StdChannel { + type Sender = StdSender; + type Receiver = StdReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = mpsc::channel(); + (StdSender(tx), StdReceiver(rx)) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = mpsc::channel(); + (StdSender(tx), StdReceiver(rx)) + } +} + +pub struct StdSuck { + _phantom: std::marker::PhantomData, +} + +impl StdSuck { + pub fn pair() -> (StdSucker, StdSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = StdChannel::create_request_channel(); + let (response_tx, response_rx) = StdChannel::create_response_channel::(); + + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { + source: crate::types::ValueSource::None, + closed: false, + })); + + let sucker = crate::Sucker { + request_tx, + response_rx, + closed: std::sync::Mutex::new(false), + _phantom: std::marker::PhantomData, + }; + + let sourcer = crate::Sourcer { + request_rx, + response_tx, + state: std::sync::Arc::clone(&state), + _phantom: std::marker::PhantomData, + }; + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + use std::thread; + + #[test] + fn test_pre_computed_value() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value = sucker.get().unwrap(); + assert_eq!(value, 42); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_closure_value() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + sourcer + .set(move || { + let mut count = counter_clone.lock().unwrap(); + *count += 1; + *count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + // Ensure consumer gets the next value + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_no_source_error() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.run().unwrap(); + }); + + // Consumer should get NoSource error + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_channel_closed_error() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Close consumer + sucker.close().unwrap(); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ChannelClosed))); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_producer_disconnection_error() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + // Simulate crash + panic!("Producer crashed!"); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ProducerDisconnected))); + + let _ = producer_handle.join(); + } + + #[test] + fn test_is_closed() { + let (sucker, sourcer) = StdSuck::::pair(); + + assert!(!sucker.is_closed()); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Get one value + let _ = sucker.get().unwrap(); + assert!(!sucker.is_closed()); + + // Close and check + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } +} + diff --git a/src/sync/traits.rs b/src/sync/traits.rs new file mode 100644 index 0000000..d2d1fa5 --- /dev/null +++ b/src/sync/traits.rs @@ -0,0 +1,23 @@ +pub use crate::error::Error as ChannelError; + +pub trait ChannelSender { + fn send(&self, msg: T) -> Result<(), ChannelError>; +} + +pub trait ChannelReceiver { + fn recv(&self) -> Result; +} + +pub trait ChannelType { + type Sender: ChannelSender; + type Receiver: ChannelReceiver; + + fn create_request_channel() -> ( + Self::Sender, + Self::Receiver, + ); + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ); +} diff --git a/src/types.rs b/src/types.rs index f56fb85..0e82f9c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -12,14 +12,14 @@ pub enum Response { } /// Represents the source of values: either static or dynamic -pub enum ValueSource { +pub(crate) enum ValueSource { Static(T), Dynamic(Box T + Send + Sync + 'static>), None, } /// Internal channel state shared between producer and consumer -pub struct ChannelState { - pub source: ValueSource, - pub closed: bool, +pub(crate) struct ChannelState { + pub(crate) source: ValueSource, + pub(crate) closed: bool, }