Skip to content

Commit

Permalink
fix load fails
Browse files Browse the repository at this point in the history
  • Loading branch information
leeeizhang committed Sep 7, 2024
1 parent 37e4f04 commit e579a4f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def load_model(project_dir: str, model_name: str):
project_dir (str): The project directory.
model_name (str): The model name.
"""
config = get_config()
config = get_config(project_dir)
if config['platform'] == MODEL_OPENAI:
return OpenAIModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_CLAUDE:
Expand Down
2 changes: 1 addition & 1 deletion mle/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.collection_name = 'memory'
self.client = chromadb.PersistentClient(path=os.path.join(project_path, self.db_name))

config = get_config()
config = get_config(project_path)
# use the OpenAI embedding function if the openai section is set in the configuration.
if config['platform'] == 'OpenAI':
self.client.get_or_create_collection(
Expand Down
5 changes: 3 additions & 2 deletions mle/utils/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ def check_config(console: Optional[Console] = None):
return True


def get_config() -> Optional[Dict[str, Any]]:
def get_config(workdir: str = None) -> Optional[Dict[str, Any]]:
"""
Get the configuration file.
:workdir: the project directory.
:return: the configuration file.
"""
config_dir = os.path.join(os.getcwd(), '.mle')
config_dir = os.path.join(workdir or os.getcwd(), '.mle')
config_path = os.path.join(config_dir, 'project.yml')
if not os.path.exists(config_path):
return None
Expand Down

0 comments on commit e579a4f

Please sign in to comment.