diff --git a/docker-compose.yml b/docker-compose.yml index 5f756e7..1a2a183 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -31,15 +31,15 @@ services: profiles: [ollama-docker] search-agent-server: - image: dria-searching-agent:server + image: firstbatch/dria-searching-agent:latest build: context: . dockerfile: Dockerfile ports: - 5000:5000 environment: - AGENT_MODEL_PROVIDER: Ollama - AGENT_MODEL: phi3:latest + AGENT_MODEL_PROVIDER: ${AGENT_MODEL_PROVIDER} + AGENT_MODEL_NAME: ${AGENT_MODEL_NAME} ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY} OPENAI_API_KEY: ${OPENAI_API_KEY} diff --git a/src/dria_searching_agent/config/config.py b/src/dria_searching_agent/config/config.py index 9fedc55..b7c41eb 100644 --- a/src/dria_searching_agent/config/config.py +++ b/src/dria_searching_agent/config/config.py @@ -5,8 +5,8 @@ class Config: def __init__(self): load_dotenv() # This loads environment variables from a .env file if it exists - self.agent_model_provider = os.getenv('AGENT_MODEL_PROVIDER') - self.agent_model = os.getenv('AGENT_MODEL') + self.agent_model_provider = os.getenv('AGENT_MODEL_PROVIDER', "ollama") + self.agent_model_name = os.getenv('AGENT_MODEL_NAME', "gpt-4o") self.agent_max_iter = os.getenv('AGENT_MAX_ITER', 10) self.anthropic_key = os.getenv('ANTHROPIC_KEY') @@ -28,8 +28,8 @@ def load_config(): def AGENT_MODEL_PROVIDER(): return config.agent_model_provider -def AGENT_MODEL(): - return config.agent_model +def AGENT_MODEL_NAME(): + return config.agent_model_name def AGENT_MAX_ITER(): return config.agent_max_iter diff --git a/src/dria_searching_agent/main.py b/src/dria_searching_agent/main.py index c99f36c..cc9811c 100644 --- a/src/dria_searching_agent/main.py +++ b/src/dria_searching_agent/main.py @@ -164,11 +164,11 @@ def __create_agents(self): def __get_model(self): if config.AGENT_MODEL_PROVIDER().lower() == "anthropic": - return ChatAnthropic(model=config.AGENT_MODEL(), api_key=config.ANTHROPIC_KEY()) + return ChatAnthropic(model=config.AGENT_MODEL_NAME(), api_key=config.ANTHROPIC_KEY()) elif config.AGENT_MODEL_PROVIDER().lower() == "openai": - return ChatOpenAI(model=config.AGENT_MODEL(), api_key=config.OPENAI_API_KEY()) + return ChatOpenAI(model=config.AGENT_MODEL_NAME(), api_key=config.OPENAI_API_KEY()) elif config.AGENT_MODEL_PROVIDER().lower() == "ollama": - return ollama.Ollama(model=config.AGENT_MODEL(), base_url=config.OLLAMA_URL()) + return ollama.Ollama(model=config.AGENT_MODEL_NAME(), base_url=config.OLLAMA_URL()) _research_crew_instance = None def GetResearchCrew():