Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [Relax] [LambdaLift] Argument type mismatch: expected R.Tensor, given R.Object #17406

Open
Cookiee235 opened this issue Sep 23, 2024 · 1 comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/llm.py", line 35, in <module>
    ex = relax.build(mod, target='llvm')
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-latest/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm-latest/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-latest/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
    raise_last_ffi_error()
  File "/software/tvm-latest/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm-latest/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm-latest/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-latest/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
    raise_last_ffi_error()
  File "/software/tvm-latest/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  31: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  30: tvm::transform::Pass::operator()(tvm::IRModule) const
  29: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  28: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  25: _ZN3tvm7runtime13PackedFuncObj
  24: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::LowerRuntimeBuiltin()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::LowerRuntimeBuiltin()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  23: tvm::relax::LowerRuntimeBuiltin(tvm::RelayExpr const&)
  22: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  21: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  20: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  19: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  18: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  17: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  16: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  15: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  14: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  13: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  12: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  11: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::GlobalVarNode const*)
  10: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  9: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  8: tvm::relax::LowerRuntimeBuiltinMutator::VisitExpr_(tvm::relax::CallNode const*)
  7: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
  6: tvm::relax::Normalizer::VisitExpr(tvm::RelayExpr const&)
  5: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
  2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
  1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
  File "/software/tvm-latest/src/relax/ir/block_builder.cc", line 158
TVMError: Argument 0 type mismatch: expected R.Tensor((2, 3), dtype="float32"), given R.Object

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
        cls = Module

        @R.function
        def outer_func(c1: R.Tensor((2, 3), dtype="float32")) -> R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True):

            @R.function
            def inner_func(x1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
                s = R.add(x1,c1)
                return s

            return inner_func

        in_call: R.Callable((R.Tensor((2, 3), dtype="float32"),), R.Tensor((2, 3), dtype="float32"), True) = outer_func(x)
        res: R.Tensor((2, 3), dtype="float32") = in_call(y)
        res_1 = R.add(res,x)
        return res_1


mod = Module
mod = relax.transform.LambdaLift()(mod)
mod.show()

with tvm.transform.PassContext(opt_level=4):
    ex = relax.build(mod, target='llvm')
    vm = relax.VirtualMachine(ex, tvm.cpu())

CC @Lunderberg @junrushao

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Sep 23, 2024
@Cookiee235 Cookiee235 changed the title [Bug] [Relax] [Transform] Argument type mismatch: expected R.Tensor, given R.Object [Bug] [Relax] [LambdaLift] Argument type mismatch: expected R.Tensor, given R.Object Sep 23, 2024
@Cookiee235
Copy link
Contributor Author

Cookiee235 commented Sep 23, 2024

When Relax lifts the inner function during Lambda lifting, it results in an ambiguous return type (i.e., R.Object), leading to an unexpected crash during the build process.

The Relax IR after the LambdaLift is:

@I.ir_module
class Module:
    @R.function(private=True)
    def main_inner_func(x1: R.Tensor((2, 3), dtype="float32"), c1: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
        s: R.Tensor((2, 3), dtype="float32") = R.add(x1, c1)
        return s

    @R.function(private=True)
    def main_outer_func(c1: R.Tensor((2, 3), dtype="float32")) -> R.Object:
        cls = Module
        inner_func: R.Object = R.make_closure(cls.main_inner_func, (c1,))
        return inner_func

    @R.function
    def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
        cls = Module
        in_call: R.Object = cls.main_outer_func(x)
        res: R.Tensor((2, 3), dtype="float32") = R.invoke_pure_closure(in_call, (y,), sinfo_args=(R.Tensor((2, 3), dtype="float32"),))
        res_1: R.Tensor((2, 3), dtype="float32") = R.add(res, x)
        return res_1

@Lunderberg Can you help me review this bug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant