Skip to content

Commit

Permalink
Merge pull request #192 from suryavanshi/mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyz0918 authored Sep 10, 2024
2 parents 13968d7 + 9e01150 commit 6c9812c
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 9 deletions.
9 changes: 6 additions & 3 deletions mle/agents/advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rich.console import Console

from mle.function import *
from mle.utils import get_config, print_in_box
from mle.utils import get_config, print_in_box, clean_json_string


def process_report(requirement: str, suggestions: dict):
Expand Down Expand Up @@ -118,7 +118,10 @@ def suggest(self, requirement):
)

self.chat_history.append({"role": "assistant", "content": text})
suggestions = json.loads(text)
try:
suggestions = json.loads(text)
except json.JSONDecodeError as e:
suggestions = clean_json_string(text)

return process_report(requirement, suggestions)

Expand Down Expand Up @@ -185,7 +188,7 @@ def clarify_dataset(self, dataset: str):
text = self.model.query(chat_history)
chat_history.append({"role": "assistant", "content": text})
if "yes" in text.lower():
return
return dataset

# recommend some datasets based on the users' description
user_prompt = f"""
Expand Down
5 changes: 2 additions & 3 deletions mle/agents/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import questionary
from rich.console import Console

from mle.utils import print_in_box
from mle.utils import print_in_box, clean_json_string


def process_plan(plan_dict: dict):
Expand Down Expand Up @@ -100,8 +100,7 @@ def plan(self, user_prompt):
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"Error parsing JSON response: {e}")
sys.exit(1)
return clean_json_string(text)

def interact(self, user_prompt):
"""
Expand Down
9 changes: 7 additions & 2 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def new(name):

platform = questionary.select(
"Which language model platform do you want to use?",
choices=['OpenAI', 'Ollama', 'Claude']
choices=['OpenAI', 'Ollama', 'Claude', 'MistralAI']
).ask()

api_key = None
Expand All @@ -208,6 +208,12 @@ def new(name):
if not api_key:
console.log("API key is required. Aborted.")
return

elif platform == 'MistralAI':
api_key = questionary.password("What is your MistralAI API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

search_api_key = questionary.password("What is your Tavily API key? (if no, the web search will be disabled)").ask()
if search_api_key:
Expand All @@ -223,7 +229,6 @@ def new(name):
'api_key': api_key,
'search_key': search_api_key
}, outfile, default_flow_style=False)

# init the memory
Memory(project_dir)

Expand Down
114 changes: 113 additions & 1 deletion mle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MODEL_OLLAMA = 'Ollama'
MODEL_OPENAI = 'OpenAI'
MODEL_CLAUDE = 'Claude'
MODEL_MISTRAL = 'MistralAI'

class Model(ABC):

Expand Down Expand Up @@ -262,7 +263,6 @@ def query(self, chat_history, **kwargs):
stream=False,
tools=tools,
)

if completion.stop_reason == "tool_use":
for func in completion.content:
if func.type != "tool_use":
Expand Down Expand Up @@ -309,6 +309,116 @@ def stream(self, chat_history, **kwargs):
for chunk in stream.text_stream:
yield chunk

class MistralModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the Mistral model.
Args:
api_key (str): The Mistral API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "mistralai"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.mistral = importlib.import_module(dependency).Mistral
else:
raise ImportError(
"It seems you didn't install mistralai. In order to enable the Mistral AI client related features, "
"please make sure mistralai Python package has been installed. "
"More information, please refer to: https://github.com/mistralai/client-python"
)

self.model = model if model else 'mistral-large-latest'
self.model_type = MODEL_MISTRAL
self.temperature = temperature
self.client = self.mistral(api_key=api_key)
self.func_call_history = []

def _convert_functions_to_tools(self, functions):
"""
Convert OpenAI-style functions to Mistral-style tools.
"""
tools = []
for func in functions:
tool = {
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func["parameters"]
}
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions",[])
tools = self._convert_functions_to_tools(functions)
tool_choice = kwargs.get('tool_choice', 'any')
parameters = kwargs
completion = self.client.chat.complete(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=False,
tools=tools,
tool_choice=tool_choice,
)
resp = completion.choices[0].message
if resp.tool_calls:
for tool_call in resp.tool_calls:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['tool_choice'] = "none"
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
return self.query(chat_history, **parameters)
else:
return resp.content

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions",[])
tools = self._convert_functions_to_tools(functions)
tool_choice = kwargs.get('tool_choice', 'any')
for chunk in self.client.chat.complete(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=True,
tools=tools,
tool_choice=tool_choice
):
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function.name:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name})
yield from self.stream(chat_history, **kwargs)
else:
yield chunk.choices[0].delta.content

def load_model(project_dir: str, model_name: str):
"""
Expand All @@ -324,4 +434,6 @@ def load_model(project_dir: str, model_name: str):
return ClaudeModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_OLLAMA:
return OllamaModel(model=model_name)
if config['platform'] == MODEL_MISTRAL:
return MistralModel(api_key=config['api_key'], model=model_name)
return None
1 change: 1 addition & 0 deletions mle/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .system import *
from .cache import *
from .memory import *
from .data import *
13 changes: 13 additions & 0 deletions mle/utils/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import re


def clean_json_string(input_string):
"""
clean the json string
:input_string: the input json string
"""
cleaned = input_string.strip()
cleaned = re.sub(r'^```\s*json?\s*', '', cleaned)
cleaned = re.sub(r'\s*```\s*$', '', cleaned)
parsed_json = json.loads(cleaned)
return parsed_json

0 comments on commit 6c9812c

Please sign in to comment.