Skip to content

Commit

Permalink
feat: Pass torch kwargs to bentoml.pytorch.load_model (#4930)
Browse files Browse the repository at this point in the history
  • Loading branch information
TokeReines authored Aug 23, 2024
1 parent df638dd commit 6c33521
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/bentoml/_internal/frameworks/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING
from typing import Any

import cloudpickle

Expand Down Expand Up @@ -46,6 +47,7 @@ def get(tag_like: str | Tag) -> Model:
def load_model(
bentoml_model: str | Tag | Model,
device_id: t.Optional[str] = "cpu",
**torch_load_args: Any,
) -> torch.nn.Module:
"""
Load a model from a BentoML Model with given name.
Expand Down Expand Up @@ -76,7 +78,9 @@ def load_model(

weight_file = bentoml_model.path_of(MODEL_FILENAME)
with Path(weight_file).open("rb") as file:
model: "torch.nn.Module" = torch.load(file, map_location=device_id)
model: "torch.nn.Module" = torch.load(
file, map_location=device_id, **torch_load_args
)
return model


Expand Down

0 comments on commit 6c33521

Please sign in to comment.