Skip to content

Commit

Permalink
Merge pull request #108 from deepghs/dev/gradio
Browse files Browse the repository at this point in the history
dev(narugo): add quick gradio demo for classifiers/yolos
  • Loading branch information
narugo1992 authored Sep 18, 2024
2 parents 7508461 + dbdba35 commit 4d530a2
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_doc/generic/classify.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ClassifyModel
-----------------------------------------

.. autoclass:: ClassifyModel
:members: __init__, predict_score, predict, clear
:members: __init__, predict_score, predict, clear, make_ui, launch_demo



Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_doc/generic/yolo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ YOLOModel
-----------------------------------------

.. autoclass:: YOLOModel
:members: __init__, predict, clear
:members: __init__, predict, clear, make_ui, launch_demo



Expand Down
112 changes: 112 additions & 0 deletions imgutils/generic/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,37 @@

import numpy as np
from PIL import Image
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem

from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model

try:
import gradio as gr
except (ImportError, ModuleNotFoundError):
gr = None

__all__ = [
'ClassifyModel',
'classify_predict_score',
'classify_predict',
]


def _check_gradio_env():
"""
Check if the Gradio library is installed and available.
:raises EnvironmentError: If Gradio is not installed.
"""
if gr is None:
raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
f'Please install it with `pip install dghs-imgutils[demo]`.')


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Expand Down Expand Up @@ -287,6 +305,100 @@ def clear(self):
self._models.clear()
self._labels.clear()

def make_ui(self, default_model_name: Optional[str] = None):
"""
Create the user interface components for the classifier model demo.
This method sets up the Gradio UI components including an image input, model selection dropdown,
submit button, and output label. It also configures the interaction between these components.
:param default_model_name: The name of the default model to be selected in the dropdown.
If None, the most recently updated model will be selected.
:type default_model_name: Optional[str]
:raises ImportError: If Gradio is not installed or properly configured.
:Example:
>>> model = ClassifyModel("username/repo_name")
>>> model.make_ui(default_model_name="model_v1")
"""

# demo for classifier model
_check_gradio_env()
model_list = self.model_names
if not default_model_name:
hf_client = get_hf_client(hf_token=self._get_hf_token())
selected_model_name, selected_time = None, None
for fileitem in hf_client.get_paths_info(
repo_id=self.repo_id,
repo_type='model',
paths=[f'{model_name}/model.onnx' for model_name in model_list],
expand=True,
):
if not selected_time or fileitem.last_commit.date > selected_time:
selected_model_name = os.path.dirname(fileitem.path)
selected_time = fileitem.last_commit.date
default_model_name = selected_model_name

with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
gr_submit = gr.Button(value='Submit', variant='primary')

with gr.Column():
gr_output = gr.Label(label='Prediction')

gr_submit.click(
self.predict_score,
inputs=[
gr_input_image,
gr_model,
],
outputs=[gr_output],
)

def launch_demo(self, default_model_name: Optional[str] = None,
server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
"""
Launch the Gradio demo for the classifier model.
This method creates a Gradio Blocks interface, sets up the UI components using make_ui(),
and launches the demo server.
:param default_model_name: The name of the default model to be selected in the dropdown.
:type default_model_name: Optional[str]
:param server_name: The name of the server to run the demo on. Defaults to None.
:type server_name: Optional[str]
:param server_port: The port number to run the demo on. Defaults to None.
:type server_port: Optional[int]
:param kwargs: Additional keyword arguments to pass to the Gradio launch method.
:raises ImportError: If Gradio is not installed or properly configured.
:Example:
>>> model = ClassifyModel("username/repo_name")
>>> model.launch_demo(default_model_name="model_v1", server_name="0.0.0.0", server_port=7860)
"""

_check_gradio_env()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
gr.HTML(f'<h2 style="text-align: center;">Classifier Demo For {self.repo_id}</h2>')
gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). '
f'Powered by `dghs-imgutils`\'s quick demo module.')

with gr.Row():
self.make_ui(default_model_name=default_model_name)

demo.launch(
server_name=server_name,
server_port=server_port,
**kwargs,
)


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel:
Expand Down
180 changes: 180 additions & 0 deletions imgutils/generic/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,61 @@

import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download

from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model

try:
import gradio as gr
except (ImportError, ModuleNotFoundError):
gr = None

__all__ = [
'YOLOModel',
'yolo_predict',
]


def _check_gradio_env():
"""
Check if the Gradio library is installed and available.
:raises EnvironmentError: If Gradio is not installed.
"""
if gr is None:
raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
f'Please install it with `pip install dghs-imgutils[demo]`.')


def _v_fix(v):
"""
Round and convert a float value to an integer.
:param v: The float value to be rounded and converted.
:type v: float
:return: The rounded integer value.
:rtype: int
"""
return int(round(v))


def _bbox_fix(bbox):
"""
Fix the bounding box coordinates by rounding them to integers.
:param bbox: The bounding box coordinates.
:type bbox: tuple
:return: A tuple of fixed (rounded to integer) bounding box coordinates.
:rtype: tuple
"""
return tuple(map(_v_fix, bbox))


def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
"""
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format.
Expand Down Expand Up @@ -403,9 +446,146 @@ def predict(self, image: ImageTyping, model_name: str,
def clear(self):
"""
Clear cached model and metadata.
This method removes all cached models and their associated metadata from memory.
It's useful for freeing up memory or ensuring that the latest versions of models are loaded.
"""
self._models.clear()

def make_ui(self, default_model_name: Optional[str] = None,
default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
"""
Create a Gradio-based user interface for object detection.
This method sets up an interactive UI that allows users to upload images,
select models, and adjust detection parameters. It uses the Gradio library
to create the interface.
:param default_model_name: The name of the default model to use.
If None, the most recently updated model is selected.
:type default_model_name: Optional[str]
:param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25.
:type default_conf_threshold: float
:param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
:type default_iou_threshold: float
:raises ImportError: If Gradio is not installed in the environment.
:Example:
>>> model = YOLOModel("username/repo_name")
>>> model.make_ui(default_model_name="yolov5s")
"""
_check_gradio_env()
model_list = self.model_names
if not default_model_name:
hf_client = get_hf_client(hf_token=self._get_hf_token())
selected_model_name, selected_time = None, None
for fileitem in hf_client.get_paths_info(
repo_id=self.repo_id,
repo_type='model',
paths=[f'{model_name}/model.onnx' for model_name in model_list],
expand=True,
):
if not selected_time or fileitem.last_commit.date > selected_time:
selected_model_name = os.path.dirname(fileitem.path)
selected_time = fileitem.last_commit.date
default_model_name = selected_model_name

def _gr_detect(image: ImageTyping, model_name: str,
iou_threshold: float = 0.7, score_threshold: float = 0.25) \
-> gr.AnnotatedImage:
_, _, labels = self._open_model(model_name=model_name)
_colors = list(map(str, rnd_colors(len(labels))))
_color_map = dict(zip(labels, _colors))
return gr.AnnotatedImage(
value=(image, [
(_bbox_fix(bbox), label)
for bbox, label, _ in self.predict(
image=image,
model_name=model_name,
iou_threshold=iou_threshold,
conf_threshold=score_threshold,
)
]),
color_map=_color_map,
label='Labeled',
)

with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
with gr.Row():
gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')

gr_submit = gr.Button(value='Submit', variant='primary')

with gr.Column():
gr_output_image = gr.AnnotatedImage(label="Labeled")

gr_submit.click(
_gr_detect,
inputs=[
gr_input_image,
gr_model,
gr_iou_threshold,
gr_score_threshold,
],
outputs=[gr_output_image],
)

def launch_demo(self, default_model_name: Optional[str] = None,
default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
"""
Launch a Gradio demo for object detection.
This method creates and launches a Gradio demo that allows users to interactively
perform object detection on uploaded images using the YOLO model.
:param default_model_name: The name of the default model to use.
If None, the most recently updated model is selected.
:type default_model_name: Optional[str]
:param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25.
:type default_conf_threshold: float
:param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
:type default_iou_threshold: float
:param server_name: The name of the server to run the demo on. Default is None.
:type server_name: Optional[str]
:param server_port: The port to run the demo on. Default is None.
:type server_port: Optional[int]
:param kwargs: Additional keyword arguments to pass to gr.Blocks.launch().
:raises EnvironmentError: If Gradio is not installed in the environment.
Example:
>>> model = YOLOModel("username/repo_name")
>>> model.launch_demo(default_model_name="yolov5s", server_name="0.0.0.0", server_port=7860)
"""
_check_gradio_env()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
gr.HTML(f'<h2 style="text-align: center;">YOLO Demo For {self.repo_id}</h2>')
gr.Markdown(f'This is the quick demo for YOLO model [{self.repo_id}]({repo_url}). '
f'Powered by `dghs-imgutils`\'s quick demo module.')

with gr.Row():
self.make_ui(
default_model_name=default_model_name,
default_conf_threshold=default_conf_threshold,
default_iou_threshold=default_iou_threshold,
)

demo.launch(
server_name=server_name,
server_port=server_port,
**kwargs,
)


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YOLOModel:
Expand Down
1 change: 1 addition & 0 deletions requirements-demo.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gradio>=4.44.0

0 comments on commit 4d530a2

Please sign in to comment.