diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index 08c0926277f1..3e5f838a270d 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -43,5 +43,5 @@ slow_summation, timeout_job, ) -from .rpc_run import rpc_run +from .runner import local_run, rpc_run from .utils import * diff --git a/python/tvm/testing/rpc_run.py b/python/tvm/testing/runner.py similarity index 66% rename from python/tvm/testing/rpc_run.py rename to python/tvm/testing/runner.py index 08c00ca4d172..5b677df4bd8f 100644 --- a/python/tvm/testing/rpc_run.py +++ b/python/tvm/testing/runner.py @@ -22,16 +22,14 @@ if TYPE_CHECKING: import numpy as np - from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig from tvm.runtime import Device, Module, NDArray # pylint: disable=import-outside-toplevel,protected-access -def _args_to_remote(args, device): +def _args_to_device(args, device): import numpy as np - from tvm.runtime.ndarray import NDArray, empty uploaded_args = [] @@ -45,7 +43,7 @@ def _args_to_remote(args, device): return uploaded_args -def _args_to_local(args): +def _args_to_numpy(args): from tvm.runtime.ndarray import NDArray downloaded_args = [] @@ -77,6 +75,77 @@ def export_with(func): return export_func, output_format +def local_run( # pylint: disable=too-many-arguments,too-many-locals + mod: "Module", + device_type: str, + args: List[Union["np.ndarray", "NDArray", int, float]], + evaluator_config: Optional["EvaluatorConfig"] = None, + export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]] = "tar", + output_format: Optional[str] = None, +): + """Run a TVM module on a local device. + + Parameters + ---------- + mod : Module + The TVM module to run. + device_type : str + The device type to run the module on. + args : List[Union[np.ndarray, NDArray, int, float]] + The arguments to be fed to the module. + evaluator_config : Optional[EvaluatorConfig] + The evaluator configuration to use. + export_func : Union[Callable[Module, str], Literal["tar", "ndk"]] + The function to export the module to a file. + If callable, it must be a function that takes two arguments: the module to export and the + path to export to. + If "tar", the module will be exported to a tar file. + If "ndk", the module will be exported to a shared library. + output_format : Optional[str] + The format of the exported module. + If not specified, it will be inferred from the `export_func` argument. + + Returns + ------- + args : List[Union[np.ndarray, NDArray, int, float]] + The results of running the module. + """ + import os.path as osp + import tempfile + + from tvm.meta_schedule.runner import EvaluatorConfig + from tvm.runtime import device, load_module + + evaluator_config = EvaluatorConfig._normalized(evaluator_config) + export_func, output_format = _normalize_export_func(export_func, output_format) + + with tempfile.TemporaryDirectory() as tmp_dir: + artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format) + export_func(mod, artifact_path) + device: Device = device(device_type, 0) + + try: + args = _args_to_device(args, device) + remote_mod = load_module(artifact_path) + profile_result = remote_mod.time_evaluator( + func_name=remote_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + )(*args) + print(profile_result) + remote_mod(*args) + args = _args_to_numpy(args) + finally: + pass + + return args + + def rpc_run( # pylint: disable=too-many-arguments,too-many-locals mod: "Module", device_type: str, @@ -103,7 +172,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals If not specified, the default RPC configuration will be used, which reads the following environment variables: - TVM_TRACKER_HOST - - TVM_TRACKER_PORmod + - TVM_TRACKER_PORT - TVM_TRACKER_KEY export_func : Union[Callable[Module, str], Literal["tar", "ndk"]] The function to export the module to a file. @@ -134,12 +203,12 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format) _, remote_path = osp.split(artifact_path) session = rpc_config.connect_server() - device: Device = session.device(dev_type=device_type, dev_id=0) + device: Device = session.device(device_type, 0) export_func(mod, artifact_path) try: session.upload(artifact_path, remote_path) - args = _args_to_remote(args, device) + args = _args_to_device(args, device) remote_mod = session.load_module(remote_path) profile_result = remote_mod.time_evaluator( func_name=remote_mod.entry_name, @@ -153,7 +222,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals )(*args) print(profile_result) remote_mod(*args) - args = _args_to_local(args) + args = _args_to_numpy(args) finally: session.remove(remote_path) session.remove(remote_path + "." + output_format)