Skip to content

Commit

Permalink
[MetaSchedule] Fix a multilevel tiling error on dynamic relax workload (
Browse files Browse the repository at this point in the history
#17465)

fix meta-schedule tiling primitive segfault on dynamic workload

Co-authored-by: wrongtest <[email protected]>
  • Loading branch information
wrongtest-intellif and wrongtest authored Oct 16, 2024
1 parent 35d6a1b commit 58a43c8
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1581,14 +1581,14 @@ std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::Sche
tir::IterVarType type = GetLoopIterType(loop_sref);
if (type == tir::kDataPar) {
const int64_t* extent = GetLoopIntExtent(loop_sref);
if (*extent != -1) {
if (extent && *extent != -1) {
cum_space_len *= *extent;
} else {
return std::make_pair(-1, -1);
}
} else if (type == tir::kCommReduce) {
const int64_t* extent = GetLoopIntExtent(loop_sref);
if (*extent != -1) {
if (extent && *extent != -1) {
cum_reduce_len *= *extent;
} else {
return std::make_pair(-1, -1);
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ Array<ExprRV> ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int
int max_innermost_factor,
Optional<Array<Integer>> decision) {
TVM_TIR_SCHEDULE_BEGIN();
// use None RV object to denotes auto-infer tile factors.
return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n,
max_innermost_factor, &decision));
max_innermost_factor, &decision),
/*convert_negone_to_none=*/true);
TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_);
throw;
}
Expand Down
12 changes: 10 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ class ConcreteScheduleNode : public ScheduleNode {
/*!
* \brief Add a list of integers as random variables into the symbol table
* \param value The list of integers to be added to the symbol table
* \param convert_negone_to_none Convert negative one to none RV.
* Which is convention of certain primitives.
* \return The new random variables created
*/
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value,
bool convert_negone_to_none = false);
/*! \brief Remove a random variable from the symbol table */
inline void RemoveFromSymbolTable(const ObjectRef& rv);
/*!
Expand Down Expand Up @@ -362,10 +365,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) {
return std::move(rv);
}

inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>& value) {
inline Array<ExprRV> ConcreteScheduleNode::CreateRV(const std::vector<int64_t>& value,
bool convert_negone_to_none) {
Array<ExprRV> results;
results.reserve(value.size());
for (int64_t v : value) {
if (convert_negone_to_none && v == -1) {
results.push_back(ExprRV(nullptr));
continue;
}
results.push_back(CreateRV(v));
}
return results;
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ Array<String> TranslateAddOutputRVs(
ICHECK(!rv_names->count(output))
<< "ValueError: The random variable has been produced once: " << rv_names->at(output);
String result{ObjectPtr<StringObj>{nullptr}};
if (output->IsInstance<BlockRVNode>()) {
if (!output.defined()) {
result = "_";
} else if (output->IsInstance<BlockRVNode>()) {
result = "b" + std::to_string(i);
} else if (output->IsInstance<LoopRVNode>()) {
result = "l" + std::to_string(i);
Expand Down
8 changes: 5 additions & 3 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array<runtime::Int>& candidat
Array<ExprRV> TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n,
int max_innermost_factor,
Optional<Array<Integer>> decision) {
Array<ExprRV> results = CreateRV(tir::SamplePerfectTile(
&this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision));

// use None RV object to denotes auto-infer tile factors.
Array<ExprRV> results =
CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n,
max_innermost_factor, &decision),
/*convert_negone_to_none=*/true);
static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{loop_rv},
Expand Down
28 changes: 28 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,33 @@ def test_sample_perfect_tile_after_copy():
sch_copy.sample_perfect_tile(i, n=4)


def test_sample_perfect_tile_on_dynamic_loops():
"""Currently dynamic loop is trivially tiled"""

@T.prim_func
def workload(a: T.handle) -> None:
n = T.int32()
A = T.match_buffer(a, (n, 1024))
for i, j in T.grid(n, 1024):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = 1.0

sch = tir.Schedule(workload, debug_mask="all")
di, si = sch.get_loops(sch.get_block("B"))

factors = sch.sample_perfect_tile(si, n=4)
factors = [sch.get(i) for i in factors]
prod = factors[0] * factors[1] * factors[2] * factors[3]
assert prod == 1024

factors = sch.sample_perfect_tile(di, n=4)
assert factors[0] is None
factors = [sch.get(i) for i in factors[1:]]
prod = factors[0] * factors[1] * factors[2]
assert prod == 1
verify_trace_roundtrip(sch, mod=workload)


if __name__ == "__main__":
tvm.testing.main()
35 changes: 35 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,41 @@ def test_split_with_inferred_factor():
verify_trace_roundtrip(sch=sch, mod=elementwise)


def test_split_with_dynamic_inferred_factor():
@T.prim_func
def before(a: T.handle, b: T.handle) -> None:
N = T.int32()
M = T.int32()
A = T.match_buffer(a, (N, 128, M))
B = T.match_buffer(b, (N, 128, M))
for i, j, k in T.grid(N, 128, M):
with T.block("B"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0

@T.prim_func
def expected(a: T.handle, b: T.handle) -> None:
N, M = T.int32(), T.int32()
A = T.match_buffer(a, (N, 128, M))
B = T.match_buffer(b, (N, 128, M))
for i_0, i_1, j_0, j_1, k_0, k_1 in T.grid((N + 15) // 16, 16, 4, 32, 16, (M + 15) // 16):
with T.block("B"):
vi = T.axis.spatial(N, i_0 * 16 + i_1)
vj = T.axis.spatial(128, j_0 * 32 + j_1)
vk = T.axis.spatial(M, k_0 * ((M + 15) // 16) + k_1)
T.where(i_0 * 16 + i_1 < N and k_0 * ((M + 15) // 16) + k_1 < M)
B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2.0)

sch = tir.Schedule(before, debug_mask="all")
block_b = sch.get_block("B")
i, j, k = sch.get_loops(block_b)
sch.split(i, factors=[None, 16])
sch.split(j, factors=[4, 32])
sch.split(k, factors=[16, None])
assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=before)


def test_split_with_predicate():
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
Expand Down

0 comments on commit 58a43c8

Please sign in to comment.