From ec2aa2418840400a9fd54a370e8c7b53c0db14d8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 27 Aug 2023 07:12:27 -0500 Subject: [PATCH 1/4] [Unity][UnitTests] Parameterized struct info analysis IsBaseOf tests --- .../test_analysis_struct_info_analysis.py | 126 ++++++++++-------- 1 file changed, 70 insertions(+), 56 deletions(-) diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index 879194037c05..e170a655dfb8 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -203,9 +203,8 @@ def fn_info(c): tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0) -def test_base_check(): +def generate_base_check_test_cases(): BR = rx.analysis.BaseCheckResult - bcheck = rx.analysis.struct_info_base_check n, m = tir.Var("n", "int64"), tir.Var("m", "int64") obj0 = rx.ObjectStructInfo() @@ -243,76 +242,73 @@ def test_base_check(): tensor16 = rx.TensorStructInfo([n, m, 2], "int32", vdevice4) # obj - assert bcheck(obj0, prim0) == BR.PASS - assert bcheck(obj0, shape1) == BR.PASS - assert bcheck(obj0, tensor2) == BR.PASS - assert obj0.is_base_of(tensor2) + yield (obj0, prim0, BR.PASS) + yield (obj0, shape1, BR.PASS) + yield (obj0, tensor2, BR.PASS) # prim - assert prim0.is_base_of(prim0) - assert not prim0.is_base_of(prim1) - assert bcheck(prim0, obj0) == BR.FAIL_L1 - assert bcheck(prim0, prim0) == BR.PASS - assert bcheck(prim0, prim1) == BR.FAIL_L0 + yield (prim0, obj0, BR.FAIL_L1) + yield (prim0, prim0, BR.PASS) + yield (prim0, prim1, BR.FAIL_L0) # shape - assert bcheck(shape0, obj0) == BR.FAIL_L1 - assert bcheck(shape0, prim0) == BR.FAIL_L0 + yield (shape0, obj0, BR.FAIL_L1) + yield (shape0, prim0, BR.FAIL_L0) # unknown dim - assert bcheck(shape0, shape1) == BR.PASS - assert bcheck(shape1, shape0) == BR.FAIL_L1 + yield (shape0, shape1, BR.PASS) + yield (shape1, shape0, BR.FAIL_L1) # ndim mismatch - assert bcheck(shape1, shape2) == BR.FAIL_L0 + yield (shape1, shape2, BR.FAIL_L0) # lhs do not have symbolic value but ndim match - assert bcheck(shape2, shape3) == BR.PASS + yield (shape2, shape3, BR.PASS) # rhs do not symbolic but lhs do - assert bcheck(shape3, shape2) == BR.FAIL_L2 + yield (shape3, shape2, BR.FAIL_L2) # shape mismatch - assert bcheck(shape3, shape4) == BR.FAIL_L2 - assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3])) + yield (shape3, shape4, BR.FAIL_L2) + yield (shape4, rx.ShapeStructInfo([1, n, 3]), BR.PASS) # tensor - assert bcheck(tensor0, obj0) == BR.FAIL_L1 - assert bcheck(tensor0, prim0) == BR.FAIL_L0 - assert bcheck(tensor0, shape0) == BR.FAIL_L0 + yield (tensor0, obj0, BR.FAIL_L1) + yield (tensor0, prim0, BR.FAIL_L0) + yield (tensor0, shape0, BR.FAIL_L0) # dtype mismatch - assert bcheck(tensor0, tensor1) == BR.FAIL_L0 - assert bcheck(tensor0, tensor3) == BR.FAIL_L0 - assert bcheck(tensor3, tensor4) == BR.FAIL_L0 - assert bcheck(tensor1, tensor2) == BR.FAIL_L0 + yield (tensor0, tensor1, BR.FAIL_L0) + yield (tensor0, tensor3, BR.FAIL_L0) + yield (tensor3, tensor4, BR.FAIL_L0) + yield (tensor1, tensor2, BR.FAIL_L0) # vdevice mismatch - assert bcheck(tensor8, tensor9) == BR.FAIL_L0 - assert bcheck(tensor9, tensor10) == BR.FAIL_L0 - assert bcheck(tensor10, tensor11) == BR.FAIL_L0 - assert bcheck(tensor13, tensor14) == BR.FAIL_L0 - assert bcheck(tensor14, tensor15) == BR.FAIL_L0 - assert bcheck(tensor15, tensor16) == BR.FAIL_L0 + yield (tensor8, tensor9, BR.FAIL_L0) + yield (tensor9, tensor10, BR.FAIL_L0) + yield (tensor10, tensor11, BR.FAIL_L0) + yield (tensor13, tensor14, BR.FAIL_L0) + yield (tensor14, tensor15, BR.FAIL_L0) + yield (tensor15, tensor16, BR.FAIL_L0) # ndim mismatch - assert bcheck(tensor2, tensor5) == BR.FAIL_L0 + yield (tensor2, tensor5, BR.FAIL_L0) # static shape mismatch - assert bcheck(tensor5, tensor6) == BR.FAIL_L0 + yield (tensor5, tensor6, BR.FAIL_L0) # match - assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32")) - assert tensor0.is_base_of(tensor2) - assert tensor0.is_base_of(tensor4) - assert tensor0.is_base_of(tensor5) - assert tensor0.is_base_of(tensor6) - assert tensor2.is_base_of(tensor4) - assert tensor3.is_base_of(tensor7) - assert tensor3.is_base_of(tensor8) - assert tensor6.is_base_of(tensor12) - assert tensor6.is_base_of(tensor13) - assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32")) + yield (tensor0, rx.TensorStructInfo(ndim=-1, dtype="int32"), BR.PASS) + yield (tensor0, tensor2, BR.PASS) + yield (tensor0, tensor4, BR.PASS) + yield (tensor0, tensor5, BR.PASS) + yield (tensor0, tensor6, BR.PASS) + yield (tensor2, tensor4, BR.PASS) + yield (tensor3, tensor7, BR.PASS) + yield (tensor3, tensor8, BR.PASS) + yield (tensor6, tensor12, BR.PASS) + yield (tensor6, tensor13, BR.PASS) + yield (tensor4, rx.TensorStructInfo([n, m], dtype="int32"), BR.PASS) # tuple t0 = rx.TupleStructInfo([obj0, tensor0]) @@ -320,13 +316,12 @@ def test_base_check(): t2 = rx.TupleStructInfo([obj0, tensor0, obj0]) t3 = rx.TupleStructInfo([tensor0, obj0]) - assert t0.is_base_of(t1) + yield (t0, t1, BR.PASS) + yield (t0, t2, BR.FAIL_L0) + yield (t0, t3, BR.FAIL_L1) - assert bcheck(t0, t2) == BR.FAIL_L0 - assert bcheck(t0, t3) == BR.FAIL_L1 - - assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1])) - assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 + yield (rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t1]), BR.PASS) + yield (rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0]), BR.FAIL_L1) def fn_info_shape(c): n, m = tir.Var("n", "int64"), tir.Var("m", "int64") @@ -341,12 +336,31 @@ def fn_info_erased(): z = rx.TensorStructInfo(ndim=2, dtype="float32") return rx.FuncStructInfo([x, y], z) - assert fn_info_shape(1).is_base_of(fn_info_shape(1)) - assert fn_info_erased().is_base_of(fn_info_shape(1)) - assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2 + yield (fn_info_shape(1), fn_info_shape(1), BR.PASS) + yield (fn_info_erased(), fn_info_shape(1), BR.PASS) + yield (fn_info_shape(1), fn_info_erased(), BR.FAIL_L2) fopaque = rx.FuncStructInfo.opaque_func() - assert fopaque.is_base_of(fn_info_shape(1)) + yield (fopaque, fn_info_shape(1), BR.PASS) + + +base_check_test_case = tvm.testing.parameter(*generate_base_check_test_cases()) + + +def test_base_check(base_check_test_case): + base, derived, expected_result = base_check_test_case + + actual_result = rx.analysis.BaseCheckResult(rx.analysis.struct_info_base_check(base, derived)) + assert actual_result == expected_result, ( + f"When checking if {base} is a superset of {derived}, " + f"expected the result to be {str(expected_result)}, " + f"but received {str(actual_result)}" + ) + + if expected_result == rx.analysis.BaseCheckResult.PASS: + assert base.is_base_of( + derived + ), f"Expected {base} to be recognized as a superset of {derived}" def _check_derive(ctx, finfo, args_sinfo, ret): From 1c8b1b9a272457d60c6e3f3d84b88d433b479e78 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 27 Aug 2023 08:00:32 -0500 Subject: [PATCH 2/4] [Arith] Added simplification rule for multiple equality compares The expression `(x==y) && (x==z)` requires that `y==z`. When `y` and `z` are constants, this can allow better constant folding by rewriting `(x==c1) && (x==c2)` into `(x==c1) && (c1==c2)`. This commit adds the above rewrite, and the corresponding rewrite of the negative expression. --- src/arith/rewrite_simplify.cc | 2 ++ tests/python/unittest/test_arith_rewrite_simplify.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 40088fd963d7..63becf8eb77f 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1856,6 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { }), cfalse, c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2)); TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x == c1 && c1 != c2); TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) == c3, @@ -2000,6 +2001,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE(x != c1 || x != c2, x != c1 || c1 != c2); TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 46ac0f975157..0b0a43a7d3d3 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -951,6 +951,7 @@ class TestLogical(BaseCompare): TestCase(tvm.tir.And(x <= 1, 2 <= x), tvm.tir.const(False, "bool")), TestCase(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool")), TestCase(tvm.tir.And(x == 1, x != 2), x == 1), + TestCase(tvm.tir.And(x == 1, x == 2), tvm.tir.const(False, "bool")), TestCase(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True, "bool")), @@ -965,6 +966,7 @@ class TestLogical(BaseCompare): TestCase(tvm.tir.Or(x <= 1, 2 <= x), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool")), TestCase(tvm.tir.Or(x != 1, x == 2), x != 1), + TestCase(tvm.tir.Or(x != 1, x != 2), tvm.tir.const(True, "bool")), TestCase( tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)), tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1), From 1b9ad6efcba437aa8db18fcb6a1900f2ca0d2679 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 27 Aug 2023 08:18:56 -0500 Subject: [PATCH 3/4] [Relax][Analysis] Improve handling of symbolic variables Previously, the struct info analysis returned `kFailL2` (compatible, depending on runtime-inferred symbolic variables) based on each match expression considered in isolation. This ignored cases such as matching a square tensor `[n,n]` against a static shape `[16,32]`. While no one dimension is incompatible on its own, the match requires that `(n==16) && (n==32)`. This this can be statically proven to be false, the `StructInfoBaseChecker` should return `kFailL0` (statically proven to be incompatible) instead. This commit updates the `StructInfoBaseChecker` to track implied requirements for symbolic variables across multiple matched dimensions. --- src/relax/analysis/struct_info_analysis.cc | 43 ++++++++++++++----- .../test_analysis_struct_info_analysis.py | 16 +++++++ 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 4a633e9df4b7..4c5d382f7c4e 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -453,6 +453,11 @@ class StructInfoBaseChecker // struct equal checker StructuralEqual struct_equal_; + // Saved condition that must be true for a weaker L2 failure If + // this condition is provably false, then the condition is + // downgraded to L0. + PrimExpr upgrade_L2_fail_to_L0_{Bool(false)}; + // customizable functions. /*! * \brief Check symbolic shape value equivalence. @@ -461,17 +466,35 @@ class StructInfoBaseChecker * \return CheckResult. */ virtual BaseCheckResult PrimValueMatchCheck(const PrimExpr& lhs, const PrimExpr& rhs) { - // get static shape checking right. - auto* int_lhs = lhs.as(); - auto* int_rhs = rhs.as(); - if (int_lhs && int_rhs) { - if (int_lhs->value == int_rhs->value) { - return BaseCheckResult::kPass; - } else { - return BaseCheckResult::kFailL0; - } + PrimExpr is_same = analyzer_->Simplify(lhs == rhs); + + if (analyzer_->CanProve(is_same)) { + // Sames are provably the same (e.g. same static shapes, or same + // symbolic variable). + return BaseCheckResult::kPass; + } else if (analyzer_->CanProve(!is_same)) { + // Sames are provably the same (e.g. different static shapes). + return BaseCheckResult::kFailL0; + } + + // By combining the required match conditions across the entire + // comparison, a match may be recognized as illegal, even though + // each component would have been legal. For example, if matching + // shape [n, n+1] to [16, 32], the conditions `n==16` and + // `n+1==32` cannot be proven false when considered separately. + // However, their joint condition `(n==16) && (n+1==32)` can be + // proven false. + upgrade_L2_fail_to_L0_ = analyzer_->Simplify(upgrade_L2_fail_to_L0_ || !is_same); + if (analyzer_->CanProve(upgrade_L2_fail_to_L0_)) { + // This match condition is incompatible with a previous match + // condition, so the entire check can return L0 failure. + return BaseCheckResult::kFailL0; + } else { + // Nothing is incompatible so far. The `is_same` expression has + // been cached as part of `upgrade_L2_fail_to_L0_`, in case it + // can be used to provide L0 failure at a later point. + return BaseCheckResult::kFailL2; } - return analyzer_->CanProveEqual(lhs, rhs) ? BaseCheckResult::kPass : BaseCheckResult::kFailL2; } /*! * \brief CheckShape value. diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index e170a655dfb8..ee7f4912a639 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -343,6 +343,22 @@ def fn_info_erased(): fopaque = rx.FuncStructInfo.opaque_func() yield (fopaque, fn_info_shape(1), BR.PASS) + # Symbolic var tests + static_shape = rx.ShapeStructInfo([1, 2, 4]) + compatible_symbolic_shape = rx.ShapeStructInfo([1, n, n * n]) + incompatible_symbolic_shape = rx.ShapeStructInfo([1, n, n + 1]) + + # Symbolic shapes may occur in multiple locations. Incompatible + # shapes may fail at L0, even if each use of the symbolic variable + # would only have failed at L2. + yield (rx.ShapeStructInfo([n, n]), rx.ShapeStructInfo([16, 32]), BR.FAIL_L0) + + # If `n==2`, then the shapes are compatible. + yield (compatible_symbolic_shape, static_shape, BR.FAIL_L2) + + # There is no value of `n` for which `n==2 and n+1==4` + yield (incompatible_symbolic_shape, static_shape, BR.FAIL_L0) + base_check_test_case = tvm.testing.parameter(*generate_base_check_test_cases()) From ae8bd9bb54e5ea845ca2b48f109e1b67261fb90c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Sep 2023 14:25:59 -0500 Subject: [PATCH 4/4] ci bump