Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Add instruments to relay translator #15601

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/tvm/relax/testing/relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
# pylint: disable=too-many-nested-blocks, unused-variable
"""Relay to Relax translator."""

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

import tvm
from tvm import relax, relay
from tvm.ir.module import IRModule
from tvm.ir.instrument import PassInstrument
from tvm.relax.testing import nn
from tvm.relay.backend.te_compiler import select_implementation
from tvm.runtime import NDArray
Expand All @@ -37,6 +38,7 @@ def from_relay(
*,
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
instruments: Optional[Sequence[PassInstrument]] = None,
disabled_pass: Optional[List[str]] = None,
translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None,
append_op_attrs: bool = False,
Expand All @@ -60,6 +62,10 @@ def from_relay(
pass_config: Optional[Dict[str, Any]]
Pass configuration.

instruments : Optional[Sequence[PassInstrument]]
The list of pass instrument implementations to be passed onto relay
while calling relay passes

disabled_pass: Optional[List[str]]
Passes to disable.

Expand Down Expand Up @@ -255,6 +261,7 @@ def visit_func(node):
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
instruments=instruments,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it, though I'm wondering if we should have it forward the instrumentation of the existing PassContext, rather than explicitly accept it as an argument. That way, debug instrumentation set by a much higher scope would also be preserved, without requiring an update to the immediately-calling scope.

instruments = tvm.transform.PassCurrent().instruments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review @Lunderberg.

I actually did it that way initially, but since the other arguments to PassContext such as opt_level and pass_config were taken as arguments to the relay translator, I modified it take instruments as an argument as well.

I can change it back to be forwarded from current PassContext as you've suggested, either way is fine with me, but in that case, should we also remove opt_level and pass_config from being arguments and forward them from current context as well? Since that would be a breaking change, I'm a bit hesitant to do that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point. Since the other arguments are currently specified explicitly, probably better to keep it that way for consistency.

In the long-term, I think that entering a PassContext should probably distinguish between options that override the previous (opt_level), options that append to the previous (instruments, disabled_passes).

):
mod = tvm.IRModule.from_expr(func)
mod = seq(mod)
Expand Down
26 changes: 26 additions & 0 deletions tests/python/relax/test_relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,5 +312,31 @@ def test_append_op_attrs():
assert "op_attrs" not in relax_mod_wo_attrs["concatenate"].attrs


def test_instruments_support():
x = relay.var("x", shape=(10, 16))
y = relay.var("y", shape=(10, 16))
out = relay.add(x, y)
mod = tvm.IRModule.from_expr(out)

@tvm.instrument.pass_instrument
class SampleRunBeforeAfterInstrument:
def __init__(self):
self.events = []

def run_before_pass(self, mod, info):
self.events.append("run before " + info.name)

def run_after_pass(self, mod, info):
self.events.append("run after " + info.name)

my_test = SampleRunBeforeAfterInstrument()
relax_mod_with_attrs = relay_translator.from_relay(
mod["main"], target="llvm", instruments=[my_test]
)

assert "run after " in "".join(my_test.events)
assert "run before " in "".join(my_test.events)


if __name__ == "__main__":
pytest.main([__file__])
Loading