Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add gemini model #230

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def new(name):

platform = questionary.select(
"Which language model platform do you want to use?",
choices=['OpenAI', 'Ollama', 'Claude', 'MistralAI', 'DeepSeek']
choices=['OpenAI', 'Ollama', 'Claude', 'Gemini', 'MistralAI', 'DeepSeek']
).ask()

api_key = None
Expand Down Expand Up @@ -208,6 +208,12 @@ def new(name):
console.log("API key is required. Aborted.")
return

elif platform == 'Gemini':
api_key = questionary.password("What is your Gemini API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

search_api_key = questionary.password("What is your Tavily API key? (if no, the web search will be disabled)").ask()
if search_api_key:
os.environ["SEARCH_API_KEY"] = search_api_key
Expand Down
4 changes: 4 additions & 0 deletions mle/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .mistral import *
from .ollama import *
from .openai import *
from .gemini import *

from mle.utils import get_config

Expand All @@ -12,6 +13,7 @@
MODEL_CLAUDE = 'Claude'
MODEL_MISTRAL = 'MistralAI'
MODEL_DEEPSEEK = 'DeepSeek'
MODEL_GEMINI = 'Gemini'


class ObservableModel:
Expand Down Expand Up @@ -64,6 +66,8 @@ def load_model(project_dir: str, model_name: str=None, observable=True):
model = MistralModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_DEEPSEEK:
model = DeepSeekModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_GEMINI:
model = GeminiModel(api_key=config['api_key'], model=model_name)

if observable:
return ObservableModel(model)
Expand Down
124 changes: 124 additions & 0 deletions mle/model/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import importlib.util
import json

from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
from mle.model.common import Model


class GeminiModel(Model):
def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the Gemini model.
Args:
api_key (str): The Gemini API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "google"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.gemini = importlib.import_module(dependency).generativeai
else:
raise ImportError(
"It seems you didn't install `google-generativeai`. "
"In order to enable the Gemini client related features, "
"please make sure gemini Python package has been installed. "
"More information, please refer to: https://ai.google.dev/gemini-api/docs/quickstart?lang=python"
)

self.model = model if model else 'gemini-1.5-flash-002'
self.model_type = 'Gemini'
self.temperature = temperature
self.func_call_history = []

@staticmethod
def _map_roles_from_openai(chat_history):
_map_dict = {
"system": "model",
"user": "user",
"assistant": "model",
"content": "parts",
}
return dict({_map_dict[k]: v for k, v in chat_history.items()})

def query(self, chat_history, **kwargs):
"""
Query the LLM model.

Args:
chat_history: The context (chat history).
"""
parameters = kwargs

tools = None
if parameters.get("functions") is not None:
tools = {'function_declarations': parameters["functions"]}
self.gemini.protos.Tool(tools)

client = self.gemini.GenerativeModel(self.model, tools=tools)
chat_handler = client.start_chat(history=chat_history[:-1])
generation_config = self.gemini.types.GenerationConfig(
max_output_tokens=4096,
temperature=self.temperature,
response_mime_type='application/json' \
if parameters.get("response_format") == {'type': 'json_object'} else None,
)

completion = chat_handler.send_message(
chat_history[-1]["parts"],
generation_config=generation_config,
)

function_outputs = {}
for part in completion.parts:
fn = part.function_call
if fn:
print("[MLE FUNC CALL]: ", fn.name)
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['functions'] = None
result = get_function(fn.name)(**fn.args)
function_outputs[fn.name] = result

if len(function_outputs):
response_parts = [
self.gemini.protos.Part(
function_response=self.gemini.protos.FunctionResponse(
name=fn, response={"result": val}
)
)
for fn, val in function_outputs.items()
]

completion = chat_handler.send_message(
response_parts,
generation_config=generation_config,
)

return completion.text

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
client = self.gemini.GenerativeModel(self.model)
chat_handler = client.start_chat(history=chat_history[:-1])
generation_config = self.gemini.types.GenerationConfig(
max_output_tokens=4096,
temperature=self.temperature,
)

completions = chat_handler.send_message(
chat_history[-1]["parts"],
generation_config=generation_config,
stream=True
)

for chunk in completions:
yield chunk.text
Loading