From 5cee2af47a77bf522fd957599913fe1b050170c3 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 4 Jul 2023 17:07:25 -0700 Subject: [PATCH] [DRAFT][TIR] Prototype of T.macro --- python/tvm/script/parser/_core.py | 2 +- python/tvm/script/parser/core/entry.py | 7 +++ python/tvm/script/parser/tir/__init__.py | 4 +- python/tvm/script/parser/tir/entry.py | 21 ++++++++- python/tvm/script/parser/tir/parser.py | 58 +++++++++++++++++++++++- 5 files changed, 86 insertions(+), 6 deletions(-) diff --git a/python/tvm/script/parser/_core.py b/python/tvm/script/parser/_core.py index 4f5411dc368fd..16e9dca190c25 100644 --- a/python/tvm/script/parser/_core.py +++ b/python/tvm/script/parser/_core.py @@ -18,5 +18,5 @@ # pylint: disable=unused-import from .core import dispatch, doc, utils from .core.dispatch import OpMethod, register_op -from .core.entry import parse +from .core.entry import gen_ast, parse from .core.parser import Parser diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 5315c0f6755e4..9be6aec73c24f 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -25,6 +25,13 @@ from .parser import Parser +def gen_ast(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any: + # Simply generate the AST. It will be parsed at the time of inclusion. + source = Source(program) + node = source.as_ast() + return node + + def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any: """Register a method for a operand type, AST operator node and operand index. diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py index ad16821a89a33..6f34cdb110432 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tir/__init__.py @@ -30,6 +30,6 @@ # so most tvmscript won't trigger pylint error here. prim_func = staticmethod else: - from .entry import prim_func + from .entry import prim_func, macro, include -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro", "include"] diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index d5bff7a856d56..1209b2955ff3d 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -16,13 +16,13 @@ # under the License. """The entry point of TVM parser for tir.""" import inspect -from typing import Callable, Union +from typing import Any, Callable, Union from tvm.ir.base import deprecated from tvm.tir import Buffer, PrimFunc from ...ir_builder.tir import buffer, ptr -from .._core import parse, utils +from .._core import doc, gen_ast, parse, utils def prim_func(func: Callable) -> Union[PrimFunc, Callable]: @@ -50,6 +50,23 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: setattr(prim_func, "dispatch_token", "tir") +def macro(func: Callable) -> doc.AST: + f = gen_ast(func, utils.inspect_function_capture(func)) + setattr(f, "__name__", func.__name__) + # We don't need to explicitly store the return value anywhere. + # This function is a decorator, so the return value will replace + # the function definition (to which the decorator it is applied) + # in that function's name space. + return f + + +# There is no dispatch_token for macro, because macro doesn't invoke parser. + + +def include(name: Union[str, doc.Name], *args, **kwargs) -> Any: + return + + class BufferProxy: """Buffer proxy class for constructing tir buffer.""" diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index f81f9bd9ea785..2f869cc41095d 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -16,9 +16,10 @@ # under the License. """The base parser for tir""" +import ast import contextlib from functools import partial -from typing import Any +from typing import Any, Union import tvm from tvm.ir import GlobalVar, PrimType @@ -427,6 +428,61 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: node : doc.Expr The doc AST Expr node. """ + def is_include_macro(node: doc.Call) -> bool: + if not isinstance(node.func, doc.Attribute): + return False + attr = node.func + if not isinstance(attr.value, doc.Name): + return False + if attr.value.id != "T" or attr.attr != "include": + return False + return True + + if isinstance(node.value, doc.Call) and is_include_macro(node.value): + def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]: + for decl in decl_list: + if isinstance(decl, doc.FunctionDef) and decl.name == name: + return decl + return None + + call = node.value + macro_name = call.args[0] + + if not isinstance(macro_name, str): + if not isinstance(macro_name, doc.Name): + self.report_error(node, "Invalid macro name in T.include") + macro_name = macro_name.id + + macro = self.var_table.get().get(macro_name) + if macro is None: + self.report_error(node, f"Undefined macro '{macro_name}'") + + if isinstance(macro, doc.Module): + macro_def = find_macro_def(macro_name, macro.body) + elif not isinstance(macro, doc.FunctionDef) or macro.name != macro_name: + macro_def = None + + if macro_def is None: + self.report_error(macro, f"Undefined macro {macro_name}") + + params = macro_def.args.args + args = call.args[1:] + body = macro_def.body + assert len(params) == len(args) + + + with self.var_table.with_frame(): + for p, a in zip(params, args): + lhs = doc.to_doc(ast.fix_missing_locations(ast.Name(id=p.arg, ctx=ast.Store()))) + if isinstance(a, doc.Name): + rhs = self.var_table.get()[a.id] + else: + rhs = self.visit(a) + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + return self.visit_body(macro_def.body) + + res = self.eval_expr(node.value) if res is None: pass