diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 2e7f4311f950..246b38f6f83b 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -33,68 +33,187 @@ namespace tvm { namespace relax { -class BindingCanonicalizer : public ExprMutator { +namespace { + +struct CanonicalizationPlan { + Map replace_usage; + Map replace_binding; + std::unordered_set bindings_to_remove; +}; + +/*! \brief Utility class to identify usage location + * + * Canonicalization of a variable binding may require information from + * later in the function. For example, replacing `dataflow_x = expr` + * with `var_x = expr` to avoid a trivial binding of `var_x = + * dataflow_x` later in the function. This utility examines a relax + * expression, and plans the changes to be made in a mutation pass. + */ +class CanonicalizePlanner : public ExprVisitor { public: - BindingCanonicalizer() {} - - using ExprMutator::VisitExpr_; - - Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override { - if (auto tuple_var = tuple_get_item->tuple.as()) { - if (auto tuple_value = LookupBinding(tuple_var.value())) { - if (auto explicit_tuple = tuple_value.as()) { - CHECK_GE(tuple_get_item->index, 0) - << "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index - << ", but negative indices are not supported in this context."; - CHECK_LT(tuple_get_item->index, explicit_tuple->fields.size()) - << "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index - << ", but the tuple size is only " << explicit_tuple->fields.size(); - return VisitExpr(explicit_tuple->fields[tuple_get_item->index]); + static CanonicalizationPlan Collect(const Expr& expr) { + CanonicalizePlanner visitor; + visitor.VisitExpr(expr); + + CanonicalizationPlan plan; + + std::unordered_set handled; + + for (const auto& binding_iter : visitor.trivial_bindings_) { + Var bound_var = binding_iter.first; + Var bound_to = binding_iter.second; + + while (auto opt = visitor.trivial_bindings_.Get(bound_to)) { + // This may be a trivial binding into a trivial binding. In + // that case, unwrap the bindings until we find the earliest + // non-trivial binding. + bound_to = opt.value(); + } + + while (auto opt = plan.replace_binding.Get(bound_to->vid)) { + // The variable we are binding to may have already been + // replaced, if it fell into Case 4 (Var = DataflowVar). In + // that case, we check against its replacement instead. + bound_to = opt.value(); + } + + if (bound_var.as() || !bound_to.as()) { + // Case 1: Var = Var + // Case 2: DataflowVar = Var + // Case 3: DataflowVar = DataflowVar + // + // For these three cases, the trivial binding can be + // unwrapped, using the bound variable directly at the point + // of use. + plan.replace_usage.Set(bound_var->vid, bound_to); + plan.bindings_to_remove.insert(bound_var->vid); + handled.insert(bound_to); + } else { + // Case 4: Var = DataflowVar + // + // Replacing a Var with a DataflowVar could result in illegal + // use of a DataflowVar outside of a DataflowBlock. Instead, + // we replace in the opposite direction, replacing the binding + // of the DataflowVar with a binding of the Var. + plan.replace_binding.Set(bound_to->vid, bound_var); + plan.replace_usage.Set(bound_to->vid, bound_var); + plan.bindings_to_remove.insert(bound_var->vid); + handled.insert(bound_var); + } + } + + // If a Var has been defined inside a DataflowBlock, is only used + // within a DataflowBlock, and is not already handled by removal + // of trivial bindings, then we can replace it with a DataflowVar. + for (const auto& var : visitor.defined_inside_dataflow_) { + if (!var.as() && !visitor.used_outside_dataflow_.count(var) && + !handled.count(var)) { + DataflowVar new_var(var->name_hint(), GetStructInfo(var)); + plan.replace_binding.Set(var->vid, new_var); + plan.replace_usage.Set(var->vid, new_var); + } + } + + return plan; + } + + private: + void VisitBindingBlock_(const DataflowBlockNode* block) override { + bool cache = inside_dataflow_; + inside_dataflow_ = true; + ExprVisitor::VisitBindingBlock_(block); + inside_dataflow_ = cache; + } + + void VisitBinding(const Binding& binding) override { + bool has_same_struct_info = true; + Expr value; + if (auto ptr = binding.as()) { + value = ptr->value; + } else if (auto ptr = binding.as()) { + has_same_struct_info = + StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value)); + value = ptr->value; + } else { + LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); + } + + // Unwrap TupleGetItem, if the Tuple being accessed is known. + if (auto tuple_get_item = value.as()) { + Expr tuple = tuple_get_item->tuple; + while (auto tuple_var = tuple.as()) { + if (auto opt = known_bindings_.Get(tuple_var.value())) { + tuple = opt.value(); + } else { + break; } } + + if (auto ptr = tuple.as()) { + value = ptr->fields[tuple_get_item->index]; + } + } + + if (auto parent = value.as(); parent && has_same_struct_info) { + trivial_bindings_.Set(binding->var, parent.value()); } - return ExprMutator::VisitExpr_(tuple_get_item); + + known_bindings_.Set(binding->var, value); + + ExprVisitor::VisitBinding(binding); } - void VisitBinding_(const VarBindingNode* binding) override { - // Unlike default visitor, we do not permit the struct info to change - // if the new value's struct info is different (this preserves user annotations) - Expr new_value = this->VisitExpr(binding->value); - Var new_var = this->VisitVarDef(binding->var); - - if (auto opt_var = new_value.as(); - opt_var && CanCanonicalizeVar(new_var, opt_var.value())) { - var_remap_[new_var->vid] = opt_var.value(); - } else if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { - this->builder_->EmitNormalized(GetRef(binding)); + void VisitVarDef(const Var& var) override { + if (inside_dataflow_) { + defined_inside_dataflow_.insert(var); + } + } + + void VisitExpr_(const VarNode* var) override { + if (!inside_dataflow_) { + used_outside_dataflow_.insert(GetRef(var)); + } + } + + bool inside_dataflow_{false}; + + Map trivial_bindings_; + Map known_bindings_; + std::unordered_set defined_inside_dataflow_; + std::unordered_set used_outside_dataflow_; +}; + +/*! \brief The mutator class to apply a CanonicalizationPlan */ +class BindingCanonicalizer : public ExprMutator { + public: + static Expr Apply(Expr expr) { + auto used_outside_dataflow = CanonicalizePlanner::Collect(expr); + BindingCanonicalizer mutator(std::move(used_outside_dataflow)); + return mutator.VisitExpr(expr); + } + + private: + explicit BindingCanonicalizer(CanonicalizationPlan plan) : plan_(plan) {} + + void VisitBinding(const Binding& binding) override { + if (!plan_.bindings_to_remove.count(binding->var->vid)) { + ExprMutator::VisitBinding(binding); + } + } + + Var VisitVarDef(const Var& var) override { + if (auto opt = plan_.replace_binding.Get(var->vid)) { + return ExprMutator::VisitVarDef(opt.value()); } else { - this->builder_->EmitNormalized(VarBinding(new_var, new_value)); + return ExprMutator::VisitVarDef(var); } } - void VisitBinding_(const MatchCastNode* binding) override { - // If we have a trivial shape check (the struct_info_ of LHS and RHS is the same), - // we can canonicalize to a var binding - Expr new_value = this->VisitExpr(binding->value); - bool has_same_struct_info = StructuralEqual()(binding->struct_info, GetStructInfo(new_value)); - - if (has_same_struct_info) { - if (auto parent = new_value.as(); - parent && CanCanonicalizeVar(binding->var, parent.value())) { - // LHS and RHS have the same struct info, and occur in a - // context where the RHS can replace the LHS. - var_remap_[binding->var->vid] = parent.value(); - } else { - // LHS and RHS have the same struct info, but the RHS is not a - // drop-in replacement for the LHS. - builder_->EmitNormalized(VarBinding(binding->var, new_value)); - } - } else if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + Expr VisitExpr_(const VarNode* var) override { + if (auto opt = plan_.replace_usage.Get(var->vid)) { + return ExprMutator::VisitExpr(opt.value()); } else { - // we can't elide in the same way as with var bindings because - // the struct info comparison has semantics - builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info)); + return ExprMutator::VisitExpr_(var); } } @@ -200,31 +319,11 @@ class BindingCanonicalizer : public ExprMutator { } private: - bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2, - std::function check_eq) { - // annotations differ if one is present but not the other - // or they're both present and they differ - bool both_present = obj1.defined() && obj2.defined(); - bool neither_present = !obj1.defined() && !obj2.defined(); - return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2)); - } - - bool CanCanonicalizeVar(Var var, Var parent_var) { - // Cases when we conservatively do not unify: - // 1. The struct_info_ of the child differs from that of the parent - // In this case, we could be overriding user annotations. - // 2. If the child is a Var and the parent is a DataflowVar. - // That could result in a DataflowVar leaving the current DataflowBlock. - bool annotations_differ = AnnotationsDiffer(var->struct_info_, parent_var->struct_info_, - [&](const ObjectRef& lhs, const ObjectRef& rhs) { - return tvm::StructuralEqual()(lhs, rhs); - }); - bool var_to_dataflow = (!var.as() && parent_var.as()); - return !annotations_differ && !var_to_dataflow; - } + CanonicalizationPlan plan_; }; +} // namespace -Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); } +Expr CanonicalizeBindings(const Expr& expr) { return BindingCanonicalizer::Apply(expr); } namespace transform { diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 49b7d11a804e..a8b71aa5ebfe 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1403,8 +1403,7 @@ def before(x: R.Tensor((1024,))): @R.function(private=True) def expected(x: R.Tensor((1024,))): with R.dataflow(): - a = R.add(x, x) - b = a + b = R.add(x, x) R.output(b) return b diff --git a/tests/python/relax/test_optimize_layout_transform.py b/tests/python/relax/test_optimize_layout_transform.py index 08c9e3110705..3addfab2e88c 100644 --- a/tests/python/relax/test_optimize_layout_transform.py +++ b/tests/python/relax/test_optimize_layout_transform.py @@ -130,10 +130,9 @@ def main( (lv1, lv2), out_sinfo=R.Tensor((4, 4), dtype="float32"), ) - lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform( + gv: R.Tensor((16,), dtype="float32") = R.layout_transform( lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None ) - gv: R.Tensor((16,), dtype="float32") = lv2_1 R.output(gv) return gv @@ -256,10 +255,9 @@ def main( (lv3, lv4), out_sinfo=R.Tensor((4, 4), dtype="float32"), ) - lv6: R.Tensor((16,), dtype="float32") = R.layout_transform( + gv: R.Tensor((16,), dtype="float32") = R.layout_transform( lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None ) - gv: R.Tensor((16,), dtype="float32") = lv6 R.output(gv) return gv @@ -399,10 +397,9 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32" pad_value=None, axis_separators=[], ) - lv_2 = R.call_tir( + gv = R.call_tir( Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32") ) - gv: R.Tensor((14,), dtype="float32") = lv_2 R.output(gv) return gv diff --git a/tests/python/relax/test_remove_redundant_reshape.py b/tests/python/relax/test_remove_redundant_reshape.py index 11e8c87cf1aa..a28141616c10 100644 --- a/tests/python/relax/test_remove_redundant_reshape.py +++ b/tests/python/relax/test_remove_redundant_reshape.py @@ -52,8 +52,7 @@ def main( x: R.Tensor((1, 1001, 1, 1), dtype="float16") ) -> R.Tensor((1, 1001), dtype="float16"): with R.dataflow(): - lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001])) - gv: R.Tensor((1, 1001), dtype="float16") = lv + gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001])) R.output(gv) return gv diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index aed5dad5574c..92057ce46a95 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -295,10 +295,9 @@ class MultipleUse: @R.function def main() -> R.Tensor((), "int32"): with R.dataflow(): - y = R.const(1) + n = R.const(1) # multiple uses -> cannot coalesce - m = R.add(y, y) - n = y + m = R.add(n, n) R.output(n) return n @@ -353,15 +352,244 @@ class UsedInMultipleOutputs: @R.function def main() -> R.Tensor((), "int32"): with R.dataflow(): - x = R.const(1) - l = x - m = x - n = x - R.output(l, m, n) + n = R.const(1) + R.output(n) return n verify(UsedInMultipleOutputs, UsedInMultipleOutputs) +def test_canonicalize_var_to_dataflow_var_if_legal(): + """Canonicalize Var to DataflowVar inside DataflowBlock + + DataflowVar instances may only be used inside a DataflowBlock. If + a trivial binding `y = x` occurs, where `x` is a `DataflowVar` and + `y` is a `Var`, replacing `y` with `x` may result in usage of a + `DataflowVar` outside of a `DataflowBlock`. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + z = R.add(y, R.const(1)) + R.output(y, z) + return z + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + z = R.add(y, R.const(1)) + R.output(z) + return z + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + +def test_update_dataflow_computations_if_var_replacement_occurs(): + """Canonicalize Var to DataflowVar inside DataflowBlock + + DataflowBlocks may produce additional outputs after the first + output Var, and these additional outputs may be in terms of the + first output. Computations that depend on a replaced var must be + updated to remain well-formed. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + lv1 = R.add(x, R.const(1)) + gv1 = lv1 + gv2 = R.add(lv1, R.const(1)) + R.output(gv1, gv2) + return (gv1, gv2) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + # lv1 has been replaced with gv1 + gv1 = R.add(x, R.const(1)) + # So gv1 must be used in the computation of gv2 + gv2 = R.add(gv1, R.const(1)) + R.output(gv1, gv2) + return (gv1, gv2) + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + +def test_update_dataflow_computations_if_var_replacement_occurs_after_usage(): + """Canonicalize Var to DataflowVar inside DataflowBlock + + Like test_update_dataflow_computations_if_var_replacement_occurs, + but the usage of a DataflowVar occurs before the trivial binding + that causes it to be replaced. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + lv1 = R.add(x, R.const(1)) + gv2 = R.add(lv1, R.const(1)) + gv1 = lv1 + R.output(gv1, gv2) + return (gv1, gv2) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + # lv1 has been replaced with gv1 + gv1 = R.add(x, R.const(1)) + # So gv1 must be used in the computation of gv2 + gv2 = R.add(gv1, R.const(1)) + # Even though the trivial binding of "gv1 = lv1" + # occurred in this position. + R.output(gv1, gv2) + return (gv1, gv2) + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + +def test_canonicalize_trivial_binding_to_dataflow_var(): + """Canonicalize Var to DataflowVar inside DataflowBlock + + DataflowVar instances may only be used inside a DataflowBlock. If + a trivial binding `y = x` occurs, where `x` is a `DataflowVar` and + `y` is a `Var`, replacing `y` with `x` may result in usage of a + `DataflowVar` outside of a `DataflowBlock`. + + If a binding exists solely to convert from DataflowVar into Var, + then canonicalization replaces the earlier DataflowVar with a Var. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + z = y + R.output(z) + return z + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + R.output(y) + return y + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + +def test_canonicalize_multiple_trivial_binding_to_dataflow_var(): + """Canonicalize Var to DataflowVar inside DataflowBlock + + Like test_canonicalize_trivial_binding_to_dataflow_var, but there + exist multiple trivial bindings to the DataflowVar. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(w: R.Tensor): + with R.dataflow(): + x = R.add(w, R.const(1)) + y = x + z = x + R.output(y, z) + return (y, z) + + @tvm.script.ir_module + class Expected: + @R.function + def main(w: R.Tensor): + with R.dataflow(): + x = R.add(w, R.const(1)) + R.output(x) + return (x, x) + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + +def test_canonicalize_trivial_var_binding_inside_dataflow_block(): + """Canonicalize Var to DataflowVar inside DataflowBlock + + Canonicalization handles cases where a Var could be replaced by a + DataflowVar, and where a Var is a trivial binding. If these two + cases both occur, should produce reasonable results. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + z = y + R.output(y, z) + return z + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + R.output(y) + return y + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + +def test_canonicalize_across_non_dataflow_tuple(): + """Canonicalize Var to DataflowVar inside DataflowBlock""" + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + z = (y,) + gv = R.add(z[0], R.const(1)) + R.output(z, gv) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + with R.dataflow(): + y = R.add(x, R.const(1)) + z = (y,) + gv = R.add(y, R.const(1)) + R.output(gv) + return gv + + after = relax.transform.CanonicalizeBindings()(Before) + assert_structural_equal(Expected, after) + + if __name__ == "__main__": tvm.testing.main()