Skip to content

Commit

Permalink
Merge pull request #3 from SBU-BMI/dev
Browse files Browse the repository at this point in the history
Add schema and expand command line program
  • Loading branch information
kaczmarj authored Jun 21, 2023
2 parents bed3749 + 2750ab6 commit a35f863
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 122 deletions.
8 changes: 7 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ python_requires = >= 3.7
install_requires =
click>=8.0,<9
huggingface_hub
jsonschema
requests
tabulate

[options.extras_require]
dev =
Expand All @@ -45,8 +47,12 @@ dev =

[options.entry_points]
console_scripts =
wsinfer_zoo = wsinfer_zoo.cli:cli
wsinfer-zoo = wsinfer_zoo.cli:cli

[options.package_data]
wsinfer =
py.typed
schemas/*.json

[flake8]
max-line-length = 88
Expand Down
10 changes: 10 additions & 0 deletions wsinfer-zoo-registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
"description": "Prostate tumor",
"hf_repo_id": "kaczmarj/prostate-tumor-resnet34.tcga-prad",
"hf_revision": "main"
},
"lymphnodes-tiatoolbox-resnet50.patchcamelyon": {
"description": "Lymph node metastasis (PatchCamelyon)",
"hf_repo_id": "kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon",
"hf_revision": "main"
},
"colorectal-tiatoolbox-resnet50.kather100k": {
"description": "Colorectal cancer tissue classification (Kather100K)",
"hf_repo_id": "kaczmarj/colorectal-tiatoolbox-resnet50.kather100k",
"hf_revision": "main"
}
}
}
34 changes: 0 additions & 34 deletions wsinfer-zoo-registry.schema.json

This file was deleted.

30 changes: 29 additions & 1 deletion wsinfer_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
import os as _os

import json
import jsonschema
from wsinfer_zoo import _version
from wsinfer_zoo.client import _download_registry_if_necessary
from wsinfer_zoo.client import (
_download_registry_if_necessary,
ModelRegistry,
WSINFER_ZOO_REGISTRY_DEFAULT_PATH,
InvalidRegistryConfiguration,
validate_model_zoo_json,
)

__version__ = _version.get_versions()["version"]

if _os.getenv("WSINFER_ZOO_NO_UPDATE_REGISTRY") is None:
_download_registry_if_necessary()


if not WSINFER_ZOO_REGISTRY_DEFAULT_PATH.exists():
raise FileNotFoundError(
f"registry file not found: {WSINFER_ZOO_REGISTRY_DEFAULT_PATH}"
)
with open(WSINFER_ZOO_REGISTRY_DEFAULT_PATH) as f:
d = json.load(f)
try:
validate_model_zoo_json(d)
except InvalidRegistryConfiguration as e:
raise InvalidRegistryConfiguration(
"Registry schema is invalid. Please contact the developer by"
" creating a new issue on our GitHub page:"
" https://github.com/SBU-BMI/wsinfer-zoo/issues/new."
) from e

registry = ModelRegistry.from_dict(d)

del d, json, jsonschema
210 changes: 193 additions & 17 deletions wsinfer_zoo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@

import dataclasses
import json
import sys
from pathlib import Path

import click
import huggingface_hub
import requests
import tabulate

from wsinfer_zoo.client import WSINFER_ZOO_REGISTRY_DEFAULT_PATH, ModelRegistry
from wsinfer_zoo.client import (
WSINFER_ZOO_REGISTRY_DEFAULT_PATH,
HF_CONFIG_NAME,
HF_WEIGHTS_SAFETENSORS_NAME,
HF_WEIGHTS_PICKLE_NAME,
HF_TORCHSCRIPT_NAME,
InvalidModelConfiguration,
InvalidRegistryConfiguration,
ModelRegistry,
validate_model_zoo_json,
validate_config_json,
)


@click.group()
Expand All @@ -24,46 +39,207 @@ def cli(ctx: click.Context, *, registry_file: Path):
raise click.ClickException(f"registry file not found: {registry_file}")
with open(registry_file) as f:
d = json.load(f)
# Raise an error if validation fails.
try:
validate_model_zoo_json(d)
except InvalidRegistryConfiguration as e:
raise InvalidRegistryConfiguration(
"Registry schema is invalid. Please contact the developer by"
" creating a new issue on our GitHub page:"
" https://github.com/SBU-BMI/wsinfer-zoo/issues/new."
) from e
registry = ModelRegistry.from_dict(d)
ctx.ensure_object(dict)
ctx.obj["registry"] = registry


@cli.command()
@click.option("--as-json", is_flag=True, help="Print as JSON")
@click.option("--json", "as_json", is_flag=True, help="Print as JSON lines")
@click.pass_context
def ls(ctx: click.Context, *, as_json: bool):
"""List registered models."""
"""List registered models.
If not a TTY, only model names are printed. If a TTY, a pretty table
of models is printed.
"""
registry: ModelRegistry = ctx.obj["registry"]
if not as_json:
click.echo("\n".join(str(m) for m in registry.models))
else:
for m in registry.models:
if as_json:
for m in registry.models.values():
click.echo(json.dumps(dataclasses.asdict(m)))
else:
if sys.stdout.isatty():
info = [
[m.name, m.description, m.hf_repo_id, m.hf_revision]
for m in registry.models.values()
]
click.echo(
tabulate.tabulate(
info,
headers=["Name", "Description", "HF Repo ID", "Rev"],
tablefmt="grid",
maxcolwidths=[None, 24, 30, None],
)
)
else:
# You're being piped or redirected
click.echo("\n".join(str(m) for m in registry.models))


@cli.command()
@click.option(
"--model-id",
"--model-name",
required=True,
help="Number of the model to get. See `ls` to list model numbers",
type=int,
help="Number of the model to get. See `ls` to list model names",
)
@click.pass_context
def get(ctx: click.Context, *, model_id: int):
"""Retrieve the model and configuration.
def get(ctx: click.Context, *, model_name: str):
"""Retrieve a model and its configuration.
Outputs JSON with model configuration, path to the model, and origin of the model.
This downloads the pretrained model if necessary.
The pretrained model is downloaded to a cache and reused if it is already present.
"""
registry: ModelRegistry = ctx.obj["registry"]
if model_id not in registry.model_ids:
if model_name not in registry.models.keys():
raise click.ClickException(
f"'{model_id}' not found, available models are {registry.model_ids}. Use `wsinfer_zoo ls` to list all models."
f"'{model_name}' not found, available models are"
" {list(registry.models.keys())}. Use `wsinfer_zoo ls` to list all"
" models."
)

registered_model = registry.get_model_by_id(model_id)
model = registered_model.load_model()
registered_model = registry.get_model_by_name(model_name)
model = registered_model.load_model_torchscript()
model_dict = dataclasses.asdict(model)
model_json = json.dumps(model_dict)
click.echo(model_json)


@cli.command()
@click.argument("input", type=click.File("r"))
def validate_config(*, input):
"""Validate a model configuration file against the JSON schema.
INPUT is the config file to validate.
Use a dash - to read standard input.
"""
try:
c = json.load(input)
except Exception as e:
raise click.ClickException(f"Unable to read JSON file. Original error: {e}")

# Raise an error if the schema is not valid.
try:
validate_config_json(c)
except InvalidRegistryConfiguration as e:
raise InvalidModelConfiguration(
"The configuration is invalid. Please see the traceback above for details."
) from e
click.secho("Configuration file is VALID", fg="green")


@cli.command()
@click.argument("huggingface_repo_id")
@click.option("-r", "--revision", help="Revision to validate", default="main")
def validate_repo(*, huggingface_repo_id: str, revision: str):
"""Validate a repository on HuggingFace.
This checks that the repository contains all of the necessary files and that
the configuration JSON file is valid.
"""
repo_id = huggingface_repo_id
del huggingface_repo_id

try:
files_in_repo = list(
huggingface_hub.list_files_info(repo_id=repo_id, revision=revision)
)
except huggingface_hub.utils.RepositoryNotFoundError:
click.secho(
f"Error: huggingface_repo_id '{repo_id}' not found on the HuggingFace Hub",
fg="red",
)
sys.exit(1)
except huggingface_hub.utils.RevisionNotFoundError:
click.secho(
f"Error: revision {revision} not found for repository {repo_id}",
fg="red",
)
sys.exit(1)
except requests.RequestException as e:
click.echo("Error with request: {e}")
click.echo("Please try again.")
sys.exit(2)

file_info = {f.rfilename: f for f in files_in_repo}

repo_url = f"https://huggingface.co/{repo_id}/tree/{revision}"

filenames_and_help = [
(
HF_CONFIG_NAME,
"This file is a JSON file with the configuration of the model and includes"
" necessary information for how to apply this model to new data. You can"
" validate this file with the command 'wsinfer_zoo validate-config'.",
),
(
HF_TORCHSCRIPT_NAME,
"This file is a TorchScript representation of the model and can be made"
" with 'torch.jit.script(model)' followed by 'torch.jit.save'. This file"
" contains the pre-trained weights as well as a graph of the model."
" Importantly, it does not require a Python runtime to be used."
f" Then, upload the file to the HuggingFace model repo at {repo_url}",
),
(
HF_WEIGHTS_PICKLE_NAME,
"This file contains the weights of the pre-trained model in normal PyTorch"
" format. Once you have a trained model, create this file with"
f'\n\n torch.save(model.state_dict(), "{HF_WEIGHTS_PICKLE_NAME}")'
f"\n\n Then, upload the file to the HuggingFace model repo at {repo_url}",
),
(
HF_WEIGHTS_SAFETENSORS_NAME,
"This file contains the weights of the pre-trained model in SafeTensors"
" format. The advantage of this file is that it does not have security"
" concerns that Pickle files (pytorch default) have. To create the file:"
"\n\n from safetensors.torch import save_file"
f'\n save_file(model.state_dict(), "{HF_WEIGHTS_SAFETENSORS_NAME}")'
f"\n\n Then, upload the file to the HuggingFace model repo at {repo_url}",
),
]

invalid = False
for name, help_msg in filenames_and_help:
if name not in file_info:
click.secho(
f"Required file '{name}' not found in HuggingFace model repo '{repo_id}'",
fg="red",
)
click.echo(f" {help_msg}")
click.echo("-" * 40)
invalid = True

if invalid:
click.secho(
f"Model repository {repo_id} is invalid. See above for details.", fg="red"
)
sys.exit(1)

config_path = huggingface_hub.hf_hub_download(
repo_id, HF_CONFIG_NAME, revision=revision
)
with open(config_path) as f:
config_dict = json.load(f)
try:
validate_config_json(config_dict)
except InvalidModelConfiguration as e:
click.secho(
"Model configuration JSON file is invalid. Use 'wsinfer_zoo validate-config'"
" with the configuration file to debug this further.",
fg="red",
)
click.secho(
f"Model repository {repo_id} is invalid. See above for details.", fg="red"
)
sys.exit(1)

click.secho(f"Repository {repo_id} is VALID.", fg="green")
Loading

0 comments on commit a35f863

Please sign in to comment.