Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [Relax] Missing IR structure checking and correction #17211

Open
Cookiee235 opened this issue Jul 28, 2024 · 4 comments
Open

[Bug] [Relax] Missing IR structure checking and correction #17211

Cookiee235 opened this issue Jul 28, 2024 · 4 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Cookiee235 commented Jul 28, 2024

Hi all, I set check_well_formed=True in the below Relax IR construction and can run mod.show() to show the IR successfully. It seems the Relax IR passed the legitimacy checking. However, the compilation crashed when executing ex = relax.build(mod, target='llvm'). The crash message shows that
"Argument 0 type mismatch: expected R.Tensor((16,), dtype="float32"), given R.Tuple(R.Tensor((16,), dtype="float32"))"

Based on my analysis, if we replace the code gv1 = R.call_tir(cls.relu, (x), out_sinfo=R.Tensor((1, 512, 64, 64))) (Line 26) with gv1 = R.nn.relu(x) (Line 27) or gv1 = R.call_tir(cls.relu, (x,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) (Line 28), the script can run well.
Even if the Relax IR constructor can convert gv1 = R.nn.relu(x) to full information with type based on the context, why didn't it complete the missing type for gv1 (Line 26).

To take a step back, if the Relax IR constructor cannot complete the missing information and we set check_cell_formed=True in the Relax IR construction, we should throw an exception early in mod = Module rather than relax.build(). Early crashes will make the code more robust.

BTW, I prefer the IR constructor can fill in missing information or correct the inconsistent constraints based on IRs' context.

Actual behavior

Traceback (most recent call last):
  File "demo_simple.py", line 26, in <module>
    ex = relax.build(mod, target='llvm')  # crash here!
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  37: tvm::transform::Pass::operator()(tvm::IRModule) const
  36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  35: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  33: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  32: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  30: tvm::relax::CallTIRMutator::Run()
  29: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  28: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  27: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  26: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  25: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  22: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  21: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  20: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  19: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  18: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  17: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  16: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl
  15: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
  14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  11: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
  10: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String)
  9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String)
  8: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
  7: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  5: non-virtual thunk to tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
  2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
  1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
  File "/software/tvm/src/relax/ir/block_builder.cc", line 159
TVMError: Argument 0 type mismatch: expected R.Tensor((16,), dtype="float32"), given R.Tuple(R.Tensor((16,), dtype="float32"))

Environment

  • TVM: 0.17.dev0

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


@I.ir_module(check_well_formed=True)
class Module:
    @T.prim_func(private=True)
    #def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32")):
    def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "float32"), B: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)))):
        T.func_attr({"op_pattern": 0})
        # with T.block("root"):
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(512), T.int64(64), T.int64(64)):
            with T.block("relu"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(A[v_i0, v_i1, v_i2, v_i3])
                T.writes(B[v_i0, v_i1, v_i2, v_i3])
                B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0))

    @R.function
    def main(x: R.Tensor((1, 512, 64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv1 = R.call_tir(cls.relu, (x), out_sinfo=R.Tensor((1, 512, 64, 64)))  # crash
            # gv1 = R.nn.relu(x)  # run well
            # gv1 = R.call_tir(cls.relu, (x,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32"))  # run well
            R.output(gv1)
        return gv1

mod = Module
mod.show()

mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm')

cc @Lunderberg @junrushao @tqchen

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Jul 28, 2024
@Cookiee235 Cookiee235 changed the title [Bug] [Bug] [Relax] Missing IR structure checking and correction Jul 28, 2024
@Lunderberg
Copy link
Contributor

Good catch, and I think this is arising from a number of different edge cases.

  • The StructInfo for R.call_tir is always inferred from the out_sinfo, not from the TIR function's signature. This is for historical reasons, as TIR functions only recently started holding the annotations that would allow them to perform shape inference. As a result, no errors are seen during the initial well-formed check.

  • The default of T.Buffer and R.Tensor is different. If unspecified, T.Buffer defaults to "float32" datatype, where R.Tensor defaults to DataType::Void, which is used to represent an unknown datatype that might be inferred later in compilation. There is no equivalent in TIR, which must have a known datatype for each buffer.

  • There is no rule that would infer the unknown Relax datatype from the mandatory TIR datatype. As a result, the out_sinfo remains the incomplete R.Tensor(shape), rather than R.Tensor(shape, dtype="float32").

  • The error is raised during CallTIRRewrite, which rewrites low-level calls from having an implied allocation for the output to having an explicit argument for the output. Here, this rewirtes the R.call_tir(cls.relu, [x], out_sinfo=R.Tensor([1,512,64,64])) into cls.relu(x, output_allocation), where output_allocation has shape R.Tensor([1,512,64,64]). This is the first point at which the TIR function's signature is actually inspected.

  • Currently, when checking whether the constraints required by a subroutine, the constraints must either pass or fail. There is no mechanism for the subroutine's constraints to be hoisted into the calling scope. Since "tensor of arbitrary element type" is not a valid argument for "tensor with float32 element type", the check fails.

I think there's a number of improvements that could be made, in order to close each of these loopholes.

  1. Improved well-formed checker. If out_sinfo is explicitly stated in R.call_tir, then IsBaseOf(inferred_sinfo, out_sinfo) must return true.

  2. Infer the dtype of out_sinfo in R.call_tir. If out_sinfo is a Tensor, or a Tuple of tensors, and one of those tensors has DataType::Void(), normalize the out_sinfo argument to include the datatype from the PrimFunc.

  3. Improved struct inference for R.call_tir. Now that PrimFuncs have a known shape for each argument, the output of R.call_tir could be improved. For backwards compatibility, an explicit out_sinfo argument would still take precedence. However, if out_sinfo is omitted (which currently would cause an immediate error), it would instead infer the output struct info assuming that the last len(params) - len(args) are output parameters.

  4. Improved normalization in block builder. If an operator has restrictions on an argument, normalization could expose those constraints to the Relax levels, rather than only marking it as pass/fail. For example, normalization of an operator whose argument must be DataType::Float(32), but which received DataType::Void(), could produce a binding of new_arg = R.match_cast(arg, R.Tensor(arg.struct_info.shape, "float32")), then use new_arg in its call.

I think all of these would be useful changes to make, but some would have wider impacts than others. The well-formed checks could be added with the smallest risk of breakage, but also place the greatest load on new developers. Improved normalization would provide the greatest ease-of-use, but would require the most widespread changes. @tqchen, since some of these would be much more involved changes, do you have preferences/thoughts on them?

@Cookiee235
Copy link
Contributor Author

A similar bug occurs as shown below.
Based on what I saw. The well-formed checker commonly corrects the return type and shape. However, when the type of relax function return var is R.Tuple(), the well-formed checker seems not to work.

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/res_type.py", line 82, in <module>
    mod_outputs = vm['main'](input_0, input_1)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::_LookupFunction(tvm::runtime::String const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  7: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::GetClosureInternal(tvm::runtime::String const&, bool)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  5: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&)
  4: tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()
  3: tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction)
  2: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16Pack
  0: tvm::runtime::relax_vm::CheckTensorInfo(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/software/tvm/src/runtime/relax_vm/builtin.cc", line 247
ValueError: Check failed: (DataType(ptr->dl_tensor.dtype) == dtype) is false: ErrorContext(fn=main, loc=return, annotation=R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="float32")))  expect Tensor with dtype float32 but get int32

Steps to reproduce

import tvm
from tvm import relax
import numpy as np

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def ones(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 1

    @T.prim_func(private=True)
    def zeros(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0
    @T.prim_func(private=True)
    def zeros1(T_full: T.Buffer((T.int64(32), T.int64(32)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0

    @R.function(private=True)
    def func() -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
        cls = Module
        A = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        B = R.call_tir(cls.ones, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        C = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((32, 32), dtype="int32"))
        return (A, B, C)

    @R.function
    def main_2() -> R.Tuple(R.Tensor, R.Tensor):
        cls = Module
        args: R.Tuple(R.Tensor, R.Tensor, R.Tensor) = cls.func()
        gv1: R.Tensor = args[0]
        gv2: R.Tensor = args[2]
        return (gv1, gv2)
    @R.function
    def main(v3_0: R.Tensor((1, 22, 1), dtype="float16"), v6_0: R.Tensor((1, 37), dtype="float16")) -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="float32")):  # if return value is a tuple, well_form checker cannot correct it!
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            res: R.Tuple(R.Tensor, R.Tensor) = cls.main_2()
            R.output(res)
        return res


mod = Module
mod.show()
mod = tvm.relax.transform.LegalizeOps()(mod)

mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())

input_0 = tvm.nd.array(10 * np.random.random([1, 22, 1]).astype('float16'))
input_1 = tvm.nd.array(10 * np.random.random([1, 37]).astype('float16'))
mod_outputs = vm['main'](input_0, input_1)

@Lunderberg
Copy link
Contributor

Hmm. I think this is something that should be catchable by propagating the known struct info, but currently isn't caught.

  1. In main_2, cls.func() returns a tuple with known dtype and static shapes, but is assigned to a variable with unknown dtype and shape. This is legal, because the set of all R.Tuple(R.Tensor, R.Tensor, R.Tensor) is a superset of the set of all R.Tuple(R.Tensor((16,16), "int32"), R.Tensor((16,16), "int32"), R.Tensor((32,32), "int32")).
  2. In main, even if the return type of cls.main_2() isn't explicitly specified, it gets inferred as R.Tuple(R.Tensor, R.Tensor).
  3. The return type from main may be more specific than the body. This is intended to ensure that the return type is stable, even if an optimization prevents shape inference from reaching all the way to the end of the function, the function still has accurate annotations. However, this means that the return struct info may be more a sub-type of the body's struct info.
  4. Whenever the return type is a sub-type of the body's struct info, a runtime assert is inserted. This is the assert that triggers the error message.

I think this is a limitation in the StructInfo inference, which should catch the IRModule as ill-formed at compile-time, rather than runtime. However, it would first require a few extra steps of StructInfo inference that aren't currently performed.

  1. If an expression has more specific StructInfo than the variable it is bound to, propagate from the expression to the variable.
  2. If the body of a function has more specific StructInfo than the current return type, propagate from the body to the return type.
  3. If a function has more specific StructInfo than the GlobalVar used to represent it, propagate from the function to the GlobalVar.

For the example, this would let the "int32" type returned by cls.func to be propagated through main_2, and into main. At that point, it could be recognized as an error to return "int32" in a function that is marked as returning "float32".

@Lunderberg
Copy link
Contributor

And one step implemented which should make it harder for these inconsistent shapes to emerge. In #17216, the out_sinfo field is made optional, and is inferred from the PrimFunc signature if omitted. While it doesn't yet catch a case where the out_sinfo is inconsistent with the callee's signature, it does move in that direction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants