Skip to content

Commit

Permalink
[Testing] Add tvm.testing.local_run
Browse files Browse the repository at this point in the history
This PR introduces `tvm.testing.local_run`, which serves as a convenient
numpy-in numpy-out interface to quickly run a `runtime.Module` in TVM
and obtain its running time and outputs.

Example:

```python

@I.ir_module
class Module:
  ...

n = 128
np_a = np.random.uniform(-1, 1, [1, 32, 1, 128]).astype(np.float16)
np_b = np.random.uniform(-1, 1, [1, 32, n, 128]).astype(np.float16)
np_c = np.random.uniform(-1, 1, [1, 1, 1, n]).astype(np.float16)
np_d = np.random.uniform(-1, 1, [1, 32, 1, n]).astype(np.float32)

_, _, _, np_d = local_run(
    tvm.build(Module, target="llvm"),
    device_type="cpu",
    args=[np_a, np_b, np_c, np_d],
)
```
  • Loading branch information
junrushao committed Jul 8, 2023
1 parent 24ae0d5 commit f8cbc0f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@
slow_summation,
timeout_job,
)
from .rpc_run import rpc_run
from .runner import local_run, rpc_run
from .utils import *
85 changes: 77 additions & 8 deletions python/tvm/testing/rpc_run.py → python/tvm/testing/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@

if TYPE_CHECKING:
import numpy as np

from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
from tvm.runtime import Device, Module, NDArray

# pylint: disable=import-outside-toplevel,protected-access


def _args_to_remote(args, device):
def _args_to_device(args, device):
import numpy as np

from tvm.runtime.ndarray import NDArray, empty

uploaded_args = []
Expand All @@ -45,7 +43,7 @@ def _args_to_remote(args, device):
return uploaded_args


def _args_to_local(args):
def _args_to_numpy(args):
from tvm.runtime.ndarray import NDArray

downloaded_args = []
Expand Down Expand Up @@ -77,6 +75,77 @@ def export_with(func):
return export_func, output_format


def local_run( # pylint: disable=too-many-arguments,too-many-locals
mod: "Module",
device_type: str,
args: List[Union["np.ndarray", "NDArray", int, float]],
evaluator_config: Optional["EvaluatorConfig"] = None,
export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]] = "tar",
output_format: Optional[str] = None,
):
"""Run a TVM module on a local device.
Parameters
----------
mod : Module
The TVM module to run.
device_type : str
The device type to run the module on.
args : List[Union[np.ndarray, NDArray, int, float]]
The arguments to be fed to the module.
evaluator_config : Optional[EvaluatorConfig]
The evaluator configuration to use.
export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
The function to export the module to a file.
If callable, it must be a function that takes two arguments: the module to export and the
path to export to.
If "tar", the module will be exported to a tar file.
If "ndk", the module will be exported to a shared library.
output_format : Optional[str]
The format of the exported module.
If not specified, it will be inferred from the `export_func` argument.
Returns
-------
args : List[Union[np.ndarray, NDArray, int, float]]
The results of running the module.
"""
import os.path as osp
import tempfile

from tvm.meta_schedule.runner import EvaluatorConfig
from tvm.runtime import device, load_module

evaluator_config = EvaluatorConfig._normalized(evaluator_config)
export_func, output_format = _normalize_export_func(export_func, output_format)

with tempfile.TemporaryDirectory() as tmp_dir:
artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
export_func(mod, artifact_path)
device: Device = device(device_type, 0)

try:
args = _args_to_device(args, device)
remote_mod = load_module(artifact_path)
profile_result = remote_mod.time_evaluator(
func_name=remote_mod.entry_name,
dev=device,
number=evaluator_config.number,
repeat=evaluator_config.repeat,
min_repeat_ms=evaluator_config.min_repeat_ms,
f_preproc="cache_flush_cpu_non_first_arg"
if evaluator_config.enable_cpu_cache_flush
else "",
)(*args)
print(profile_result)
remote_mod(*args)
args = _args_to_numpy(args)
finally:
pass

return args


def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
mod: "Module",
device_type: str,
Expand All @@ -103,7 +172,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
If not specified, the default RPC configuration will be used, which reads the following
environment variables:
- TVM_TRACKER_HOST
- TVM_TRACKER_PORmod
- TVM_TRACKER_PORT
- TVM_TRACKER_KEY
export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
The function to export the module to a file.
Expand Down Expand Up @@ -134,12 +203,12 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
_, remote_path = osp.split(artifact_path)
session = rpc_config.connect_server()
device: Device = session.device(dev_type=device_type, dev_id=0)
device: Device = session.device(device_type, 0)

export_func(mod, artifact_path)
try:
session.upload(artifact_path, remote_path)
args = _args_to_remote(args, device)
args = _args_to_device(args, device)
remote_mod = session.load_module(remote_path)
profile_result = remote_mod.time_evaluator(
func_name=remote_mod.entry_name,
Expand All @@ -153,7 +222,7 @@ def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
)(*args)
print(profile_result)
remote_mod(*args)
args = _args_to_local(args)
args = _args_to_numpy(args)
finally:
session.remove(remote_path)
session.remove(remote_path + "." + output_format)
Expand Down

0 comments on commit f8cbc0f

Please sign in to comment.