From 338fc25c405ffb29eed90eea93c63f85499a621b Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 4 Jul 2023 17:07:25 -0700 Subject: [PATCH 01/14] [TIR] Implement TIR macros This patch introduces two new symbols: `T.macro` and `T.insert`. `T.macro` is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted via `T.insert` into a PrimFunc. For example: ```python @T.macro def copy_backwards(dst, src, size): with T.block("backwards"): for i in T.serial(size): ai = T.axis.remap("S", [i]) T.reads(src[0:size]) T.writes(dst[0:size]) dst[ai] = src[size - ai - 1] @T.prim_func def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): T.insert(copy_backwards, A, B, 128) @T.prim_func def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")): T.insert(copy_backwards, A, B, 128) ``` The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types. Semantics: - Function that is decorated with @T.macro can have any parameters that follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed. - The arguments to `T.insert` are macro name followed by the argument list. For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into the body of the macro as in the call `arg1(arg2, arg3, ...)`. The body with the substituted values is then inserted at the point where the `T.insert` is located. --- python/tvm/script/parser/_core.py | 2 +- python/tvm/script/parser/core/entry.py | 8 ++ python/tvm/script/parser/tir/__init__.py | 4 +- python/tvm/script/parser/tir/entry.py | 45 +++++++++- python/tvm/script/parser/tir/parser.py | 90 ++++++++++++++++++- .../unittest/test_tvmscript_parser_tir.py | 28 ++++++ 6 files changed, 171 insertions(+), 6 deletions(-) diff --git a/python/tvm/script/parser/_core.py b/python/tvm/script/parser/_core.py index 4f5411dc368f..b7ba5ee4713f 100644 --- a/python/tvm/script/parser/_core.py +++ b/python/tvm/script/parser/_core.py @@ -18,5 +18,5 @@ # pylint: disable=unused-import from .core import dispatch, doc, utils from .core.dispatch import OpMethod, register_op -from .core.entry import parse +from .core.entry import parse, parse_macro from .core.parser import Parser diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 5315c0f6755e..a52dd752140d 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -25,6 +25,14 @@ from .parser import Parser +def parse_macro(program: Union[Any, str]) -> Any: + """Generate the AST, and the source code for __repr__.""" + # The AST will be converted into TIR at the time of insertion. + source = Source(program) + node = source.as_ast() + return node, source.source + + def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Register a method for a operand type, AST operator node and operand index. diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py index ad16821a89a3..14c05bc13e0d 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tir/__init__.py @@ -30,6 +30,6 @@ # so most tvmscript won't trigger pylint error here. prim_func = staticmethod else: - from .entry import prim_func + from .entry import prim_func, macro, insert -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro", "insert"] diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index d5bff7a856d5..0c7aa2455d75 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,13 +16,13 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Callable, Union +from typing import Any, Callable, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc from ...ir_builder.tir import buffer, ptr -from .._core import parse, utils +from .._core import doc, parse, parse_macro, utils def prim_func(func: Callable) -> Union[PrimFunc, Callable]: @@ -50,6 +50,47 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: setattr(prim_func, "dispatch_token", "tir") +# Semantics of TIR macros: +# - Function that is decorated with @T.macro can have any parameters that +# follow Python syntax, i.e. positional, keyword, etc. Type annotations +# are not required, but are allowed. +# - The arguments to `T.insert` are: macro name (either as value, or as +# a string with the name), followed by the argument list. +# For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into +# the body of the macro as in the call `arg1(arg2, arg3, ...)`. +# The body with the substituted values is then inserted at the point +# where the `T.insert` is located. + + +class TIRMacro: + """Representation of T.macro: consists of the doc.AST and the text of the source.""" + + def __init__(self, node, source): + self.doc = node + self.source = source + + def __repr__(self): + return self.source + + +def macro(func: Callable) -> doc.AST: + obj = TIRMacro(*parse_macro(func)) + setattr(obj, "__name__", func.__name__) + # We don't need to explicitly store the return value anywhere. + # This function is a decorator, so the return value will replace + # the function definition (to which the decorator it is applied) + # in that function's name space. + return obj + + +# There is no dispatch_token for macro, because macro doesn't invoke parser. + + +def insert(name: Union[str, doc.Name], *args, **kwargs) -> Any: # pylint: disable=unused-argument + """Placeholder function, so that T.insert (i.e. macro insertion) can be parsed without errors. + """ + + class BufferProxy: """Buffer proxy class for constructing tir buffer.""" diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index f81f9bd9ea78..9c9acc97b2d5 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -16,9 +16,10 @@ # under the License. """The base parser for tir""" +import ast import contextlib from functools import partial -from typing import Any +from typing import Any, Union import tvm from tvm.ir import GlobalVar, PrimType @@ -427,6 +428,20 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: node : doc.Expr The doc AST Expr node. """ + + def is_insert_macro(node: doc.Call) -> bool: + if not isinstance(node.func, doc.Attribute): + return False + attr = node.func + if not isinstance(attr.value, doc.Name): + return False + if attr.value.id != "T" or attr.attr != "insert": + return False + return True + + if isinstance(node.value, doc.Call) and is_insert_macro(node.value): + return process_insert_macro(self, node.value) + res = self.eval_expr(node.value) if res is None: pass @@ -528,3 +543,76 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar # Only ret_type is needed for func_signature. func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) return I.decl_function(node.name, func_signature) + + +def process_insert_macro(self: Parser, call: doc.Call) -> None: + """Bind arguments to T.insert to the parameters of the macro, and pass the macro body + for further parsing. + """ + + def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: + for decl in decl_list: + if isinstance(decl, doc.FunctionDef) and decl.name == name: + return decl + return None + + macro_name = call.args[0] + + if not isinstance(macro_name, doc.Name): + self.report_error(call, "Invalid macro name in T.insert") + macro_name = macro_name.id + + macro = self.var_table.get().get(macro_name) + if macro is None: + self.report_error(node, f"Undefined macro '{macro_name}'") + + if isinstance(macro.doc, doc.Module): + macro_def = find_macro_def(macro_name, macro.doc.body) + elif not isinstance(macro.doc, doc.FunctionDef) or macro.doc.name != macro_name: + macro_def = None + + if macro_def is None: + self.report_error(call, f"Undefined macro {macro_name}") + + # `macro_def` is a FunctionDef of the macro. + + # We have the AST for the macro definition, and the AST for the call. We need to + # substitute the actual arguments from the call for the parameters from the + # definition. To allow full flexibility of python, i.e. positional, unnamed, and + # keyword parameters, get the python interpreter to do the work: create and execute + # the following: + # ``` + # def macro_name(...macro parameters...) + # return locals() + # tmp = macro_name(...arguments from the call...) + # ``` + # Obtain the dictionary `tmp` resulting from the execution, and update the var_table + # with it. + + # Construct the function with the macro's parameters, and returning locals(). + macro_ast = doc.from_doc(macro_def) + macro_ast.body = [ + ast.Return(value=ast.Call(func=ast.Name("locals", ctx=ast.Load()), args=[], keywords=[])) + ] + macro_ast.decorator_list = [] + + # Construct the assignment with the call. + call_ast = doc.from_doc(call) + call_ast.func = ast.Name(macro_name, ctx=ast.Load()) + call_ast.args = call_ast.args[1:] + tmp_name = "__tmp_param_eval_64e98b523301204b" + assign_ast = ast.Assign(targets=[ast.Name(tmp_name, ctx=ast.Store())], value=call_ast) + + # Finalize and execute the module: + module_ast = ast.Module(body=[macro_ast, assign_ast], type_ignores=[]) + module_ast = ast.fix_missing_locations(module_ast) + cmacro = compile(module_ast, filename="", mode="exec") + local_vars = {} + exec(cmacro, self.var_table.get(), local_vars) # pylint disable=exec-used + local_vars = local_vars[tmp_name] + + with self.var_table.with_frame(): + for k, v in local_vars.items(): + self.var_table.add(k, v) + + self.visit_body(macro_def.body) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 31bf5cc10180..eeb1c3bacd77 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -71,5 +71,33 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: assert matmul.__name__ == "matmul" +def test_tir_macro(): + @T.macro + def assign(i, *args, t1, **kwargs): + vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]]) + kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk] + + @T.prim_func + def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + T.insert(assign, i, j, k, t1=A, t2=B, t3=C) + + @T.prim_func + def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro) + + if __name__ == "__main__": tvm.testing.main() From 64e99ab4c452fc51210689184a8c1d180aa4d8fc Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 7 Jul 2023 06:40:38 -0700 Subject: [PATCH 02/14] Fix linter --- python/tvm/script/parser/tir/entry.py | 3 +-- python/tvm/script/parser/tir/parser.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 0c7aa2455d75..d520612bdff4 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -87,8 +87,7 @@ def macro(func: Callable) -> doc.AST: def insert(name: Union[str, doc.Name], *args, **kwargs) -> Any: # pylint: disable=unused-argument - """Placeholder function, so that T.insert (i.e. macro insertion) can be parsed without errors. - """ + """Placeholder function, so that T.insert (i.e. macro insertion) can be parsed without errors.""" class BufferProxy: diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 9c9acc97b2d5..1e8f643e76e0 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -547,7 +547,7 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar def process_insert_macro(self: Parser, call: doc.Call) -> None: """Bind arguments to T.insert to the parameters of the macro, and pass the macro body - for further parsing. + for further parsing. """ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: @@ -608,7 +608,7 @@ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any] module_ast = ast.fix_missing_locations(module_ast) cmacro = compile(module_ast, filename="", mode="exec") local_vars = {} - exec(cmacro, self.var_table.get(), local_vars) # pylint disable=exec-used + exec(cmacro, self.var_table.get(), local_vars) # pylint: disable=exec-used local_vars = local_vars[tmp_name] with self.var_table.with_frame(): From 3b1a26ae5d9c18c9c52bfc8cf9e95e3bd77af9a7 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 7 Jul 2023 08:30:10 -0700 Subject: [PATCH 03/14] Fix linter again One linter suggested something that the other didn't like... --- python/tvm/script/parser/tir/entry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index d520612bdff4..e7a9e695b4c2 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -87,7 +87,9 @@ def macro(func: Callable) -> doc.AST: def insert(name: Union[str, doc.Name], *args, **kwargs) -> Any: # pylint: disable=unused-argument - """Placeholder function, so that T.insert (i.e. macro insertion) can be parsed without errors.""" + """Placeholder function, so that T.insert (i.e. macro insertion) + can be parsed without errors. + """ class BufferProxy: From a1bb8b3d61712bbb2a67617ed36889beaa621305 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 7 Jul 2023 17:09:03 -0700 Subject: [PATCH 04/14] Get rid of T.insert, apply macro via function-call syntax --- python/tvm/script/parser/tir/__init__.py | 4 +- python/tvm/script/parser/tir/entry.py | 19 ++--- python/tvm/script/parser/tir/parser.py | 71 +++++-------------- .../unittest/test_tvmscript_parser_tir.py | 2 +- 4 files changed, 25 insertions(+), 71 deletions(-) diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py index 14c05bc13e0d..9d3fc1ec98da 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tir/__init__.py @@ -30,6 +30,6 @@ # so most tvmscript won't trigger pylint error here. prim_func = staticmethod else: - from .entry import prim_func, macro, insert + from .entry import prim_func, macro -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro", "insert"] +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro"] diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index e7a9e695b4c2..09be868b2cda 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -54,12 +54,10 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: # - Function that is decorated with @T.macro can have any parameters that # follow Python syntax, i.e. positional, keyword, etc. Type annotations # are not required, but are allowed. -# - The arguments to `T.insert` are: macro name (either as value, or as -# a string with the name), followed by the argument list. -# For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into -# the body of the macro as in the call `arg1(arg2, arg3, ...)`. -# The body with the substituted values is then inserted at the point -# where the `T.insert` is located. +# - Macro use follows the same syntax as a function call. +# For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into +# the body of the macro, and the body with the substituted values is then +# inserted at the point where the call to the macro is located. class TIRMacro: @@ -75,7 +73,8 @@ def __repr__(self): def macro(func: Callable) -> doc.AST: obj = TIRMacro(*parse_macro(func)) - setattr(obj, "__name__", func.__name__) + obj.__name__ = func.__name__ + obj.func = func # We don't need to explicitly store the return value anywhere. # This function is a decorator, so the return value will replace # the function definition (to which the decorator it is applied) @@ -86,12 +85,6 @@ def macro(func: Callable) -> doc.AST: # There is no dispatch_token for macro, because macro doesn't invoke parser. -def insert(name: Union[str, doc.Name], *args, **kwargs) -> Any: # pylint: disable=unused-argument - """Placeholder function, so that T.insert (i.e. macro insertion) - can be parsed without errors. - """ - - class BufferProxy: """Buffer proxy class for constructing tir buffer.""" diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 1e8f643e76e0..bee849efd86a 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -18,6 +18,7 @@ import ast import contextlib +import inspect from functools import partial from typing import Any, Union @@ -30,6 +31,7 @@ from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame from .._core import Parser, dispatch, doc +from .entry import TIRMacro def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: @@ -429,18 +431,10 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: The doc AST Expr node. """ - def is_insert_macro(node: doc.Call) -> bool: - if not isinstance(node.func, doc.Attribute): - return False - attr = node.func - if not isinstance(attr.value, doc.Name): - return False - if attr.value.id != "T" or attr.attr != "insert": - return False - return True - - if isinstance(node.value, doc.Call) and is_insert_macro(node.value): - return process_insert_macro(self, node.value) + if isinstance(node.value, doc.Call): + callee = self.eval_expr(node.value.func) + if isinstance(callee, TIRMacro): + return expand_macro(self, callee, node.value) res = self.eval_expr(node.value) if res is None: @@ -545,9 +539,9 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar return I.decl_function(node.name, func_signature) -def process_insert_macro(self: Parser, call: doc.Call) -> None: - """Bind arguments to T.insert to the parameters of the macro, and pass the macro body - for further parsing. +def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None: + """Bind arguments to the macro invocation to the parameters in the macro definition, + and pass the macro body for further parsing. """ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: @@ -556,13 +550,9 @@ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any] return decl return None - macro_name = call.args[0] - - if not isinstance(macro_name, doc.Name): - self.report_error(call, "Invalid macro name in T.insert") - macro_name = macro_name.id - + macro_name = callee.__name__ macro = self.var_table.get().get(macro_name) + if macro is None: self.report_error(node, f"Undefined macro '{macro_name}'") @@ -576,40 +566,11 @@ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any] # `macro_def` is a FunctionDef of the macro. - # We have the AST for the macro definition, and the AST for the call. We need to - # substitute the actual arguments from the call for the parameters from the - # definition. To allow full flexibility of python, i.e. positional, unnamed, and - # keyword parameters, get the python interpreter to do the work: create and execute - # the following: - # ``` - # def macro_name(...macro parameters...) - # return locals() - # tmp = macro_name(...arguments from the call...) - # ``` - # Obtain the dictionary `tmp` resulting from the execution, and update the var_table - # with it. - - # Construct the function with the macro's parameters, and returning locals(). - macro_ast = doc.from_doc(macro_def) - macro_ast.body = [ - ast.Return(value=ast.Call(func=ast.Name("locals", ctx=ast.Load()), args=[], keywords=[])) - ] - macro_ast.decorator_list = [] - - # Construct the assignment with the call. - call_ast = doc.from_doc(call) - call_ast.func = ast.Name(macro_name, ctx=ast.Load()) - call_ast.args = call_ast.args[1:] - tmp_name = "__tmp_param_eval_64e98b523301204b" - assign_ast = ast.Assign(targets=[ast.Name(tmp_name, ctx=ast.Store())], value=call_ast) - - # Finalize and execute the module: - module_ast = ast.Module(body=[macro_ast, assign_ast], type_ignores=[]) - module_ast = ast.fix_missing_locations(module_ast) - cmacro = compile(module_ast, filename="", mode="exec") - local_vars = {} - exec(cmacro, self.var_table.get(), local_vars) # pylint: disable=exec-used - local_vars = local_vars[tmp_name] + args = [self.eval_expr(arg) for arg in call.args] + kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords} + param_binding = inspect.signature(callee.func).bind(*args, **kwargs) + param_binding.apply_defaults() + local_vars = param_binding.arguments with self.var_table.with_frame(): for k, v in local_vars.items(): diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index eeb1c3bacd77..d8e4c421f6a8 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -84,7 +84,7 @@ def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128]) for i, j, k in T.grid(128, 128, 128): with T.block("update"): - T.insert(assign, i, j, k, t1=A, t2=B, t3=C) + assign(i, j, k, t1=A, t2=B, t3=C) @T.prim_func def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None: From 295777829e50eaf9add283292937b124eb6b4e85 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 7 Jul 2023 17:16:01 -0700 Subject: [PATCH 05/14] Store closure vars in TIRMacro --- python/tvm/script/parser/tir/entry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 09be868b2cda..c9476f7420ea 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -75,6 +75,7 @@ def macro(func: Callable) -> doc.AST: obj = TIRMacro(*parse_macro(func)) obj.__name__ = func.__name__ obj.func = func + obj.closure_vars = utils.inspect_function_capture(func) # We don't need to explicitly store the return value anywhere. # This function is a decorator, so the return value will replace # the function definition (to which the decorator it is applied) From 152dea9f5319f5b5c9d80d2760ee0ee88326df6f Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 7 Jul 2023 17:17:05 -0700 Subject: [PATCH 06/14] ast.parse always returns ast.Module, hence doc is doc.Module --- python/tvm/script/parser/tir/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index bee849efd86a..b00c4cc94ed7 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -558,7 +558,7 @@ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any] if isinstance(macro.doc, doc.Module): macro_def = find_macro_def(macro_name, macro.doc.body) - elif not isinstance(macro.doc, doc.FunctionDef) or macro.doc.name != macro_name: + else: macro_def = None if macro_def is None: From 53c718abd04a5b94fdc33f3709b5e4119715aa3d Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sat, 8 Jul 2023 06:15:54 -0700 Subject: [PATCH 07/14] Simplify `expand_macro`, capture environment variables --- python/tvm/script/parser/core/entry.py | 29 +++++++++++++------------- python/tvm/script/parser/tir/entry.py | 17 ++++++++------- python/tvm/script/parser/tir/parser.py | 20 +++++------------- 3 files changed, 30 insertions(+), 36 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index a52dd752140d..c24f2f950142 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -25,12 +25,23 @@ from .parser import Parser -def parse_macro(program: Union[Any, str]) -> Any: +def _default_globals() -> Dict[str, Any]: + import tvm # pylint: disable=import-outside-toplevel + from tvm.script.parser import ir # pylint: disable=import-outside-toplevel + from tvm.script.parser import tir # pylint: disable=import-outside-toplevel + + extra_vars = {"tvm": tvm, "I": ir, "ir": ir, "T": tir, "tir": tir} + return extra_vars + + +def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Generate the AST, and the source code for __repr__.""" # The AST will be converted into TIR at the time of insertion. source = Source(program) - node = source.as_ast() - return node, source.source + source_txt = source.source + source_ast = source.as_ast() + closure_vars = extra_vars or _default_globals() + return source_ast, source_txt, closure_vars def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: @@ -50,17 +61,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) The parsed TVMScript program. """ if extra_vars is None: - import tvm # pylint: disable=import-outside-toplevel - from tvm.script.parser import ir # pylint: disable=import-outside-toplevel - from tvm.script.parser import tir # pylint: disable=import-outside-toplevel - - extra_vars = { - "tvm": tvm, - "I": ir, - "ir": ir, - "T": tir, - "tir": tir, - } + extra_vars = _default_globals() ann = {} if inspect.isfunction(program): diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index c9476f7420ea..c260daefb625 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,7 +16,7 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc @@ -63,19 +63,22 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: class TIRMacro: """Representation of T.macro: consists of the doc.AST and the text of the source.""" - def __init__(self, node, source): - self.doc = node - self.source = source + func: Callable + + def __init__(self, source_ast: doc.AST, source_txt: str, closure_vars: Dict[str, Any]) -> None: + self.source_ast = source_ast + self.source_txt = source_txt + self.closure_vars = closure_vars + self.func = None def __repr__(self): - return self.source + return self.source_txt def macro(func: Callable) -> doc.AST: - obj = TIRMacro(*parse_macro(func)) + obj = TIRMacro(*parse_macro(func, utils.inspect_function_capture(func))) obj.__name__ = func.__name__ obj.func = func - obj.closure_vars = utils.inspect_function_capture(func) # We don't need to explicitly store the return value anywhere. # This function is a decorator, so the return value will replace # the function definition (to which the decorator it is applied) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index b00c4cc94ed7..e58f48d7eac9 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -544,27 +544,17 @@ def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None: and pass the macro body for further parsing. """ + assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}" + def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: for decl in decl_list: if isinstance(decl, doc.FunctionDef) and decl.name == name: return decl return None - macro_name = callee.__name__ - macro = self.var_table.get().get(macro_name) - - if macro is None: - self.report_error(node, f"Undefined macro '{macro_name}'") - - if isinstance(macro.doc, doc.Module): - macro_def = find_macro_def(macro_name, macro.doc.body) - else: - macro_def = None - - if macro_def is None: - self.report_error(call, f"Undefined macro {macro_name}") - - # `macro_def` is a FunctionDef of the macro. + macro_def = find_macro_def(callee.__name__, callee.source_ast.body) + assert macro_def is not None, f"Invalid macro AST for {callee.__name__}" + # `macro_def` is the FunctionDef of the macro. args = [self.eval_expr(arg) for arg in call.args] kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords} From de6e34cbb6465bbea1df395e31441e0ba42b7293 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sat, 8 Jul 2023 07:29:41 -0700 Subject: [PATCH 08/14] Implement macro hygiene --- python/tvm/script/parser/tir/entry.py | 77 +++++++++++++++---- python/tvm/script/parser/tir/parser.py | 27 ++++++- .../unittest/test_tvmscript_parser_tir.py | 42 +++++++++- 3 files changed, 127 insertions(+), 19 deletions(-) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index c260daefb625..ae29cdb24310 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -61,29 +61,78 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: class TIRMacro: - """Representation of T.macro: consists of the doc.AST and the text of the source.""" + """Representation of T.macro.""" - func: Callable - - def __init__(self, source_ast: doc.AST, source_txt: str, closure_vars: Dict[str, Any]) -> None: + def __init__( + self, + source_ast: doc.AST, + source_txt: str, + closure_vars: Dict[str, Any], + func: Callable, + hygienic: bool, + ) -> None: self.source_ast = source_ast self.source_txt = source_txt self.closure_vars = closure_vars - self.func = None + self.func = func + self.hygienic = hygienic def __repr__(self): return self.source_txt -def macro(func: Callable) -> doc.AST: - obj = TIRMacro(*parse_macro(func, utils.inspect_function_capture(func))) - obj.__name__ = func.__name__ - obj.func = func - # We don't need to explicitly store the return value anywhere. - # This function is a decorator, so the return value will replace - # the function definition (to which the decorator it is applied) - # in that function's name space. - return obj +def macro(*, hygienic: bool = True) -> Callable: + """Decorator for macro definitions. + + Parameters + ---------- + hygienic: bool + Specifies whether the macro is hygienic or not. + A macro is hygienic if all symbols used in the macro's body are resolved + to values from the location of the macro definition. A non-hygienic macro + will have its symbols resolved to values at the time of the macro's use. + + Example: + ``` + import tvm + from tvm.script import tir as T + + x_value = 128 + + @T.macro(hygienic=True) + def static_capture(A, B): + B[()] = A[x_value] ### x_value binds to 128 + + @T.macro(hygienic=False) + def dynamic_capture(A, B): + B[()] = A[x_value] ### x_value will bind at the time of use + + + @T.prim_func + def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + static_capture(A, B) ### Produces B[()] = A[128] + + @T.prim_func + def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + dynamic_capture(A, B) ### Produces B[()] = A[x_value] + ``` + """ + + def _decorator(func: Callable) -> TIRMacro: + source_ast, source_txt, closure_vars = parse_macro( + func, utils.inspect_function_capture(func) + ) + obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic) + obj.__name__ = func.__name__ + # We don't need to explicitly store the return value anywhere. + # This function is a decorator, so the return value will replace + # the function definition (to which the decorator it is applied) + # in that function's name space. + return obj + + return _decorator # There is no dispatch_token for macro, because macro doesn't invoke parser. diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index e58f48d7eac9..ae137f03b133 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -31,6 +31,7 @@ from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame from .._core import Parser, dispatch, doc +from ..core.parser import VarTable from .entry import TIRMacro @@ -562,8 +563,26 @@ def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any] param_binding.apply_defaults() local_vars = param_binding.arguments - with self.var_table.with_frame(): - for k, v in local_vars.items(): - self.var_table.add(k, v) + if callee.hygienic: + # If the macro was hygienic, construct new var_table with a single frame that + # contains the captured environment, and process the macro's body with that + # frame. + saved_var_table = self.var_table + self.var_table = VarTable() + with self.var_table.with_frame(): + for k, v in callee.closure_vars.items(): + self.var_table.add(k, v) + for k, v in local_vars.items(): + self.var_table.add(k, v) + + self.visit_body(macro_def.body) + + self.var_table = saved_var_table + + else: + # Otherwise, dynamically resolve symbols in the macro's body. + with self.var_table.with_frame(): + for k, v in local_vars.items(): + self.var_table.add(k, v) - self.visit_body(macro_def.body) + self.visit_body(macro_def.body) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index d8e4c421f6a8..3ab044e8c6e1 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -71,7 +71,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: assert matmul.__name__ == "matmul" -def test_tir_macro(): +def test_tir_macro_signature(): @T.macro def assign(i, *args, t1, **kwargs): vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]]) @@ -99,5 +99,45 @@ def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None: tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro) +def test_tir_macro_hygienic(): + x_value = 128 + + @T.macro(hygienic=True) + def static_capture(A, B): + B[()] = A[x_value] + + @T.prim_func + def use_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + static_capture(A, B) + + @T.prim_func + def expected_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in range(10): + B[()] = A[128] + + tvm.ir.assert_structural_equal(use_hygienic, expected_hygienic) + + +def test_tir_macro_non_hygienic(): + x_value = 128 + + @T.macro(hygienic=False) + def dynamic_capture(A, B): + B[()] = A[x_value] + + @T.prim_func + def use_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in T.serial(10): + dynamic_capture(A, B) + + @T.prim_func + def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + for x_value in range(10): + B[()] = A[x_value] + + tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic) + + if __name__ == "__main__": tvm.testing.main() From 4a67795a9e181d80b413edda610cfe06c0cf5b7f Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sat, 8 Jul 2023 09:40:34 -0700 Subject: [PATCH 09/14] Fix linter --- python/tvm/script/parser/tir/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index ae137f03b133..67e14d0e9772 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -16,7 +16,6 @@ # under the License. """The base parser for tir""" -import ast import contextlib import inspect from functools import partial @@ -457,6 +456,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: pass else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") + return None # For pylint @dispatch.register(token="tir", type_name="If") From d2180fe323715f62d5775f0d378ea2c2c30f6294 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sat, 8 Jul 2023 15:18:36 -0700 Subject: [PATCH 10/14] Make T.macro work same as T.macro() The previous commit inadvertently made T.macro (without parentheses) illegal, only abbreviated form allowed was T.macro(). Restore T.macro as a valid decorator use. --- python/tvm/script/parser/tir/entry.py | 11 +++- .../unittest/test_tvmscript_parser_tir.py | 58 +++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index ae29cdb24310..64b71d699f3d 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -81,7 +81,7 @@ def __repr__(self): return self.source_txt -def macro(*, hygienic: bool = True) -> Callable: +def macro(*args, hygienic: bool = True) -> Callable: """Decorator for macro definitions. Parameters @@ -132,7 +132,14 @@ def _decorator(func: Callable) -> TIRMacro: # in that function's name space. return obj - return _decorator + if len(args) == 0: + return _decorator + if len(args) == 1 and inspect.isfunction(args[0]): + return _decorator(args[0]) + + raise ValueError( + "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])" + ) # There is no dispatch_token for macro, because macro doesn't invoke parser. diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 3ab044e8c6e1..74cfcff10e8b 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -71,6 +71,64 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: assert matmul.__name__ == "matmul" +def test_tir_macro_decorator(): + @T.macro + def func1(n): + T.evaluate(n) + + assert func1.hygienic + + @T.macro() + def func2(n): + T.evaluate(n) + + assert func2.hygienic + + with pytests.raises(ValueException) as exc: + + @T.macro(True) + def func3(n): + T.evaluate(n) + + +def test_tir_macro_decorator_signature(): + @T.prim_func + def evaluate0(): + T.evaluate(0) + + # Ok, no parentheses + @T.macro + def func1(): + T.evaluate(0) + + assert func1.hygienic + + @T.prim_func + def use1(): + func1() + + tvm.ir.assert_structural_equal(use1, evaluate0) + + # Ok, empty parentheses + @T.macro() + def func2(): + T.evaluate(0) + + assert func2.hygienic + + @T.prim_func + def use2(): + func2() + + tvm.ir.assert_structural_equal(use1, evaluate0) + + with pytest.raises(ValueError): + # Wrong: non-keyword argument + @T.macro(True) + def func3(): + T.evaluate() + + def test_tir_macro_signature(): @T.macro def assign(i, *args, t1, **kwargs): From 0c057b5eeef7adb1a0ac2cf03819bd1772ba5292 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sat, 8 Jul 2023 15:26:21 -0700 Subject: [PATCH 11/14] Edit comment: insertion -> expansion --- python/tvm/script/parser/core/entry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index c24f2f950142..08a593d5d31b 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -36,7 +36,7 @@ def _default_globals() -> Dict[str, Any]: def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Generate the AST, and the source code for __repr__.""" - # The AST will be converted into TIR at the time of insertion. + # The AST will be converted into TIR at the time of expansion. source = Source(program) source_txt = source.source source_ast = source.as_ast() From bf9f03825f0910d3711df9e0bde2bdd6fc6af369 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sat, 8 Jul 2023 18:32:09 -0700 Subject: [PATCH 12/14] Add import pytest --- tests/python/unittest/test_tvmscript_parser_tir.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 74cfcff10e8b..b62707e30e3b 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -16,6 +16,7 @@ # under the License. """Unittests for tvm.script.parser.tir""" +import pytest import tvm.testing from tvm.script.parser import tir as T from tvm import ir, tir From 5a637dc436dd4940c08f9682cad4a774bcdeebb1 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sun, 9 Jul 2023 06:06:45 -0700 Subject: [PATCH 13/14] One more typo... --- tests/python/unittest/test_tvmscript_parser_tir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index b62707e30e3b..c44393c30450 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -85,7 +85,7 @@ def func2(n): assert func2.hygienic - with pytests.raises(ValueException) as exc: + with pytest.raises(ValueException) as exc: @T.macro(True) def func3(n): From cdb04b4b9303ca2cd0166dd374550f8c14fabd0c Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Sun, 9 Jul 2023 08:37:06 -0700 Subject: [PATCH 14/14] Remove stale testcase --- .../unittest/test_tvmscript_parser_tir.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index c44393c30450..38d3e1474656 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -72,26 +72,6 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: assert matmul.__name__ == "matmul" -def test_tir_macro_decorator(): - @T.macro - def func1(n): - T.evaluate(n) - - assert func1.hygienic - - @T.macro() - def func2(n): - T.evaluate(n) - - assert func2.hygienic - - with pytest.raises(ValueException) as exc: - - @T.macro(True) - def func3(n): - T.evaluate(n) - - def test_tir_macro_decorator_signature(): @T.prim_func def evaluate0():