diff --git a/CHANGELOG.md b/CHANGELOG.md index 25e85a7..aa68df1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [0.0.3] - 2025-10-14 + +### 🚀 Features + +- Remove closed flag from ChannelState +- Add internal constructor for `Sucker`/`Sourcer` + +### 🐛 Bug Fixes + +- Correct toolchain in flake ## [0.0.2] - 2025-09-04 ### 🚀 Features diff --git a/Cargo.toml b/Cargo.toml index 94fe60b..0c7bdbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "suck" description = "Suck data up through a channel" authors = ["Callum Leslie "] -version = "0.0.2" +version = "0.0.3" edition = "2024" documentation = "https://docs.rs/suck" homepage = "https://github.com/callumio/suck" @@ -17,6 +17,7 @@ exclude = ["flake.nix", "flake.lock", ".envrc", "cliff.toml", "release-plz.toml" thiserror = "2.0" flume = { version = "0.11", optional = true } crossbeam-channel = { version = "0.5", optional = true } +arc-swap = "1.7.1" [features] default = ["all"] diff --git a/src/channel.rs b/src/channel.rs index 3ce9e7d..c3c8159 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,4 +1,5 @@ use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; use crate::error::Error; use crate::sync::traits::{ChannelReceiver, ChannelSender}; @@ -62,31 +63,47 @@ where impl Sourcer where - T: Clone + Send + 'static, + T: Send + 'static, SR: ChannelReceiver, ST: ChannelSender>, { /// Set a fixed value - pub fn set_static(&self, value: T) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Static(value); + pub fn set_static(&self, val: T) -> Result<(), Error> + where + T: Clone, + { + self.state.swap(Arc::new(ValueSource::Static { + val, + clone: T::clone, + })); Ok(()) } - /// Set a closure + /// Set a closure that implements [Fn] pub fn set(&self, closure: F) -> Result<(), Error> where F: Fn() -> T + Send + Sync + 'static, { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Dynamic(Box::new(closure)); + self.state + .swap(Arc::new(ValueSource::Dynamic(Box::new(closure)))); + Ok(()) + } + + /// Set a closure that implements [FnMut] + pub fn set_mut(&self, closure: F) -> Result<(), Error> + where + F: FnMut() -> T + Send + Sync + 'static, + { + self.state + .swap(Arc::new(ValueSource::DynamicMut(Mutex::new(Box::new( + closure, + ))))); Ok(()) } /// Close the channel pub fn close(&self) -> Result<(), Error> { - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Cleared; + self.state.swap(Arc::new(ValueSource::Cleared)); Ok(()) } @@ -103,8 +120,7 @@ where } Ok(Request::Close) => { // Close channel - let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - *state = ValueSource::Cleared; + self.close()?; break; } Err(_) => { @@ -117,12 +133,26 @@ where } fn handle_get_value(&self) -> Result, Error> { - let state = self.state.lock().map_err(|_| Error::InternalError)?; + let state = self.state.load(); - match &*state { - ValueSource::Static(value) => Ok(Response::Value(value.clone())), + match &**state { + ValueSource::Static { val, clone } => { + let value = self.execute_closure_safely(&mut || clone(val)); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), // Closure execution failed + } + } ValueSource::Dynamic(closure) => { - let value = self.execute_closure_safely(closure); + let value = self.execute_closure_safely(&mut || closure()); + match value { + Ok(v) => Ok(Response::Value(v)), + Err(_) => Ok(Response::NoSource), // Closure execution failed + } + } + ValueSource::DynamicMut(closure) => { + let mut closure = closure.lock().unwrap(); + let value = self.execute_closure_safely(&mut *closure); match value { Ok(v) => Ok(Response::Value(v)), Err(_) => Ok(Response::NoSource), // Closure execution failed @@ -135,7 +165,7 @@ where fn execute_closure_safely( &self, - closure: &dyn Fn() -> T, + closure: &mut dyn FnMut() -> T, ) -> Result> { std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) } diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs index 99c6a9c..4e72624 100644 --- a/src/sync/crossbeam.rs +++ b/src/sync/crossbeam.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + #[cfg(feature = "sync-crossbeam")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use arc_swap::ArcSwap; use crossbeam_channel; type CrossbeamSucker = @@ -62,10 +65,10 @@ impl CrossbeamSuck { let (request_tx, request_rx) = CrossbeamChannel::create_request_channel(); let (response_tx, response_rx) = CrossbeamChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); + let state = ArcSwap::new(Arc::new(crate::types::ValueSource::None)); let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, std::sync::Arc::clone(&state)); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/sync/flume.rs b/src/sync/flume.rs index a493513..78ec2a5 100644 --- a/src/sync/flume.rs +++ b/src/sync/flume.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + #[cfg(feature = "sync-flume")] use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use arc_swap::ArcSwap; use flume; type FlumeSucker = @@ -62,10 +65,11 @@ impl FlumeSuck { let (request_tx, request_rx) = FlumeChannel::create_request_channel(); let (response_tx, response_rx) = FlumeChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, std::sync::Arc::clone(&state)); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/sync/std.rs b/src/sync/std.rs index f58a845..8184a2b 100644 --- a/src/sync/std.rs +++ b/src/sync/std.rs @@ -1,5 +1,8 @@ +use arc_swap::ArcSwap; + use crate::sync::traits::{ChannelError, ChannelReceiver, ChannelSender, ChannelType}; use crate::types; +use std::sync::Arc; #[cfg(feature = "sync-std")] use std::sync::mpsc; @@ -60,10 +63,11 @@ impl StdSuck { let (request_tx, request_rx) = StdChannel::create_request_channel(); let (response_tx, response_rx) = StdChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); + let state = Arc::new(crate::types::ValueSource::None); + let state = ArcSwap::new(state); let sucker = crate::Sucker::new(request_tx, response_rx); - let sourcer = crate::Sourcer::new(request_rx, response_tx, std::sync::Arc::clone(&state)); + let sourcer = crate::Sourcer::new(request_rx, response_tx, state); (sucker, sourcer) } diff --git a/src/types.rs b/src/types.rs index 203a743..244584b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,4 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; + +use arc_swap::ArcSwap; /// Request messages sent from consumer to producer pub enum Request { @@ -15,11 +17,12 @@ pub enum Response { /// Represents the source of values: either static or dynamic pub(crate) enum ValueSource { - Static(T), + Static { val: T, clone: fn(&T) -> T }, + DynamicMut(Mutex T + Send + Sync + 'static>>), Dynamic(Box T + Send + Sync + 'static>), None, // Never set Cleared, // Was set but cleared (closed) } /// Internal channel state shared between producer and consumer -pub(crate) type ChannelState = Arc>>; +pub(crate) type ChannelState = ArcSwap>;