diff --git a/cf2cdm/cfcoords.py b/cf2cdm/cfcoords.py index 4278b80..f9d6088 100644 --- a/cf2cdm/cfcoords.py +++ b/cf2cdm/cfcoords.py @@ -19,22 +19,25 @@ import functools import logging -import typing as T +from typing import Any, Callable, Dict, Hashable, List, Mapping import xarray as xr from . import cfunits -CoordModelType = T.Dict[str, T.Dict[str, str]] -CoordTranslatorType = T.Callable[[str, xr.Dataset, CoordModelType], xr.Dataset] +# force ruff to keep comment types +_ = [Any, Mapping, Hashable, List] + +CoordModelType = Dict[str, Dict[str, str]] +CoordTranslatorType = Callable[[str, xr.Dataset, CoordModelType], xr.Dataset] COORD_MODEL: CoordModelType = {} -COORD_TRANSLATORS: T.Dict[str, CoordTranslatorType] = {} +COORD_TRANSLATORS: Dict[str, CoordTranslatorType] = {} LOG = logging.getLogger(__name__) def match_values(match_value_func, mapping): - # type: (T.Callable[[T.Any], bool], T.Mapping[T.Hashable, T.Any]) -> T.List[str] + # type: (Callable[[Any], bool], Mapping[Hashable, Any]) -> List[str] matched_names = [] for name, value in mapping.items(): if match_value_func(value): @@ -60,7 +63,7 @@ def coord_translator( default_out_name: str, default_units: str, default_direction: str, - is_cf_type: T.Callable[[xr.IndexVariable], bool], + is_cf_type: Callable[[xr.IndexVariable], bool], cf_type: str, data: xr.Dataset, coord_model: CoordModelType = COORD_MODEL, @@ -206,7 +209,7 @@ def is_forecast_month(coord: xr.IndexVariable) -> bool: def translate_coords( data, coord_model=COORD_MODEL, errors="warn", coord_translators=COORD_TRANSLATORS ): - # type: (xr.Dataset, CoordModelType, str, T.Dict[str, CoordTranslatorType]) -> xr.Dataset + # type: (xr.Dataset, CoordModelType, str, Dict[str, CoordTranslatorType]) -> xr.Dataset for cf_name, translator in coord_translators.items(): try: data = translator(cf_name, data, coord_model) diff --git a/pyproject.toml b/pyproject.toml index fd5140a..fdb9069 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering" ] +dependencies = ["xarray"] description = "Translate xarray dataset to a custom data model" dynamic = ["version"] license = {file = "LICENSE"} diff --git a/tests/test_20_cfcoords.py b/tests/test_20_cfcoords.py index b4fb186..29e98c3 100644 --- a/tests/test_20_cfcoords.py +++ b/tests/test_20_cfcoords.py @@ -1,4 +1,5 @@ import sys +from typing import Any, Dict, Hashable import numpy as np import pytest @@ -77,7 +78,7 @@ def da3() -> xr.Dataset: def test_match_values() -> None: - mapping = {"callable": len, "int": 1} # type: T.Dict[T.Hashable, T.Any] + mapping: Dict[Hashable, Any] = {"callable": len, "int": 1} res = cfcoords.match_values(callable, mapping) assert res == ["callable"] diff --git a/tests/test_50_datamodels.py b/tests/test_50_datamodels.py index 1442e21..f075d8d 100644 --- a/tests/test_50_datamodels.py +++ b/tests/test_50_datamodels.py @@ -1,6 +1,6 @@ import os.path -from cfgrib import xarray_store +import xarray as xr from cf2cdm import cfcoords, datamodels @@ -10,7 +10,7 @@ def test_cds() -> None: - ds = xarray_store.open_dataset(TEST_DATA1) + ds = xr.open_dataset(TEST_DATA1) res = cfcoords.translate_coords(ds, coord_model=datamodels.CDS) @@ -31,7 +31,7 @@ def test_cds() -> None: "time", } - ds = xarray_store.open_dataset(TEST_DATA2) + ds = xr.open_dataset(TEST_DATA2) res = cfcoords.translate_coords(ds, coord_model=datamodels.CDS) @@ -47,7 +47,7 @@ def test_cds() -> None: def test_ecmwf() -> None: - ds = xarray_store.open_dataset(TEST_DATA1) + ds = xr.open_dataset(TEST_DATA1) res = cfcoords.translate_coords(ds, coord_model=datamodels.ECMWF) @@ -62,7 +62,7 @@ def test_ecmwf() -> None: "valid_time", } - ds = xarray_store.open_dataset(TEST_DATA2) + ds = xr.open_dataset(TEST_DATA2) res = cfcoords.translate_coords(ds, coord_model=datamodels.ECMWF)