diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index b60e60c3cfc9..6195313fddae 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1581,14 +1581,14 @@ std::pair GetCumulativeSpaceAndReductionLength(const tir::Sche tir::IterVarType type = GetLoopIterType(loop_sref); if (type == tir::kDataPar) { const int64_t* extent = GetLoopIntExtent(loop_sref); - if (*extent != -1) { + if (extent && *extent != -1) { cum_space_len *= *extent; } else { return std::make_pair(-1, -1); } } else if (type == tir::kCommReduce) { const int64_t* extent = GetLoopIntExtent(loop_sref); - if (*extent != -1) { + if (extent && *extent != -1) { cum_reduce_len *= *extent; } else { return std::make_pair(-1, -1); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index f6cb1f05ef6e..dd1a376deaf8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -246,8 +246,10 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int int max_innermost_factor, Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); + // use None RV object to denotes auto-infer tile factors. return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, - max_innermost_factor, &decision)); + max_innermost_factor, &decision), + /*convert_negone_to_none=*/true); TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); throw; } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index b8ad56d2ab56..4aebe3036cf2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -219,9 +219,12 @@ class ConcreteScheduleNode : public ScheduleNode { /*! * \brief Add a list of integers as random variables into the symbol table * \param value The list of integers to be added to the symbol table + * \param convert_negone_to_none Convert negative one to none RV. + * Which is convention of certain primitives. * \return The new random variables created */ - inline Array CreateRV(const std::vector& value); + inline Array CreateRV(const std::vector& value, + bool convert_negone_to_none = false); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); /*! @@ -362,10 +365,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return std::move(rv); } -inline Array ConcreteScheduleNode::CreateRV(const std::vector& value) { +inline Array ConcreteScheduleNode::CreateRV(const std::vector& value, + bool convert_negone_to_none) { Array results; results.reserve(value.size()); for (int64_t v : value) { + if (convert_negone_to_none && v == -1) { + results.push_back(ExprRV(nullptr)); + continue; + } results.push_back(CreateRV(v)); } return results; diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 6e243bf19198..7421cbbf32df 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -227,7 +227,9 @@ Array TranslateAddOutputRVs( ICHECK(!rv_names->count(output)) << "ValueError: The random variable has been produced once: " << rv_names->at(output); String result{ObjectPtr{nullptr}}; - if (output->IsInstance()) { + if (!output.defined()) { + result = "_"; + } else if (output->IsInstance()) { result = "b" + std::to_string(i); } else if (output->IsInstance()) { result = "l" + std::to_string(i); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d790f21e671a..784ecdeb32cb 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -70,9 +70,11 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidat Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision) { - Array results = CreateRV(tir::SamplePerfectTile( - &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); - + // use None RV object to denotes auto-infer tile factors. + Array results = + CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision), + /*convert_negone_to_none=*/true); static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{loop_rv}, diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index 8ae576e9b922..f37c818e7992 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -212,5 +212,33 @@ def test_sample_perfect_tile_after_copy(): sch_copy.sample_perfect_tile(i, n=4) +def test_sample_perfect_tile_on_dynamic_loops(): + """Currently dynamic loop is trivially tiled""" + + @T.prim_func + def workload(a: T.handle) -> None: + n = T.int32() + A = T.match_buffer(a, (n, 1024)) + for i, j in T.grid(n, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 1.0 + + sch = tir.Schedule(workload, debug_mask="all") + di, si = sch.get_loops(sch.get_block("B")) + + factors = sch.sample_perfect_tile(si, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 1024 + + factors = sch.sample_perfect_tile(di, n=4) + assert factors[0] is None + factors = [sch.get(i) for i in factors[1:]] + prod = factors[0] * factors[1] * factors[2] + assert prod == 1 + verify_trace_roundtrip(sch, mod=workload) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index f5e5b3b54e76..22344acfe1d4 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -389,6 +389,41 @@ def test_split_with_inferred_factor(): verify_trace_roundtrip(sch=sch, mod=elementwise) +def test_split_with_dynamic_inferred_factor(): + @T.prim_func + def before(a: T.handle, b: T.handle) -> None: + N = T.int32() + M = T.int32() + A = T.match_buffer(a, (N, 128, M)) + B = T.match_buffer(b, (N, 128, M)) + for i, j, k in T.grid(N, 128, M): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle) -> None: + N, M = T.int32(), T.int32() + A = T.match_buffer(a, (N, 128, M)) + B = T.match_buffer(b, (N, 128, M)) + for i_0, i_1, j_0, j_1, k_0, k_1 in T.grid((N + 15) // 16, 16, 4, 32, 16, (M + 15) // 16): + with T.block("B"): + vi = T.axis.spatial(N, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 32 + j_1) + vk = T.axis.spatial(M, k_0 * ((M + 15) // 16) + k_1) + T.where(i_0 * 16 + i_1 < N and k_0 * ((M + 15) // 16) + k_1 < M) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2.0) + + sch = tir.Schedule(before, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 16]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, None]) + assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=before) + + def test_split_with_predicate(): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B")