-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Allow symbolic bounds in IndexMap analysis #15264
[TIR] Allow symbolic bounds in IndexMap analysis #15264
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
757600f
to
0733726
Compare
31c4756
to
34a60f9
Compare
This PR adds the bounds of shape variables to the arithmetic analyzer so that it is possible to simplify certain expressions.
34a60f9
to
7500285
Compare
Following apache#15264, this PR makes changes accordingly to the Unity branch to enable symbolic bounds in IndexMap analysis.
Following #15264, this PR makes changes accordingly to the Unity branch to enable symbolic bounds in IndexMap analysis.
This PR adds the bounds of shape variables to the arithmetic analyzer so that it is possible to simplify certain expressions.
Following apache#15264, this PR makes changes accordingly to the Unity branch to enable symbolic bounds in IndexMap analysis.
Hi @junrushao , I have encountered some issue and bisect to this pull request. here is my case: import tvm
from tvm.script import tir as T
from tvm.tir import IndexMap
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col
def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j):
thread_id = kernel_i * 2 + kernel_j // 8
local_id = kernel_j % 8
return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id)
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, [16, 16], dtype="float16")
B = T.match_buffer(b, [16, 16], dtype="float16")
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(A[vi, vj])
A[vi, vj] = B[vi, vj]
ir_module = MyModule
sch = tvm.tir.Schedule(ir_module)
block_b = sch.get_block("B")
sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x16_32x8_16x16)
print("========================inject transform=============================")
print(sch.mod["main"].script())
index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16)
inversed_index_map = index_map.inverse([16, 16])
def inverse_permutation(i, j):
return inversed_index_map.map_indices([i, j])
sch.transform_layout(block_b, ('read', 0), inverse_permutation)
print("========================inverse inject transform=============================")
print(sch.mod["main"].script()) before this pr, the output is ========================inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8])
T.writes(A[vi, vj])
A[vi, vj] = B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8]
========================inverse inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(A[vi, vj])
A[vi, vj] = B[vi, vj] As we can see, the indexmap can be simplified, and can be inversed. After this pr, the output is ========================inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8])
T.writes(A[vi, vj])
A[vi, vj] = B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8]
Traceback (most recent call last):
File "/home/t-leiwang/ladder_workspace/tvm_gpu_gemm/discuss_inversemap.py", line 42, in <module>
sch.transform_layout(block_b, ('read', 0), inverse_permutation)
File "/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
return func(*args, **kwargs)
File "/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/schedule.py", line 3296, in transform_layout
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
File "/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/_ffi/_ctypes/packed_func.py", line 238, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
7: TVMFuncCall
6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
5: tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const [clone .isra.0]
4: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
3: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
2: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
1: tvm::tir::IndexMap::NonSurjectiveInverse(tvm::runtime::Array<tvm::Range, void>, tvm::arith::Analyzer*) const
0: tvm::tir::IndexMapInverseImpl(tvm::tir::IndexMap const&, tvm::runtime::Array<tvm::Range, void> const&, tvm::arith::IterMapLevel, tvm::arith::Analyzer*)
File "/home/t-leiwang/mlc_workspace/tvm_rebase/src/tir/ir/index_map.cc", line 96
TVMError: Check failed: (padded_iter_map->errors.empty()) is false: Could not parse mapping as sum of iterators. Error: IterMapExpr or subclasses should only result from calls in IterMapRewriter using DirectMutate. Indirect return occurred in i Ths indexmap is not well optimized and the map inverse will throw an error. |
@LeiWang1999 do you mind dig further? specificlaly would be good to know what is difference in terms of input to the index map. Likely the analyzer have more context
seems to indicate that there are some issues in the internal IterMapRewriter |
@LeiWang1999 what you met seems is related to the index bound(i32 i64 related), which is a bit unfortunate but as we transition to enable i64 it is necessary.
if you add index_dtype="int32", given your loops are in i32, the inverse map seems to be OK. if we don't do that, there will be cast as a result we cannot make the inverse |
Thanks tq, I found that under ========================inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8])
T.writes(A[vi, vj])
A[vi, vj] = B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8]
========================inverse inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j in T.grid(16, 16):yu
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[(vi * 2 + vj // 8) // 2, vj % 16])
T.writes(A[vi, vj])
A[vi, vj] = B[(vi * 2 + vj // 8) // 2, vj % 16] After applying the inverse_map to the map, the layout should remain consistent with its state prior to the transformation. Before this pull request, the code was functioning as expected. ========================inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8])
T.writes(A[vi, vj])
A[vi, vj] = B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8]
========================inverse inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(B[vi, vj])
T.writes(A[vi, vj])
A[vi, vj] = B[vi, vj] I'll take a look tomorrow. |
insert tvm/src/tir/schedule/primitive/layout_transformation.cc Lines 806 to 807 in f36a093
to simplify the indice works for me. looks like it's removed by this commit, wonder if there’s a specific rationale behind it. |
Previously, we would like to have option that preserves trivial simplify (simplifies iterator of range 1 to zero). See implementation here https://github.com/apache/tvm/blob/main/src/arith/ir_mutator_with_analyzer.cc#L52 Running I thin the best approach would be to enhance IterMapSimplify to handle this case, since this involves affine map transformation and the best case we don;t rely on other simplifier. We can can see |
This PR adds the bounds of shape variables to the arithmetic analyzer so that it is possible to simplify certain expressions.