Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop support for Python 3.9 #232

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
fail-fast: false
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -55,7 +55,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
fail-fast: false
steps:
- uses: actions/checkout@v4
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ description = "Batch generation from Xarray objects"
readme = "README.rst"
license = {text = "Apache"}
authors = [{name = "xbatcher Developers", email = "[email protected]"}]
requires-python = ">=3.9"
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Intended Audience :: Science/Research",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down Expand Up @@ -62,7 +61,7 @@ fallback_version = "999"


[tool.ruff]
target-version = "py39"
target-version = "py310"
extend-include = ["*.ipynb"]


Expand Down
10 changes: 5 additions & 5 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Union
from typing import Any

import xarray as xr

from .generators import BatchGenerator


def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
def _as_xarray_dataarray(xr_obj: xr.Dataset | xr.DataArray) -> xr.DataArray:
"""
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
be converted into a Tensor object.
Expand All @@ -19,7 +19,7 @@ def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArra
@xr.register_dataarray_accessor('batch')
@xr.register_dataset_accessor('batch')
class BatchAccessor:
def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]):
def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
"""
Batch accessor returning a BatchGenerator object via the `generator method`
"""
Expand All @@ -42,7 +42,7 @@ def generator(self, *args, **kwargs) -> BatchGenerator:
@xr.register_dataarray_accessor('tf')
@xr.register_dataset_accessor('tf')
class TFAccessor:
def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]):
def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
self._obj = xarray_obj

def to_tensor(self) -> Any:
Expand All @@ -57,7 +57,7 @@ def to_tensor(self) -> Any:
@xr.register_dataarray_accessor('torch')
@xr.register_dataset_accessor('torch')
class TorchAccessor:
def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]):
def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
self._obj = xarray_obj

def to_tensor(self) -> Any:
Expand Down
48 changes: 22 additions & 26 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import itertools
import json
import warnings
from collections.abc import Hashable, Iterator, Sequence
from collections.abc import Callable, Hashable, Iterator, Sequence
from operator import itemgetter
from typing import Any, Callable, Optional, Union
from typing import Any

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -55,10 +55,10 @@ class BatchSchema:

def __init__(
self,
ds: Union[xr.Dataset, xr.DataArray],
ds: xr.Dataset | xr.DataArray,
input_dims: dict[Hashable, int],
input_overlap: Optional[dict[Hashable, int]] = None,
batch_dims: Optional[dict[Hashable, int]] = None,
input_overlap: dict[Hashable, int] | None = None,
batch_dims: dict[Hashable, int] | None = None,
concat_input_bins: bool = True,
preload_batch: bool = True,
):
Expand Down Expand Up @@ -91,9 +91,7 @@ def __init__(
)
self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds)

def _gen_batch_selectors(
self, ds: Union[xr.DataArray, xr.Dataset]
) -> BatchSelectorSet:
def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet:
"""
Create batch selectors dict, which can be used to create a batch
from an Xarray data object.
Expand All @@ -106,9 +104,7 @@ def _gen_batch_selectors(
else: # Each patch gets its own batch
return {ind: [value] for ind, value in enumerate(patch_selectors)}

def _gen_patch_selectors(
self, ds: Union[xr.DataArray, xr.Dataset]
) -> PatchGenerator:
def _gen_patch_selectors(self, ds: xr.DataArray | xr.Dataset) -> PatchGenerator:
"""
Create an iterator that can be used to index an Xarray Dataset/DataArray.
"""
Expand All @@ -127,7 +123,7 @@ def _gen_patch_selectors(
return all_slices

def _combine_patches_into_batch(
self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator
self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator
) -> BatchSelectorSet:
"""
Combine the patch selectors to form a batch
Expand Down Expand Up @@ -169,7 +165,7 @@ def _combine_patches_grouped_by_batch_dims(
return dict(enumerate(batch_selectors))

def _combine_patches_grouped_by_input_and_batch_dims(
self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator
self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator
) -> BatchSelectorSet:
"""
Combine patches with multiple slices along ``batch_dims`` grouped into
Expand Down Expand Up @@ -197,7 +193,7 @@ def _gen_empty_batch_selectors(self) -> BatchSelectorSet:
n_batches = np.prod(list(self._n_batches_per_dim.values()))
return {k: [] for k in range(n_batches)}

def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]):
def _gen_patch_numbers(self, ds: xr.DataArray | xr.Dataset):
"""
Calculate the number of patches per dimension and the number of patches
in each batch per dimension.
Expand All @@ -214,7 +210,7 @@ def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]):
for dim, length in self._all_sliced_dims.items()
}

def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]):
def _gen_batch_numbers(self, ds: xr.DataArray | xr.Dataset):
"""
Calculate the number of batches per dimension
"""
Expand Down Expand Up @@ -324,7 +320,7 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[sli


def _iterate_through_dimensions(
ds: Union[xr.Dataset, xr.DataArray],
ds: xr.Dataset | xr.DataArray,
*,
dims: dict[Hashable, int],
overlap: dict[Hashable, int] = {},
Expand All @@ -350,10 +346,10 @@ def _iterate_through_dimensions(


def _drop_input_dims(
ds: Union[xr.Dataset, xr.DataArray],
ds: xr.Dataset | xr.DataArray,
input_dims: dict[Hashable, int],
suffix: str = '_input',
) -> Union[xr.Dataset, xr.DataArray]:
) -> xr.Dataset | xr.DataArray:
# remove input_dims coordinates from datasets, rename the dimensions
# then put intput_dims back in as coordinates
out = ds.copy()
Expand All @@ -368,9 +364,9 @@ def _drop_input_dims(


def _maybe_stack_batch_dims(
ds: Union[xr.Dataset, xr.DataArray],
ds: xr.Dataset | xr.DataArray,
input_dims: Sequence[Hashable],
) -> Union[xr.Dataset, xr.DataArray]:
) -> xr.Dataset | xr.DataArray:
batch_dims = [d for d in ds.sizes if d not in input_dims]
if len(batch_dims) < 2:
return ds
Expand Down Expand Up @@ -424,14 +420,14 @@ class BatchGenerator:

def __init__(
self,
ds: Union[xr.Dataset, xr.DataArray],
ds: xr.Dataset | xr.DataArray,
input_dims: dict[Hashable, int],
input_overlap: dict[Hashable, int] = {},
batch_dims: dict[Hashable, int] = {},
concat_input_dims: bool = False,
preload_batch: bool = True,
cache: Optional[dict[str, Any]] = None,
cache_preprocess: Optional[Callable] = None,
cache: dict[str, Any] | None = None,
cache_preprocess: Callable | None = None,
):
self.ds = ds
self.cache = cache
Expand Down Expand Up @@ -466,14 +462,14 @@ def concat_input_dims(self):
def preload_batch(self):
return self._batch_selectors.preload_batch

def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
def __iter__(self) -> Iterator[xr.DataArray | xr.Dataset]:
for idx in self._batch_selectors.selectors:
yield self[idx]

def __len__(self) -> int:
return len(self._batch_selectors.selectors)

def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
def __getitem__(self, idx: int) -> xr.Dataset | xr.DataArray:
if not isinstance(idx, int):
raise NotImplementedError(
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
Expand Down Expand Up @@ -532,7 +528,7 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
def _batch_in_cache(self, idx: int) -> bool:
return self.cache is not None and f'{idx}/.zgroup' in self.cache

def _cache_batch(self, idx: int, batch: Union[xr.Dataset, xr.DataArray]) -> None:
def _cache_batch(self, idx: int, batch: xr.Dataset | xr.DataArray) -> None:
batch.to_zarr(self.cache, group=str(idx), mode='a')

def _get_cached_batch(self, idx: int) -> xr.Dataset:
Expand Down
7 changes: 4 additions & 3 deletions xbatcher/loaders/keras.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any

try:
import tensorflow as tf
Expand All @@ -21,8 +22,8 @@ def __init__(
X_generator,
y_generator,
*,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transform: Callable | None = None,
target_transform: Callable | None = None,
) -> None:
"""
Keras Dataset adapter for Xbatcher
Expand Down
7 changes: 4 additions & 3 deletions xbatcher/loaders/torch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any

try:
import torch
Expand All @@ -24,8 +25,8 @@ def __init__(
self,
X_generator,
y_generator,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transform: Callable | None = None,
target_transform: Callable | None = None,
) -> None:
"""
PyTorch Dataset adapter for Xbatcher
Expand Down
3 changes: 1 addition & 2 deletions xbatcher/testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Hashable
from typing import Union
from unittest import TestCase

import numpy as np
Expand Down Expand Up @@ -170,7 +169,7 @@ def get_batch_dimensions(generator: BatchGenerator) -> dict[Hashable, int]:


def validate_batch_dimensions(
*, expected_dims: dict[Hashable, int], batch: Union[xr.Dataset, xr.DataArray]
*, expected_dims: dict[Hashable, int], batch: xr.Dataset | xr.DataArray
) -> None:
"""
Raises an AssertionError if the shape and dimensions of a batch do not
Expand Down