diff --git a/Cargo.toml b/Cargo.toml index 94fe60b..ada7cc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ exclude = ["flake.nix", "flake.lock", ".envrc", "cliff.toml", "release-plz.toml" thiserror = "2.0" flume = { version = "0.11", optional = true } crossbeam-channel = { version = "0.5", optional = true } +arc-swap = "1.7.1" [features] default = ["all"] diff --git a/flake.nix b/flake.nix index a6030b1..06a1c34 100644 --- a/flake.nix +++ b/flake.nix @@ -20,11 +20,17 @@ flake-utils.lib.eachDefaultSystem (system: let overlays = [(import rust-overlay)]; pkgs = import nixpkgs {inherit system overlays;}; - rustToolchain = pkgs.pkgsBuildHost.rust-bin.stable.latest.default; + rustToolchain = pkgs.pkgsBuildHost.rust-bin.fromRustupToolchainFile ./rust-toolchain.toml; + rustToolchainNightly = pkgs.pkgsBuildHost.rust-bin.nightly.latest.default; tools = with pkgs; [cargo-nextest]; - nativeBuildInputs = with pkgs; [rustToolchain pkg-config] ++ tools; + nativeBuildInputs = with pkgs; [rustToolchain rustToolchainNightly pkg-config] ++ tools; in with pkgs; { - devShells.default = mkShell {inherit nativeBuildInputs;}; + devShells.default = mkShell { + inherit nativeBuildInputs; + shellHook = '' + export CARGO_NIGHTLY="${rustToolchainNightly}/bin/cargo" + ''; + }; }); } diff --git a/src/channel.rs b/src/channel.rs index 8da830b..c3c8159 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,3 +1,4 @@ +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use crate::error::Error; @@ -10,10 +11,26 @@ where ST: ChannelSender, SR: ChannelReceiver>, { - pub(crate) request_tx: ST, - pub(crate) response_rx: SR, - pub(crate) closed: Mutex, - pub(crate) _phantom: std::marker::PhantomData, + request_tx: ST, + response_rx: SR, + closed: AtomicBool, + _phantom: std::marker::PhantomData, +} + +impl Sucker +where + ST: ChannelSender, + SR: ChannelReceiver>, +{ + /// Create a new Sucker instance + pub(crate) fn new(request_tx: ST, response_rx: SR) -> Self { + Self { + request_tx, + response_rx, + closed: AtomicBool::new(false), + _phantom: std::marker::PhantomData, + } + } } /// The producer side of the channel that provides values @@ -22,46 +39,71 @@ where SR: ChannelReceiver, ST: ChannelSender>, { - pub(crate) request_rx: SR, - pub(crate) response_tx: ST, - pub(crate) state: Arc>>, - pub(crate) _phantom: std::marker::PhantomData, + request_rx: SR, + response_tx: ST, + state: ChannelState, + _phantom: std::marker::PhantomData, } impl Sourcer where - T: Clone + Send + 'static, + SR: ChannelReceiver, + ST: ChannelSender>, +{ + /// Create a new Sourcer instance + pub(crate) fn new(request_rx: SR, response_tx: ST, state: ChannelState) -> Self { + Self { + request_rx, + response_tx, + state, + _phantom: std::marker::PhantomData, + } + } +} + +impl Sourcer +where + T: Send + 'static, SR: ChannelReceiver, ST: ChannelSender>, { /// Set a fixed value - pub fn set_static(&self, value: T) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Err(Error::ChannelClosed); - } - state.source = ValueSource::Static(value); + pub fn set_static(&self, val: T) -> Result<(), Error> + where + T: Clone, + { + self.state.swap(Arc::new(ValueSource::Static { + val, + clone: T::clone, + })); Ok(()) } - /// Set a closure + /// Set a closure that implements [Fn] pub fn set(&self, closure: F) -> Result<(), Error> where F: Fn() -> T + Send + Sync + 'static, { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Err(Error::ChannelClosed); - } - state.source = ValueSource::Dynamic(Box::new(closure)); + self.state + .swap(Arc::new(ValueSource::Dynamic(Box::new(closure)))); + Ok(()) + } + + /// Set a closure that implements [FnMut] + pub fn set_mut(&self, closure: F) -> Result<(), Error> + where + F: FnMut() -> T + Send + Sync + 'static, + { + self.state + .swap(Arc::new(ValueSource::DynamicMut(Mutex::new(Box::new( + closure, + ))))); Ok(()) } /// Close the channel pub fn close(&self) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - state.closed = true; - state.source = ValueSource::None; + self.state.swap(Arc::new(ValueSource::Cleared)); Ok(()) } @@ -78,9 +120,7 @@ where } Ok(Request::Close) => { // Close channel - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - state.closed = true; - state.source = ValueSource::None; + self.close()?; break; } Err(_) => { @@ -93,27 +133,39 @@ where } fn handle_get_value(&self) -> Result, Error> { - let state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Ok(Response::Closed); - } + let state = self.state.load(); - match &state.source { - ValueSource::Static(value) => Ok(Response::Value(value.clone())), - ValueSource::Dynamic(closure) => { - let value = self.execute_closure_safely(closure); + match &**state { + ValueSource::Static { val, clone } => { + let value = self.execute_closure_safely(&mut || clone(val)); match value { Ok(v) => Ok(Response::Value(v)), Err(_) => Ok(Response::NoSource), // Closure execution failed } } - ValueSource::None => Ok(Response::NoSource), + ValueSource::Dynamic(closure) => { + let value = self.execute_closure_safely(&mut || closure()); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), // Closure execution failed + } + } + ValueSource::DynamicMut(closure) => { + let mut closure = closure.lock().unwrap(); + let value = self.execute_closure_safely(&mut *closure); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), // Closure execution failed + } + } + ValueSource::None => Ok(Response::NoSource), // No source was ever set + ValueSource::Cleared => Ok(Response::Closed), // Channel was closed (source was set then cleared) } } fn execute_closure_safely( &self, - closure: &dyn Fn() -> T, + closure: &mut dyn FnMut() -> T, ) -> Result> { std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) } @@ -127,7 +179,7 @@ where /// Get the current value from the producer pub fn get(&self) -> Result { // Check if locally marked as closed - if *self.closed.lock().unwrap() { + if self.closed.load(Ordering::Acquire) { return Err(Error::ChannelClosed); } @@ -152,7 +204,7 @@ where /// Close the channel from the consumer side pub fn close(&self) -> Result<(), Error> { // Mark locally as closed - *self.closed.lock().unwrap() = true; + self.closed.store(true, Ordering::Release); // Send close request self.request_tx diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs index c20217c..4e72624 100644 --- a/src/sync/crossbeam.rs +++ b/src/sync/crossbeam.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + #[cfg(feature = "sync-crossbeam")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use arc_swap::ArcSwap; use crossbeam_channel; type CrossbeamSucker = @@ -62,24 +65,10 @@ impl CrossbeamSuck { let (request_tx, request_rx) = CrossbeamChannel::create_request_channel(); let (response_tx, response_rx) = CrossbeamChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { - source: crate::types::ValueSource::None, - closed: false, - })); + let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); - let sucker = crate::Sucker { - request_tx, - response_rx, - closed: std::sync::Mutex::new(false), - _phantom: std::marker::PhantomData, - }; - - let sourcer = crate::Sourcer { - request_rx, - response_tx, - state: std::sync::Arc::clone(&state), - _phantom: std::marker::PhantomData, - }; + let sucker = crate::Sucker::new(request_tx, response_rx); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/sync/flume.rs b/src/sync/flume.rs index 9176b14..78ec2a5 100644 --- a/src/sync/flume.rs +++ b/src/sync/flume.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + #[cfg(feature = "sync-flume")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use arc_swap::ArcSwap; use flume; type FlumeSucker = @@ -62,24 +65,11 @@ impl FlumeSuck { let (request_tx, request_rx) = FlumeChannel::create_request_channel(); let (response_tx, response_rx) = FlumeChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { - source: crate::types::ValueSource::None, - closed: false, - })); + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); - let sucker = crate::Sucker { - request_tx, - response_rx, - closed: std::sync::Mutex::new(false), - _phantom: std::marker::PhantomData, - }; - - let sourcer = crate::Sourcer { - request_rx, - response_tx, - state: std::sync::Arc::clone(&state), - _phantom: std::marker::PhantomData, - }; + let sucker = crate::Sucker::new(request_tx, response_rx); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/sync/std.rs b/src/sync/std.rs index d0458a1..8184a2b 100644 --- a/src/sync/std.rs +++ b/src/sync/std.rs @@ -1,5 +1,8 @@ +use arc_swap::ArcSwap; + use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use std::sync::Arc; #[cfg(feature = "sync-std")] use std::sync::mpsc; @@ -60,24 +63,11 @@ impl StdSuck { let (request_tx, request_rx) = StdChannel::create_request_channel(); let (response_tx, response_rx) = StdChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { - source: crate::types::ValueSource::None, - closed: false, - })); + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); - let sucker = crate::Sucker { - request_tx, - response_rx, - closed: std::sync::Mutex::new(false), - _phantom: std::marker::PhantomData, - }; - - let sourcer = crate::Sourcer { - request_rx, - response_tx, - state: std::sync::Arc::clone(&state), - _phantom: std::marker::PhantomData, - }; + let sucker = crate::Sucker::new(request_tx, response_rx); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/types.rs b/src/types.rs index 0e82f9c..244584b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,7 @@ +use std::sync::Mutex; + +use arc_swap::ArcSwap; + /// Request messages sent from consumer to producer pub enum Request { GetValue, @@ -13,13 +17,12 @@ pub enum Response { /// Represents the source of values: either static or dynamic pub(crate) enum ValueSource { - Static(T), + Static { val: T, clone: fn(&T) -> T }, + DynamicMut(Mutex T + Send + Sync + 'static>>), Dynamic(Box T + Send + Sync + 'static>), - None, + None, // Never set + Cleared, // Was set but cleared (closed) } /// Internal channel state shared between producer and consumer -pub(crate) struct ChannelState { - pub(crate) source: ValueSource, - pub(crate) closed: bool, -} +pub(crate) type ChannelState = ArcSwap>;