From 2c40869832a796e84caab5971eacbdcc9a4f28ab Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 24 Oct 2023 15:12:34 -0500 Subject: [PATCH] [Unity][UnitTest] Enable BindParams test for R.Prim This test was implemented in https://github.com/apache/tvm/pull/15626, but was initially disabled as it depended on functionality not introduced until https://github.com/apache/tvm/pull/15577. Since that PR has landed, cleaning up and enabling the unit test. --- tests/python/relax/test_bind_params.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_bind_params.py b/tests/python/relax/test_bind_params.py index a92e4fe8e510..189a44303d6c 100644 --- a/tests/python/relax/test_bind_params.py +++ b/tests/python/relax/test_bind_params.py @@ -111,21 +111,23 @@ def expected() -> R.Shape([16]): prim_value_dtype = tvm.testing.parameter("int64", "int32", "float32") -@pytest.mark.xfail(reason="Depends on relax.PrimValue holding a tir.PrimExpr, PR#15577") def test_bind_prim_value(prim_value_dtype): + N = tir.Var("N", prim_value_dtype) + value = tir.const(16, prim_value_dtype) + @R.function - def before(A: R.Prim(value="N", dtype=prim_value_dtype)): + def before(A: R.Prim(value=N)): R.func_attr({"global_symbol": "main"}) - B: R.Prim(value="N", dtype=prim_value_dtype) = A + B: R.Prim(value=N) = A return B @R.function - def expected() -> R.Prim(value=16, dtype=prim_value_dtype): + def expected() -> R.Prim(value=value): R.func_attr({"global_symbol": "main"}) - B = R.PrimValue(value=16, dtype=dtype) + B = R.prim_value(value) return B - after = before.bind_params({"A": relax.PrimValue(tir.const(16, prim_value_dtype))}) + after = before.bind_params({"A": relax.PrimValue(value)}) tvm.ir.assert_structural_equal(expected, after)