Skip to content

Commit

Permalink
Implement map transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
npuichigo committed Dec 13, 2023
1 parent c77ee8a commit 59e577f
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
# This file is autogenerated by maturin v1.4.0
# To update, run
#
# maturin generate-ci github
#
name: CI
name: Python Release

on:
push:
Expand Down Expand Up @@ -36,11 +31,12 @@ jobs:
args: --release --out dist --find-interpreter
sccache: 'true'
manylinux: auto
working-directory: ./snake-pyo3
- name: Upload wheels
uses: actions/upload-artifact@v3
with:
name: wheels
path: dist
path: ./snake-pyo3/dist

windows:
runs-on: windows-latest
Expand All @@ -59,11 +55,12 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --find-interpreter
sccache: 'true'
working-directory: ./snake-pyo3
- name: Upload wheels
uses: actions/upload-artifact@v3
with:
name: wheels
path: dist
path: ./snake-pyo3/dist

macos:
runs-on: macos-latest
Expand All @@ -81,11 +78,12 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --find-interpreter
sccache: 'true'
working-directory: ./snake-pyo3
- name: Upload wheels
uses: actions/upload-artifact@v3
with:
name: wheels
path: dist
path: ./snake-pyo3/dist

sdist:
runs-on: ubuntu-latest
Expand All @@ -96,11 +94,12 @@ jobs:
with:
command: sdist
args: --out dist
working-directory: ./snake-pyo3
- name: Upload sdist
uses: actions/upload-artifact@v3
with:
name: wheels
path: dist
path: ./snake-pyo3/dist

release:
name: Release
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[workspace]
members = [
"snake",
"snake-pyo3",
]
resolver = "2"
Expand All @@ -9,3 +8,4 @@ resolver = "2"
futures = "0.3"
par-stream = { version = "0.10.2", features = ["runtime-tokio"] }
tokio = { version = "1", features = ["full"] }
async-stream = "0.3.5"
4 changes: 1 addition & 3 deletions snake-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,4 @@ pyo3-async = "0.3.2"
futures = { workspace = true }
par-stream = { workspace = true }
tokio = { workspace = true }
rand = "0.8.5"
flume = "0.11.0"
async-stream = "0.3.5"
async-stream = { workspace = true }
12 changes: 8 additions & 4 deletions snake-pyo3/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
[build-system]
requires = ["maturin>=1.4,<2.0"]
build-backend = "maturin"

[project]
name = "snakedata"
requires-python = ">=3.8"
Expand All @@ -12,5 +8,13 @@ classifiers = [
]
dynamic = ["version"]

[project.optional-dependencies]
testing = ["pytest", "pytest-asyncio"]
dev = ["snakedata[testing]"]

[build-system]
requires = ["maturin>=1.4,<2.0"]
build-backend = "maturin"

[tool.maturin]
features = ["pyo3/extension-module"]
26 changes: 22 additions & 4 deletions snake-pyo3/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,34 @@ use std::cell::RefCell;

#[pyclass]
pub(crate) struct Dataset {
stream: RefCell<Option<BoxStream<'static, usize>>>,
stream: RefCell<Option<BoxStream<'static, PyResult<usize>>>>,
}

#[pymethods]
impl Dataset {
#[staticmethod]
fn range(start: usize, end: usize) -> PyResult<Self> {
let stream = stream::iter(start..end).map(|x| Ok(x)).boxed();
Ok(Dataset {
stream: RefCell::new(Some(stream::iter(start..end).boxed())),
stream: RefCell::new(Some(stream)),
})
}

fn map(&self, f: PyObject) -> PyResult<Self> {
let stream = self
.stream
.borrow_mut()
.take()
.ok_or_else(|| PyRuntimeError::new_err("Dataset is already transformed before"))?
.map(move |x| {
Python::with_gil(|py| {
let y = f.call1(py, (x?,))?;
let y = y.extract::<usize>(py)?;
Ok(y)
})
});
Ok(Dataset {
stream: RefCell::new(Some(stream.boxed())),
})
}

Expand All @@ -23,8 +42,7 @@ impl Dataset {
.stream
.borrow_mut()
.take()
.ok_or_else(|| PyRuntimeError::new_err("Stream can only be consumed once"))?
.map(|x| PyResult::Ok(x));
.ok_or_else(|| PyRuntimeError::new_err("Stream can only be consumed once"))?;
Ok(AsyncGenerator::from_stream(stream))
}
}
14 changes: 14 additions & 0 deletions snake-pyo3/tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
import snakedata


class TestDataset:
@pytest.mark.asyncio
async def test_range(self):
dataset = snakedata.Dataset.range(0, 3)
assert [x async for x in dataset] == [0, 1, 2]

@pytest.mark.asyncio
async def test_map(self):
dataset = snakedata.Dataset.range(0, 3).map(lambda x: x * 2)
assert [x async for x in dataset] == [0, 2, 4]
Empty file removed snake/.gitignore
Empty file.
6 changes: 0 additions & 6 deletions snake/Cargo.toml

This file was deleted.

3 changes: 0 additions & 3 deletions snake/src/main.rs

This file was deleted.

0 comments on commit 59e577f

Please sign in to comment.