Skip to content

Commit

Permalink
add: StabilityAPIModel
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 4, 2024
1 parent 0682bd5 commit cbfdbf9
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 9 deletions.
9 changes: 7 additions & 2 deletions hemm/eval_pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .eval_pipeline import EvaluationPipeline
from .model import BaseDiffusionModel
from .model import BaseDiffusionModel, FalDiffusionModel, StabilityAPIModel

__all__ = ["BaseDiffusionModel", "EvaluationPipeline"]
__all__ = [
"BaseDiffusionModel",
"EvaluationPipeline",
"FalDiffusionModel",
"StabilityAPIModel",
]
68 changes: 63 additions & 5 deletions hemm/eval_pipelines/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import io
import os
from typing import Any, Dict

import fal_client
import requests
import torch
import weave
from diffusers import DiffusionPipeline
Expand All @@ -10,15 +13,25 @@
from ..utils import custom_weave_wrapper


STABILITY_MODEL_HOST = {
"sd3-large": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
"sd3-large-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3",
}


class BaseDiffusionModel(weave.Model):
"""Base `weave.Model` wrapping `diffusers.DiffusionPipeline`.
"""`weave.Model` wrapping `diffusers.DiffusionPipeline`.
Args:
diffusion_model_name_or_path (str): The name or path of the diffusion model.
enable_cpu_offfload (bool): Enable CPU offload for the diffusion model.
image_height (int): The height of the generated image.
image_width (int): The width of the generated image.
num_inference_steps (int): The number of inference steps.
disable_safety_checker (bool): Disable safety checker for the diffusion model.
configs (Dict[str, Any]): Additional configs.
pipeline_configs (Dict[str, Any]): Diffusion pipeline configs.
inference_kwargs (Dict[str, Any]): Inference kwargs.
"""

diffusion_model_name_or_path: str
Expand Down Expand Up @@ -85,20 +98,65 @@ def predict(self, prompt: str, seed: int) -> Dict[str, Any]:
return {"image": pipeline_output.images[0]}


class FalDiffusionModel(BaseDiffusionModel):
model_address: str
class FalDiffusionModel(weave.Model):
"""`weave.Model` wrapping [FalAI](https://fal.ai/) calls.
Args:
model_name (str): FalAI model name.
inference_kwargs (Dict[str, Any]): Inference kwargs.
"""

model_name: str
inference_kwargs: Dict[str, Any] = {}

@weave.op()
def generate_image(self, prompt: str, seed: int) -> Image.Image:
result = custom_weave_wrapper(name="fal_client.submit.get")(
fal_client.submit(
self.model_address,
self.model_name,
arguments={"prompt": prompt, "seed": seed, **self.inference_kwargs},
).get
)()
return load_image(result["images"][0]["url"])

@weave.op()
def predict(self, prompt: str, seed: int) -> Image.Image:
return self.generate_image(prompt=prompt, seed=seed)
return {"image": self.generate_image(prompt=prompt, seed=seed)}


class StabilityAPIModel(weave.Model):
"""`weave.Model` wrapping Stability API calls.
Args:
model_name (str): Stability model name.
"""

model_name: str

@weave.op()
def send_generation_request(self, prompt: str, seed: int):
api_key = os.environ["STABILITY_KEY"]
headers = {"Accept": "image/*", "Authorization": f"Bearer {api_key}"}
response = requests.post(
STABILITY_MODEL_HOST[self.model_name],
headers=headers,
files={"none": ""},
data={
"prompt": prompt,
"negative_prompt": "",
"aspect_ratio": "1:1",
"seed": seed,
"output_format": "png",
"model": self.model_name,
"mode": "text-to-image",
},
)
if not response.ok:
raise Exception(f"HTTP {response.status_code}: {response.text}")
return response

@weave.op()
def predict(self, prompt: str, seed: int) -> Image.Image:
response = self.send_generation_request(prompt=prompt, seed=seed)
image = Image.open(io.BytesIO(response.content))
return {"image": image}
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ torchmetrics = { extras = ["multimodal"], version = "^1.4.1" }
mkdocstrings = {version = "^0.25.2", extras = ["python"]}
sentencepiece = "^0.2.0"
fal-client = "^0.4.1"
python-dotenv = "^1.0.1"

[tool.poetry.extras]
core = [
Expand All @@ -48,6 +49,7 @@ core = [
"fire",
"fal-client",
"jsonlines",
"python-dotenv",
"spacy",
"instructor",
"torchmetrics",
Expand Down

0 comments on commit cbfdbf9

Please sign in to comment.