Skip to content

Commit

Permalink
_get_file_range_from_remote authorization bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 5, 2024
1 parent 1e4d6c9 commit 472dcff
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
11 changes: 10 additions & 1 deletion olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ async def _get_file_range_from_remote(
end_pos: int,
):
headers = {}
headers["authorization"] = remote_info.headers.get("authorization", None)
headers["range"] = f"bytes={start_pos}-{end_pos - 1}"

chunk_bytes = 0
raw_data = b""
async with client.stream(
Expand Down Expand Up @@ -537,7 +539,14 @@ async def _file_realtime_stream(
else:
if method.lower() == "head":
async with httpx.AsyncClient() as client:
response = await client.request(method="head", url=hf_url,headers={},timeout=WORKER_API_TIMEOUT)
response = await client.request(
method="head",
url=hf_url,
headers={
"authorization": request.headers.get("authorization", None)
},
timeout=WORKER_API_TIMEOUT,
)
if "etag" in response.headers:
response_headers["etag"] = response.headers["etag"]
else:
Expand Down
17 changes: 15 additions & 2 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,13 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request):
if org is None and repo is None:
return error_repo_not_found()
if not app.app_settings.config.offline:
new_commit = await get_newest_commit_hf(app, repo_type, org, repo)
new_commit = await get_newest_commit_hf(
app,
repo_type,
org,
repo,
authorization=request.headers.get("authorization", None),
)
if new_commit is None:
return error_repo_not_found()
else:
Expand All @@ -269,11 +275,18 @@ async def meta_proxy(repo_type: str, org_repo: str, request: Request):
authorization=request.headers.get("authorization", None),
)


@app.head("/api/{repo_type}/{org}/{repo}")
@app.get("/api/{repo_type}/{org}/{repo}")
async def meta_proxy(repo_type: str, org: str, repo: str, request: Request):
if not app.app_settings.config.offline:
new_commit = await get_newest_commit_hf(app, repo_type, org, repo)
new_commit = await get_newest_commit_hf(
app,
repo_type,
org,
repo,
authorization=request.headers.get("authorization", None),
)
if new_commit is None:
return error_repo_not_found()
else:
Expand Down
3 changes: 2 additions & 1 deletion olah/utils/repo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ async def get_newest_commit_hf(
repo_type: Optional[Literal["models", "datasets", "spaces"]],
org: Optional[str],
repo: str,
authorization: Optional[str] = None,
) -> Optional[str]:
"""
Retrieves the newest commit hash for a repository.
Expand All @@ -188,7 +189,7 @@ async def get_newest_commit_hf(
return await get_newest_commit_hf_offline(app, repo_type, org, repo)
try:
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=WORKER_API_TIMEOUT)
response = await client.get(url, headers={"authorization": authorization}, timeout=WORKER_API_TIMEOUT)
if response.status_code != 200:
return await get_newest_commit_hf_offline(app, repo_type, org, repo)
obj = json.loads(response.text)
Expand Down

0 comments on commit 472dcff

Please sign in to comment.