From 6c335213266f7ce7d4e16d14a9ee0d98f95eb992 Mon Sep 17 00:00:00 2001 From: Toke Emil Heldbo Reines Date: Fri, 23 Aug 2024 04:09:37 +0200 Subject: [PATCH] feat: Pass torch kwargs to bentoml.pytorch.load_model (#4930) --- src/bentoml/_internal/frameworks/pytorch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/bentoml/_internal/frameworks/pytorch.py b/src/bentoml/_internal/frameworks/pytorch.py index 99382534f67..6f442edcc38 100644 --- a/src/bentoml/_internal/frameworks/pytorch.py +++ b/src/bentoml/_internal/frameworks/pytorch.py @@ -5,6 +5,7 @@ from pathlib import Path from types import ModuleType from typing import TYPE_CHECKING +from typing import Any import cloudpickle @@ -46,6 +47,7 @@ def get(tag_like: str | Tag) -> Model: def load_model( bentoml_model: str | Tag | Model, device_id: t.Optional[str] = "cpu", + **torch_load_args: Any, ) -> torch.nn.Module: """ Load a model from a BentoML Model with given name. @@ -76,7 +78,9 @@ def load_model( weight_file = bentoml_model.path_of(MODEL_FILENAME) with Path(weight_file).open("rb") as file: - model: "torch.nn.Module" = torch.load(file, map_location=device_id) + model: "torch.nn.Module" = torch.load( + file, map_location=device_id, **torch_load_args + ) return model