diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 6069f4289cf3..cade5457b03f 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -912,7 +912,9 @@ class RFactorBlockCreator : public BaseBlockCreator { write_regions_.reserve(old_block->writes.size()); for (const BufferRegion& write_region : old_block->writes) { Array region = write_region->region; - region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, 1)); + region.insert(region.begin() + factor_axis_, + Range::FromMinExtent(additional_iter_->var, + make_const(additional_iter_->var.dtype(), 1))); Optional rf_buffer = buffer_map.Get(write_region->buffer); ICHECK(rf_buffer.defined()); write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_))); @@ -1005,7 +1007,7 @@ class WriteBackBlockCreator : public BaseBlockCreator { Array region; region.reserve(buf_load->indices.size()); for (const PrimExpr& index : buf_load->indices) { - region.push_back(Range::FromMinExtent(index, 1)); + region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1))); } buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region))); } diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index c1eb04b7c314..43374d375105 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring import pytest + import tvm import tvm.testing from tvm import te, tir, topi @@ -1643,5 +1644,61 @@ def test_reduction_rfactor_topi_argmin(): verify_trace_roundtrip(s, mod=argmin_topi) +def test_reduction_rfactor_int64(): + # fmt: off + @T.prim_func + def before( + A: T.Buffer((T.int64(128), T.int64(128)), "float32"), + B: T.Buffer((T.int64(128), T.int64(128)), "float32"), + C: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid( + T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4) + ): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i0, i1]) + vk = T.axis.R( + T.int64(128), + i2_outer * T.int64(32) + i2_inner_outer * T.int64(4) + i2_inner_inner, + ) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + + @T.prim_func + def expected(A: T.Buffer((T.int64(128), T.int64(128)), "float32"), + B: T.Buffer((T.int64(128), T.int64(128)), "float32"), + C: T.Buffer((T.int64(128), T.int64(128)), "float32"), + ): + C_rf = T.alloc_buffer((T.int64(4), T.int64(128), T.int64(128)), "float32") + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(T.int64(128), T.int64(128), T.int64(4), T.int64(8), T.int64(4)): + with T.block("update_rf"): + vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer= T.axis.remap("SSSRR", [i2_inner_inner, i0, i1, i2_outer, i2_inner_outer]) + with T.init(): + C_rf[vi2_inner_inner, vi, vj] = 0.0 + C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( + A[vi, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)] + * B[vj, (((vi2_outer * T.int64(32)) + (vi2_inner_outer * T.int64(4))) + vi2_inner_inner)] + ) + + for i0_1, i1_1, i2_inner_inner_1 in T.grid(T.int64(128), T.int64(128), T.int64(4)): + with T.block("update"): + vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) + with T.init(): + C[vi_1, vj_1] = 0.0 + C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] + # fmt: on + + s = tir.Schedule(before, debug_mask="all") + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) + rf_block = s.rfactor(kii, 0) + assert_structural_equal_ignore_global_symbol(s.mod["main"], expected) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) + verify_trace_roundtrip(s, mod=before) + + if __name__ == "__main__": tvm.testing.main()