Skip to content

Commit

Permalink
Added boolean flag to control IPFS upload (#470)
Browse files Browse the repository at this point in the history
* Added boolean flag to control IPFS upload

* Fixed test

* ENABLE_UPLOAD_IPFS default False

* Fixed isort | removed optionality

* Tests a bit more robust and clean

* Fixed tests
  • Loading branch information
gabrielfior authored Oct 4, 2024
1 parent 05535b1 commit 20c6229
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ BET_FROM_PRIVATE_KEY=
OPENAI_API_KEY=
GRAPH_API_KEY=
PINATA_API_KEY=
PINATA_API_SECRET=
PINATA_API_SECRET=
ENABLE_IPFS_UPLOAD=
7 changes: 7 additions & 0 deletions prediction_market_agent_tooling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class APIKeys(BaseSettings):
LANGFUSE_HOST: t.Optional[str] = None
LANGFUSE_DEPLOYMENT_VERSION: t.Optional[str] = None

ENABLE_IPFS_UPLOAD: bool = False
PINATA_API_KEY: t.Optional[SecretStr] = None
PINATA_API_SECRET: t.Optional[SecretStr] = None

Expand Down Expand Up @@ -151,6 +152,12 @@ def default_enable_langfuse(self) -> bool:
and self.LANGFUSE_HOST is not None
)

@property
def enable_ipfs_upload(self) -> bool:
return check_not_none(
self.ENABLE_IPFS_UPLOAD, "ENABLE_IPFS_UPLOAD missing in the environment."
)

@property
def pinata_api_key(self) -> SecretStr:
return check_not_none(
Expand Down
15 changes: 10 additions & 5 deletions prediction_market_agent_tooling/deploy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import BaseModel, BeforeValidator, computed_field
from typing_extensions import Annotated
from web3 import Web3
from web3.constants import HASH_ZERO

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.deploy.betting_strategy import (
Expand Down Expand Up @@ -302,7 +303,6 @@ def __init__(
) -> None:
super().__init__(enable_langfuse=enable_langfuse)
self.place_bet = place_bet
self.ipfs_handler = IPFSHandler(APIKeys())

def get_betting_strategy(self, market: AgentMarket) -> BettingStrategy:
user_id = market.get_user_id(api_keys=APIKeys())
Expand Down Expand Up @@ -522,16 +522,21 @@ def store_prediction(
if processed_market.answer.reasoning
else ""
)
ipfs_hash = self.ipfs_handler.store_agent_result(
IPFSAgentResult(reasoning=reasoning)
)

ipfs_hash_decoded = HexBytes(HASH_ZERO)
if keys.enable_ipfs_upload:
logger.info("Storing prediction on IPFS.")
ipfs_hash = IPFSHandler(keys).store_agent_result(
IPFSAgentResult(reasoning=reasoning)
)
ipfs_hash_decoded = ipfscidv0_to_byte32(ipfs_hash)

tx_hashes = [
HexBytes(HexStr(i.id)) for i in processed_market.trades if i.id is not None
]
prediction = ContractPrediction(
publisher=keys.public_key,
ipfs_hash=ipfscidv0_to_byte32(ipfs_hash),
ipfs_hash=ipfs_hash_decoded,
tx_hashes=tx_hashes,
estimated_probability_bps=int(processed_market.answer.p_yes * 10000),
)
Expand Down
18 changes: 13 additions & 5 deletions tests_integration_with_local_chain/markets/omen/test_omen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from web3 import Web3
from web3.constants import HASH_ZERO

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import (
Expand Down Expand Up @@ -455,25 +456,32 @@ def get_position_balance_by_position_id(
)


def test_add_predictions(local_web3: Web3, test_keys: APIKeys) -> None:
@pytest.mark.parametrize(
"ipfs_hash",
["0x3750ffa211dab39b4d0711eb27b02b56a17fa9d257ee549baa3110725fd1d41b", HASH_ZERO],
)
def test_add_predictions(local_web3: Web3, test_keys: APIKeys, ipfs_hash: str) -> None:
agent_result_mapping = OmenAgentResultMappingContract()
market_address = test_keys.public_key
dummy_transaction_hash = (
"0x3750ffa211dab39b4d0711eb27b02b56a17fa9d257ee549baa3110725fd1d41b"
)
stored_predictions = agent_result_mapping.get_predictions(
market_address, web3=local_web3
)
p = ContractPrediction(
tx_hashes=[HexBytes(dummy_transaction_hash)],
estimated_probability_bps=5454,
ipfs_hash=HexBytes(dummy_transaction_hash),
ipfs_hash=HexBytes(ipfs_hash),
publisher=test_keys.public_key,
)

agent_result_mapping.add_prediction(test_keys, market_address, p, web3=local_web3)
stored_predictions = agent_result_mapping.get_predictions(
updated_stored_predictions = agent_result_mapping.get_predictions(
market_address, web3=local_web3
)
assert len(stored_predictions) == 1
assert stored_predictions[0] == p
assert len(updated_stored_predictions) == len(stored_predictions) + 1
assert updated_stored_predictions[-1] == p


def test_place_bet_with_prev_existing_positions(
Expand Down

0 comments on commit 20c6229

Please sign in to comment.