diff --git a/Cargo.toml b/Cargo.toml index 94fe60b..ada7cc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ exclude = ["flake.nix", "flake.lock", ".envrc", "cliff.toml", "release-plz.toml" thiserror = "2.0" flume = { version = "0.11", optional = true } crossbeam-channel = { version = "0.5", optional = true } +arc-swap = "1.7.1" [features] default = ["all"] diff --git a/src/channel.rs b/src/channel.rs index 531c0bc..06e1dfe 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,4 +1,5 @@ use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; use crate::error::Error; use crate::sync::traits::{ChannelReceiver, ChannelSender}; @@ -68,8 +69,7 @@ where { /// Set a fixed value pub fn set_static(&self, value: T) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Static(value); + self.state.swap(Arc::new(ValueSource::Static(value))); Ok(()) } @@ -78,15 +78,16 @@ where where F: Fn() -> T + Send + Sync + 'static, { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Dynamic(Box::new(closure)); + self.state + .swap(Arc::new(ValueSource::Dynamic(Mutex::new(Box::new( + closure, + ))))); Ok(()) } /// Close the channel pub fn close(&self) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Cleared; + self.state.swap(Arc::new(ValueSource::Cleared)); Ok(()) } @@ -103,8 +104,7 @@ where } Ok(Request::Close) => { // Close channel - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Cleared; + self.close()?; break; } Err(_) => { @@ -117,12 +117,13 @@ where } fn handle_get_value(&self) -> Result, Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; + let state = self.state.load(); - match &mut *state { + match &**state { ValueSource::Static(value) => Ok(Response::Value(value.clone())), ValueSource::Dynamic(closure) => { - let value = self.execute_closure_safely(closure); + let mut closure = closure.lock().unwrap(); + let value = self.execute_closure_safely(&mut *closure); match value { Ok(v) => Ok(Response::Value(v)), Err(_) => Ok(Response::NoSource), // Closure execution failed diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs index 99c6a9c..4e72624 100644 --- a/src/sync/crossbeam.rs +++ b/src/sync/crossbeam.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + #[cfg(feature = "sync-crossbeam")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use arc_swap::ArcSwap; use crossbeam_channel; type CrossbeamSucker = @@ -62,10 +65,10 @@ impl CrossbeamSuck { let (request_tx, request_rx) = CrossbeamChannel::create_request_channel(); let (response_tx, response_rx) = CrossbeamChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); + let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, std::sync::Arc::clone(&state)); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/sync/flume.rs b/src/sync/flume.rs index a493513..78ec2a5 100644 --- a/src/sync/flume.rs +++ b/src/sync/flume.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + #[cfg(feature = "sync-flume")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use arc_swap::ArcSwap; use flume; type FlumeSucker = @@ -62,10 +65,11 @@ impl FlumeSuck { let (request_tx, request_rx) = FlumeChannel::create_request_channel(); let (response_tx, response_rx) = FlumeChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, std::sync::Arc::clone(&state)); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/sync/std.rs b/src/sync/std.rs index f58a845..8184a2b 100644 --- a/src/sync/std.rs +++ b/src/sync/std.rs @@ -1,5 +1,8 @@ +use arc_swap::ArcSwap; + use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use std::sync::Arc; #[cfg(feature = "sync-std")] use std::sync::mpsc; @@ -60,10 +63,11 @@ impl StdSuck { let (request_tx, request_rx) = StdChannel::create_request_channel(); let (response_tx, response_rx) = StdChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, std::sync::Arc::clone(&state)); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/types.rs b/src/types.rs index 32eb625..d20e022 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,4 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; + +use arc_swap::ArcSwap; /// Request messages sent from consumer to producer pub enum Request { @@ -16,10 +18,10 @@ pub enum Response { /// Represents the source of values: either static or dynamic pub(crate) enum ValueSource { Static(T), - Dynamic(Box T + Send + Sync + 'static>), + Dynamic(Mutex T + Send + Sync + 'static>>), None, // Never set Cleared, // Was set but cleared (closed) } /// Internal channel state shared between producer and consumer -pub(crate) type ChannelState = Arc>>; +pub(crate) type ChannelState = ArcSwap>;