From e579a4fb7a535486a85f96f2f7e02cd67af7eda3 Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Sat, 7 Sep 2024 07:46:10 -0500 Subject: [PATCH] fix load fails --- mle/model.py | 2 +- mle/utils/memory.py | 2 +- mle/utils/system.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mle/model.py b/mle/model.py index 78f0c70..9f47c46 100644 --- a/mle/model.py +++ b/mle/model.py @@ -310,7 +310,7 @@ def load_model(project_dir: str, model_name: str): project_dir (str): The project directory. model_name (str): The model name. """ - config = get_config() + config = get_config(project_dir) if config['platform'] == MODEL_OPENAI: return OpenAIModel(api_key=config['api_key'], model=model_name) if config['platform'] == MODEL_CLAUDE: diff --git a/mle/utils/memory.py b/mle/utils/memory.py index b33aecf..eef2fcb 100644 --- a/mle/utils/memory.py +++ b/mle/utils/memory.py @@ -29,7 +29,7 @@ def __init__( self.collection_name = 'memory' self.client = chromadb.PersistentClient(path=os.path.join(project_path, self.db_name)) - config = get_config() + config = get_config(project_path) # use the OpenAI embedding function if the openai section is set in the configuration. if config['platform'] == 'OpenAI': self.client.get_or_create_collection( diff --git a/mle/utils/system.py b/mle/utils/system.py index 34dcc09..fffae56 100644 --- a/mle/utils/system.py +++ b/mle/utils/system.py @@ -97,12 +97,13 @@ def check_config(console: Optional[Console] = None): return True -def get_config() -> Optional[Dict[str, Any]]: +def get_config(workdir: str = None) -> Optional[Dict[str, Any]]: """ Get the configuration file. + :workdir: the project directory. :return: the configuration file. """ - config_dir = os.path.join(os.getcwd(), '.mle') + config_dir = os.path.join(workdir or os.getcwd(), '.mle') config_path = os.path.join(config_dir, 'project.yml') if not os.path.exists(config_path): return None