Skip to content

Commit

Permalink
[Unity][Frontend][NN] Op print_ (#15604)
Browse files Browse the repository at this point in the history
* print done

* fix
  • Loading branch information
LeshengJin authored Aug 24, 2023
1 parent 619cbec commit 71cdd46
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,15 @@ def finalize(self) -> List[rx.Var]:

def print_(self, tensor: Tensor) -> None:
"""Encloses the side effect of NDArray printing"""
raise NotImplementedError
self.effect = rx.BlockBuilder.current().emit(
rx.call_pure_packed(
rx.extern("effect.print"),
self.effect,
tensor._expr, # pylint: disable=protected-access
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=self.effect.name_hint,
)


@register_func("effect.print")
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...block_builder import BlockBuilder
from ...struct_info import TensorStructInfo, TupleStructInfo
from .core import Tensor
from .spec import SpecBuilder

IntExpr = Union[int, _tir.PrimExpr]

Expand Down Expand Up @@ -938,3 +939,7 @@ def _convert(arg):
),
name=name_hint,
)


def print_(array: Tensor):
SpecBuilder.current().io_effect.print_(array)
38 changes: 38 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import pytest
import torch
import sys
import io

import tvm
import tvm.testing
Expand Down Expand Up @@ -304,5 +307,40 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten
tvm.ir.assert_structural_equal(irmodule, Expected)


def test_print():
class Model(Module):
def test(self, x: Tensor):
z = op.add(x, x)
op.print_(z)
return x

# fmt: off
@I.ir_module
class Expected:
@R.function
def _initialize_effect() -> R.Tuple(R.Object):
with R.dataflow():
_io: R.Object = R.null_value()
lv: R.Tuple(R.Object) = (_io,)
gv: R.Tuple(R.Object) = lv
R.output(gv)
return gv

@R.function
def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
with R.dataflow():
add: R.Tensor((10, 10), dtype="float32") = R.add(x, x)
_io1: R.Object = R.call_pure_packed("effect.print", _io, add, sinfo_args=(R.Object(),))
gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io1,)
R.output(gv1)
return gv1
# fmt: on

m = Model()
irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10], "float32")}})

tvm.ir.assert_structural_equal(irmodule["test"], Expected["test"])


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 71cdd46

Please sign in to comment.