Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Implement TIR macros #15260

Merged
merged 14 commits into from
Jul 11, 2023
2 changes: 1 addition & 1 deletion python/tvm/script/parser/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# pylint: disable=unused-import
from .core import dispatch, doc, utils
from .core.dispatch import OpMethod, register_op
from .core.entry import parse
from .core.entry import parse, parse_macro
from .core.parser import Parser
31 changes: 20 additions & 11 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@
from .parser import Parser


def _default_globals() -> Dict[str, Any]:
import tvm # pylint: disable=import-outside-toplevel
from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
from tvm.script.parser import tir # pylint: disable=import-outside-toplevel

extra_vars = {"tvm": tvm, "I": ir, "ir": ir, "T": tir, "tir": tir}
return extra_vars


def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Generate the AST, and the source code for __repr__."""
# The AST will be converted into TIR at the time of expansion.
source = Source(program)
source_txt = source.source
source_ast = source.as_ast()
closure_vars = extra_vars or _default_globals()
return source_ast, source_txt, closure_vars


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Register a method for a operand type, AST operator node and operand index.

Expand All @@ -42,17 +61,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
The parsed TVMScript program.
"""
if extra_vars is None:
import tvm # pylint: disable=import-outside-toplevel
from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
from tvm.script.parser import tir # pylint: disable=import-outside-toplevel

extra_vars = {
"tvm": tvm,
"I": ir,
"ir": ir,
"T": tir,
"tir": tir,
}
extra_vars = _default_globals()

ann = {}
if inspect.isfunction(program):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
# so most tvmscript won't trigger pylint error here.
prim_func = staticmethod
else:
from .entry import prim_func
from .entry import prim_func, macro

__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro"]
99 changes: 97 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.
"""The entry point of TVM parser for tir."""
import inspect
from typing import Callable, Union
from typing import Any, Callable, Dict, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer, ptr
from .._core import parse, utils
from .._core import doc, parse, parse_macro, utils


def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
Expand Down Expand Up @@ -50,6 +50,101 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
setattr(prim_func, "dispatch_token", "tir")


# Semantics of TIR macros:
# - Function that is decorated with @T.macro can have any parameters that
# follow Python syntax, i.e. positional, keyword, etc. Type annotations
# are not required, but are allowed.
# - Macro use follows the same syntax as a function call.
# For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into
# the body of the macro, and the body with the substituted values is then
# inserted at the point where the call to the macro is located.


class TIRMacro:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be specific to TIR, or should it apply to any dialect supported by TVMScript? Thinking that this would be quite useful on the unity branch as well, where a Relax method for an end-to-end model often contains many repeated elements. Implementing those as a macro would also allow Relax's shape propagation to resolve differently for each expansion of the macro (e.g. in a chain of convolutions).

If we want it to be more general, we could move the implementation over to the tvm.script.parser.ir namespace instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be I.macro then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would, yes. Within TVMScript, the default set of global definitions is defined here. This provides both tvm.script.tir as T and tvm.script.ir as I.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this in a separate PR?

"""Representation of T.macro."""

def __init__(
self,
source_ast: doc.AST,
source_txt: str,
closure_vars: Dict[str, Any],
func: Callable,
hygienic: bool,
) -> None:
self.source_ast = source_ast
self.source_txt = source_txt
self.closure_vars = closure_vars
self.func = func
self.hygienic = hygienic

def __repr__(self):
return self.source_txt


def macro(*args, hygienic: bool = True) -> Callable:
"""Decorator for macro definitions.

Parameters
----------
hygienic: bool
Specifies whether the macro is hygienic or not.
A macro is hygienic if all symbols used in the macro's body are resolved
to values from the location of the macro definition. A non-hygienic macro
will have its symbols resolved to values at the time of the macro's use.

Example:
```
import tvm
from tvm.script import tir as T

x_value = 128

@T.macro(hygienic=True)
def static_capture(A, B):
B[()] = A[x_value] ### x_value binds to 128

@T.macro(hygienic=False)
def dynamic_capture(A, B):
B[()] = A[x_value] ### x_value will bind at the time of use


@T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128]

@T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
```
"""

def _decorator(func: Callable) -> TIRMacro:
source_ast, source_txt, closure_vars = parse_macro(
func, utils.inspect_function_capture(func)
)
obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic)
obj.__name__ = func.__name__
# We don't need to explicitly store the return value anywhere.
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return obj

if len(args) == 0:
return _decorator
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])

raise ValueError(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
)


# There is no dispatch_token for macro, because macro doesn't invoke parser.


class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""

Expand Down
60 changes: 59 additions & 1 deletion python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""The base parser for tir"""

import contextlib
import inspect
from functools import partial
from typing import Any
from typing import Any, Union

import tvm
from tvm.ir import GlobalVar, PrimType
Expand All @@ -29,6 +30,8 @@
from ...ir_builder.base import IRBuilder
from ...ir_builder.base import IRBuilderFrame as Frame
from .._core import Parser, dispatch, doc
from ..core.parser import VarTable
from .entry import TIRMacro


def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
Expand Down Expand Up @@ -427,6 +430,12 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
node : doc.Expr
The doc AST Expr node.
"""

if isinstance(node.value, doc.Call):
callee = self.eval_expr(node.value.func)
if isinstance(callee, TIRMacro):
return expand_macro(self, callee, node.value)

res = self.eval_expr(node.value)
if res is None:
pass
Expand All @@ -447,6 +456,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
pass
else:
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")
return None # For pylint


@dispatch.register(token="tir", type_name="If")
Expand Down Expand Up @@ -528,3 +538,51 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
return I.decl_function(node.name, func_signature)


def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None:
"""Bind arguments to the macro invocation to the parameters in the macro definition,
and pass the macro body for further parsing.
"""

assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}"

def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]:
for decl in decl_list:
if isinstance(decl, doc.FunctionDef) and decl.name == name:
return decl
return None

macro_def = find_macro_def(callee.__name__, callee.source_ast.body)
assert macro_def is not None, f"Invalid macro AST for {callee.__name__}"
# `macro_def` is the FunctionDef of the macro.

args = [self.eval_expr(arg) for arg in call.args]
kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords}
param_binding = inspect.signature(callee.func).bind(*args, **kwargs)
param_binding.apply_defaults()
local_vars = param_binding.arguments

if callee.hygienic:
# If the macro was hygienic, construct new var_table with a single frame that
# contains the captured environment, and process the macro's body with that
# frame.
saved_var_table = self.var_table
self.var_table = VarTable()
with self.var_table.with_frame():
for k, v in callee.closure_vars.items():
self.var_table.add(k, v)
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)

self.var_table = saved_var_table

else:
# Otherwise, dynamically resolve symbols in the macro's body.
with self.var_table.with_frame():
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)
107 changes: 107 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Unittests for tvm.script.parser.tir"""

import pytest
import tvm.testing
from tvm.script.parser import tir as T
from tvm import ir, tir
Expand Down Expand Up @@ -71,5 +72,111 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
assert matmul.__name__ == "matmul"


def test_tir_macro_decorator_signature():
@T.prim_func
def evaluate0():
T.evaluate(0)

# Ok, no parentheses
@T.macro
def func1():
T.evaluate(0)

assert func1.hygienic

@T.prim_func
def use1():
func1()

tvm.ir.assert_structural_equal(use1, evaluate0)

# Ok, empty parentheses
@T.macro()
def func2():
T.evaluate(0)

assert func2.hygienic

@T.prim_func
def use2():
func2()

tvm.ir.assert_structural_equal(use1, evaluate0)

with pytest.raises(ValueError):
# Wrong: non-keyword argument
@T.macro(True)
def func3():
T.evaluate()


def test_tir_macro_signature():
@T.macro
def assign(i, *args, t1, **kwargs):
vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]])
kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk]

@T.prim_func
def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
assign(i, j, k, t1=A, t2=B, t3=C)

@T.prim_func
def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro)


def test_tir_macro_hygienic():
x_value = 128

@T.macro(hygienic=True)
def static_capture(A, B):
B[()] = A[x_value]

@T.prim_func
def use_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B)

@T.prim_func
def expected_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in range(10):
B[()] = A[128]

tvm.ir.assert_structural_equal(use_hygienic, expected_hygienic)


def test_tir_macro_non_hygienic():
x_value = 128

@T.macro(hygienic=False)
def dynamic_capture(A, B):
B[()] = A[x_value]

@T.prim_func
def use_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B)

@T.prim_func
def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
for x_value in range(10):
B[()] = A[x_value]

tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)


if __name__ == "__main__":
tvm.testing.main()
Loading