Skip to content

Commit

Permalink
[TIR] Fix Primitive Rfactor DType (#15413)
Browse files Browse the repository at this point in the history
The rfactor primitive will create/rewrite two blocks, together with the
block read/write regions. However, the generated read/write region extents
are not valid when it's a int64 index. This commit fixes the issue.
  • Loading branch information
Hzfengsy authored Jul 26, 2023
1 parent 22ec541 commit 3e00253
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,9 @@ class RFactorBlockCreator : public BaseBlockCreator {
write_regions_.reserve(old_block->writes.size());
for (const BufferRegion& write_region : old_block->writes) {
Array<Range> region = write_region->region;
region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, 1));
region.insert(region.begin() + factor_axis_,
Range::FromMinExtent(additional_iter_->var,
make_const(additional_iter_->var.dtype(), 1)));
Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer);
ICHECK(rf_buffer.defined());
write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_)));
Expand Down Expand Up @@ -1005,7 +1007,7 @@ class WriteBackBlockCreator : public BaseBlockCreator {
Array<Range> region;
region.reserve(buf_load->indices.size());
for (const PrimExpr& index : buf_load->indices) {
region.push_back(Range::FromMinExtent(index, 1));
region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1)));
}
buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region)));
}
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_tir_schedule_rfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest

import tvm
import tvm.testing
from tvm import te, tir, topi
Expand Down Expand Up @@ -1643,5 +1644,61 @@ def test_reduction_rfactor_topi_argmin():
verify_trace_roundtrip(s, mod=argmin_topi)


def test_reduction_rfactor_int64():
# fmt: off
@T.prim_func
def before(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(
T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)
):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i0, i1])
vk = T.axis.R(
T.int64(128),
i2_outer * T.int64(32) + i2_inner_outer * T.int64(4) + i2_inner_inner,
)
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])

@T.prim_func
def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
):
C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)), "float32")

for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)):
with T.block("update_rf"):
vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer= T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer])
with T.init():
C_rf[vi2_inner_inner, vi, vj] = 0.0
C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (
A[vi, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)]
* B[vj, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)]
)

for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128), T.int64(4)):
with T.block("update"):
vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1])
with T.init():
C[vi_1, vj_1] = 0.0
C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
# fmt: on

s = tir.Schedule(before, debug_mask="all")
update = s.get_block("update")
_, _, _, _, kii = s.get_loops(update)
rf_block = s.rfactor(kii, 0)
assert_structural_equal_ignore_global_symbol(s.mod["main"], expected)
assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))
assert s.get(update).same_as(s.get(s.get_block("update")))
verify_trace_roundtrip(s, mod=before)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 3e00253

Please sign in to comment.