Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Implement TIR macros #15260

Merged
merged 14 commits into from
Jul 11, 2023
Merged

Conversation

kparzysz-quic
Copy link
Contributor

@kparzysz-quic kparzysz-quic commented Jul 7, 2023

This patch introduces new symbol: T.macro. T.macro is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted into a PrimFunc by calling the macro the same way as a function.

For example:

@T.macro
def copy_backwards(dst, src, size):
    with T.block("backwards"):
        for i in T.serial(size):
            ai = T.axis.remap("S", [i])
            T.reads(src[0:size])
            T.writes(dst[0:size])
            dst[ai] = src[size - ai - 1]

@T.prim_func
def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
    copy_backwards(A, B, 128)

@T.prim_func
def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")):
    copy_backwards(A, B, 128)

The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types.

@T.prim_func
def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
    # with T.block("root"):
    for i in range(128):
        with T.block("copy"):
            ai = T.axis.spatial(128, i)
            T.reads(B[0:128])
            T.writes(A[0:128])
            A[ai] = B[128 - ai - 1]

@T.prim_func
def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")):
    # with T.block("root"):
    for i in range(128):
        with T.block("copy"):
            ai = T.axis.spatial(128, i)
            T.reads(B[0:128])
            T.writes(A[0:128])
            A[ai] = B[128 - ai - 1]

Semantics:

  • Function that is decorated with @T.macro can have any parameters that follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed.
  • For macro_name(arg1, arg2, arg3, ...), the values are substituted into the body of the macro, and the body with the substituted values is then inserted at the point where the call is located.
  • Macros are hygienic by default, but can be made non-hygienic via a keyword argument hygienic to T.macro, e.g. T.macro(hygienic=False).

This patch introduces two new symbols: `T.macro` and `T.insert`.
`T.macro` is a decorator that, when applied to a function, turns the
body of that function into a piece of TIR that can be inserted via
`T.insert` into a PrimFunc.

For example:

```python
@T.macro
def copy_backwards(dst, src, size):
    with T.block("backwards"):
        for i in T.serial(size):
            ai = T.axis.remap("S", [i])
            T.reads(src[0:size])
            T.writes(dst[0:size])
            dst[ai] = src[size - ai - 1]

@T.prim_func
def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
    T.insert(copy_backwards, A, B, 128)

@T.prim_func
def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")):
    T.insert(copy_backwards, A, B, 128)
```

The above will generate two PrimFuncs that do the same backwards copy,
but applied to buffers with different data types.

Semantics:
- Function that is decorated with @T.macro can have any parameters that
  follow Python syntax, i.e. positional, keyword, etc. Type annotations
  are not required, but are allowed.
- The arguments to `T.insert` are macro name followed by the argument
  list.
  For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into
  the body of the macro as in the call `arg1(arg2, arg3, ...)`.
  The body with the substituted values is then inserted at the point
  where the `T.insert` is located.
@tvm-bot
Copy link
Collaborator

tvm-bot commented Jul 7, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

3 similar comments
@tvm-bot
Copy link
Collaborator

tvm-bot commented Jul 7, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jul 7, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jul 7, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@kparzysz-quic
Copy link
Contributor Author

cc: @yzh119

Krzysztof Parzyszek added 2 commits July 7, 2023 06:41
One linter suggested something that the other didn't like...
Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it, just some questions on simplifying the implementation, which variables should be in-scope within the body of the macro, and whether macros should be provided for use in Relax as well as TIR.

I also liked your point from the discuss thread about the root tir::Block being handled differently if trying to generate a Stmt externally, and agree that it is better to re-parse the body of the macro in order to avoid those special cases.

python/tvm/script/parser/tir/parser.py Outdated Show resolved Hide resolved
python/tvm/script/parser/tir/parser.py Outdated Show resolved Hide resolved
# where the `T.insert` is located.


class TIRMacro:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be specific to TIR, or should it apply to any dialect supported by TVMScript? Thinking that this would be quite useful on the unity branch as well, where a Relax method for an end-to-end model often contains many repeated elements. Implementing those as a macro would also allow Relax's shape propagation to resolve differently for each expansion of the macro (e.g. in a chain of convolutions).

If we want it to be more general, we could move the implementation over to the tvm.script.parser.ir namespace instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be I.macro then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would, yes. Within TVMScript, the default set of global definitions is defined here. This provides both tvm.script.tir as T and tvm.script.ir as I.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this in a separate PR?

python/tvm/script/parser/tir/parser.py Outdated Show resolved Hide resolved
python/tvm/script/parser/tir/entry.py Outdated Show resolved Hide resolved
python/tvm/script/parser/tir/entry.py Outdated Show resolved Hide resolved
python/tvm/script/parser/tir/parser.py Outdated Show resolved Hide resolved
python/tvm/script/parser/tir/parser.py Outdated Show resolved Hide resolved
python/tvm/script/parser/tir/entry.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@kparzysz-quic kparzysz-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

python/tvm/script/parser/tir/parser.py Outdated Show resolved Hide resolved
@kparzysz-quic
Copy link
Contributor Author

Do you have any more comments? @Lunderberg

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, and thank you for making the changes!

I think the only open question was which namespace the macro support should go under. By placing it under the tvm.script.tir namespace, we'll need to keep a backwards-compatibility import if/when we move it to tvm.script.ir, unless done in the current PR. That said, that would be a pretty useful import anyways, regardless of the backwards compatibility, so I'm not worried there.

@kparzysz-quic kparzysz-quic merged commit fddbec7 into apache:main Jul 11, 2023
6 checks passed
@kparzysz-quic kparzysz-quic deleted the parser-macro branch July 11, 2023 16:37
@junrushao
Copy link
Member

Didn't have the chance to review timely but this looks really awesome to me!

@kparzysz-quic
Copy link
Contributor Author

@Lunderberg, @junrushao
I'm starting to work on making this more general and extending it to Relax. Relax code has a lot of assignments, the targets of which would disappear from the frame in the current implementation of expand_macro. My first thought was to do something like

@R.macro(outputs=["var2, "var3"]):
def something():
   var1 = ...
   var2 = ...
   var3 = ...

@R.function
def something_else(...)
   something()
   ... = var3

Admittedly, var3 = something() looks a lot prettier, but it's sort of breaks the idea behind a macro: the macro would have to "return" a value, which implies a lot more complications in parsing or expanding it.

I'm wondering if you have any better ideas.

@Lunderberg
Copy link
Contributor

In Relax, I think this use case could be handled by a function that returns a relax.Expr. Unlike in TIR, Relax has an additional normalization step that flattens out any relax::SeqExpr. So, if you had thje function something() return a relax::SeqExpr, then call it with var3 = something(), I think that would result in the behavior you want.

For multiple return values, I can't remember if Relax's normalizer would remove trivial Tuple bindings, where the function return value is a tuple which gets immediately unpacked, but that would be a reasonable extension to add to it.

@kparzysz-quic
Copy link
Contributor Author

For the purposes of the rest of the comment, I renamed TIRMacro to ScriptMacro, and created a subclass RelaxMacro.

There are some complications with macros returning values, specifically when the call to the macro is a part of a larger expression. The problem is that there are no specific parser visitors for subexpressions---the whole expression is evaluated by the evaluator all at once, so I can't do the same "AST injection" as I did for TIR (at least not without adding detailed visitors and capturing AST.Call nodes).

For example,

@R.macro
def m():
    return 1

@R.function
def foo():
    x = m() + 1

The entire m() + 1 will be evaluated, causing Python to actually call the RelaxMacro object. I went down that route a bit, added __call__ operator to RelaxMacro, parsed it as if it was R.function, and returned the body. This works, but has severe limitations: the macro's parameters need to be annotated and can only be objects that a R.function can take, so no str for example.

I'm inclined to follow the TIR route instead, but I wanted to consult with you in case this is a bad idea.

@kparzysz-quic
Copy link
Contributor Author

kparzysz-quic commented Jul 17, 2023

I think I got it to work:

import tvm
from tvm import relax
from tvm.script import relax as R

@R.macro
def alloc_and_shape(dtype: str):
    alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype=dtype)
    shape = R.shape_of(alloc)
    return shape

@R.function
def foo(x: R.Tensor((4, 4), "float32")):
    shape = alloc_and_shape(dtype="float32")
    return shape


print(alloc_and_shape)
print()
print(foo)

Produces

@R.macro
def alloc_and_shape(dtype: str):
    alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype=dtype)
    shape = R.shape_of(alloc)
    return shape


# from tvm.script import relax as R

@R.function
def foo(x: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
    alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([4, 4]), R.dtype("float32"), R.prim_value(0))
    shape: R.Shape([4, 4]) = R.shape_of(alloc)
    shape_1: R.Shape([4, 4]) = shape
    return shape_1

This is still via the __call__ route.

I'll wait for main to be merged into unity, and I'll create a (draft?) PR for this in unity.

junrushao pushed a commit to junrushao/tvm that referenced this pull request Jul 24, 2023
* [TIR] Implement TIR macros

This patch introduces two new symbols: `T.macro` and `T.insert`.
`T.macro` is a decorator that, when applied to a function, turns the
body of that function into a piece of TIR that can be inserted via
`T.insert` into a PrimFunc.

For example:

```python
@T.macro
def copy_backwards(dst, src, size):
    with T.block("backwards"):
        for i in T.serial(size):
            ai = T.axis.remap("S", [i])
            T.reads(src[0:size])
            T.writes(dst[0:size])
            dst[ai] = src[size - ai - 1]

@T.prim_func
def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
    T.insert(copy_backwards, A, B, 128)

@T.prim_func
def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")):
    T.insert(copy_backwards, A, B, 128)
```

The above will generate two PrimFuncs that do the same backwards copy,
but applied to buffers with different data types.

Semantics:
- Function that is decorated with @T.macro can have any parameters that
  follow Python syntax, i.e. positional, keyword, etc. Type annotations
  are not required, but are allowed.
- The arguments to `T.insert` are macro name followed by the argument
  list.
  For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into
  the body of the macro as in the call `arg1(arg2, arg3, ...)`.
  The body with the substituted values is then inserted at the point
  where the `T.insert` is located.

* Fix linter

* Fix linter again

One linter suggested something that the other didn't like...

* Get rid of T.insert, apply macro via function-call syntax

* Store closure vars in TIRMacro

* ast.parse always returns ast.Module, hence doc is doc.Module

* Simplify `expand_macro`, capture environment variables

* Implement macro hygiene

* Fix linter

* Make T.macro work same as T.macro()

The previous commit inadvertently made T.macro (without parentheses)
illegal, only abbreviated form allowed was T.macro(). Restore T.macro
as a valid decorator use.

* Edit comment: insertion -> expansion

* Add import pytest

* One more typo...

* Remove stale testcase
junrushao pushed a commit to junrushao/tvm that referenced this pull request Jul 27, 2023
* [TIR] Implement TIR macros

This patch introduces two new symbols: `T.macro` and `T.insert`.
`T.macro` is a decorator that, when applied to a function, turns the
body of that function into a piece of TIR that can be inserted via
`T.insert` into a PrimFunc.

For example:

```python
@T.macro
def copy_backwards(dst, src, size):
    with T.block("backwards"):
        for i in T.serial(size):
            ai = T.axis.remap("S", [i])
            T.reads(src[0:size])
            T.writes(dst[0:size])
            dst[ai] = src[size - ai - 1]

@T.prim_func
def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
    T.insert(copy_backwards, A, B, 128)

@T.prim_func
def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")):
    T.insert(copy_backwards, A, B, 128)
```

The above will generate two PrimFuncs that do the same backwards copy,
but applied to buffers with different data types.

Semantics:
- Function that is decorated with @T.macro can have any parameters that
  follow Python syntax, i.e. positional, keyword, etc. Type annotations
  are not required, but are allowed.
- The arguments to `T.insert` are macro name followed by the argument
  list.
  For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into
  the body of the macro as in the call `arg1(arg2, arg3, ...)`.
  The body with the substituted values is then inserted at the point
  where the `T.insert` is located.

* Fix linter

* Fix linter again

One linter suggested something that the other didn't like...

* Get rid of T.insert, apply macro via function-call syntax

* Store closure vars in TIRMacro

* ast.parse always returns ast.Module, hence doc is doc.Module

* Simplify `expand_macro`, capture environment variables

* Implement macro hygiene

* Fix linter

* Make T.macro work same as T.macro()

The previous commit inadvertently made T.macro (without parentheses)
illegal, only abbreviated form allowed was T.macro(). Restore T.macro
as a valid decorator use.

* Edit comment: insertion -> expansion

* Add import pytest

* One more typo...

* Remove stale testcase
junrushao pushed a commit to junrushao/tvm that referenced this pull request Jul 30, 2023
* [TIR] Implement TIR macros

This patch introduces two new symbols: `T.macro` and `T.insert`.
`T.macro` is a decorator that, when applied to a function, turns the
body of that function into a piece of TIR that can be inserted via
`T.insert` into a PrimFunc.

For example:

```python
@T.macro
def copy_backwards(dst, src, size):
    with T.block("backwards"):
        for i in T.serial(size):
            ai = T.axis.remap("S", [i])
            T.reads(src[0:size])
            T.writes(dst[0:size])
            dst[ai] = src[size - ai - 1]

@T.prim_func
def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
    T.insert(copy_backwards, A, B, 128)

@T.prim_func
def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")):
    T.insert(copy_backwards, A, B, 128)
```

The above will generate two PrimFuncs that do the same backwards copy,
but applied to buffers with different data types.

Semantics:
- Function that is decorated with @T.macro can have any parameters that
  follow Python syntax, i.e. positional, keyword, etc. Type annotations
  are not required, but are allowed.
- The arguments to `T.insert` are macro name followed by the argument
  list.
  For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into
  the body of the macro as in the call `arg1(arg2, arg3, ...)`.
  The body with the substituted values is then inserted at the point
  where the `T.insert` is located.

* Fix linter

* Fix linter again

One linter suggested something that the other didn't like...

* Get rid of T.insert, apply macro via function-call syntax

* Store closure vars in TIRMacro

* ast.parse always returns ast.Module, hence doc is doc.Module

* Simplify `expand_macro`, capture environment variables

* Implement macro hygiene

* Fix linter

* Make T.macro work same as T.macro()

The previous commit inadvertently made T.macro (without parentheses)
illegal, only abbreviated form allowed was T.macro(). Restore T.macro
as a valid decorator use.

* Edit comment: insertion -> expansion

* Add import pytest

* One more typo...

* Remove stale testcase
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants