diff --git a/src/bentoml/_internal/frameworks/pytorch.py b/src/bentoml/_internal/frameworks/pytorch.py index 99382534f67..6f442edcc38 100644 --- a/src/bentoml/_internal/frameworks/pytorch.py +++ b/src/bentoml/_internal/frameworks/pytorch.py @@ -5,6 +5,7 @@ from pathlib import Path from types import ModuleType from typing import TYPE_CHECKING +from typing import Any import cloudpickle @@ -46,6 +47,7 @@ def get(tag_like: str | Tag) -> Model: def load_model( bentoml_model: str | Tag | Model, device_id: t.Optional[str] = "cpu", + **torch_load_args: Any, ) -> torch.nn.Module: """ Load a model from a BentoML Model with given name. @@ -76,7 +78,9 @@ def load_model( weight_file = bentoml_model.path_of(MODEL_FILENAME) with Path(weight_file).open("rb") as file: - model: "torch.nn.Module" = torch.load(file, map_location=device_id) + model: "torch.nn.Module" = torch.load( + file, map_location=device_id, **torch_load_args + ) return model