-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug/Model Request]: Load model files from path, not from huggingface cach directory #321
Comments
Hi @satyaloka93 You want to put the same files as in HF hub to the cache directory and initialize from them. |
Hi, they are the files from the Qdrant HF repo: https://huggingface.co/Qdrant/Splade_PP_en_v1/tree/main. Our organization pulls them via git, scans, and moves them where we can load them up. When I try to load from that directory, even using it as cache_dir and local_files_only=True, it does not work. I’m assuming because it’s expecting to have a cache structure, versus the normal HF files in your repo. |
|
I had claude help me modify fastembed/common/model_management.py to actually load the model from local files, with no HF callouts or requirements for a cache structure. It should still use HF repo if the flag isn't set, but I haven't tested it. Maybe someone could take a look. import os
import time
import shutil
import tarfile
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from loguru import logger
from tqdm import tqdm
class ModelManagement:
@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""Lists the supported models.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
raise NotImplementedError()
@classmethod
def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
"""
Gets the model description from the model_name.
Args:
model_name (str): The name of the model.
raises:
ValueError: If the model_name is not supported.
Returns:
Dict[str, Any]: The model description.
"""
for model in cls.list_supported_models():
if model_name.lower() == model["model"].lower():
return model
raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")
@classmethod
def load_from_local(cls, model_name: str, cache_dir: str) -> Path:
"""
Loads a model from a local directory.
Args:
model_name (str): The name of the model.
cache_dir (str): The path to the cache directory.
Returns:
Path: The path to the local model directory.
"""
#model_dir = Path(cache_dir) / model_name
model_dir = Path(cache_dir)
if not model_dir.exists():
raise FileNotFoundError(f"Model directory {model_dir} does not exist.")
required_files = ["config.json", "model.onnx"] # Add or modify as needed
for file in required_files:
if not (model_dir / file).exists():
raise FileNotFoundError(f"Required file {file} not found in {model_dir}")
return model_dir
@classmethod
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
"""
Downloads a file from Google Cloud Storage.
Args:
url (str): The URL to download the file from.
output_path (str): The path to save the downloaded file to.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
str: The path to the downloaded file.
"""
if os.path.exists(output_path):
return output_path
response = requests.get(url, stream=True)
# Handle HTTP errors
if response.status_code == 403:
raise PermissionError(
"Authentication Error: You do not have permission to access this resource. "
"Please check your credentials."
)
# Get the total size of the file
total_size_in_bytes = int(response.headers.get("content-length", 0))
# Warn if the total size is zero
if total_size_in_bytes == 0:
print(f"Warning: Content-length header is missing or zero in the response from {url}.")
show_progress = total_size_in_bytes and show_progress
with tqdm(
total=total_size_in_bytes,
unit="iB",
unit_scale=True,
disable=not show_progress,
) as progress_bar:
with open(output_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # Filter out keep-alive new chunks
progress_bar.update(len(chunk))
file.write(chunk)
return output_path
@classmethod
def download_files_from_huggingface(
cls,
hf_source_repo: str,
cache_dir: Optional[str] = None,
extra_patterns: Optional[List[str]] = None,
**kwargs,
) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
extra_patterns (Optional[List[str]]): extra patterns to allow in the snapshot download, typically
includes the required model files.
Returns:
Path: The path to the model directory.
"""
allow_patterns = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"preprocessor_config.json",
]
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)
return snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=kwargs.get("local_files_only", False),
)
@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""
Decompresses a .tar.gz file to a cache directory.
Args:
targz_path (str): Path to the .tar.gz file.
cache_dir (str): Path to the cache directory.
Returns:
cache_dir (str): Path to the cache directory.
"""
# Check if targz_path exists and is a file
if not os.path.isfile(targz_path):
raise ValueError(f"{targz_path} does not exist or is not a file.")
# Check if targz_path is a .tar.gz file
if not targz_path.endswith(".tar.gz"):
raise ValueError(f"{targz_path} is not a .tar.gz file.")
try:
# Open the tar.gz file
with tarfile.open(targz_path, "r:gz") as tar:
# Extract all files into the cache directory
tar.extractall(path=cache_dir)
except tarfile.TarError as e:
# If any error occurs while opening or extracting the tar.gz file,
# delete the cache directory (if it was created in this function)
# and raise the error again
if "tmp" in cache_dir:
shutil.rmtree(cache_dir)
raise ValueError(f"An error occurred while decompressing {targz_path}: {e}")
return cache_dir
@classmethod
def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path:
#fast_model_name = f"fast-{model_name.split('/')[-1]}"
cache_tmp_dir = Path(cache_dir) / "tmp"
model_tmp_dir = cache_tmp_dir / fast_model_name
#model_dir = Path(cache_dir) / fast_model_name
model_dir = cache_dir
# check if the model_dir and the model files are both present for macOS
if model_dir.exists() and len(list(model_dir.glob("*"))) > 0:
return model_dir
if model_tmp_dir.exists():
shutil.rmtree(model_tmp_dir)
cache_tmp_dir.mkdir(parents=True, exist_ok=True)
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
if model_tar_gz.exists():
model_tar_gz.unlink()
cls.download_file_from_gcs(
source_url,
output_path=str(model_tar_gz),
)
cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
model_tar_gz.unlink()
# Rename from tmp to final name is atomic
model_tmp_dir.rename(model_dir)
return model_dir
@classmethod
def download_model(cls, model: Dict[str, Any], cache_dir: Path, retries=3, **kwargs) -> Path:
"""
Attempts to load a model from a local directory first, then falls back to online sources if necessary.
Args:
model (Dict[str, Any]): The model description.
cache_dir (str): The path to the cache directory.
retries: (int): The number of times to retry (including the first attempt)
**kwargs: Additional keyword arguments, including 'local_files_only'.
Returns:
Path: The path to the model directory.
"""
model_name = model["model"]
local_files_only = kwargs.get("local_files_only", False)
try:
logger.info(f"Loading from: {cache_dir}")
return cls.load_from_local(model_name, str(cache_dir))
except FileNotFoundError:
if local_files_only:
raise ValueError(f"Model {model_name} not found locally and local_files_only is set to True.")
# If local loading fails and online fetching is allowed, proceed with online sources
hf_source = model.get("sources", {}).get("hf")
url_source = model.get("sources", {}).get("url")
sleep = 3.0
while retries > 0:
retries -= 1
if hf_source:
# ... [rest of the Hugging Face download logic] ...
extra_patterns = [model["model_file"]]
extra_patterns.extend(model.get("additional_files", []))
try:
return Path(
cls.download_files_from_huggingface(
hf_source,
cache_dir=str(cache_dir),
extra_patterns=extra_patterns,
local_files_only=kwargs.get("local_files_only", False),
)
)
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
f"Could not download model from HuggingFace: {e} "
"Falling back to other sources."
)
if url_source:
# ... [rest of the GCS download logic] ...
try:
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))
except Exception:
logger.error(f"Could not download model from url: {url_source}")
logger.error(
f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left."
)
time.sleep(sleep)
sleep *= 3
raise ValueError(f"Failed to load or download model {model_name} after all attempts.") |
What happened?
Unable to test this in my organization, as we do not use hugging face cache folders for models, models are downloaded via git, scanned, then allowed for usage. I see some attempt to use local files via 'local_files_only' kwarg in this PR, but this won't work apparently as I do not have files in the snapshot format. Request loading models from a normal directory, like transformers/sentence-transformers and most other frameworks. Really would like to incorporate this technology in our information retrieval, but this is a show stopper.
What Python version are you on? e.g. python --version
Python 3.10
Version
0.2.7 (Latest)
What os are you seeing the problem on?
Linux
Relevant stack traces and/or logs
No response
The text was updated successfully, but these errors were encountered: