Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Allow symbolic bounds in IndexMap analysis #15264

Merged

Conversation

junrushao
Copy link
Member

This PR adds the bounds of shape variables to the arithmetic analyzer so that it is possible to simplify certain expressions.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jul 8, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@junrushao junrushao marked this pull request as ready for review July 8, 2023 01:53
@junrushao junrushao force-pushed the feature/2023-07-07/symbolic-shape-layout branch 7 times, most recently from 757600f to 0733726 Compare July 8, 2023 19:04
src/arith/iter_affine_map.cc Outdated Show resolved Hide resolved
@junrushao junrushao force-pushed the feature/2023-07-07/symbolic-shape-layout branch 4 times, most recently from 31c4756 to 34a60f9 Compare July 8, 2023 23:00
This PR adds the bounds of shape variables to the arithmetic analyzer
so that it is possible to simplify certain expressions.
@junrushao junrushao force-pushed the feature/2023-07-07/symbolic-shape-layout branch from 34a60f9 to 7500285 Compare July 8, 2023 23:03
@MasterJH5574 MasterJH5574 merged commit a60cd0f into apache:main Jul 9, 2023
6 checks passed
junrushao added a commit to junrushao/tvm that referenced this pull request Jul 9, 2023
Following apache#15264, this PR makes changes accordingly to the Unity branch
to enable symbolic bounds in IndexMap analysis.
tqchen pushed a commit that referenced this pull request Jul 9, 2023
Following #15264, this PR makes changes accordingly to the Unity branch
to enable symbolic bounds in IndexMap analysis.
junrushao added a commit to junrushao/tvm that referenced this pull request Jul 15, 2023
This PR adds the bounds of shape variables to the arithmetic analyzer
so that it is possible to simplify certain expressions.
junrushao added a commit to junrushao/tvm that referenced this pull request Jul 15, 2023
Following apache#15264, this PR makes changes accordingly to the Unity branch
to enable symbolic bounds in IndexMap analysis.
@LeiWang1999
Copy link
Contributor

LeiWang1999 commented Dec 20, 2023

Hi @junrushao , I have encountered some issue and bisect to this pull request. here is my case:

import tvm
from tvm.script import tir as T
from tvm.tir import IndexMap

def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
    row = 8 * (thread_id // 16) + (thread_id % 8)
    col = 8 * ((thread_id % 16) // 8) + local_id % 8
    return row, col

def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j):
    thread_id = kernel_i * 2 + kernel_j // 8
    local_id = kernel_j % 8
    return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id)

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [16, 16], dtype="float16")
        B = T.match_buffer(b, [16, 16], dtype="float16")
        
        for i, j in T.grid(16, 16):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(B[vi, vj])
                T.writes(A[vi, vj])
                A[vi, vj] = B[vi, vj]

ir_module = MyModule
sch = tvm.tir.Schedule(ir_module)

block_b = sch.get_block("B")
sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x16_32x8_16x16)
print("========================inject transform=============================")
print(sch.mod["main"].script())

index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16)
inversed_index_map = index_map.inverse([16, 16])
def inverse_permutation(i, j):
    return inversed_index_map.map_indices([i, j])
sch.transform_layout(block_b, ('read', 0), inverse_permutation)
print("========================inverse inject transform=============================")
print(sch.mod["main"].script())

before this pr, the output is

========================inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
    # function attr dict
    T.func_attr({"tir.noalias": True, "global_symbol": "main"})
    # body
    # with T.block("root")
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8])
            T.writes(A[vi, vj])
            A[vi, vj] = B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8]

========================inverse inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
    # function attr dict
    T.func_attr({"tir.noalias": True, "global_symbol": "main"})
    # body
    # with T.block("root")
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[vi, vj])
            T.writes(A[vi, vj])
            A[vi, vj] = B[vi, vj]

As we can see, the indexmap can be simplified, and can be inversed.

After this pr, the output is

========================inject transform=============================
# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
    T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8])
            T.writes(A[vi, vj])
            A[vi, vj] = B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8]
Traceback (most recent call last):
  File "/home/t-leiwang/ladder_workspace/tvm_gpu_gemm/discuss_inversemap.py", line 42, in <module>
    sch.transform_layout(block_b, ('read', 0), inverse_permutation)
  File "/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
    return func(*args, **kwargs)
  File "/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/tir/schedule/schedule.py", line 3296, in transform_layout
    _ffi_api.ScheduleTransformLayout(  # type: ignore # pylint: disable=no-member
  File "/home/t-leiwang/mlc_workspace/tvm_rebase/python/tvm/_ffi/_ctypes/packed_func.py", line 238, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  5: tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::tir::BlockRV const&, int, int, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)#17}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const [clone .isra.0]
  4: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
  3: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
  2: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool)
  1: tvm::tir::IndexMap::NonSurjectiveInverse(tvm::runtime::Array<tvm::Range, void>, tvm::arith::Analyzer*) const
  0: tvm::tir::IndexMapInverseImpl(tvm::tir::IndexMap const&, tvm::runtime::Array<tvm::Range, void> const&, tvm::arith::IterMapLevel, tvm::arith::Analyzer*)
  File "/home/t-leiwang/mlc_workspace/tvm_rebase/src/tir/ir/index_map.cc", line 96
TVMError: Check failed: (padded_iter_map->errors.empty()) is false: Could not parse mapping as sum of iterators.  Error: IterMapExpr or subclasses should only result from calls in IterMapRewriter using DirectMutate.  Indirect return occurred in i

Ths indexmap is not well optimized and the map inverse will throw an error.

@tqchen
Copy link
Member

tqchen commented Dec 20, 2023

@LeiWang1999 do you mind dig further? specificlaly would be good to know what is difference in terms of input to the index map. Likely the analyzer have more context

IterMapExpr or subclasses should only result from calls in IterMapRewriter using DirectMutate.  Indirect return occurred in i

seems to indicate that there are some issues in the internal IterMapRewriter

@tqchen
Copy link
Member

tqchen commented Dec 20, 2023

@LeiWang1999 what you met seems is related to the index bound(i32 i64 related), which is a bit unfortunate but as we transition to enable i64 it is necessary.

index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16, index_dtype="int32")

if you add index_dtype="int32", given your loops are in i32, the inverse map seems to be OK. if we don't do that, there will be cast as a result we cannot make the inverse

@LeiWang1999
Copy link
Contributor

Thanks tq, I found that under index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16, index_dtype="int32") though, the inverse map is not ok, the output is:

========================inject transform=============================
# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8])
            T.writes(A[vi, vj])
            A[vi, vj] = B[(vi * 2 + vj // 8) // 16 * 8 + (vi * 2 + vj // 8) % 8, (vi * 2 + vj // 8) % 16 // 8 * 8 + vj % 8]
========================inverse inject transform=============================
# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i, j in T.grid(16, 16):yu
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[(vi * 2 + vj // 8) // 2, vj % 16])
            T.writes(A[vi, vj])
            A[vi, vj] = B[(vi * 2 + vj // 8) // 2, vj % 16]

After applying the inverse_map to the map, the layout should remain consistent with its state prior to the transformation.

Before this pull request, the code was functioning as expected.

========================inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
    # function attr dict
    T.func_attr({"tir.noalias": True, "global_symbol": "main"})
    # body
    # with T.block("root")
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8])
            T.writes(A[vi, vj])
            A[vi, vj] = B[vi // 8 * 8 + vi % 4 * 2 + vj // 8, vi % 8 // 4 * 8 + vj % 8]

========================inverse inject transform=============================
# from tvm.script import tir as T
@T.prim_func
def func(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"]):
    # function attr dict
    T.func_attr({"tir.noalias": True, "global_symbol": "main"})
    # body
    # with T.block("root")
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[vi, vj])
            T.writes(A[vi, vj])
            A[vi, vj] = B[vi, vj]

I'll take a look tomorrow.

@LeiWang1999
Copy link
Contributor

insert (*indices).MutateByApply( [&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_); });
before

*indices = this->IterMapSimplifyWithContext(*indices, true);
}

to simplify the indice works for me. looks like it's removed by this commit, wonder if there’s a specific rationale behind it.

@tqchen
Copy link
Member

tqchen commented Dec 21, 2023

Previously, we would like to have option that preserves trivial simplify (simplifies iterator of range 1 to zero).

See implementation here https://github.com/apache/tvm/blob/main/src/arith/ir_mutator_with_analyzer.cc#L52

Running SimplifyNonTrivialExpr before means we cannot have the non-trivial option. I think it might be fine to update IterMapSimplifyWithContext to includeSimplifyNonTrivialExpr but still preserve the non-trivial option.

I thin the best approach would be to enhance IterMapSimplify to handle this case, since this involves affine map transformation and the best case we don;t rely on other simplifier.

We can can see (vi * 2 + vj // 8) // 16 should be readily transformed to (vi // 8 + vj // 128), my suspect is that if you apply IterMap simplify again on the result it should also work (maybe some intermediate values did not further simplifes). Good to dig a bit what happens in this case without relying on other simplfier(so it have potential to work for future symbolic case)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants