Skip to content

Commit

Permalink
[TVMScript][Relax] Allow return statement in DataflowBlock (#17131)
Browse files Browse the repository at this point in the history
Prior to this commit, TVMScript required the return value of a Relax
to be specified outside of any `with R.dataflow()` blocks.  This
resulted in a common pattern, where the return value of a function was
first called with `R.output(ret_value)`, to mark `ret_value` as a
`tvm::relax::Var` instead of a `tvm::relax::DataflowVar`, followed
immediately by a `return ret_value` statement.

This commit updates the TVMScript parser to allow a `return` statement
inside a `with R.dataflow()` block.  This is syntactic sugar that
is equivalent to calling `R.output`, followed by a `return`.

With this change, the following two TVMScript examples are now
equivalent.  (Prior to this change, the `return_inside_dataflow`
example would raise an error during parsing.)

```python
@R.function(private=True)
def output_then_return(A: R.Tensor):
    with R.dataflow():
        B = R.add(A, A)
        C = R.multiply(B, B)
        R.output(C)

    return C

@R.function(private=True)
def return_inside_dataflow(A: R.Tensor):
    with R.dataflow():
        B = R.add(A, A)
        C = R.multiply(B, B)
        return C
```
  • Loading branch information
Lunderberg authored Sep 18, 2024
1 parent ff8e416 commit a242046
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 48 deletions.
69 changes: 28 additions & 41 deletions src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,23 @@ void BlockFrameNode::EnterWithScope() {
}
}

class DataflowBlockRewriter : public tvm::relax::ExprMutator {
class VarReplacer : public tvm::relax::ExprMutator {
public:
static tvm::relax::DataflowBlock Rewrite(const tvm::relax::DataflowBlock& block,
const Array<tvm::relax::Var>& output_vars) {
DataflowBlockRewriter rewriter(output_vars);
return Downcast<tvm::relax::DataflowBlock>(rewriter.VisitBindingBlock(block));
explicit VarReplacer(
std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, ObjectPtrEqual>
var_remap) {
var_remap_ = std::move(var_remap);
}

private:
explicit DataflowBlockRewriter(const Array<tvm::relax::Var>& output_vars) {
for (const tvm::relax::Var& var : output_vars) {
output_var_set_.insert(var.get());
}
}

tvm::relax::Var VisitVarDef_(const tvm::relax::DataflowVarNode* op) final {
auto it = output_var_set_.find(op);
if (it != output_var_set_.end()) {
// Rewrite dataflow vars to global vars
auto n = make_object<tvm::relax::VarNode>(*op);
tvm::relax::Var new_var(n);
this->var_remap_[op->vid] = new_var;
return new_var;
tvm::relax::Var VisitVarDef(const tvm::relax::Var& var) override {
// ExprMutator only applies var_remap_ at usage sites. This
// applies var_remap_ at each definition site as well.
if (auto it = var_remap_.find(var->vid); it != var_remap_.end()) {
return it->second;
} else {
return GetRef<tvm::relax::Var>(op);
return var;
}
}

private:
std::unordered_set<const tvm::relax::VarNode*> output_var_set_;
};

void BlockFrameNode::ExitWithScope() {
Expand All @@ -164,25 +151,27 @@ void BlockFrameNode::ExitWithScope() {

// Step 3. Rewrite the dataflow block.
if (is_dataflow) {
// Step 3.1. Rewrite block binding
block = DataflowBlockRewriter::Rewrite(Downcast<tvm::relax::DataflowBlock>(block), output_vars);

// Step 3.2. Collect global vars' reference in bindings
Map<tvm::relax::Id, tvm::relax::Var> new_global_vars;
for (const tvm::relax::Binding& binding : block->bindings) {
if (!binding->var->IsInstance<tvm::relax::DataflowVarNode>()) {
new_global_vars.Set(binding->var->vid, binding->var);
}
// Step 3.0. Define a map to replace variables
Array<tvm::relax::Var> new_output_vars;
std::unordered_map<tvm::relax::Id, tvm::relax::Var, ObjectPtrHash, ObjectPtrEqual> var_remap;
for (const auto& output_var : output_vars) {
tvm::relax::Var new_output_var(output_var->name_hint(), GetStructInfo(output_var));
new_output_vars.push_back(new_output_var);
var_remap[output_var->vid] = new_output_var;
}
VarReplacer mutator(std::move(var_remap));

// Step 3.1. Rewrite block binding
block = mutator.VisitBindingBlock(block);

// Step 3.3. Rewrite output vars
Array<tvm::relax::Var> new_output_vars;
for (const auto& var : output_vars) {
auto it = new_global_vars.find(var->vid);
ICHECK(it != new_global_vars.end());
new_output_vars.push_back((*it).second);
}
output_vars = std::move(new_output_vars);

// Step 3.4 Rewrite usage of output var, if any
auto function = FindFunctionFrame("R.dataflow()");
if (function->output.defined()) {
function->output = mutator.VisitExpr(function->output.value());
}
}

// Step 3. Get the last frame from the IRBuilder frame stack.
Expand All @@ -196,8 +185,6 @@ void BlockFrameNode::ExitWithScope() {

// Step 5. Push the block frame into the corresponding field of the last frame.
if (const auto* seq_frame = last_frame.as<SeqExprFrameNode>()) {
ICHECK(!seq_frame->output.defined())
<< "The function is not expected to have output values when emitting blocks.";
auto frame = GetRef<SeqExprFrame>(seq_frame);
frame->binding_blocks.push_back(block);
} else {
Expand Down
23 changes: 16 additions & 7 deletions src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,29 @@ void FuncRetValue(const tvm::relax::Expr& value) {
const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
tvm::relax::Expr normalized_value = block_builder->Normalize(value);

IRBuilder ir_builder = IRBuilder::Current();

// Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of
// a function body. Therefore if there is any unended block frame when dealing with function
// return, we should end the block frame.
Optional<BlockFrame> block_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>();
if (block_frame.defined()) {
block_frame.value()->ExitWithScope();
ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
<< "ValueError: Relax functions don't support return in true/false branch of If Node.";

if (auto opt = ir_builder->GetLastFrame<BlockFrame>()) {
auto block_frame = opt.value();
for (const auto& var : tvm::relax::FreeVars(normalized_value)) {
if (var->IsInstance<tvm::relax::DataflowVarNode>()) {
block_frame->output_vars.push_back(var);
}
}
}
// Step 2. Add the output value to the function frame.
FunctionFrame frame = FindFunctionFrame("return");
CHECK(!frame->output.defined())
<< "ValueError: Relax functions don't support multiple return statement. Please make sure "
"the return statement appears at the end of function.";
<< "ValueError: "
<< "Relax functions do not support multiple return statement. "
<< "However, return of " << normalized_value << " occurred after a return of "
<< frame->output << ". "
<< "Please make sure function only has a single return statement, "
<< "which appears at the end of function.";

frame->output = std::move(normalized_value);
}
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,5 +2410,36 @@ def inferred_sinfo(
tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)


def test_return_from_dataflow_block():
"""Return statements imply
The `R.output` statement in a `R.dataflow()` block marks a
variable that should be a `relax.Var` instead of a
`relax.DataflowVar`, allowing it to be used outside of the
`DataflowBlock` that defined it. A relax function's output is not
part of any binding, and must not contain any `DataflowVar`, so
these are exposed implicitly.
"""

@R.function(private=True)
def output_then_return(A: R.Tensor([16], "float16")):
with R.dataflow():
B = R.add(A, A)
C = R.multiply(B, B)
R.output(C)

return C

@R.function(private=True)
def return_inside_dataflow(A: R.Tensor([16], "float16")):
with R.dataflow():
B = R.add(A, A)
C = R.multiply(B, B)
return C

tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit a242046

Please sign in to comment.