Skip to content

Commit

Permalink
[TIR] Output DeclBuffer in LowerTVMBuiltin (#15243)
Browse files Browse the repository at this point in the history
* [TIR] Output DeclBuffer in LowerTVMBuiltin

For the `stack_shape` and `stack_tcode` buffers, generate a
`DeclBuffer`.

This is a subset of the changes made in
#14778, broken out for ease of
testing and review.

* Updated LowerTVMBuiltin tests for DeclBuffer
  • Loading branch information
Lunderberg authored Jul 7, 2023
1 parent 916542e commit 5a78da4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class BuiltinLower : public StmtExprMutator {
if (scope.max_sizes.shape_stack != -1) {
scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)},
DataType::Int(64), "stack_shape");
stmt = DeclBuffer(scope.stack_shape, stmt);
stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack),
stmt);
}
Expand All @@ -159,6 +160,7 @@ class BuiltinLower : public StmtExprMutator {
stmt =
LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt);

stmt = DeclBuffer(scope.stack_tcode, stmt);
stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack),
stmt);
}
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def check_packed_func(target="llvm"):

# Recursively visit PrimFunc until we meet the for-loop:
while True:
if isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)):
if isinstance(
node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt, tvm.tir.DeclBuffer)
):
node = node.body
elif isinstance(node, tvm.tir.SeqStmt):
node = node[0]
Expand All @@ -98,7 +100,7 @@ def check_packed_func(target="llvm"):
#
# let stack_value = tir.tvm_stack_alloca("arg_value", 4)
#
alloca_value = alloca_tcode.body
alloca_value = alloca_tcode.body.body
assert isinstance(alloca_value, tvm.tir.LetStmt)

expected_value = tvm.tir.call_intrin(
Expand Down

0 comments on commit 5a78da4

Please sign in to comment.