From f4da1035f6141b045b8736719b35a20072da3e4d Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Sat, 21 Sep 2024 08:50:37 -0500 Subject: [PATCH 1/2] inject the project info into chat system prompt --- mle/cli.py | 11 ++++++++++- mle/utils/cache.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mle/cli.py b/mle/cli.py index 21fad3a..6c10c26 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -17,7 +17,7 @@ from mle.model import load_model from mle.agents import CodeAgent import mle.workflow as workflow -from mle.utils import Memory +from mle.utils import Memory, WorkflowCache from mle.utils.system import ( get_config, write_config, @@ -147,8 +147,17 @@ def chat(): return model = load_model(os.getcwd()) + cache = WorkflowCache(os.getcwd()) coder = CodeAgent(model) + # read the project information + dataset = cache(step=1).resume("dataset") + ml_requirement = cache(step=2).resume("ml_requirement") + advisor_report = cache(step=3).resume("advisor_report") + + # inject the project information into prompts + coder.read_requirement(advisor_report or ml_requirement or dataset) + while True: try: user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask() diff --git a/mle/utils/cache.py b/mle/utils/cache.py index c175520..fd2bb1b 100644 --- a/mle/utils/cache.py +++ b/mle/utils/cache.py @@ -128,7 +128,7 @@ def _store_cache_buffer(self) -> None: """ write_config(self.buffer) - def __call__(self, step: int, name: str) -> WorkflowCacheOperator: + def __call__(self, step: int, name: Optional[str] = None) -> WorkflowCacheOperator: """ Initialize the cache content for a given step and name. From 083dd6910fb8ce9b1d5fb7160a2b1bc73ff65631 Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Sat, 21 Sep 2024 09:09:23 -0500 Subject: [PATCH 2/2] update the interface --- mle/cli.py | 6 +++--- mle/utils/cache.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mle/cli.py b/mle/cli.py index 6c10c26..25aea2b 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -151,9 +151,9 @@ def chat(): coder = CodeAgent(model) # read the project information - dataset = cache(step=1).resume("dataset") - ml_requirement = cache(step=2).resume("ml_requirement") - advisor_report = cache(step=3).resume("advisor_report") + dataset = cache.resume_variable("dataset") + ml_requirement = cache.resume_variable("ml_requirement") + advisor_report = cache.resume_variable("advisor_report") # inject the project information into prompts coder.read_requirement(advisor_report or ml_requirement or dataset) diff --git a/mle/utils/cache.py b/mle/utils/cache.py index fd2bb1b..d9d737f 100644 --- a/mle/utils/cache.py +++ b/mle/utils/cache.py @@ -110,6 +110,26 @@ def current_step(self) -> int: """ return max(self.cache.keys()) if self.cache else 0 + def resume_variable(self, key: str, step: Optional[int] = None): + """ + Resume the cached variable. + + Args: + key (str): The key of the value to be resumed. + step (str): The step to be initialized. + + Returns: + object: The resumed value, or None if the key does not exist. + """ + if step is not None: + return self.__call__(step).resume(key) + else: + for step in range(self.current_step()): + value = self.resume_variable(key, step) + if value is not None: + return value + return None + def _load_cache_buffer(self) -> Dict[str, Any]: """ Load the cache buffer from the configuration.