Skip to content

Commit

Permalink
[Unity] Add instruments to relay translator (#15601)
Browse files Browse the repository at this point in the history
* [Unity] Add instruments to relay translator

Sometimes its useful to instrument relay passes that are run during
relay to relax translation, and this patch adds a new argument to relay
translator to accept an instruments list argument that gets passed onto
the PassContext used while running relay prefix passes

* Fix test case
  • Loading branch information
quic-sanirudh authored Aug 27, 2023
1 parent c9a7378 commit d2972f3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/tvm/relax/testing/relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
# pylint: disable=too-many-nested-blocks, unused-variable
"""Relay to Relax translator."""

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

import tvm
from tvm import relax, relay
from tvm.ir.module import IRModule
from tvm.ir.instrument import PassInstrument
from tvm.relax.testing import nn
from tvm.relay.backend.te_compiler import select_implementation
from tvm.runtime import NDArray
Expand All @@ -37,6 +38,7 @@ def from_relay(
*,
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
disabled_pass: Optional[List[str]] = None,
translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None,
append_op_attrs: bool = False,
Expand All @@ -60,6 +62,10 @@ def from_relay(
pass_config: Optional[Dict[str, Any]]
Pass configuration.
instruments : Optional[Sequence[PassInstrument]]
The list of pass instrument implementations to be passed onto relay
while calling relay passes
disabled_pass: Optional[List[str]]
Passes to disable.
Expand Down Expand Up @@ -255,6 +261,7 @@ def visit_func(node):
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
instruments=instruments,
):
mod = tvm.IRModule.from_expr(func)
mod = seq(mod)
Expand Down
26 changes: 26 additions & 0 deletions tests/python/relax/test_relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,5 +312,31 @@ def test_append_op_attrs():
assert "op_attrs" not in relax_mod_wo_attrs["concatenate"].attrs


def test_instruments_support():
x = relay.var("x", shape=(10, 16))
y = relay.var("y", shape=(10, 16))
out = relay.add(x, y)
mod = tvm.IRModule.from_expr(out)

@tvm.instrument.pass_instrument
class SampleRunBeforeAfterInstrument:
def __init__(self):
self.events = []

def run_before_pass(self, mod, info):
self.events.append("run before " + info.name)

def run_after_pass(self, mod, info):
self.events.append("run after " + info.name)

my_test = SampleRunBeforeAfterInstrument()
relax_mod_with_attrs = relay_translator.from_relay(
mod["main"], target="llvm", instruments=[my_test]
)

assert "run after " in "".join(my_test.events)
assert "run before " in "".join(my_test.events)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit d2972f3

Please sign in to comment.