Skip to content

Commit

Permalink
[Bugfix][Relax] Preserve existing DataflowBlock in ConvertToDataflow (#…
Browse files Browse the repository at this point in the history
…17148)

The `relax.transform.ConvertToDataflow` identifies portions of a Relax
function that satisfy the requirements of a `relax::DataflowBlock`,
and converts those portions to a new `DataflowBlock`, provided they
are at least some minimum number of operations.  Prior to this
commit, if a function contained a region that would be converted to a
`DataflowBlock`, but also contains existing `DataflowBlock`s that were
smaller than the size required for creating a `DataflowBlock`, those
existing blocks would be erroneously converted to non-dataflow.

This commit updates the `ConvertToDataflow` pass to preserve all
existing `DataflowBlock` present in the input.
  • Loading branch information
Lunderberg authored Sep 13, 2024
1 parent 3755571 commit eb011c7
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 50 deletions.
117 changes: 67 additions & 50 deletions src/relax/transform/convert_dataflow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>

#include <optional>

namespace tvm {
namespace relax {

Expand All @@ -39,85 +41,100 @@ class DataflowBlockExtractor : public ExprMutator {
Array<BindingBlock> new_blocks;
Expr new_body = VisitExpr(seq->body);
bool changed = !new_body.same_as(seq->body);
bool dataflow_streak = false;
Array<Binding> dataflow_bindings;

// Accumulated bindings that are not going to be added to a
// DataflowBlock, either because they would be illegal within a
// DataflowBlock, or because there were insufficient bindings to
// make a dataflowblock. Because these bindings occur prior to
// `dataflow_bindings`, this array may only be accumulated into
// when `dataflow_bindings` is empty.
Array<Binding> non_dataflow_bindings;

// Current bindings that may legally be added to a DataflowBlock.
Array<Binding> dataflow_bindings;

// If present, a DataflowBlock whose bindings are currently in
// `dataflow_bindings`. Used to propagate DataflowBlock to the
// output, even if it doesn't meet the minimum size.
Optional<DataflowBlock> input_dataflow_block;

// Handle any bindings currently in `dataflow_bindings`. These
// are either pushed to their own block, or to the end of
// `non_dataflow_bindings`, depending on whether the bindings meet
// the minimum size requirement.
auto push_dataflow_bindings = [&]() {
if (dataflow_bindings.empty()) {
// No Dataflow bindings, so no action required.
return;
}
if (dataflow_bindings.size() < min_size_ && !input_dataflow_block) {
// The df block is below the minimum length, and no input
// DataflowBlock needs to be preserved. Combine the blocks
// and reset the dataflow collection.

non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
dataflow_bindings.end());

} else {
// A new DataflowBlock can be generated, with bindings that
// occur after the non-dataflow bindings.
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
new_blocks.push_back(DataflowBlock(dataflow_bindings));
non_dataflow_bindings = {};

// Making a dataflow block doesn't imply that the function was
// changed. A change requires that this either be a new
// dataflow block, or have additional dataflow bindings in the
// current block.
changed = changed || !input_dataflow_block.defined() ||
input_dataflow_block.value()->bindings.size() != dataflow_bindings.size();
}

dataflow_bindings = {};
input_dataflow_block = NullOpt;
};

for (auto block : seq->blocks) {
BindingBlock new_block = this->VisitBindingBlock(block);
changed = changed || !new_block.same_as(block);

// For an existing dataflow block, we add to the current streak
// or start a new streak in case there will be more dataflow operations
// coming up
if (new_block.as<DataflowBlock>()) {
if (!dataflow_streak) {
dataflow_streak = true;
}
if (auto dataflow_block = new_block.as<DataflowBlock>()) {
dataflow_bindings.insert(dataflow_bindings.end(), new_block->bindings.begin(),
new_block->bindings.end());
input_dataflow_block = dataflow_block;
continue;
}

// for a binding block, attempt to extract dataflow blocks inside
auto binding_block = Downcast<BindingBlock>(new_block);
for (size_t i = 0; i < binding_block->bindings.size(); i++) {
auto binding = binding_block->bindings[i];
for (const auto& binding : binding_block->bindings) {
Expr value = GetBoundValue(binding);
// dataflow values: not an if node and not an impure call
bool is_dataflow = (!value.as<IfNode>()) &&
(!(value.as<CallNode>() && IsImpureCall(Downcast<Call>(value))));
if (!dataflow_streak) {
// we can start a dataflow streak
if (is_dataflow) {
dataflow_streak = true;
dataflow_bindings = {binding};
} else {
non_dataflow_bindings.push_back(binding);
}
if (is_dataflow) {
// extend the streak
dataflow_bindings.push_back(binding);
} else {
if (is_dataflow) {
// extend the streak
dataflow_bindings.push_back(binding);
} else {
// this is the end of the streak
dataflow_streak = false;

// if the df block is below the minimum length, combine the blocks
// and reset the dataflow collection
if (dataflow_bindings.size() < min_size_) {
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
dataflow_bindings.end());
dataflow_bindings = {};
} else {
// otherwise insert both collections
changed = true;
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
new_blocks.push_back(DataflowBlock(dataflow_bindings));
non_dataflow_bindings = {};
dataflow_bindings = {};
}
non_dataflow_bindings.push_back(binding);
}
// End the streak, if one currently exists.
push_dataflow_bindings();
non_dataflow_bindings.push_back(binding);
}
}
}

// handle any remaining bindings
if (dataflow_bindings.size() < min_size_) {
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
dataflow_bindings.end());
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
} else {
changed = true;
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
new_blocks.push_back(DataflowBlock(dataflow_bindings));
}
push_dataflow_bindings();
new_blocks.push_back(BindingBlock(non_dataflow_bindings));

if (!changed) {
if (changed) {
return SeqExpr(new_blocks, new_body);
} else {
return GetRef<SeqExpr>(seq);
}
return SeqExpr(new_blocks, new_body);
}

private:
Expand Down
106 changes: 106 additions & 0 deletions tests/python/relax/test_transform_convert_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,5 +489,111 @@ def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
return v


class TestPreserveExistingDataflowBlocksAtBeginning(ExtractCompare):
"""Preserve existing DataflowBlocks
This is a regression test. In previous implementations, a
DataflowBlock in the input, without enough bindings to become a
new dataflow block, could be accidentally ommitted.
This test is identical to
`TestPreserveExistingDataflowBlocksAtEnd`, except that the
existing dataflow block is at the beginning of the function.
"""

@I.ir_module
class Before:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This DataflowBlock is below the minimum size for a new
# block, but already exists in the input IRModule.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

R.print(format="impure_function")

# This sequence is large enough that it may be converted
# to a DataflowBlock.
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)

return (A1, B3)

@I.ir_module
class Expected:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This dataflow block should be preserved in the output.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

R.print(format="impure_function")

with R.dataflow():
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
R.output(B3)

return (A1, B3)


class TestPreserveExistingDataflowBlocksAtEnd(ExtractCompare):
"""Preserve existing DataflowBlocks
This is a regression test. In previous implementations, a
DataflowBlock in the input, without enough bindings to become a
new dataflow block, could be accidentally ommitted.
This test is identical to
`TestPreserveExistingDataflowBlocksAtBeginning`, except that the
existing dataflow block is at the end of the function.
"""

@I.ir_module
class Before:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This sequence is large enough that it may be converted
# to a DataflowBlock.
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)

R.print(format="impure_function")

# This DataflowBlock is below the minimum size for a new
# block, but already exists in the input IRModule.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

return (A1, B3)

@I.ir_module
class Expected:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
with R.dataflow():
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
R.output(B3)

R.print(format="impure_function")

# This dataflow block should be preserved in the output.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)

return (A1, B3)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit eb011c7

Please sign in to comment.