diff --git a/meme_generator/meme.py b/meme_generator/meme.py index 3b07bac..6830a0a 100644 --- a/meme_generator/meme.py +++ b/meme_generator/meme.py @@ -2,7 +2,7 @@ from datetime import datetime from io import BytesIO from pathlib import Path -from typing import Any, Literal, Optional, Protocol, TypeVar, Union +from typing import Any, Literal, Optional, TypeVar, Union, Callable from arclet.alconna import ArgFlag, Args, Empty, Option from arclet.alconna.action import Action @@ -31,14 +31,7 @@ class MemeArgsModel(BaseModel): ArgsModel = TypeVar("ArgsModel", bound=MemeArgsModel) - -class MemeFunction(Protocol): - def __call__( - self, - images: list[BuildImage], - texts: list[str], - args: ArgsModel, # type: ignore - ) -> BytesIO: ... +MemeFunction = Callable[[list[BuildImage], list[str], ArgsModel], BytesIO] class ParserArg(BaseModel): @@ -146,12 +139,11 @@ def __call__( for image in images: if isinstance(image, bytes): image = BytesIO(image) - imgs.append(BuildImage.open(image)) # type: ignore + imgs.append(BuildImage.open(image)) except Exception as e: raise OpenImageFailed(str(e)) - values = {"images": imgs, "texts": texts, "args": model} - return self.function(**values) + return self.function(imgs, texts, model) def generate_preview(self, *, args: dict[str, Any] = {}) -> BytesIO: default_images = [random_image() for _ in range(self.params_type.min_images)] diff --git a/meme_generator/utils.py b/meme_generator/utils.py index 66d5118..953e55c 100644 --- a/meme_generator/utils.py +++ b/meme_generator/utils.py @@ -10,7 +10,7 @@ from functools import partial, wraps from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar import httpx from PIL.Image import Image as IMG @@ -210,12 +210,9 @@ def get_aligned_gif_indexes( return frame_idxs_input, frame_idxs_target -class Maker(Protocol): - def __call__(self, imgs: list[BuildImage]) -> BuildImage: ... +Maker = Callable[[list[BuildImage]], BuildImage] - -class GifMaker(Protocol): - def __call__(self, i: int) -> Maker: ... +GifMaker = Callable[[int], Maker] def merge_gif(imgs: list[BuildImage], func: Maker) -> BytesIO: