diff --git a/src/channel.rs b/src/channel.rs index 8452487..158e21a 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,5 +1,4 @@ use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex}; use crate::error::Error; use crate::sync::traits::{ChannelReceiver, ChannelSender}; @@ -25,7 +24,7 @@ where { pub(crate) request_rx: SR, pub(crate) response_tx: ST, - pub(crate) state: Arc>>, + pub(crate) state: ChannelState, pub(crate) _phantom: std::marker::PhantomData, } @@ -38,10 +37,7 @@ where /// Set a fixed value pub fn set_static(&self, value: T) -> Result<(), Error> { let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Err(Error::ChannelClosed); - } - state.source = ValueSource::Static(value); + *state = ValueSource::Static(value); Ok(()) } @@ -51,18 +47,14 @@ where F: Fn() -> T + Send + Sync + 'static, { let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Err(Error::ChannelClosed); - } - state.source = ValueSource::Dynamic(Box::new(closure)); + *state = ValueSource::Dynamic(Box::new(closure)); Ok(()) } /// Close the channel pub fn close(&self) -> Result<(), Error> { let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - state.closed = true; - state.source = ValueSource::None; + *state = ValueSource::Cleared; Ok(()) } @@ -80,8 +72,7 @@ where Ok(Request::Close) => { // Close channel let mut state = self.state.lock().map_err(|_| Error::InternalError)?; - state.closed = true; - state.source = ValueSource::None; + *state = ValueSource::Cleared; break; } Err(_) => { @@ -95,11 +86,8 @@ where fn handle_get_value(&self) -> Result, Error> { let state = self.state.lock().map_err(|_| Error::InternalError)?; - if state.closed { - return Ok(Response::Closed); - } - match &state.source { + match &*state { ValueSource::Static(value) => Ok(Response::Value(value.clone())), ValueSource::Dynamic(closure) => { let value = self.execute_closure_safely(closure); @@ -108,7 +96,8 @@ where Err(_) => Ok(Response::NoSource), // Closure execution failed } } - ValueSource::None => Ok(Response::NoSource), + ValueSource::None => Ok(Response::NoSource), // No source was ever set + ValueSource::Cleared => Ok(Response::Closed), // Channel was closed (source was set then cleared) } } diff --git a/src/sync/crossbeam.rs b/src/sync/crossbeam.rs index 20bc162..35a99f4 100644 --- a/src/sync/crossbeam.rs +++ b/src/sync/crossbeam.rs @@ -64,10 +64,7 @@ impl CrossbeamSuck { let (request_tx, request_rx) = CrossbeamChannel::create_request_channel(); let (response_tx, response_rx) = CrossbeamChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { - source: crate::types::ValueSource::None, - closed: false, - })); + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); let sucker = crate::Sucker { request_tx, diff --git a/src/sync/flume.rs b/src/sync/flume.rs index 5208cf5..14a0a79 100644 --- a/src/sync/flume.rs +++ b/src/sync/flume.rs @@ -64,10 +64,7 @@ impl FlumeSuck { let (request_tx, request_rx) = FlumeChannel::create_request_channel(); let (response_tx, response_rx) = FlumeChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { - source: crate::types::ValueSource::None, - closed: false, - })); + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); let sucker = crate::Sucker { request_tx, diff --git a/src/sync/std.rs b/src/sync/std.rs index d735ebb..2038543 100644 --- a/src/sync/std.rs +++ b/src/sync/std.rs @@ -61,10 +61,7 @@ impl StdSuck { let (request_tx, request_rx) = StdChannel::create_request_channel(); let (response_tx, response_rx) = StdChannel::create_response_channel::(); - let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ChannelState { - source: crate::types::ValueSource::None, - closed: false, - })); + let state = std::sync::Arc::new(std::sync::Mutex::new(crate::types::ValueSource::None)); let sucker = crate::Sucker { request_tx, diff --git a/src/types.rs b/src/types.rs index 0e82f9c..203a743 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,5 @@ +use std::sync::{Arc, Mutex}; + /// Request messages sent from consumer to producer pub enum Request { GetValue, @@ -15,11 +17,9 @@ pub enum Response { pub(crate) enum ValueSource { Static(T), Dynamic(Box T + Send + Sync + 'static>), - None, + None, // Never set + Cleared, // Was set but cleared (closed) } /// Internal channel state shared between producer and consumer -pub(crate) struct ChannelState { - pub(crate) source: ValueSource, - pub(crate) closed: bool, -} +pub(crate) type ChannelState = Arc>>;