Skip to content

Commit

Permalink
fix tool use
Browse files Browse the repository at this point in the history
  • Loading branch information
umnooob committed Sep 14, 2024
1 parent af54745 commit 73098a5
Showing 1 changed file with 48 additions and 39 deletions.
87 changes: 48 additions & 39 deletions mle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _convert_functions_to_tools(self, functions):
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Expand Down Expand Up @@ -448,49 +448,62 @@ def __init__(self, api_key, model, temperature=0.7):
self.model = model if model else "deepseek-coder"
self.model_type = MODEL_DEEPSEEK
self.temperature = temperature
self.client = self.openai(api_key=api_key, base_url="https://api.deepseek.com/beta")
self.client = self.openai(
api_key=api_key, base_url="https://api.deepseek.com/beta"
)
self.func_call_history = []

def _convert_functions_to_tools(self, functions):
"""
Convert OpenAI-style functions to DeepSeek-style tools.
"""
tools = []
for func in functions:
tool = {
"type": "function",
"function": {
"name": func["name"],
"description": func.get("description", ""),
"parameters": func["parameters"],
},
}
tools.append(tool)
return tools

def query(self, chat_history, **kwargs):
"""
Query the LLM model.
Args:
chat_history: The context (chat history).
"""
functions = kwargs.get("functions", None)
tools = self._convert_functions_to_tools(functions) if functions else None
parameters = kwargs
completion = self.client.chat.completions.create(
model=self.model,
messages=chat_history,
temperature=self.temperature,
stream=False,
tools=tools,
**parameters,
)

resp = completion.choices[0].message
if resp.function_call:
function_name = process_function_name(resp.function_call.name)
arguments = json.loads(resp.function_call.arguments)
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append(
{"name": function_name, "arguments": arguments}
)
# avoid the multiple search function calls
search_attempts = [
item
for item in self.func_call_history
if item["name"] in SEARCH_FUNCTIONS
]
if len(search_attempts) > 3:
parameters["function_call"] = "none"
result = get_function(function_name)(**arguments)
chat_history.append(
{"role": "assistant", "function_call": dict(resp.function_call)}
)
chat_history.append(
{"role": "function", "content": result, "name": function_name}
)
return self.query(chat_history, **parameters)
if resp.tool_calls:
for tool_call in resp.tool_calls:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
print("[MLE FUNC CALL]: ", function_name)
self.func_call_history.append({"name": function_name, "arguments": arguments})
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['tool_choice'] = "none"
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
return self.query(chat_history, **parameters)
else:
return resp.content

Expand All @@ -509,21 +522,17 @@ def stream(self, chat_history, **kwargs):
stream=True,
**kwargs,
):
delta = chunk.choices[0].delta
if delta.function_call:
if delta.function_call.name:
function_name = process_function_name(delta.function_call.name)
if delta.function_call.arguments:
arguments += delta.function_call.arguments

if chunk.choices[0].finish_reason == "function_call":
result = get_function(function_name)(**json.loads(arguments))
chat_history.append(
{"role": "function", "content": result, "name": function_name}
)
yield from self.stream(chat_history, **kwargs)
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function.name:
chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
function_name = process_function_name(tool_call.function.name)
arguments = json.loads(tool_call.function.arguments)
result = get_function(function_name)(**arguments)
chat_history.append({"role": "tool", "content": result, "name": function_name})
yield from self.stream(chat_history, **kwargs)
else:
yield delta.content
yield chunk.choices[0].delta.content

def load_model(project_dir: str, model_name: Optional[str]=None):
"""
Expand Down

0 comments on commit 73098a5

Please sign in to comment.