Skip to content

Commit

Permalink
[Unity][Transform] Improved canonicalization of non-dataflow Var (apa…
Browse files Browse the repository at this point in the history
…che#15941)

* [Unity][Transform] Improved canonicalization of non-dataflow Var

Prior to this commit, `relax.transform.CanonicalizeBindings` removed
trivial bindings `var_y = var_x` where a `var_y: relax.DataflowVar`
and `var_x: relax.Var`, but did not remove trivial bindings when
`var_y: relax.Var` and `var_x: relax.DataflowVar`.  This was to avoid
invalid use of a `relax.DataflowVar` outside of a dataflow block.

This commit updates `CanonicalizeBindings` to handle this type of
binding as well.  To ensure that no `relax.DataflowVar` instances are
used outside of a dataflow block, this is done by replacing `var_y:
relax.DataflowVar` at its point of definition, instead of replacing
`var_x: relax.Var` at its point of use.

This commit also canonicalizes `relax.Var` definitions to
`relax.DataflowVar`, if the binding occurs within a dataflow block,
and the variable is never used outside of a dataflow block.

* Simplify unwrapping of known bindings

* Updated to use Map<Id,Var>, to avoid while(true) loops
  • Loading branch information
Lunderberg authored Oct 25, 2023
1 parent bcdbc3e commit 4d19c8a
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 92 deletions.
247 changes: 173 additions & 74 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,68 +33,187 @@
namespace tvm {
namespace relax {

class BindingCanonicalizer : public ExprMutator {
namespace {

struct CanonicalizationPlan {
Map<Id, Var> replace_usage;
Map<Id, Var> replace_binding;
std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
};

/*! \brief Utility class to identify usage location
*
* Canonicalization of a variable binding may require information from
* later in the function. For example, replacing `dataflow_x = expr`
* with `var_x = expr` to avoid a trivial binding of `var_x =
* dataflow_x` later in the function. This utility examines a relax
* expression, and plans the changes to be made in a mutation pass.
*/
class CanonicalizePlanner : public ExprVisitor {
public:
BindingCanonicalizer() {}

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
if (auto tuple_value = LookupBinding(tuple_var.value())) {
if (auto explicit_tuple = tuple_value.as<TupleNode>()) {
CHECK_GE(tuple_get_item->index, 0)
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but negative indices are not supported in this context.";
CHECK_LT(tuple_get_item->index, explicit_tuple->fields.size())
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but the tuple size is only " << explicit_tuple->fields.size();
return VisitExpr(explicit_tuple->fields[tuple_get_item->index]);
static CanonicalizationPlan Collect(const Expr& expr) {
CanonicalizePlanner visitor;
visitor.VisitExpr(expr);

CanonicalizationPlan plan;

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> handled;

for (const auto& binding_iter : visitor.trivial_bindings_) {
Var bound_var = binding_iter.first;
Var bound_to = binding_iter.second;

while (auto opt = visitor.trivial_bindings_.Get(bound_to)) {
// This may be a trivial binding into a trivial binding. In
// that case, unwrap the bindings until we find the earliest
// non-trivial binding.
bound_to = opt.value();
}

while (auto opt = plan.replace_binding.Get(bound_to->vid)) {
// The variable we are binding to may have already been
// replaced, if it fell into Case 4 (Var = DataflowVar). In
// that case, we check against its replacement instead.
bound_to = opt.value();
}

if (bound_var.as<DataflowVarNode>() || !bound_to.as<DataflowVarNode>()) {
// Case 1: Var = Var
// Case 2: DataflowVar = Var
// Case 3: DataflowVar = DataflowVar
//
// For these three cases, the trivial binding can be
// unwrapped, using the bound variable directly at the point
// of use.
plan.replace_usage.Set(bound_var->vid, bound_to);
plan.bindings_to_remove.insert(bound_var->vid);
handled.insert(bound_to);
} else {
// Case 4: Var = DataflowVar
//
// Replacing a Var with a DataflowVar could result in illegal
// use of a DataflowVar outside of a DataflowBlock. Instead,
// we replace in the opposite direction, replacing the binding
// of the DataflowVar with a binding of the Var.
plan.replace_binding.Set(bound_to->vid, bound_var);
plan.replace_usage.Set(bound_to->vid, bound_var);
plan.bindings_to_remove.insert(bound_var->vid);
handled.insert(bound_var);
}
}

// If a Var has been defined inside a DataflowBlock, is only used
// within a DataflowBlock, and is not already handled by removal
// of trivial bindings, then we can replace it with a DataflowVar.
for (const auto& var : visitor.defined_inside_dataflow_) {
if (!var.as<DataflowVarNode>() && !visitor.used_outside_dataflow_.count(var) &&
!handled.count(var)) {
DataflowVar new_var(var->name_hint(), GetStructInfo(var));
plan.replace_binding.Set(var->vid, new_var);
plan.replace_usage.Set(var->vid, new_var);
}
}

return plan;
}

private:
void VisitBindingBlock_(const DataflowBlockNode* block) override {
bool cache = inside_dataflow_;
inside_dataflow_ = true;
ExprVisitor::VisitBindingBlock_(block);
inside_dataflow_ = cache;
}

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = true;
Expr value;
if (auto ptr = binding.as<VarBindingNode>()) {
value = ptr->value;
} else if (auto ptr = binding.as<MatchCastNode>()) {
has_same_struct_info =
StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value));
value = ptr->value;
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}

// Unwrap TupleGetItem, if the Tuple being accessed is known.
if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
Expr tuple = tuple_get_item->tuple;
while (auto tuple_var = tuple.as<Var>()) {
if (auto opt = known_bindings_.Get(tuple_var.value())) {
tuple = opt.value();
} else {
break;
}
}

if (auto ptr = tuple.as<TupleNode>()) {
value = ptr->fields[tuple_get_item->index];
}
}

if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
trivial_bindings_.Set(binding->var, parent.value());
}
return ExprMutator::VisitExpr_(tuple_get_item);

known_bindings_.Set(binding->var, value);

ExprVisitor::VisitBinding(binding);
}

void VisitBinding_(const VarBindingNode* binding) override {
// Unlike default visitor, we do not permit the struct info to change
// if the new value's struct info is different (this preserves user annotations)
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);

if (auto opt_var = new_value.as<Var>();
opt_var && CanCanonicalizeVar(new_var, opt_var.value())) {
var_remap_[new_var->vid] = opt_var.value();
} else if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
void VisitVarDef(const Var& var) override {
if (inside_dataflow_) {
defined_inside_dataflow_.insert(var);
}
}

void VisitExpr_(const VarNode* var) override {
if (!inside_dataflow_) {
used_outside_dataflow_.insert(GetRef<Var>(var));
}
}

bool inside_dataflow_{false};

Map<Var, Var> trivial_bindings_;
Map<Var, Expr> known_bindings_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_inside_dataflow_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_outside_dataflow_;
};

/*! \brief The mutator class to apply a CanonicalizationPlan */
class BindingCanonicalizer : public ExprMutator {
public:
static Expr Apply(Expr expr) {
auto used_outside_dataflow = CanonicalizePlanner::Collect(expr);
BindingCanonicalizer mutator(std::move(used_outside_dataflow));
return mutator.VisitExpr(expr);
}

private:
explicit BindingCanonicalizer(CanonicalizationPlan plan) : plan_(plan) {}

void VisitBinding(const Binding& binding) override {
if (!plan_.bindings_to_remove.count(binding->var->vid)) {
ExprMutator::VisitBinding(binding);
}
}

Var VisitVarDef(const Var& var) override {
if (auto opt = plan_.replace_binding.Get(var->vid)) {
return ExprMutator::VisitVarDef(opt.value());
} else {
this->builder_->EmitNormalized(VarBinding(new_var, new_value));
return ExprMutator::VisitVarDef(var);
}
}

void VisitBinding_(const MatchCastNode* binding) override {
// If we have a trivial shape check (the struct_info_ of LHS and RHS is the same),
// we can canonicalize to a var binding
Expr new_value = this->VisitExpr(binding->value);
bool has_same_struct_info = StructuralEqual()(binding->struct_info, GetStructInfo(new_value));

if (has_same_struct_info) {
if (auto parent = new_value.as<Var>();
parent && CanCanonicalizeVar(binding->var, parent.value())) {
// LHS and RHS have the same struct info, and occur in a
// context where the RHS can replace the LHS.
var_remap_[binding->var->vid] = parent.value();
} else {
// LHS and RHS have the same struct info, but the RHS is not a
// drop-in replacement for the LHS.
builder_->EmitNormalized(VarBinding(binding->var, new_value));
}
} else if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<MatchCast>(binding));
Expr VisitExpr_(const VarNode* var) override {
if (auto opt = plan_.replace_usage.Get(var->vid)) {
return ExprMutator::VisitExpr(opt.value());
} else {
// we can't elide in the same way as with var bindings because
// the struct info comparison has semantics
builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info));
return ExprMutator::VisitExpr_(var);
}
}

Expand Down Expand Up @@ -200,31 +319,11 @@ class BindingCanonicalizer : public ExprMutator {
}

private:
bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2,
std::function<bool(const ObjectRef&, const ObjectRef&)> check_eq) {
// annotations differ if one is present but not the other
// or they're both present and they differ
bool both_present = obj1.defined() && obj2.defined();
bool neither_present = !obj1.defined() && !obj2.defined();
return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2));
}

bool CanCanonicalizeVar(Var var, Var parent_var) {
// Cases when we conservatively do not unify:
// 1. The struct_info_ of the child differs from that of the parent
// In this case, we could be overriding user annotations.
// 2. If the child is a Var and the parent is a DataflowVar.
// That could result in a DataflowVar leaving the current DataflowBlock.
bool annotations_differ = AnnotationsDiffer(var->struct_info_, parent_var->struct_info_,
[&](const ObjectRef& lhs, const ObjectRef& rhs) {
return tvm::StructuralEqual()(lhs, rhs);
});
bool var_to_dataflow = (!var.as<DataflowVarNode>() && parent_var.as<DataflowVarNode>());
return !annotations_differ && !var_to_dataflow;
}
CanonicalizationPlan plan_;
};
} // namespace

Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); }
Expr CanonicalizeBindings(const Expr& expr) { return BindingCanonicalizer::Apply(expr); }

namespace transform {

Expand Down
3 changes: 1 addition & 2 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,8 +1403,7 @@ def before(x: R.Tensor((1024,))):
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
with R.dataflow():
a = R.add(x, x)
b = a
b = R.add(x, x)
R.output(b)
return b

Expand Down
9 changes: 3 additions & 6 deletions tests/python/relax/test_optimize_layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,9 @@ def main(
(lv1, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv2_1
R.output(gv)
return gv

Expand Down Expand Up @@ -256,10 +255,9 @@ def main(
(lv3, lv4),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv6
R.output(gv)
return gv

Expand Down Expand Up @@ -399,10 +397,9 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"
pad_value=None,
axis_separators=[],
)
lv_2 = R.call_tir(
gv = R.call_tir(
Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
)
gv: R.Tensor((14,), dtype="float32") = lv_2
R.output(gv)
return gv

Expand Down
3 changes: 1 addition & 2 deletions tests/python/relax/test_remove_redundant_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001), dtype="float16"):
with R.dataflow():
lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
gv: R.Tensor((1, 1001), dtype="float16") = lv
gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
R.output(gv)
return gv

Expand Down
Loading

0 comments on commit 4d19c8a

Please sign in to comment.