diff --git a/Cargo.toml b/Cargo.toml index ada7cc2..94fe60b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ exclude = ["flake.nix", "flake.lock", ".envrc", "cliff.toml", "release-plz.toml" thiserror = "2.0" flume = { version = "0.11", optional = true } crossbeam-channel = { version = "0.5", optional = true } -arc-swap = "1.7.1" [features] default = ["all"] diff --git a/flake.nix b/flake.nix index 06a1c34..a6030b1 100644 --- a/flake.nix +++ b/flake.nix @@ -20,17 +20,11 @@ flake-utils.lib.eachDefaultSystem (system: let overlays = [(import rust-overlay)]; pkgs = import nixpkgs {inherit system overlays;}; - rustToolchain = pkgs.pkgsBuildHost.rust-bin.fromRustupToolchainFile ./rust-toolchain.toml; - rustToolchainNightly = pkgs.pkgsBuildHost.rust-bin.nightly.latest.default; + rustToolchain = pkgs.pkgsBuildHost.rust-bin.stable.latest.default; tools = with pkgs; [cargo-nextest]; - nativeBuildInputs = with pkgs; [rustToolchain rustToolchainNightly pkg-config] ++ tools; + nativeBuildInputs = with pkgs; [rustToolchain pkg-config] ++ tools; in with pkgs; { - devShells.default = mkShell { - inherit nativeBuildInputs; - shellHook = '' - export CARGO_NIGHTLY="${rustToolchainNightly}/bin/cargo" - ''; - }; + devShells.default = mkShell {inherit nativeBuildInputs;}; }); } diff --git a/src/channel.rs b/src/channel.rs index c3c8159..8da830b 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,4 +1,3 @@ -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use crate::error::Error; @@ -11,26 +10,10 @@ where ST: ChannelSender, SR: ChannelReceiver>, { - request_tx: ST, - response_rx: SR, - closed: AtomicBool, - _phantom: std::marker::PhantomData, -} - -impl Sucker -where - ST: ChannelSender, - SR: ChannelReceiver>, -{ - /// Create a new Sucker instance - pub(crate) fn new(request_tx: ST, response_rx: SR) -> Self { - Self { - request_tx, - response_rx, - closed: AtomicBool::new(false), - _phantom: std::marker::PhantomData, - } - } + pub(crate) request_tx: ST, + pub(crate) response_rx: SR, + pub(crate) closed: Mutex, + pub(crate) _phantom: std::marker::PhantomData, } /// The producer side of the channel that provides values @@ -39,71 +22,46 @@ where SR: ChannelReceiver, ST: ChannelSender>, { - request_rx: SR, - response_tx: ST, - state: ChannelState, - _phantom: std::marker::PhantomData, + pub(crate) request_rx: SR, + pub(crate) response_tx: ST, + pub(crate) state: Arc>>, + pub(crate) _phantom: std::marker::PhantomData, } impl Sourcer where - SR: ChannelReceiver, - ST: ChannelSender>, -{ - /// Create a new Sourcer instance - pub(crate) fn new(request_rx: SR, response_tx: ST, state: ChannelState) -> Self { - Self { - request_rx, - response_tx, - state, - _phantom: std::marker::PhantomData, - } - } -} - -impl Sourcer -where - T: Send + 'static, + T: Clone + Send + 'static, SR: ChannelReceiver, ST: ChannelSender>, { /// Set a fixed value - pub fn set_static(&self, val: T) -> Result<(), Error> - where - T: Clone, - { - self.state.swap(Arc::new(ValueSource::Static { - val, - clone: T::clone, - })); + pub fn set_static(&self, value: T) -> Result<(), Error> { + let mut state = self.state.lock().map_err(|_| Error::InternalError)?; + if state.closed { + return Err(Error::ChannelClosed); + } + state.source = ValueSource::Static(value); Ok(()) } - /// Set a closure that implements [Fn] + /// Set a closure pub fn set(&self, closure: F) -> Result<(), Error> where F: Fn() -> T + Send + Sync + 'static, { - self.state - .swap(Arc::new(ValueSource::Dynamic(Box::new(closure)))); - Ok(()) - } - - /// Set a closure that implements [FnMut] - pub fn set_mut(&self, closure: F) -> Result<(), Error> - where - F: FnMut() -> T + Send + Sync + 'static, - { - self.state - .swap(Arc::new(ValueSource::DynamicMut(Mutex::new(Box::new( - closure, - ))))); + let mut state = self.state.lock().map_err(|_| Error::InternalError)?; + if state.closed { + return Err(Error::ChannelClosed); + } + state.source = ValueSource::Dynamic(Box::new(closure)); Ok(()) } /// Close the channel pub fn close(&self) -> Result<(), Error> { - self.state.swap(Arc::new(ValueSource::Cleared)); + let mut state = self.state.lock().map_err(|_| Error::InternalError)?; + state.closed = true; + state.source = ValueSource::None; Ok(()) } @@ -120,7 +78,9 @@ where } Ok(Request::Close) => { // Close channel - self.close()?; + let mut state = self.state.lock().map_err(|_| Error::InternalError)?; + state.closed = true; + state.source = ValueSource::None; break; } Err(_) => { @@ -133,39 +93,27 @@ where } fn handle_get_value(&self) -> Result, Error> { - let state = self.state.load(); + let state = self.state.lock().map_err(|_| Error::InternalError)?; + if state.closed { + return Ok(Response::Closed); + } - match &**state { - ValueSource::Static { val, clone } => { - let value = self.execute_closure_safely(&mut || clone(val)); - match value { - Ok(v) => Ok(Response::Value(v)), - Err(_) => Ok(Response::NoSource), // Closure execution failed - } - } + match &state.source { + ValueSource::Static(value) => Ok(Response::Value(value.clone())), ValueSource::Dynamic(closure) => { - let value = self.execute_closure_safely(&mut || closure()); + let value = self.execute_closure_safely(closure); match value { Ok(v) => Ok(Response::Value(v)), Err(_) => Ok(Response::NoSource), // Closure execution failed } } - ValueSource::DynamicMut(closure) => { - let mut closure = closure.lock().unwrap(); - let value = self.execute_closure_safely(&mut *closure); - match value { - Ok(v) => Ok(Response::Value(v)), - Err(_) => Ok(Response::NoSource), // Closure execution failed - } - } - ValueSource::None => Ok(Response::NoSource), // No source was ever set - ValueSource::Cleared => Ok(Response::Closed), // Channel was closed (source was set then cleared) + ValueSource::None => Ok(Response::NoSource), } } fn execute_closure_safely( &self, - closure: &mut dyn FnMut() -> T, + closure: &dyn Fn() -> T, ) -> Result> { std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) } @@ -179,7 +127,7 @@ where /// Get the current value from the producer pub fn get(&self) -> Result { // Check if locally marked as closed - if self.closed.load(Ordering::Acquire) { + if *self.closed.lock().unwrap() { return Err(Error::ChannelClosed); } @@ -204,7 +152,7 @@ where /// Close the channel from the consumer side pub fn close(&self) -> Result<(), Error> { // Mark locally as closed - self.closed.store(true, Ordering::Release); + *self.closed.lock().unwrap() = true; // Send close request self.request_tx diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs index 4e72624..c20217c 100644 --- a/src/sync/crossbeam.rs +++ b/src/sync/crossbeam.rs @@ -1,9 +1,6 @@ -use std::sync::Arc; - #[cfg(feature = "sync-crossbeam")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; -use arc_swap::ArcSwap; use crossbeam_channel; type CrossbeamSucker = @@ -65,10 +62,24 @@ impl CrossbeamSuck { let (request_tx, request_rx) = CrossbeamChannel::create_request_channel(); let (response_tx, response_rx) = CrossbeamChannel::create_response_channel::(); - let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { + source: crate::types::ValueSource::None, + closed: false, + })); - let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + let sucker = crate::Sucker { + request_tx, + response_rx, + closed: std::sync::Mutex::new(false), + _phantom: std::marker::PhantomData, + }; + + let sourcer = crate::Sourcer { + request_rx, + response_tx, + state: std::sync::Arc::clone(&state), + _phantom: std::marker::PhantomData, + }; (sucker, sourcer) } diff --git a/src/sync/flume.rs b/src/sync/flume.rs index 78ec2a5..9176b14 100644 --- a/src/sync/flume.rs +++ b/src/sync/flume.rs @@ -1,9 +1,6 @@ -use std::sync::Arc; - #[cfg(feature = "sync-flume")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; -use arc_swap::ArcSwap; use flume; type FlumeSucker = @@ -65,11 +62,24 @@ impl FlumeSuck { let (request_tx, request_rx) = FlumeChannel::create_request_channel(); let (response_tx, response_rx) = FlumeChannel::create_response_channel::(); - let state = Arc::new(crate::types::ValueSource::None); - let state = ArcSwap::new(state); + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { + source: crate::types::ValueSource::None, + closed: false, + })); - let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + let sucker = crate::Sucker { + request_tx, + response_rx, + closed: std::sync::Mutex::new(false), + _phantom: std::marker::PhantomData, + }; + + let sourcer = crate::Sourcer { + request_rx, + response_tx, + state: std::sync::Arc::clone(&state), + _phantom: std::marker::PhantomData, + }; (sucker, sourcer) } diff --git a/src/sync/std.rs b/src/sync/std.rs index 8184a2b..d0458a1 100644 --- a/src/sync/std.rs +++ b/src/sync/std.rs @@ -1,8 +1,5 @@ -use arc_swap::ArcSwap; - use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; -use std::sync::Arc; #[cfg(feature = "sync-std")] use std::sync::mpsc; @@ -63,11 +60,24 @@ impl StdSuck { let (request_tx, request_rx) = StdChannel::create_request_channel(); let (response_tx, response_rx) = StdChannel::create_response_channel::(); - let state = Arc::new(crate::types::ValueSource::None); - let state = ArcSwap::new(state); + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { + source: crate::types::ValueSource::None, + closed: false, + })); - let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, state); + let sucker = crate::Sucker { + request_tx, + response_rx, + closed: std::sync::Mutex::new(false), + _phantom: std::marker::PhantomData, + }; + + let sourcer = crate::Sourcer { + request_rx, + response_tx, + state: std::sync::Arc::clone(&state), + _phantom: std::marker::PhantomData, + }; (sucker, sourcer) } diff --git a/src/types.rs b/src/types.rs index 244584b..0e82f9c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,3 @@ -use std::sync::Mutex; - -use arc_swap::ArcSwap; - /// Request messages sent from consumer to producer pub enum Request { GetValue, @@ -17,12 +13,13 @@ pub enum Response { /// Represents the source of values: either static or dynamic pub(crate) enum ValueSource { - Static { val: T, clone: fn(&T) -> T }, - DynamicMut(Mutex T + Send + Sync + 'static>>), + Static(T), Dynamic(Box T + Send + Sync + 'static>), - None, // Never set - Cleared, // Was set but cleared (closed) + None, } /// Internal channel state shared between producer and consumer -pub(crate) type ChannelState = ArcSwap>; +pub(crate) struct ChannelState { + pub(crate) source: ValueSource, + pub(crate) closed: bool, +}