Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 6, 2024
2 parents 7bf92bb + 40ffa99 commit 72ea99c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 155 deletions.
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pip install -e .
## Quick Start
Run the command in the console:
```bash
python -m olah.server
olah-cli
```

Then set the Environment Variable `HF_ENDPOINT` to the mirror site (Here is http://localhost:8090).
Expand Down Expand Up @@ -95,21 +95,21 @@ You can check the path `./repos`, in which olah stores all cached datasets and m
## Start the server
Run the command in the console:
```bash
python -m olah.server
olah-cli
```

Or you can specify the host address and listening port:
```bash
python -m olah.server --host localhost --port 8090
olah-cli --host localhost --port 8090
```
**Note: Please change --mirror-netloc and --mirror-lfs-netloc to the actual URLs of the mirror sites when modifying the host and port.**
```bash
python -m olah.server --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090
olah-cli --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090
```

The default mirror cache path is `./repos`, you can change it by `--repos-path` parameter:
```bash
python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors
olah-cli --host localhost --port 8090 --repos-path ./hf_mirrors
```

**Note that the cached data between different versions cannot be migrated. Please delete the cache folder before upgrading to the latest version of Olah.**
Expand All @@ -118,7 +118,7 @@ python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors

Additional configurations can be controlled through a configuration file by passing the `configs.toml` file as a command parameter:
```bash
python -m olah.server -c configs.toml
olah-cli -c configs.toml
```

The complete content of the configuration file can be found at [assets/full_configs.toml](https://github.com/vtuber-plan/olah/blob/main/assets/full_configs.toml).
Expand All @@ -132,6 +132,8 @@ port = 8090
ssl-key = ""
ssl-cert = ""
repos-path = "./repos"
cache-size-limit = ""
cache-clean-strategy = "LRU"
hf-scheme = "https"
hf-netloc = "huggingface.co"
hf-lfs-netloc = "cdn-lfs.huggingface.co"
Expand All @@ -144,6 +146,8 @@ mirrors-path = ["./mirrors_dir"]
- `port`: Sets the port that Olah listens to.
- `ssl-key` and `ssl-cert`: When enabling HTTPS, specify the file paths for the key and certificate.
- `repos-path`: Specifies the directory for storing cached data.
- `cache-size-limit`: Specifies cache size limit (For example, 100G, 500GB, 2TB). Olah will scan the size of the cache folder every hour. If it exceeds the limit, olah will delete some cache files.
- `cache-clean-strategy`: Specifies cache cleaning strategy (Available strategies: LRU, FIFO, LARGE_FIRST).
- `hf-scheme`: Network protocol for the Hugging Face official site (usually no need to modify).
- `hf-netloc`: Network location of the Hugging Face official site (usually no need to modify).
- `hf-lfs-netloc`: Network location for Hugging Face official site's LFS files (usually no need to modify).
Expand Down
14 changes: 9 additions & 5 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,22 @@ huggingface-cli download --repo-type dataset --resume-download Salesforce/wikite
## 启动服务器
在控制台运行以下命令:
```bash
python -m olah.server
olah-cli
```

或者您可以指定主机地址和监听端口:
```bash
python -m olah.server --host localhost --port 8090
olah-cli --host localhost --port 8090
```
**注意:请记得在修改主机和端口时将`--mirror-netloc``--mirror-lfs-netloc`更改为镜像站点的实际URL。**

```bash
python -m olah.server --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090
olah-cli --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090
```

默认的镜像缓存路径是`./repos`,您可以通过`--repos-path`参数进行更改:
```bash
python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors
olah-cli --host localhost --port 8090 --repos-path ./hf_mirrors
```

**注意,不同版本之间的缓存数据不能迁移,请删除缓存文件夹后再进行olah的升级**
Expand All @@ -119,7 +119,7 @@ python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors

更多配置可以通过配置文件进行控制,通过命令参数传入`configs.toml`以设置配置文件路径:
```bash
python -m olah.server -c configs.toml
olah-cli -c configs.toml
```

完整的配置文件内容见[assets/full_configs.toml](https://github.com/vtuber-plan/olah/blob/main/assets/full_configs.toml)
Expand All @@ -133,6 +133,8 @@ port = 8090
ssl-key = ""
ssl-cert = ""
repos-path = "./repos"
cache-size-limit = ""
cache-clean-strategy = "LRU"
hf-scheme = "https"
hf-netloc = "huggingface.co"
hf-lfs-netloc = "cdn-lfs.huggingface.co"
Expand All @@ -146,6 +148,8 @@ mirrors-path = ["./mirrors_dir"]
- port: 设置olah监听的端口
- ssl-key和ssl-cert: 当需要开启HTTPS时传入key和cert的文件路径
- repos-path: 用于保存缓存数据的目录
- cache-size-limit: 指定缓存大小限制(例如,100G,500GB,2TB)。Olah会每小时扫描缓存文件夹的大小。如果超出限制,Olah会删除一些缓存文件
- cache-clean-strategy: 指定缓存清理策略(可用策略:LRU,FIFO,LARGE_FIRST)
- hf-scheme: huggingface官方站点的网络协议(一般不需要改动)
- hf-netloc: huggingface官方站点的网络位置(一般不需要改动)
- hf-lfs-netloc: huggingface官方站点LFS文件的网络位置(一般不需要改动)
Expand Down
184 changes: 43 additions & 141 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,125 +85,6 @@ def get_contiguous_ranges(
range_start_pos = end_pos
return ranges_and_cache_list

async def _file_full_header(
app,
save_path: str,
head_path: str,
client: httpx.AsyncClient,
method: str,
url: str,
headers: Dict[str, str],
allow_cache: bool,
) -> Tuple[int, Dict[str, str], bytes]:
assert method.lower() == "head"
if not app.app_settings.config.offline:
if os.path.exists(head_path):
cache_rq = await read_cache_request(head_path)
response_headers_dict = {
k.lower(): v for k, v in cache_rq["headers"].items()
}
if "location" in response_headers_dict:
parsed_url = urlparse(response_headers_dict["location"])
if len(parsed_url.netloc) != 0:
new_loc = urljoin(
app.app_settings.config.mirror_lfs_url_base(),
get_url_tail(response_headers_dict["location"]),
)
response_headers_dict["location"] = new_loc
return cache_rq["status_code"], response_headers_dict, cache_rq["content"]
else:
if "range" in headers:
headers.pop("range")
response = await client.request(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
)
response_headers_dict = {k.lower(): v for k, v in response.headers.items()}
if allow_cache and method.lower() == "head":
if response.status_code == 200:
await write_cache_request(
head_path,
response.status_code,
response_headers_dict,
response.content,
)
elif response.status_code >= 300 and response.status_code <= 399:
from_url = urlparse(url)
parsed_url = urlparse(response.headers["location"])
if len(parsed_url.netloc) != 0:
new_loc = urljoin(
app.app_settings.config.mirror_lfs_url_base(),
get_url_tail(response.headers["location"]),
)
response_headers_dict["location"] = new_loc
# Redirect, add original location info
if check_url_has_param_name(
response_headers_dict["location"], ORIGINAL_LOC
):
raise Exception(f"Invalid field {ORIGINAL_LOC} in the url.")
else:
response_headers_dict["location"] = add_query_param(
response_headers_dict["location"],
ORIGINAL_LOC,
response.headers["location"],
)
await write_cache_request(
head_path,
response.status_code,
response_headers_dict,
response.content,
)
elif response.status_code == 403:
pass
elif response.status_code == 404:
pass
else:
raise Exception(
f"Unexpected HTTP status code {response.status_code}"
)
return response.status_code, response_headers_dict, response.content
else:
if os.path.exists(head_path):
cache_rq = await read_cache_request(head_path)
response_headers_dict = {
k.lower(): v for k, v in cache_rq["headers"].items()
}
else:
response_headers_dict = {}
cache_rq = {
"status_code": 200,
"headers": response_headers_dict,
"content": b"",
}

new_headers = {}
if "content-type" in response_headers_dict:
new_headers["content-type"] = response_headers_dict["content-type"]
if "content-length" in response_headers_dict:
new_headers["content-length"] = response_headers_dict["content-length"]
if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = (
response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "")
)
if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = (
response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "")
)
if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict:
new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = (
response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "")
)
if "etag" in response_headers_dict:
new_headers["etag"] = response_headers_dict["etag"]
if "location" in response_headers_dict:
new_headers["location"] = urljoin(
app.app_settings.config.mirror_lfs_url_base(),
get_url_tail(response_headers_dict["location"]),
)
return cache_rq["status_code"], new_headers, cache_rq["content"]


async def _get_file_range_from_cache(
cache_file: OlahCache, start_pos: int, end_pos: int
Expand Down Expand Up @@ -238,7 +119,8 @@ async def _get_file_range_from_remote(
end_pos: int,
):
headers = {}
headers["authorization"] = remote_info.headers.get("authorization", None)
if remote_info.headers.get("authorization", None) is not None:
headers["authorization"] = remote_info.headers.get("authorization", None)
headers["range"] = f"bytes={start_pos}-{end_pos - 1}"

chunk_bytes = 0
Expand Down Expand Up @@ -433,6 +315,32 @@ async def _file_chunk_head(
yield b""


async def _resource_etag(hf_url: str, authorization: Optional[str]=None, offline: bool = False) -> Optional[str]:
ret_etag = None
sha256_hash = hashlib.sha256()
sha256_hash.update(hf_url.encode("utf-8"))
content_hash = sha256_hash.hexdigest()
if offline:
ret_etag = f'"{content_hash[:32]}-10"'
else:
etag_headers = {}
if authorization is not None:
etag_headers["authorization"] = authorization
try:
async with httpx.AsyncClient() as client:
response = await client.request(
method="head",
url=hf_url,
headers=etag_headers,
timeout=WORKER_API_TIMEOUT,
)
if "etag" in response.headers:
ret_etag = response.headers["etag"]
else:
ret_etag = f'"{content_hash[:32]}-10"'
except httpx.TimeoutException:
ret_etag = None
return ret_etag
async def _file_realtime_stream(
app,
repo_type: Literal["models", "datasets", "spaces"],
Expand Down Expand Up @@ -531,28 +439,22 @@ async def _file_realtime_stream(
if commit is not None:
response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit
# Create fake headers when offline mode
sha256_hash = hashlib.sha256()
sha256_hash.update(hf_url.encode("utf-8"))
content_hash = sha256_hash.hexdigest()
if app.app_settings.config.offline:
response_headers["etag"] = f'"{content_hash[:32]}-10"'
etag = await _resource_etag(
hf_url=hf_url,
authorization=request.headers.get("authorization", None),
offline=app.app_settings.config.offline,
)
response_headers["etag"] = etag

if etag is None:
error_response = error_proxy_timeout()
yield error_response.status_code
yield error_response.headers
yield error_response.body
return
else:
if method.lower() == "head":
async with httpx.AsyncClient() as client:
response = await client.request(
method="head",
url=hf_url,
headers={
"authorization": request.headers.get("authorization", None)
},
timeout=WORKER_API_TIMEOUT,
)
if "etag" in response.headers:
response_headers["etag"] = response.headers["etag"]
else:
response_headers["etag"] = f'"{content_hash[:32]}-10"'
yield 200
yield response_headers
yield 200
yield response_headers

async with httpx.AsyncClient() as client:
if method.lower() == "get":
Expand Down
5 changes: 3 additions & 2 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from contextlib import asynccontextmanager
import os
import shutil
import glob
import argparse
import time
Expand Down Expand Up @@ -162,8 +161,10 @@ async def lifespan(app: FastAPI):
# ======================
# Application
# ======================
code_file_path = os.path.abspath(__file__)
app = FastAPI(lifespan=lifespan, debug=False)
templates = Jinja2Templates(directory="static")
templates = Jinja2Templates(directory=os.path.join(os.path.dirname(code_file_path), "..", "static"))


class AppSettings(BaseSettings):
# The address of the model controller.
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "olah"
version = "0.3.0"
version = "0.3.1"
description = "Self-hosted lightweight huggingface mirror."
readme = "README.md"
requires-python = ">=3.8"
Expand All @@ -31,5 +31,8 @@ olah-cli = "olah.server:cli"
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]

[tool.setuptools.package-data]
static = ["*.html"]

[tool.wheel]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]

0 comments on commit 72ea99c

Please sign in to comment.