Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Analysis] Improve handling of symbolic variables #15627

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
}),
cfalse, c2.Eval()->value > c1.Eval()->value);

TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2));
TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x == c1 && c1 != c2);

TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) == c3,
Expand Down Expand Up @@ -2000,6 +2001,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);

TVM_TRY_REWRITE(x != c1 || x != c2, x != c1 || c1 != c2);
TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);

Expand Down
43 changes: 33 additions & 10 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,11 @@ class StructInfoBaseChecker
// struct equal checker
StructuralEqual struct_equal_;

// Saved condition that must be true for a weaker L2 failure If
// this condition is provably false, then the condition is
// downgraded to L0.
PrimExpr upgrade_L2_fail_to_L0_{Bool(false)};

// customizable functions.
/*!
* \brief Check symbolic shape value equivalence.
Expand All @@ -461,17 +466,35 @@ class StructInfoBaseChecker
* \return CheckResult.
*/
virtual BaseCheckResult PrimValueMatchCheck(const PrimExpr& lhs, const PrimExpr& rhs) {
// get static shape checking right.
auto* int_lhs = lhs.as<IntImmNode>();
auto* int_rhs = rhs.as<IntImmNode>();
if (int_lhs && int_rhs) {
if (int_lhs->value == int_rhs->value) {
return BaseCheckResult::kPass;
} else {
return BaseCheckResult::kFailL0;
}
PrimExpr is_same = analyzer_->Simplify(lhs == rhs);

if (analyzer_->CanProve(is_same)) {
// Sames are provably the same (e.g. same static shapes, or same
// symbolic variable).
return BaseCheckResult::kPass;
} else if (analyzer_->CanProve(!is_same)) {
// Sames are provably the same (e.g. different static shapes).
return BaseCheckResult::kFailL0;
}

// By combining the required match conditions across the entire
// comparison, a match may be recognized as illegal, even though
// each component would have been legal. For example, if matching
// shape [n, n+1] to [16, 32], the conditions `n==16` and
// `n+1==32` cannot be proven false when considered separately.
// However, their joint condition `(n==16) && (n+1==32)` can be
// proven false.
upgrade_L2_fail_to_L0_ = analyzer_->Simplify(upgrade_L2_fail_to_L0_ || !is_same);
if (analyzer_->CanProve(upgrade_L2_fail_to_L0_)) {
// This match condition is incompatible with a previous match
// condition, so the entire check can return L0 failure.
return BaseCheckResult::kFailL0;
} else {
// Nothing is incompatible so far. The `is_same` expression has
// been cached as part of `upgrade_L2_fail_to_L0_`, in case it
// can be used to provide L0 failure at a later point.
return BaseCheckResult::kFailL2;
}
return analyzer_->CanProveEqual(lhs, rhs) ? BaseCheckResult::kPass : BaseCheckResult::kFailL2;
}
/*!
* \brief CheckShape value.
Expand Down
142 changes: 86 additions & 56 deletions tests/python/relax/test_analysis_struct_info_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,8 @@ def fn_info(c):
tvm.ir.assert_structural_equal(rx.analysis.erase_to_well_defined(f0), f0)


def test_base_check():
def generate_base_check_test_cases():
BR = rx.analysis.BaseCheckResult
bcheck = rx.analysis.struct_info_base_check

n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
obj0 = rx.ObjectStructInfo()
Expand Down Expand Up @@ -243,90 +242,86 @@ def test_base_check():
tensor16 = rx.TensorStructInfo([n, m, 2], "int32", vdevice4)

# obj
assert bcheck(obj0, prim0) == BR.PASS
assert bcheck(obj0, shape1) == BR.PASS
assert bcheck(obj0, tensor2) == BR.PASS
assert obj0.is_base_of(tensor2)
yield (obj0, prim0, BR.PASS)
yield (obj0, shape1, BR.PASS)
yield (obj0, tensor2, BR.PASS)

# prim
assert prim0.is_base_of(prim0)
assert not prim0.is_base_of(prim1)
assert bcheck(prim0, obj0) == BR.FAIL_L1
assert bcheck(prim0, prim0) == BR.PASS
assert bcheck(prim0, prim1) == BR.FAIL_L0
yield (prim0, obj0, BR.FAIL_L1)
yield (prim0, prim0, BR.PASS)
yield (prim0, prim1, BR.FAIL_L0)

# shape
assert bcheck(shape0, obj0) == BR.FAIL_L1
assert bcheck(shape0, prim0) == BR.FAIL_L0
yield (shape0, obj0, BR.FAIL_L1)
yield (shape0, prim0, BR.FAIL_L0)

# unknown dim
assert bcheck(shape0, shape1) == BR.PASS
assert bcheck(shape1, shape0) == BR.FAIL_L1
yield (shape0, shape1, BR.PASS)
yield (shape1, shape0, BR.FAIL_L1)

# ndim mismatch
assert bcheck(shape1, shape2) == BR.FAIL_L0
yield (shape1, shape2, BR.FAIL_L0)

# lhs do not have symbolic value but ndim match
assert bcheck(shape2, shape3) == BR.PASS
yield (shape2, shape3, BR.PASS)

# rhs do not symbolic but lhs do
assert bcheck(shape3, shape2) == BR.FAIL_L2
yield (shape3, shape2, BR.FAIL_L2)

# shape mismatch
assert bcheck(shape3, shape4) == BR.FAIL_L2
assert shape4.is_base_of(rx.ShapeStructInfo([1, n, 3]))
yield (shape3, shape4, BR.FAIL_L2)
yield (shape4, rx.ShapeStructInfo([1, n, 3]), BR.PASS)

# tensor
assert bcheck(tensor0, obj0) == BR.FAIL_L1
assert bcheck(tensor0, prim0) == BR.FAIL_L0
assert bcheck(tensor0, shape0) == BR.FAIL_L0
yield (tensor0, obj0, BR.FAIL_L1)
yield (tensor0, prim0, BR.FAIL_L0)
yield (tensor0, shape0, BR.FAIL_L0)

# dtype mismatch
assert bcheck(tensor0, tensor1) == BR.FAIL_L0
assert bcheck(tensor0, tensor3) == BR.FAIL_L0
assert bcheck(tensor3, tensor4) == BR.FAIL_L0
assert bcheck(tensor1, tensor2) == BR.FAIL_L0
yield (tensor0, tensor1, BR.FAIL_L0)
yield (tensor0, tensor3, BR.FAIL_L0)
yield (tensor3, tensor4, BR.FAIL_L0)
yield (tensor1, tensor2, BR.FAIL_L0)

# vdevice mismatch
assert bcheck(tensor8, tensor9) == BR.FAIL_L0
assert bcheck(tensor9, tensor10) == BR.FAIL_L0
assert bcheck(tensor10, tensor11) == BR.FAIL_L0
assert bcheck(tensor13, tensor14) == BR.FAIL_L0
assert bcheck(tensor14, tensor15) == BR.FAIL_L0
assert bcheck(tensor15, tensor16) == BR.FAIL_L0
yield (tensor8, tensor9, BR.FAIL_L0)
yield (tensor9, tensor10, BR.FAIL_L0)
yield (tensor10, tensor11, BR.FAIL_L0)
yield (tensor13, tensor14, BR.FAIL_L0)
yield (tensor14, tensor15, BR.FAIL_L0)
yield (tensor15, tensor16, BR.FAIL_L0)

# ndim mismatch
assert bcheck(tensor2, tensor5) == BR.FAIL_L0
yield (tensor2, tensor5, BR.FAIL_L0)

# static shape mismatch
assert bcheck(tensor5, tensor6) == BR.FAIL_L0
yield (tensor5, tensor6, BR.FAIL_L0)

# match
assert tensor0.is_base_of(rx.TensorStructInfo(ndim=-1, dtype="int32"))
assert tensor0.is_base_of(tensor2)
assert tensor0.is_base_of(tensor4)
assert tensor0.is_base_of(tensor5)
assert tensor0.is_base_of(tensor6)
assert tensor2.is_base_of(tensor4)
assert tensor3.is_base_of(tensor7)
assert tensor3.is_base_of(tensor8)
assert tensor6.is_base_of(tensor12)
assert tensor6.is_base_of(tensor13)
assert tensor4.is_base_of(rx.TensorStructInfo([n, m], dtype="int32"))
yield (tensor0, rx.TensorStructInfo(ndim=-1, dtype="int32"), BR.PASS)
yield (tensor0, tensor2, BR.PASS)
yield (tensor0, tensor4, BR.PASS)
yield (tensor0, tensor5, BR.PASS)
yield (tensor0, tensor6, BR.PASS)
yield (tensor2, tensor4, BR.PASS)
yield (tensor3, tensor7, BR.PASS)
yield (tensor3, tensor8, BR.PASS)
yield (tensor6, tensor12, BR.PASS)
yield (tensor6, tensor13, BR.PASS)
yield (tensor4, rx.TensorStructInfo([n, m], dtype="int32"), BR.PASS)

# tuple
t0 = rx.TupleStructInfo([obj0, tensor0])
t1 = rx.TupleStructInfo([prim0, tensor4])
t2 = rx.TupleStructInfo([obj0, tensor0, obj0])
t3 = rx.TupleStructInfo([tensor0, obj0])

assert t0.is_base_of(t1)
yield (t0, t1, BR.PASS)
yield (t0, t2, BR.FAIL_L0)
yield (t0, t3, BR.FAIL_L1)

assert bcheck(t0, t2) == BR.FAIL_L0
assert bcheck(t0, t3) == BR.FAIL_L1

assert rx.TupleStructInfo([t0, t1]).is_base_of(rx.TupleStructInfo([t1, t1]))
assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1
yield (rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t1]), BR.PASS)
yield (rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0]), BR.FAIL_L1)

def fn_info_shape(c):
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
Expand All @@ -341,12 +336,47 @@ def fn_info_erased():
z = rx.TensorStructInfo(ndim=2, dtype="float32")
return rx.FuncStructInfo([x, y], z)

assert fn_info_shape(1).is_base_of(fn_info_shape(1))
assert fn_info_erased().is_base_of(fn_info_shape(1))
assert bcheck(fn_info_shape(1), fn_info_erased()) == BR.FAIL_L2
yield (fn_info_shape(1), fn_info_shape(1), BR.PASS)
yield (fn_info_erased(), fn_info_shape(1), BR.PASS)
yield (fn_info_shape(1), fn_info_erased(), BR.FAIL_L2)

fopaque = rx.FuncStructInfo.opaque_func()
assert fopaque.is_base_of(fn_info_shape(1))
yield (fopaque, fn_info_shape(1), BR.PASS)

# Symbolic var tests
static_shape = rx.ShapeStructInfo([1, 2, 4])
compatible_symbolic_shape = rx.ShapeStructInfo([1, n, n * n])
incompatible_symbolic_shape = rx.ShapeStructInfo([1, n, n + 1])

# Symbolic shapes may occur in multiple locations. Incompatible
# shapes may fail at L0, even if each use of the symbolic variable
# would only have failed at L2.
yield (rx.ShapeStructInfo([n, n]), rx.ShapeStructInfo([16, 32]), BR.FAIL_L0)

# If `n==2`, then the shapes are compatible.
yield (compatible_symbolic_shape, static_shape, BR.FAIL_L2)

# There is no value of `n` for which `n==2 and n+1==4`
yield (incompatible_symbolic_shape, static_shape, BR.FAIL_L0)


base_check_test_case = tvm.testing.parameter(*generate_base_check_test_cases())


def test_base_check(base_check_test_case):
base, derived, expected_result = base_check_test_case

actual_result = rx.analysis.BaseCheckResult(rx.analysis.struct_info_base_check(base, derived))
assert actual_result == expected_result, (
f"When checking if {base} is a superset of {derived}, "
f"expected the result to be {str(expected_result)}, "
f"but received {str(actual_result)}"
)

if expected_result == rx.analysis.BaseCheckResult.PASS:
assert base.is_base_of(
derived
), f"Expected {base} to be recognized as a superset of {derived}"


def _check_derive(ctx, finfo, args_sinfo, ret):
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ class TestLogical(BaseCompare):
TestCase(tvm.tir.And(x <= 1, 2 <= x), tvm.tir.const(False, "bool")),
TestCase(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool")),
TestCase(tvm.tir.And(x == 1, x != 2), x == 1),
TestCase(tvm.tir.And(x == 1, x == 2), tvm.tir.const(False, "bool")),
TestCase(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True, "bool")),
Expand All @@ -965,6 +966,7 @@ class TestLogical(BaseCompare):
TestCase(tvm.tir.Or(x <= 1, 2 <= x), tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(x != 1, x == 2), x != 1),
TestCase(tvm.tir.Or(x != 1, x != 2), tvm.tir.const(True, "bool")),
TestCase(
tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)),
tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1),
Expand Down
Loading