diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c34f731..d30e8b9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: toolchain: stable - uses: Swatinem/rust-cache@v2 - name: Minimal build - run: cargo build --no-default-features + run: cargo build - name: Clippy run: cargo clippy --all-features --examples -- -D warnings - name: Build all diff --git a/CHANGELOG.md b/CHANGELOG.md index 2341b5e..25e85a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [0.0.2] - 2025-09-04 + +### 🚀 Features + +- Add multiple channel providers ## [0.0.1] - 2025-09-02 ### 🚀 Features diff --git a/Cargo.toml b/Cargo.toml index fd6d830..ada7cc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "suck" description = "Suck data up through a channel" authors = ["Callum Leslie "] -version = "0.0.1" +version = "0.0.2" edition = "2024" documentation = "https://docs.rs/suck" homepage = "https://github.com/callumio/suck" @@ -15,5 +15,25 @@ exclude = ["flake.nix", "flake.lock", ".envrc", "cliff.toml", "release-plz.toml" [dependencies] thiserror = "2.0" +flume = { version = "0.11", optional = true } +crossbeam-channel = { version = "0.5", optional = true } +arc-swap = "1.7.1" + +[features] +default = ["all"] + +sync = [] +async = [] + +sync-std = ["sync"] +sync-flume = ["sync", "dep:flume"] +sync-crossbeam = ["sync", "dep:crossbeam-channel"] + +all-sync = ["sync-std", "sync-flume", "sync-crossbeam"] + +all = ["all-sync"] [lib] + +[package.metadata.docs.rs] +all-features = true diff --git a/README.md b/README.md index dafd0a0..de7abd2 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,11 @@ suck = "*" ## Quick Start ```rust -use suck::SuckPair; +use suck::sync::StdSuck; fn main() -> Result<(), Box> { - // Create a pair - let (sucker, sourcer) = SuckPair::::pair(); + // Create a pair (using default std backend) + let (sucker, sourcer) = StdSuck::::pair(); // Start producer in a thread let producer = std::thread::spawn(move || { diff --git a/flake.nix b/flake.nix index a6030b1..06a1c34 100644 --- a/flake.nix +++ b/flake.nix @@ -20,11 +20,17 @@ flake-utils.lib.eachDefaultSystem (system: let overlays = [(import rust-overlay)]; pkgs = import nixpkgs {inherit system overlays;}; - rustToolchain = pkgs.pkgsBuildHost.rust-bin.stable.latest.default; + rustToolchain = pkgs.pkgsBuildHost.rust-bin.fromRustupToolchainFile ./rust-toolchain.toml; + rustToolchainNightly = pkgs.pkgsBuildHost.rust-bin.nightly.latest.default; tools = with pkgs; [cargo-nextest]; - nativeBuildInputs = with pkgs; [rustToolchain pkg-config] ++ tools; + nativeBuildInputs = with pkgs; [rustToolchain rustToolchainNightly pkg-config] ++ tools; in with pkgs; { - devShells.default = mkShell {inherit nativeBuildInputs;}; + devShells.default = mkShell { + inherit nativeBuildInputs; + shellHook = '' + export CARGO_NIGHTLY="${rustToolchainNightly}/bin/cargo" + ''; + }; }); } diff --git a/src/channel.rs b/src/channel.rs index ef2c298..c3c8159 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,90 +1,109 @@ -use std::sync::mpsc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use crate::error::Error; +use crate::sync::traits::{ChannelReceiver, ChannelSender}; use crate::types::{ChannelState, Request, Response, ValueSource}; /// The consumer side of the channel that requests values -pub struct Sucker { - request_tx: mpsc::Sender, - response_rx: mpsc::Receiver>, - closed: Mutex, -} - -/// The producer side of the channel that provides values -pub struct Sourcer { - request_rx: mpsc::Receiver, - response_tx: mpsc::Sender>, - state: Arc>>, -} - -/// Helper type for creating Sucker and Sourcer instances -pub struct SuckPair { +pub struct Sucker +where + ST: ChannelSender, + SR: ChannelReceiver>, +{ + request_tx: ST, + response_rx: SR, + closed: AtomicBool, _phantom: std::marker::PhantomData, } -impl SuckPair { - /// Create a new suck pair - pub fn pair() -> (Sucker, Sourcer) - where - T: Clone + Send + 'static, - { - let (request_tx, request_rx) = mpsc::channel(); - let (response_tx, response_rx) = mpsc::channel(); - - let state = Arc::new(Mutex::new(ChannelState { - source: ValueSource::None, - closed: false, - })); - - let sucker = Sucker { +impl Sucker +where + ST: ChannelSender, + SR: ChannelReceiver>, +{ + /// Create a new Sucker instance + pub(crate) fn new(request_tx: ST, response_rx: SR) -> Self { + Self { request_tx, response_rx, - closed: Mutex::new(false), - }; - - let sourcer = Sourcer { - request_rx, - response_tx, - state: Arc::clone(&state), - }; - - (sucker, sourcer) + closed: AtomicBool::new(false), + _phantom: std::marker::PhantomData, + } } } -impl Sourcer +/// The producer side of the channel that provides values +pub struct Sourcer where - T: Clone + Send + 'static, + SR: ChannelReceiver, + ST: ChannelSender>, +{ + request_rx: SR, + response_tx: ST, + state: ChannelState, + _phantom: std::marker::PhantomData, +} + +impl Sourcer +where + SR: ChannelReceiver, + ST: ChannelSender>, +{ + /// Create a new Sourcer instance + pub(crate) fn new(request_rx: SR, response_tx: ST, state: ChannelState) -> Self { + Self { + request_rx, + response_tx, + state, + _phantom: std::marker::PhantomData, + } + } +} + +impl Sourcer +where + T: Send + 'static, + SR: ChannelReceiver, + ST: ChannelSender>, { /// Set a fixed value - pub fn set_static(&self, value: T) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Err(Error::ChannelClosed); - } - state.source = ValueSource::Static(value); + pub fn set_static(&self, val: T) -> Result<(), Error> + where + T: Clone, + { + self.state.swap(Arc::new(ValueSource::Static { + val, + clone: T::clone, + })); Ok(()) } - /// Set a closure + /// Set a closure that implements [Fn] pub fn set(&self, closure: F) -> Result<(), Error> where F: Fn() -> T + Send + Sync + 'static, { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Err(Error::ChannelClosed); - } - state.source = ValueSource::Dynamic(Box::new(closure)); + self.state + .swap(Arc::new(ValueSource::Dynamic(Box::new(closure)))); + Ok(()) + } + + /// Set a closure that implements [FnMut] + pub fn set_mut(&self, closure: F) -> Result<(), Error> + where + F: FnMut() -> T + Send + Sync + 'static, + { + self.state + .swap(Arc::new(ValueSource::DynamicMut(Mutex::new(Box::new( + closure, + ))))); Ok(()) } /// Close the channel pub fn close(&self) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - state.closed = true; - state.source = ValueSource::None; + self.state.swap(Arc::new(ValueSource::Cleared)); Ok(()) } @@ -101,9 +120,7 @@ where } Ok(Request::Close) => { // Close channel - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - state.closed = true; - state.source = ValueSource::None; + self.close()?; break; } Err(_) => { @@ -116,37 +133,53 @@ where } fn handle_get_value(&self) -> Result, Error> { - let state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Ok(Response::Closed); - } + let state = self.state.load(); - match &state.source { - ValueSource::Static(value) => Ok(Response::Value(value.clone())), - ValueSource::Dynamic(closure) => { - let value = self.execute_closure_safely(closure); + match &**state { + ValueSource::Static { val, clone } => { + let value = self.execute_closure_safely(&mut || clone(val)); match value { Ok(v) => Ok(Response::Value(v)), Err(_) => Ok(Response::NoSource), // Closure execution failed } } - ValueSource::None => Ok(Response::NoSource), + ValueSource::Dynamic(closure) => { + let value = self.execute_closure_safely(&mut || closure()); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), // Closure execution failed + } + } + ValueSource::DynamicMut(closure) => { + let mut closure = closure.lock().unwrap(); + let value = self.execute_closure_safely(&mut *closure); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), // Closure execution failed + } + } + ValueSource::None => Ok(Response::NoSource), // No source was ever set + ValueSource::Cleared => Ok(Response::Closed), // Channel was closed (source was set then cleared) } } fn execute_closure_safely( &self, - closure: &dyn Fn() -> T, + closure: &mut dyn FnMut() -> T, ) -> Result> { std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) } } -impl Sucker { +impl Sucker +where + ST: ChannelSender, + SR: ChannelReceiver>, +{ /// Get the current value from the producer pub fn get(&self) -> Result { // Check if locally marked as closed - if *self.closed.lock().unwrap() { + if self.closed.load(Ordering::Acquire) { return Err(Error::ChannelClosed); } @@ -171,7 +204,7 @@ impl Sucker { /// Close the channel from the consumer side pub fn close(&self) -> Result<(), Error> { // Mark locally as closed - *self.closed.lock().unwrap() = true; + self.closed.store(true, Ordering::Release); // Send close request self.request_tx diff --git a/src/lib.rs b/src/lib.rs index d5e254e..867d924 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,145 +1,12 @@ #![doc = include_str!("../README.md")] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] pub mod channel; pub mod error; + +#[cfg(feature = "sync")] +pub mod sync; pub mod types; -// Re-exports -pub use channel::{Sourcer, SuckPair, Sucker}; +pub use channel::{Sourcer, Sucker}; pub use error::Error; -pub use types::ValueSource; - -#[cfg(test)] -mod tests { - use super::*; - use std::thread; - - #[test] - fn test_pre_computed_value() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - sourcer.run().unwrap(); - }); - - // Ensure consumer gets the value - let value = sucker.get().unwrap(); - assert_eq!(value, 42); - - // Close consumer - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_closure_value() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = std::thread::spawn(move || { - let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); - let counter_clone = std::sync::Arc::clone(&counter); - sourcer - .set(move || { - let mut count = counter_clone.lock().unwrap(); - *count += 1; - *count - }) - .unwrap(); - sourcer.run().unwrap(); - }); - - // Ensure consumer gets the value - let value1 = sucker.get().unwrap(); - assert_eq!(value1, 1); - - // Ensure consumer gets the next value - let value2 = sucker.get().unwrap(); - assert_eq!(value2, 2); - - // Close consumer - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_no_source_error() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.run().unwrap(); - }); - - // Consumer should get NoSource error - let result = sucker.get(); - assert!(matches!(result, Err(Error::NoSource))); - - // Close consumer - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_channel_closed_error() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - sourcer.run().unwrap(); - }); - - // Close consumer - sucker.close().unwrap(); - - let result = sucker.get(); - assert!(matches!(result, Err(Error::ChannelClosed))); - - producer_handle.join().unwrap(); - } - - #[test] - fn test_producer_disconnection_error() { - let (sucker, sourcer) = SuckPair::::pair(); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - // Simulate crash - panic!("Producer crashed!"); - }); - - let result = sucker.get(); - assert!(matches!(result, Err(Error::ProducerDisconnected))); - - let _ = producer_handle.join(); - } - - #[test] - fn test_is_closed() { - let (sucker, sourcer) = SuckPair::::pair(); - - assert!(!sucker.is_closed()); - - // Start producer - let producer_handle = thread::spawn(move || { - sourcer.set_static(42).unwrap(); - sourcer.run().unwrap(); - }); - - // Get one value - let _ = sucker.get().unwrap(); - assert!(!sucker.is_closed()); - - // Close and check - sucker.close().unwrap(); - - producer_handle.join().unwrap(); - } -} diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs new file mode 100644 index 0000000..4e72624 --- /dev/null +++ b/src/sync/crossbeam.rs @@ -0,0 +1,211 @@ +use std::sync::Arc; + +#[cfg(feature = "sync-crossbeam")] +use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::types; +use arc_swap::ArcSwap; +use crossbeam_channel; + +type CrossbeamSucker = + crate::Sucker, CrossbeamReceiver>>; +type CrossbeamSourcer = + crate::Sourcer, CrossbeamSender>>; + +/// Internal sender type for crossbeam backend +pub struct CrossbeamSender(crossbeam_channel::Sender); +/// Internal receiver type for crossbeam backend +pub struct CrossbeamReceiver(crossbeam_channel::Receiver); + +impl ChannelSender for CrossbeamSender { + fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +impl ChannelReceiver for CrossbeamReceiver { + fn recv(&self) -> Result { + self.0 + .recv() + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +/// Internal channel type for crossbeam backend +pub struct CrossbeamChannel; + +impl ChannelType for CrossbeamChannel { + type Sender = CrossbeamSender; + type Receiver = CrossbeamReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = crossbeam_channel::unbounded(); + (CrossbeamSender(tx), CrossbeamReceiver(rx)) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = crossbeam_channel::unbounded(); + (CrossbeamSender(tx), CrossbeamReceiver(rx)) + } +} + +pub struct CrossbeamSuck { + _phantom: std::marker::PhantomData, +} + +impl CrossbeamSuck { + pub fn pair() -> (CrossbeamSucker, CrossbeamSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = CrossbeamChannel::create_request_channel(); + let (response_tx, response_rx) = CrossbeamChannel::create_response_channel::(); + + let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); + + let sucker = crate::Sucker::new(request_tx, response_rx); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + use std::thread; + + #[test] + fn test_pre_computed_value() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value = sucker.get().unwrap(); + assert_eq!(value, 42); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_closure_value() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + sourcer + .set(move || { + let mut count = counter_clone.lock().unwrap(); + *count += 1; + *count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + // Ensure consumer gets the next value + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_no_source_error() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.run().unwrap(); + }); + + // Consumer should get NoSource error + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_channel_closed_error() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Close consumer + sucker.close().unwrap(); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ChannelClosed))); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_producer_disconnection_error() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + // Simulate crash + panic!("Producer crashed!"); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ProducerDisconnected))); + + let _ = producer_handle.join(); + } + + #[test] + fn test_is_closed() { + let (sucker, sourcer) = CrossbeamSuck::::pair(); + + assert!(!sucker.is_closed()); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Get one value + let _ = sucker.get().unwrap(); + assert!(!sucker.is_closed()); + + // Close and check + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } +} diff --git a/src/sync/flume.rs b/src/sync/flume.rs new file mode 100644 index 0000000..78ec2a5 --- /dev/null +++ b/src/sync/flume.rs @@ -0,0 +1,212 @@ +use std::sync::Arc; + +#[cfg(feature = "sync-flume")] +use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::types; +use arc_swap::ArcSwap; +use flume; + +type FlumeSucker = + crate::Sucker, FlumeReceiver>>; +type FlumeSourcer = + crate::Sourcer, FlumeSender>>; + +/// Internal sender type for flume backend +pub struct FlumeSender(flume::Sender); +/// Internal receiver type for flume backend +pub struct FlumeReceiver(flume::Receiver); + +impl ChannelSender for FlumeSender { + fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +impl ChannelReceiver for FlumeReceiver { + fn recv(&self) -> Result { + self.0 + .recv() + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +/// Internal channel type for flume backend +pub struct FlumeChannel; + +impl ChannelType for FlumeChannel { + type Sender = FlumeSender; + type Receiver = FlumeReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = flume::unbounded(); + (FlumeSender(tx), FlumeReceiver(rx)) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = flume::unbounded(); + (FlumeSender(tx), FlumeReceiver(rx)) + } +} + +pub struct FlumeSuck { + _phantom: std::marker::PhantomData, +} + +impl FlumeSuck { + pub fn pair() -> (FlumeSucker, FlumeSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = FlumeChannel::create_request_channel(); + let (response_tx, response_rx) = FlumeChannel::create_response_channel::(); + + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); + + let sucker = crate::Sucker::new(request_tx, response_rx); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + use std::thread; + + #[test] + fn test_pre_computed_value() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value = sucker.get().unwrap(); + assert_eq!(value, 42); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_closure_value() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + sourcer + .set(move || { + let mut count = counter_clone.lock().unwrap(); + *count += 1; + *count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + // Ensure consumer gets the next value + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_no_source_error() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.run().unwrap(); + }); + + // Consumer should get NoSource error + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_channel_closed_error() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Close consumer + sucker.close().unwrap(); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ChannelClosed))); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_producer_disconnection_error() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + // Simulate crash + panic!("Producer crashed!"); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ProducerDisconnected))); + + let _ = producer_handle.join(); + } + + #[test] + fn test_is_closed() { + let (sucker, sourcer) = FlumeSuck::::pair(); + + assert!(!sucker.is_closed()); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Get one value + let _ = sucker.get().unwrap(); + assert!(!sucker.is_closed()); + + // Close and check + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs new file mode 100644 index 0000000..8feab0d --- /dev/null +++ b/src/sync/mod.rs @@ -0,0 +1,17 @@ +pub mod traits; + +#[cfg(feature = "sync-crossbeam")] +pub mod crossbeam; +#[cfg(feature = "sync-flume")] +pub mod flume; +#[cfg(feature = "sync-std")] +pub mod std; + +#[cfg(feature = "sync-flume")] +pub use flume::FlumeSuck; + +#[cfg(feature = "sync-crossbeam")] +pub use crossbeam::CrossbeamSuck; + +#[cfg(feature = "sync-std")] +pub use std::StdSuck; diff --git a/src/sync/std.rs b/src/sync/std.rs new file mode 100644 index 0000000..8184a2b --- /dev/null +++ b/src/sync/std.rs @@ -0,0 +1,210 @@ +use arc_swap::ArcSwap; + +use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; +use crate::types; +use std::sync::Arc; +#[cfg(feature = "sync-std")] +use std::sync::mpsc; + +type StdSucker = crate::Sucker, StdReceiver>>; +type StdSourcer = crate::Sourcer, StdSender>>; + +/// Internal sender type for std backend +pub struct StdSender(mpsc::Sender); +/// Internal receiver type for std backend +pub struct StdReceiver(mpsc::Receiver); + +impl ChannelSender for StdSender { + fn send(&self, msg: T) -> Result<(), ChannelError> { + self.0 + .send(msg) + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +impl ChannelReceiver for StdReceiver { + fn recv(&self) -> Result { + self.0 + .recv() + .map_err(|_| ChannelError::ProducerDisconnected) + } +} + +/// Internal channel type for std backend +pub struct StdChannel; + +impl ChannelType for StdChannel { + type Sender = StdSender; + type Receiver = StdReceiver; + + fn create_request_channel() -> (Self::Sender, Self::Receiver) { + let (tx, rx) = mpsc::channel(); + (StdSender(tx), StdReceiver(rx)) + } + + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ) { + let (tx, rx) = mpsc::channel(); + (StdSender(tx), StdReceiver(rx)) + } +} + +pub struct StdSuck { + _phantom: std::marker::PhantomData, +} + +impl StdSuck { + pub fn pair() -> (StdSucker, StdSourcer) + where + T: Clone + Send + 'static, + { + let (request_tx, request_rx) = StdChannel::create_request_channel(); + let (response_tx, response_rx) = StdChannel::create_response_channel::(); + + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); + + let sucker = crate::Sucker::new(request_tx, response_rx); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + + (sucker, sourcer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Error; + use std::thread; + + #[test] + fn test_pre_computed_value() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value = sucker.get().unwrap(); + assert_eq!(value, 42); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_closure_value() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = std::thread::spawn(move || { + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + sourcer + .set(move || { + let mut count = counter_clone.lock().unwrap(); + *count += 1; + *count + }) + .unwrap(); + sourcer.run().unwrap(); + }); + + // Ensure consumer gets the value + let value1 = sucker.get().unwrap(); + assert_eq!(value1, 1); + + // Ensure consumer gets the next value + let value2 = sucker.get().unwrap(); + assert_eq!(value2, 2); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_no_source_error() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.run().unwrap(); + }); + + // Consumer should get NoSource error + let result = sucker.get(); + assert!(matches!(result, Err(Error::NoSource))); + + // Close consumer + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_channel_closed_error() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Close consumer + sucker.close().unwrap(); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ChannelClosed))); + + producer_handle.join().unwrap(); + } + + #[test] + fn test_producer_disconnection_error() { + let (sucker, sourcer) = StdSuck::::pair(); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + // Simulate crash + panic!("Producer crashed!"); + }); + + let result = sucker.get(); + assert!(matches!(result, Err(Error::ProducerDisconnected))); + + let _ = producer_handle.join(); + } + + #[test] + fn test_is_closed() { + let (sucker, sourcer) = StdSuck::::pair(); + + assert!(!sucker.is_closed()); + + // Start producer + let producer_handle = thread::spawn(move || { + sourcer.set_static(42).unwrap(); + sourcer.run().unwrap(); + }); + + // Get one value + let _ = sucker.get().unwrap(); + assert!(!sucker.is_closed()); + + // Close and check + sucker.close().unwrap(); + + producer_handle.join().unwrap(); + } +} diff --git a/src/sync/traits.rs b/src/sync/traits.rs new file mode 100644 index 0000000..d2d1fa5 --- /dev/null +++ b/src/sync/traits.rs @@ -0,0 +1,23 @@ +pub use crate::error::Error as ChannelError; + +pub trait ChannelSender { + fn send(&self, msg: T) -> Result<(), ChannelError>; +} + +pub trait ChannelReceiver { + fn recv(&self) -> Result; +} + +pub trait ChannelType { + type Sender: ChannelSender; + type Receiver: ChannelReceiver; + + fn create_request_channel() -> ( + Self::Sender, + Self::Receiver, + ); + fn create_response_channel() -> ( + Self::Sender>, + Self::Receiver>, + ); +} diff --git a/src/types.rs b/src/types.rs index f56fb85..244584b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,7 @@ +use std::sync::Mutex; + +use arc_swap::ArcSwap; + /// Request messages sent from consumer to producer pub enum Request { GetValue, @@ -12,14 +16,13 @@ pub enum Response { } /// Represents the source of values: either static or dynamic -pub enum ValueSource { - Static(T), +pub(crate) enum ValueSource { + Static { val: T, clone: fn(&T) -> T }, + DynamicMut(Mutex T + Send + Sync + 'static>>), Dynamic(Box T + Send + Sync + 'static>), - None, + None, // Never set + Cleared, // Was set but cleared (closed) } /// Internal channel state shared between producer and consumer -pub struct ChannelState { - pub source: ValueSource, - pub closed: bool, -} +pub(crate) type ChannelState = ArcSwap>;