Skip to content

Commit

Permalink
Merge pull request #239 from victorpolisetty/dalle-request
Browse files Browse the repository at this point in the history
DALL-E request mech tool
  • Loading branch information
0xArdi authored Jul 2, 2024
2 parents 10fbda7 + 1f5d271 commit f647200
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 4 deletions.
7 changes: 4 additions & 3 deletions packages/gnosis/customs/omen_tools/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ name: omen_tools
author: gnosis
version: 0.1.0
type: custom
description: Collection of tools to prepare requests for interacting with prediction markets on Omen.
description: Collection of tools to prepare requests for interacting with prediction
markets on Omen.
license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeibbn67pnrrm4qm3n3kbelvbs3v7fjlrjniywmw2vbizarippidtvi
prediction_sum_url_content.py: bafybeieywowx265yycgf5735bw4zyabfy6ivwnntl6smxa2hicktipgeby
omen_buy_sell.py: bafybeid3zaursxt2nkm2u7x7u4wlodg2ulzlieu5xxsfxjyxzi3vbcezdm
fingerprint_ignore_patterns: []
entry_point: omen_buy_sell.py
callable: run
Expand All @@ -26,4 +27,4 @@ dependencies:
langchain_community:
version: ==0.2.1
openai:
version: ==1.30.2
version: ==1.30.2
1 change: 1 addition & 0 deletions packages/packages.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"custom/valory/prediction_langchain/0.1.0": "bafybeihhii7veepp6ovkmqjnkp6euhkwm52obabgdltdj34ikisfd7yvqi",
"custom/victorpolisetty/gemini_request/0.1.0": "bafybeig5x6b5jtanet2q5sk7er7fdzpippbvh4q5p7uxmxpriq66omjnaq",
"custom/gnosis/omen_tools/0.1.0": "bafybeifxrawgu6m3dgsxvj7jrhxzr5gwi3zjk2m4gltkr5w3hxjjbla6nu",
"custom/victorpolisetty/dalle_request/0.1.0": "bafybeicgjdvgamkgjebdrowrxdil3aghsbcm7epup6aqidikvjpmvomn6q",
"protocol/valory/acn_data_share/0.1.0": "bafybeih5ydonnvrwvy2ygfqgfabkr47s4yw3uqxztmwyfprulwfsoe7ipq",
"protocol/valory/websocket_client/0.1.0": "bafybeifjk254sy65rna2k32kynzenutujwqndap2r222afvr3zezi27mx4",
"contract/valory/agent_mech/0.1.0": "bafybeiah6b5epo2hlvzg5rr2cydgpp2waausoyrpnoarf7oa7bw33rex34",
Expand Down
2 changes: 1 addition & 1 deletion packages/valory/services/mech/service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license: Apache-2.0
fingerprint:
README.md: bafybeif7ia4jdlazy6745ke2k2x5yoqlwsgwr6sbztbgqtwvs3ndm2p7ba
fingerprint_ignore_patterns: []
agent: valory/mech:0.1.0:bafybeid2hlmwtoze3xhhqqayisdg6xzxzgf42zw7hccdaszur7qfgp5vu4
agent: valory/mech:0.1.0:bafybeih2oex4yt4mmiyarp2ivkqqfscoavyb7metchciiqmqdmv7lhyutq
number_of_agents: 4
deployment:
agent:
Expand Down
19 changes: 19 additions & 0 deletions packages/victorpolisetty/customs/dalle_request/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#
# Copyright 2024 Valory AG
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ------------------------------------------------------------------------------
18 changes: 18 additions & 0 deletions packages/victorpolisetty/customs/dalle_request/component.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: dalle_request
author: victorpolisetty
version: 0.1.0
type: custom
description: A tool that runs a prompt against the OpenAI DALL-E API.
license: Apache-2.0
aea_version: '>=1.0.0, <2.0.0'
fingerprint:
__init__.py: bafybeicokooiqnmkldoi5tx6zv6svtjoak5ghj5o3vsi4iliq2pvqaa6uy
dalle_request.py: bafybeicagzzicf7o6u6iotbvxfacdvntn5bpa4fptmst2iakjwaz7os2ry
fingerprint_ignore_patterns: []
entry_point: dalle_request.py
callable: run
dependencies:
openai:
version: ==1.30.2
tiktoken:
version: ==0.7.0
122 changes: 122 additions & 0 deletions packages/victorpolisetty/customs/dalle_request/dalle_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import functools
from typing import Any, Dict, Optional, Tuple, Callable
from openai import OpenAI
from tiktoken import encoding_for_model

client: Optional[OpenAI] = None
MechResponse = Tuple[str, Optional[str], Optional[Dict[str, Any]], Any, Any]


def with_key_rotation(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs) -> MechResponse:
api_keys = kwargs["api_keys"]
retries_left: Dict[str, int] = api_keys.max_retries()

def execute() -> MechResponse:
"""Retry the function with a new key."""
try:
result = func(*args, **kwargs)
# Ensure the result is a tuple and has the correct length
if isinstance(result, tuple) and len(result) == 4:
return result + (api_keys,)
else:
raise ValueError("Function did not return a valid MechResponse tuple.")
except openai.error.RateLimitError as e:
# try with a new key again
if retries_left["openai"] <= 0 and retries_left["openrouter"] <= 0:
raise e
retries_left["openai"] -= 1
retries_left["openrouter"] -= 1
api_keys.rotate("openai")
api_keys.rotate("openrouter")
return execute()
except Exception as e:
return str(e), "", None, None, api_keys

mech_response = execute()
return mech_response

return wrapper


class OpenAIClientManager:
"""Client context manager for OpenAI."""

def __init__(self, api_key: str):
self.api_key = api_key

def __enter__(self) -> OpenAI:
global client
if client is None:
client = OpenAI(api_key=self.api_key)
return client

def __exit__(self, exc_type, exc_value, traceback) -> None:
global client
if client is not None:
client.close()
client = None


def count_tokens(text: str, model: str) -> int:
"""Count the number of tokens in a text."""
enc = encoding_for_model(model)
return len(enc.encode(text))


DEFAULT_DALLE_SETTINGS = {
"size": "1024x1024",
"quality": "standard",
"n": 1,
}
PREFIX = "dall-e"
ENGINES = {
"text-to-image": ["-2", "-3"],
}
ALLOWED_MODELS = [PREFIX]
ALLOWED_TOOLS = [PREFIX + value for value in ENGINES["text-to-image"]]
ALLOWED_SIZE = ["1024x1024", "1024x1792", "1792x1024"]
ALLOWED_QUALITY = ["standard", "hd"]


@with_key_rotation
def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any, Any]:
"""Run the task"""
with OpenAIClientManager(kwargs["api_keys"]["openai"]):
tool = kwargs["tool"]
prompt = kwargs["prompt"]
size = kwargs.get("size", DEFAULT_DALLE_SETTINGS["size"])
quality = kwargs.get("quality", DEFAULT_DALLE_SETTINGS["quality"])
n = kwargs.get("n", DEFAULT_DALLE_SETTINGS["n"])
counter_callback = kwargs.get("counter_callback", None)
if tool not in ALLOWED_TOOLS:
return (
f"Tool {tool} is not in the list of supported tools.",
None,
None,
None,
)
if size not in ALLOWED_SIZE:
return (
f"Size {size} is not in the list of supported sizes.",
None,
None,
None,
)
if quality not in ALLOWED_QUALITY:
return (
f"Quality {quality} is not in the list of supported qualities.",
None,
None,
None,
)

response = client.images.generate(
model=tool,
prompt=prompt,
size=size,
quality=quality,
n=n,
)
return response.data[0].url, prompt, None, counter_callback
11 changes: 11 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List, Any

from packages.gnosis.customs.omen_tools import omen_buy_sell
from packages.victorpolisetty.customs.dalle_request import dalle_request
from packages.napthaai.customs.prediction_request_rag import prediction_request_rag
from packages.napthaai.customs.prediction_request_rag_cohere import (
prediction_request_rag_cohere,
Expand Down Expand Up @@ -175,3 +176,13 @@ def _validate_response(self, response: Any) -> None:
super()._validate_response(response)
expected_num_tx_params = 2
assert len(response[2].keys()) == expected_num_tx_params

class TestDALLEGeneration(BaseToolTest):
"""Test DALL-E Generation."""

tools = dalle_request.ALLOWED_TOOLS
models = dalle_request.ALLOWED_MODELS
prompts = [
"Generate an image of a futuristic cityscape."
]
tool_module = dalle_request

0 comments on commit f647200

Please sign in to comment.