From fb2baf27713a9d06b21447321e8547e4c9db1e3b Mon Sep 17 00:00:00 2001 From: Mark-Kim Date: Mon, 29 Apr 2024 11:06:57 +0900 Subject: [PATCH 1/2] huggingface model support added --- src/vanna/hf/__init__.py | 1 + src/vanna/hf/hf.py | 79 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 src/vanna/hf/__init__.py create mode 100644 src/vanna/hf/hf.py diff --git a/src/vanna/hf/__init__.py b/src/vanna/hf/__init__.py new file mode 100644 index 00000000..aa3e9a8f --- /dev/null +++ b/src/vanna/hf/__init__.py @@ -0,0 +1 @@ +from .hf import Hf diff --git a/src/vanna/hf/hf.py b/src/vanna/hf/hf.py new file mode 100644 index 00000000..feb7ea5a --- /dev/null +++ b/src/vanna/hf/hf.py @@ -0,0 +1,79 @@ +import re +from transformers import AutoTokenizer, AutoModelForCausalLM + +from ..base import VannaBase + + +class Hf(VannaBase): + def __init__(self, config=None): + model_name = self.config.get( + "model_name", None + ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto", + ) + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def extract_sql_query(self, text): + """ + Extracts the first SQL statement after the word 'select', ignoring case, + matches until the first semicolon, three backticks, or the end of the string, + and removes three backticks if they exist in the extracted string. + + Args: + - text (str): The string to search within for an SQL statement. + + Returns: + - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. + """ + # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string + pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL) + + match = pattern.search(text) + if match: + # Remove three backticks from the matched string if they exist + return match.group(0).replace("```", "") + else: + return text + + def generate_sql(self, question: str, **kwargs) -> str: + # Use the super generate_sql + sql = super().generate_sql(question, **kwargs) + + # Replace "\_" with "_" + sql = sql.replace("\\_", "_") + + sql = sql.replace("\\", "") + + return self.extract_sql_query(sql) + + def submit_prompt(self, prompt, **kwargs) -> str: + + input_ids = self.tokenizer.apply_chat_template( + prompt, add_generation_prompt=True, return_tensors="pt" + ).to(self.model.device) + + outputs = self.model.generate( + input_ids, + max_new_tokens=512, + eos_token_id=self.tokenizer.eos_token_id, + do_sample=True, + temperature=1, + top_p=0.9, + ) + response = outputs[0][input_ids.shape[-1] :] + response = self.tokenizer.decode(response, skip_special_tokens=True) + self.log(response) + + return response From 6bb70ac1faf3b9b1afaec66b819bb71b0d3a9915 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:26:17 -0400 Subject: [PATCH 2/2] add import tests --- pyproject.toml | 3 ++- tests/test_imports.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d020f241..aae6d9fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] google = ["google-generativeai", "google-cloud-aiplatform"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] @@ -45,3 +45,4 @@ ollama = ["ollama", "httpx"] qdrant = ["qdrant-client"] vllm = ["vllm"] opensearch = ["opensearch-py", "opensearch-dsl"] +hf = ["transformers"] diff --git a/tests/test_imports.py b/tests/test_imports.py index 65d915cf..3141d37e 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -4,6 +4,7 @@ def test_regular_imports(): from vanna.anthropic.anthropic_chat import Anthropic_Chat from vanna.base.base import VannaBase from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore + from vanna.hf.hf import Hf from vanna.local import LocalContext_OpenAI from vanna.marqo.marqo import Marqo_VectorStore from vanna.mistral.mistral import Mistral @@ -20,6 +21,7 @@ def test_shortcut_imports(): from vanna.anthropic import Anthropic_Chat from vanna.base import VannaBase from vanna.chromadb import ChromaDB_VectorStore + from vanna.hf import Hf from vanna.marqo import Marqo_VectorStore from vanna.mistral import Mistral from vanna.ollama import Ollama