diff --git a/mle/cli.py b/mle/cli.py index a487a50..7f3bf3a 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -17,7 +17,7 @@ from mle.model import load_model from mle.agents import CodeAgent import mle.workflow as workflow -from mle.utils import Memory +from mle.utils import Memory, WorkflowCache from mle.utils.system import ( get_config, write_config, @@ -147,8 +147,17 @@ def chat(): return model = load_model(os.getcwd()) + cache = WorkflowCache(os.getcwd()) coder = CodeAgent(model) + # read the project information + dataset = cache.resume_variable("dataset") + ml_requirement = cache.resume_variable("ml_requirement") + advisor_report = cache.resume_variable("advisor_report") + + # inject the project information into prompts + coder.read_requirement(advisor_report or ml_requirement or dataset) + while True: try: user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask() diff --git a/mle/utils/cache.py b/mle/utils/cache.py index c175520..d9d737f 100644 --- a/mle/utils/cache.py +++ b/mle/utils/cache.py @@ -110,6 +110,26 @@ def current_step(self) -> int: """ return max(self.cache.keys()) if self.cache else 0 + def resume_variable(self, key: str, step: Optional[int] = None): + """ + Resume the cached variable. + + Args: + key (str): The key of the value to be resumed. + step (str): The step to be initialized. + + Returns: + object: The resumed value, or None if the key does not exist. + """ + if step is not None: + return self.__call__(step).resume(key) + else: + for step in range(self.current_step()): + value = self.resume_variable(key, step) + if value is not None: + return value + return None + def _load_cache_buffer(self) -> Dict[str, Any]: """ Load the cache buffer from the configuration. @@ -128,7 +148,7 @@ def _store_cache_buffer(self) -> None: """ write_config(self.buffer) - def __call__(self, step: int, name: str) -> WorkflowCacheOperator: + def __call__(self, step: int, name: Optional[str] = None) -> WorkflowCacheOperator: """ Initialize the cache content for a given step and name.