Skip to content

Commit

Permalink
Fix mypy and add dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
alexamici committed Oct 15, 2023
1 parent a2edab3 commit 3354ff7
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
17 changes: 10 additions & 7 deletions cf2cdm/cfcoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,25 @@

import functools
import logging
import typing as T
from typing import Any, Callable, Dict, Hashable, List, Mapping

import xarray as xr

from . import cfunits

CoordModelType = T.Dict[str, T.Dict[str, str]]
CoordTranslatorType = T.Callable[[str, xr.Dataset, CoordModelType], xr.Dataset]
# force ruff to keep comment types
_ = [Any, Mapping, Hashable, List]

CoordModelType = Dict[str, Dict[str, str]]
CoordTranslatorType = Callable[[str, xr.Dataset, CoordModelType], xr.Dataset]

COORD_MODEL: CoordModelType = {}
COORD_TRANSLATORS: T.Dict[str, CoordTranslatorType] = {}
COORD_TRANSLATORS: Dict[str, CoordTranslatorType] = {}
LOG = logging.getLogger(__name__)


def match_values(match_value_func, mapping):
# type: (T.Callable[[T.Any], bool], T.Mapping[T.Hashable, T.Any]) -> T.List[str]
# type: (Callable[[Any], bool], Mapping[Hashable, Any]) -> List[str]
matched_names = []
for name, value in mapping.items():
if match_value_func(value):
Expand All @@ -60,7 +63,7 @@ def coord_translator(
default_out_name: str,
default_units: str,
default_direction: str,
is_cf_type: T.Callable[[xr.IndexVariable], bool],
is_cf_type: Callable[[xr.IndexVariable], bool],
cf_type: str,
data: xr.Dataset,
coord_model: CoordModelType = COORD_MODEL,
Expand Down Expand Up @@ -206,7 +209,7 @@ def is_forecast_month(coord: xr.IndexVariable) -> bool:
def translate_coords(
data, coord_model=COORD_MODEL, errors="warn", coord_translators=COORD_TRANSLATORS
):
# type: (xr.Dataset, CoordModelType, str, T.Dict[str, CoordTranslatorType]) -> xr.Dataset
# type: (xr.Dataset, CoordModelType, str, Dict[str, CoordTranslatorType]) -> xr.Dataset
for cf_name, translator in coord_translators.items():
try:
data = translator(cf_name, data, coord_model)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering"
]
dependencies = ["xarray"]
description = "Translate xarray dataset to a custom data model"
dynamic = ["version"]
license = {file = "LICENSE"}
Expand Down
3 changes: 2 additions & 1 deletion tests/test_20_cfcoords.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from typing import Any, Dict, Hashable

import numpy as np
import pytest
Expand Down Expand Up @@ -77,7 +78,7 @@ def da3() -> xr.Dataset:


def test_match_values() -> None:
mapping = {"callable": len, "int": 1} # type: T.Dict[T.Hashable, T.Any]
mapping: Dict[Hashable, Any] = {"callable": len, "int": 1}
res = cfcoords.match_values(callable, mapping)

assert res == ["callable"]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_50_datamodels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os.path

from cfgrib import xarray_store
import xarray as xr

from cf2cdm import cfcoords, datamodels

Expand All @@ -10,7 +10,7 @@


def test_cds() -> None:
ds = xarray_store.open_dataset(TEST_DATA1)
ds = xr.open_dataset(TEST_DATA1)

res = cfcoords.translate_coords(ds, coord_model=datamodels.CDS)

Expand All @@ -31,7 +31,7 @@ def test_cds() -> None:
"time",
}

ds = xarray_store.open_dataset(TEST_DATA2)
ds = xr.open_dataset(TEST_DATA2)

res = cfcoords.translate_coords(ds, coord_model=datamodels.CDS)

Expand All @@ -47,7 +47,7 @@ def test_cds() -> None:


def test_ecmwf() -> None:
ds = xarray_store.open_dataset(TEST_DATA1)
ds = xr.open_dataset(TEST_DATA1)

res = cfcoords.translate_coords(ds, coord_model=datamodels.ECMWF)

Expand All @@ -62,7 +62,7 @@ def test_ecmwf() -> None:
"valid_time",
}

ds = xarray_store.open_dataset(TEST_DATA2)
ds = xr.open_dataset(TEST_DATA2)

res = cfcoords.translate_coords(ds, coord_model=datamodels.ECMWF)

Expand Down

0 comments on commit 3354ff7

Please sign in to comment.