Skip to content

Commit

Permalink
Merge pull request #134 from sot/chandra-models-from-ska-helpers
Browse files Browse the repository at this point in the history
Refactor get_model_spec to use ska_helpers.chandra_models
  • Loading branch information
taldcroft authored Aug 13, 2023
2 parents 10a7ce5 + 3281942 commit 5430f1c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 84 deletions.
99 changes: 16 additions & 83 deletions xija/get_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
import contextlib
import json
import os
import platform
import re
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import List, Optional, Union
from typing import List

import git
import requests
from Ska.File import get_globfiles
from ska_helpers import chandra_models
from ska_helpers.paths import chandra_models_repo_path, xija_models_path

__all__ = [
Expand All @@ -30,6 +27,11 @@
)


# Define local names for API back-compatibility
get_repo_version = chandra_models.get_repo_version
get_github_version = chandra_models.get_github_version


@contextlib.contextmanager
def temp_directory():
"""Get name of a temporary directory that is deleted at the end.
Expand Down Expand Up @@ -80,7 +82,7 @@ def get_xija_model_spec(
----------
model_name : str
Name of model
version : str
version : str, None
Tag, branch or commit of chandra_models to use (default=latest tag from
repo)
repo_path : str, Path
Expand All @@ -100,18 +102,18 @@ def get_xija_model_spec(
if repo_path is None:
repo_path = chandra_models_repo_path()

with temp_directory() as repo_path_local:
repo = git.Repo.clone_from(repo_path, repo_path_local)
if version is not None:
repo.git.checkout(version)
if version is None:
version = os.environ.get("CHANDRA_MODELS_DEFAULT_VERSION")

with chandra_models.get_local_repo(repo_path, version) as (repo, repo_path_local):
model_spec, version = _get_xija_model_spec(
model_name, version, repo_path_local, check_version, timeout
model_name, version, repo_path_local, check_version, timeout, repo=repo
)
return model_spec, version


def _get_xija_model_spec(
model_name, version=None, repo_path=None, check_version=False, timeout=5
model_name, version=None, repo_path=None, check_version=False, timeout=5, repo=None
) -> tuple:
models_path = xija_models_path(repo_path)

Expand All @@ -132,19 +134,10 @@ def _get_xija_model_spec(

# Get version and ensure that repo is clean and tip is at latest tag
if version is None:
version = get_repo_version(repo_path)
version = chandra_models.get_repo_version(repo=repo)

if check_version:
gh_version = get_github_version(timeout=timeout)
if gh_version is None:
warnings.warn(
"Could not verify GitHub chandra_models release tag "
f"due to timeout ({timeout} sec)"
)
elif version != gh_version:
raise ValueError(
f"version mismatch: local repo {version} vs github {gh_version}"
)
chandra_models.assert_latest_version(version, timeout)

return model_spec, version

Expand Down Expand Up @@ -190,63 +183,3 @@ def get_xija_model_names(repo_path=None) -> List[str]:
names = [re.sub(r"_spec\.json", "", Path(fn).name) for fn in sorted(fns)]

return names


def get_repo_version(repo_path: Optional[Path] = None) -> str:
"""Return version (most recent tag) of models repository.
Returns
-------
str
Version (most recent tag) of models repository
"""
if repo_path is None:
repo_path = chandra_models_repo_path()

with temp_directory() as repo_path_local:
if platform.system() == "Windows":
repo = git.Repo.clone_from(repo_path, repo_path_local)
else:
repo = git.Repo(repo_path)

if repo.is_dirty():
raise ValueError("repo is dirty")

tags = sorted(repo.tags, key=lambda tag: tag.commit.committed_datetime)
tag_repo = tags[-1]
if tag_repo.commit != repo.head.commit:
raise ValueError(f"repo tip is not at tag {tag_repo}")

return tag_repo.name


def get_github_version(
url: str = CHANDRA_MODELS_LATEST_URL, timeout: Union[int, float] = 5
) -> Optional[bool]:
"""Get latest chandra_models GitHub repo release tag (version).
This queries GitHub for the latest release of chandra_models.
Parameters
----------
url : str
URL for latest chandra_models release on GitHub API
timeout : int, float
Request timeout (sec, default=5)
Returns
-------
str, None
Tag name (str) of latest chandra_models release on GitHub.
None if the request timed out, indicating indeterminate answer.
"""
try:
req = requests.get(url, timeout=timeout)
except (requests.ConnectTimeout, requests.ReadTimeout):
return None

if req.status_code != requests.codes.ok:
req.raise_for_status()

page_json = req.json()
return page_json["tag_name"]
2 changes: 1 addition & 1 deletion xija/tests/test_get_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_get_model_file_fail():
with pytest.raises(ValueError, match="no models matched xxxyyyzzz"):
get_xija_model_spec("xxxyyyzzz")

with pytest.raises(git.GitCommandError, match="does not exist"):
with pytest.raises(git.exc.NoSuchPathError):
get_xija_model_spec("aca", repo_path="__NOT_A_DIRECTORY__")


Expand Down

0 comments on commit 5430f1c

Please sign in to comment.