Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
MeetWq committed Sep 21, 2024
1 parent c8c3c51 commit 8dcb274
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
16 changes: 4 additions & 12 deletions meme_generator/meme.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Any, Literal, Optional, Protocol, TypeVar, Union
from typing import Any, Literal, Optional, TypeVar, Union, Callable

from arclet.alconna import ArgFlag, Args, Empty, Option
from arclet.alconna.action import Action
Expand Down Expand Up @@ -31,14 +31,7 @@ class MemeArgsModel(BaseModel):

ArgsModel = TypeVar("ArgsModel", bound=MemeArgsModel)


class MemeFunction(Protocol):
def __call__(
self,
images: list[BuildImage],
texts: list[str],
args: ArgsModel, # type: ignore
) -> BytesIO: ...
MemeFunction = Callable[[list[BuildImage], list[str], ArgsModel], BytesIO]


class ParserArg(BaseModel):
Expand Down Expand Up @@ -146,12 +139,11 @@ def __call__(
for image in images:
if isinstance(image, bytes):
image = BytesIO(image)
imgs.append(BuildImage.open(image)) # type: ignore
imgs.append(BuildImage.open(image))
except Exception as e:
raise OpenImageFailed(str(e))

values = {"images": imgs, "texts": texts, "args": model}
return self.function(**values)
return self.function(imgs, texts, model)

def generate_preview(self, *, args: dict[str, Any] = {}) -> BytesIO:
default_images = [random_image() for _ in range(self.params_type.min_images)]
Expand Down
9 changes: 3 additions & 6 deletions meme_generator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import partial, wraps
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar

import httpx
from PIL.Image import Image as IMG
Expand Down Expand Up @@ -210,12 +210,9 @@ def get_aligned_gif_indexes(
return frame_idxs_input, frame_idxs_target


class Maker(Protocol):
def __call__(self, imgs: list[BuildImage]) -> BuildImage: ...
Maker = Callable[[list[BuildImage]], BuildImage]


class GifMaker(Protocol):
def __call__(self, i: int) -> Maker: ...
GifMaker = Callable[[int], Maker]


def merge_gif(imgs: list[BuildImage], func: Maker) -> BytesIO:
Expand Down

0 comments on commit 8dcb274

Please sign in to comment.