Skip to content

Commit

Permalink
[DRAFT][TIR] Prototype of T.macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Krzysztof Parzyszek committed Jul 5, 2023
1 parent 9710d81 commit 5cee2af
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/parser/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# pylint: disable=unused-import
from .core import dispatch, doc, utils
from .core.dispatch import OpMethod, register_op
from .core.entry import parse
from .core.entry import gen_ast, parse
from .core.parser import Parser
7 changes: 7 additions & 0 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
from .parser import Parser


def gen_ast(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
# Simply generate the AST. It will be parsed at the time of inclusion.
source = Source(program)
node = source.as_ast()
return node


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Register a method for a operand type, AST operator node and operand index.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
# so most tvmscript won't trigger pylint error here.
prim_func = staticmethod
else:
from .entry import prim_func
from .entry import prim_func, macro, include

__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro", "include"]
21 changes: 19 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.
"""The entry point of TVM parser for tir."""
import inspect
from typing import Callable, Union
from typing import Any, Callable, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer, ptr
from .._core import parse, utils
from .._core import doc, gen_ast, parse, utils


def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
Expand Down Expand Up @@ -50,6 +50,23 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
setattr(prim_func, "dispatch_token", "tir")


def macro(func: Callable) -> doc.AST:
f = gen_ast(func, utils.inspect_function_capture(func))
setattr(f, "__name__", func.__name__)
# We don't need to explicitly store the return value anywhere.
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return f


# There is no dispatch_token for macro, because macro doesn't invoke parser.


def include(name: Union[str, doc.Name], *args, **kwargs) -> Any:
return


class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""

Expand Down
58 changes: 57 additions & 1 deletion python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""The base parser for tir"""

import ast
import contextlib
from functools import partial
from typing import Any
from typing import Any, Union

import tvm
from tvm.ir import GlobalVar, PrimType
Expand Down Expand Up @@ -427,6 +428,61 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
node : doc.Expr
The doc AST Expr node.
"""
def is_include_macro(node: doc.Call) -> bool:
if not isinstance(node.func, doc.Attribute):
return False
attr = node.func
if not isinstance(attr.value, doc.Name):
return False
if attr.value.id != "T" or attr.attr != "include":
return False
return True

if isinstance(node.value, doc.Call) and is_include_macro(node.value):
def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]:
for decl in decl_list:
if isinstance(decl, doc.FunctionDef) and decl.name == name:
return decl
return None

call = node.value
macro_name = call.args[0]

if not isinstance(macro_name, str):
if not isinstance(macro_name, doc.Name):
self.report_error(node, "Invalid macro name in T.include")
macro_name = macro_name.id

macro = self.var_table.get().get(macro_name)
if macro is None:
self.report_error(node, f"Undefined macro '{macro_name}'")

if isinstance(macro, doc.Module):
macro_def = find_macro_def(macro_name, macro.body)
elif not isinstance(macro, doc.FunctionDef) or macro.name != macro_name:
macro_def = None

if macro_def is None:
self.report_error(macro, f"Undefined macro {macro_name}")

params = macro_def.args.args
args = call.args[1:]
body = macro_def.body
assert len(params) == len(args)


with self.var_table.with_frame():
for p, a in zip(params, args):
lhs = doc.to_doc(ast.fix_missing_locations(ast.Name(id=p.arg, ctx=ast.Store())))
if isinstance(a, doc.Name):
rhs = self.var_table.get()[a.id]
else:
rhs = self.visit(a)
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)

return self.visit_body(macro_def.body)


res = self.eval_expr(node.value)
if res is None:
pass
Expand Down

0 comments on commit 5cee2af

Please sign in to comment.