Skip to content

Commit

Permalink
Merge pull request #218 from leeeizhang/lei/chat-enhance
Browse files Browse the repository at this point in the history
[MRG] inject the project info into chat system prompt
  • Loading branch information
HuaizhengZhang authored Sep 23, 2024
2 parents 1b637be + 083dd69 commit 54b7a3e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
11 changes: 10 additions & 1 deletion mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mle.model import load_model
from mle.agents import CodeAgent
import mle.workflow as workflow
from mle.utils import Memory
from mle.utils import Memory, WorkflowCache
from mle.utils.system import (
get_config,
write_config,
Expand Down Expand Up @@ -147,8 +147,17 @@ def chat():
return

model = load_model(os.getcwd())
cache = WorkflowCache(os.getcwd())
coder = CodeAgent(model)

# read the project information
dataset = cache.resume_variable("dataset")
ml_requirement = cache.resume_variable("ml_requirement")
advisor_report = cache.resume_variable("advisor_report")

# inject the project information into prompts
coder.read_requirement(advisor_report or ml_requirement or dataset)

while True:
try:
user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask()
Expand Down
22 changes: 21 additions & 1 deletion mle/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ def current_step(self) -> int:
"""
return max(self.cache.keys()) if self.cache else 0

def resume_variable(self, key: str, step: Optional[int] = None):
"""
Resume the cached variable.
Args:
key (str): The key of the value to be resumed.
step (str): The step to be initialized.
Returns:
object: The resumed value, or None if the key does not exist.
"""
if step is not None:
return self.__call__(step).resume(key)
else:
for step in range(self.current_step()):
value = self.resume_variable(key, step)
if value is not None:
return value
return None

def _load_cache_buffer(self) -> Dict[str, Any]:
"""
Load the cache buffer from the configuration.
Expand All @@ -128,7 +148,7 @@ def _store_cache_buffer(self) -> None:
"""
write_config(self.buffer)

def __call__(self, step: int, name: str) -> WorkflowCacheOperator:
def __call__(self, step: int, name: Optional[str] = None) -> WorkflowCacheOperator:
"""
Initialize the cache content for a given step and name.
Expand Down

0 comments on commit 54b7a3e

Please sign in to comment.