diff --git a/snake-pyo3/src/dataset.rs b/snake-pyo3/src/dataset.rs index e6e3533..5009c68 100644 --- a/snake-pyo3/src/dataset.rs +++ b/snake-pyo3/src/dataset.rs @@ -2,11 +2,14 @@ use futures::stream::{self, BoxStream, StreamExt}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3_async::asyncio::AsyncGenerator; +use pyo3_async::AllowThreads; use std::cell::RefCell; +type SnakeStream = BoxStream<'static, PyResult>; + #[pyclass] pub(crate) struct Dataset { - stream: RefCell>>>, + stream: RefCell>, } #[pymethods] @@ -20,29 +23,42 @@ impl Dataset { } fn map(&self, f: PyObject) -> PyResult { - let stream = self - .stream - .borrow_mut() - .take() - .ok_or_else(|| PyRuntimeError::new_err("Dataset is already transformed before"))? - .map(move |x| { - Python::with_gil(|py| { - let y = f.call1(py, (x?,))?; - let y = y.extract::(py)?; - Ok(y) + self.and_then(|stream| { + stream + .map(move |x| { + Python::with_gil(|py| { + let y = f.call1(py, (x?,))?; + let y = y.extract::(py)?; + Ok(y) + }) }) - }); - Ok(Dataset { - stream: RefCell::new(Some(stream.boxed())), + .boxed() }) } fn __aiter__(slf: PyRef<'_, Self>) -> PyResult { - let stream = slf - .stream - .borrow_mut() - .take() - .ok_or_else(|| PyRuntimeError::new_err("Stream can only be consumed once"))?; - Ok(AsyncGenerator::from_stream(stream)) + Ok(AsyncGenerator::from_stream(AllowThreads( + slf.stream + .borrow_mut() + .take() + .ok_or_else(|| PyRuntimeError::new_err("Stream can only be consumed once"))?, + ))) + } +} + +impl Dataset { + fn and_then(&self, func: F) -> PyResult + where + F: FnOnce(SnakeStream) -> SnakeStream, + { + let stream = func( + self.stream + .borrow_mut() + .take() + .ok_or_else(|| PyRuntimeError::new_err("Dataset can only be transformed once"))?, + ); + Ok(Dataset { + stream: RefCell::new(Some(stream)), + }) } }