diff --git a/src/from_stream.rs b/src/from_stream.rs index ef660cb..275317c 100644 --- a/src/from_stream.rs +++ b/src/from_stream.rs @@ -17,6 +17,7 @@ pin_project! { pub struct FromStream { #[pin] stream: S, + limit: Option, } } @@ -26,6 +27,7 @@ where S: Send + Sync, { FromStream { + limit: None, stream: stream.into_stream(), } } @@ -40,4 +42,13 @@ where let this = self.project(); this.stream.poll_next(cx) } + + fn limit(mut self, limit: impl Into>) -> Self { + self.limit = limit.into(); + self + } + + fn get_limit(&self) -> Option { + self.limit + } } diff --git a/src/lib.rs b/src/lib.rs index aa5ce60..af93ecd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,3 +56,5 @@ pub use par_stream::{ForEach, Map, NextFuture, ParallelStream, Take}; pub mod prelude; pub mod vec; + +pub(crate) mod utils; diff --git a/src/par_stream/for_each.rs b/src/par_stream/for_each.rs index 321686c..e971db1 100644 --- a/src/par_stream/for_each.rs +++ b/src/par_stream/for_each.rs @@ -24,7 +24,7 @@ pin_project_lite::pin_project! { impl ForEach { /// Create a new instance of `ForEach`. - pub fn new(mut input: S, mut f: F) -> Self + pub fn new(mut stream: S, mut f: F) -> Self where S: ParallelStream, F: FnMut(S::Item) -> Fut + Send + Sync + Copy + 'static, @@ -33,6 +33,7 @@ impl ForEach { let exhausted = Arc::new(AtomicBool::new(false)); let ref_count = Arc::new(AtomicU64::new(0)); let (sender, receiver): (Sender<()>, Receiver<()>) = sync::channel(1); + let _limit = stream.get_limit(); // Initialize the return type here to prevent borrowing issues. let this = Self { @@ -42,7 +43,7 @@ impl ForEach { }; task::spawn(async move { - while let Some(item) = input.next().await { + while let Some(item) = stream.next().await { let sender = sender.clone(); let exhausted = exhausted.clone(); let ref_count = ref_count.clone(); diff --git a/src/par_stream/map.rs b/src/par_stream/map.rs index 1faf249..e10ca00 100644 --- a/src/par_stream/map.rs +++ b/src/par_stream/map.rs @@ -14,6 +14,7 @@ pin_project_lite::pin_project! { pub struct Map { #[pin] receiver: Receiver, + limit: Option, } } @@ -26,6 +27,7 @@ impl Map { Fut: Future + Send, { let (sender, receiver) = sync::channel(1); + let limit = stream.get_limit(); task::spawn(async move { while let Some(item) = stream.next().await { let sender = sender.clone(); @@ -35,7 +37,7 @@ impl Map { }); } }); - Map { receiver } + Map { receiver, limit } } } @@ -46,6 +48,15 @@ impl ParallelStream for Map { let this = self.project(); this.receiver.poll_next(cx) } + + fn limit(mut self, limit: impl Into>) -> Self { + self.limit = limit.into(); + self + } + + fn get_limit(&self) -> Option { + self.limit + } } #[async_std::test] diff --git a/src/par_stream/mod.rs b/src/par_stream/mod.rs index 6e7fca6..41aa782 100644 --- a/src/par_stream/mod.rs +++ b/src/par_stream/mod.rs @@ -21,6 +21,12 @@ pub trait ParallelStream: Sized + Send + Sync + Unpin + 'static { /// Attempts to receive the next item from the stream. fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + /// Set a max concurrency limit + fn limit(self, limit: impl Into>) -> Self; + + /// Get the max concurrency limit + fn get_limit(&self) -> Option; + /// Applies `f` to each item of this stream in parallel, producing a new /// stream with the results. fn map(self, f: F) -> Map diff --git a/src/par_stream/take.rs b/src/par_stream/take.rs index 75d6f07..b34dbb8 100644 --- a/src/par_stream/take.rs +++ b/src/par_stream/take.rs @@ -17,14 +17,19 @@ pin_project! { #[derive(Clone, Debug)] pub struct Take { #[pin] - pub(crate) stream: S, - pub(crate) remaining: usize, + stream: S, + remaining: usize, + limit: Option, } } -impl Take { +impl Take { pub(super) fn new(stream: S, remaining: usize) -> Self { - Self { stream, remaining } + Self { + limit: stream.get_limit(), + remaining, + stream, + } } } @@ -44,4 +49,13 @@ impl ParallelStream for Take { Poll::Ready(next) } } + + fn limit(mut self, limit: impl Into>) -> Self { + self.limit = limit.into(); + self + } + + fn get_limit(&self) -> Option { + self.limit + } } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..25b0f50 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,55 @@ +// use core::pin::Pin; +// use core::task::{Context, Poll}; + +// use std::sync::atomic::{AtomicUsize, Ordering}; +// use std::sync::Arc; + +// use async_std::stream::Stream; + +// /// A stream that has a max concurrency of N. +// pub(crate) struct LimitStream { +// limit: Option, +// ref_count: Arc, +// } + +// impl LimitStream { +// /// Create a new instance of LimitStream. +// pub(crate) fn new(limit: Option) -> Self { +// Self { +// limit, +// ref_count: Arc::new(AtomicUsize::new(0)), +// } +// } +// } + +// #[derive(Debug)] +// pub(crate) struct Guard { +// limit: Option, +// ref_count: Arc, +// } + +// impl Guard { +// fn new(limit: Option, ref_count: Arc) -> Self { +// Self { limit, ref_count } +// } +// } + +// impl Drop for Guard { +// fn drop(&mut self) { +// if self.limit.is_some() { +// self.ref_count.fetch_sub(1, Ordering::SeqCst); +// } +// } +// } + +// impl Stream for LimitStream { +// type Item = Guard; + +// fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// if self.limit.is_none() { +// let guard = Guard::new(self.limit, self.ref_count.clone()); +// return Poll::Ready(Some(guard)); +// } +// todo!(); +// } +// } diff --git a/src/vec.rs b/src/vec.rs index 6e3c829..b5ec202 100644 --- a/src/vec.rs +++ b/src/vec.rs @@ -17,6 +17,7 @@ pin_project_lite::pin_project! { pub struct IntoParStream { #[pin] stream: FromStream>>, + limit: Option, } } @@ -26,6 +27,15 @@ impl ParallelStream for IntoParStream { let this = self.project(); this.stream.poll_next(cx) } + + fn limit(mut self, limit: impl Into>) -> Self { + self.limit = limit.into(); + self + } + + fn get_limit(&self) -> Option { + self.limit + } } impl IntoParallelStream for Vec { @@ -36,6 +46,7 @@ impl IntoParallelStream for Vec { fn into_par_stream(self) -> Self::IntoParStream { IntoParStream { stream: from_stream(from_iter(self)), + limit: None, } } }