Skip to content

Commit

Permalink
[torch] Materialize all derivable bounds and divisor information in t…
Browse files Browse the repository at this point in the history
…he IR. (#18646)

* Adds new util ops: util.assume.divisible, util.assume.narrow,
util.assume.range
* Adds new pass torch-iree-bind-symbolic-shapes which will lower
torch.bind_symbolic_shape ops if present in the IR (these are currently
suppressed in the frontend with a flag, so adding this pass
unconditionally is a no-op)
* Canonicalizes all dynamics dims so that equal-dims are represented
program wide with the same SSA value and related-dims are derived from
the same root SSA values.
* Followon steps will clone the assume annotations into dispatches so
that codegen can make decisions based on the knowledge

---------

Signed-off-by: Stella Laurenzo <[email protected]>
Co-authored-by: Ben Vanik <[email protected]>
  • Loading branch information
stellaraccident and benvanik authored Oct 2, 2024
1 parent 8de9856 commit 462ecb6
Show file tree
Hide file tree
Showing 9 changed files with 779 additions and 3 deletions.
472 changes: 472 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ iree_cc_library(
HDRS
"Passes.h"
SRCS
"BindSymbolicShapes.cpp"
"BitCastQuantTensor.cpp"
"ConvertTMTensorToLinalgExt.cpp"
"FuncConversion.cpp"
Expand Down
4 changes: 4 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ void createTorchToIREEPipeline(
// model) and those constants get somewhat obscured by TorchToArith.
llvm::ArrayRef<std::string> emptyArrayRef;

// Dynamic shape bindings add a lot of structure to the IR which we prefer to
// leverage and eliminate prior to any other activity, so do this first.
pm.addNestedPass<func::FuncOp>(createBindSymbolicShapesPass());

if (options.strictSymbolicShapes) {
pm.addNestedPass<func::FuncOp>(createSetStrictSymbolicShapesPass());
// Run canonicalization in case any previously non-strict dynamic code can
Expand Down
5 changes: 5 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

include "mlir/Pass/PassBase.td"

def BindSymbolicShapesPass :
InterfacePass<"torch-iree-bind-symbolic-shapes", "mlir::FunctionOpInterface"> {
let summary = "Process torch dynamic shape bindings into IREE analyzable forms";
}

def BitCastQuantTensorPass :
InterfacePass<"torch-iree-bitcast-quant-tensor", "mlir::FunctionOpInterface"> {
let summary = "Bitcasts i8 packed tensors of sub-byte types to the actual bit width";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ iree_lit_test_suite(
"assume_strict_symbols.mlir"
"auto_input_conversion.mlir"
"attention.mlir"
"bind_symbolic_shapes.mlir"
"bitcast_quant_tensor.mlir"
"func_conversion.mlir"
"func_conversion_invalid.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(torch-iree-bind-symbolic-shapes))" --split-input-file --verify-diagnostics %s | FileCheck %s

// This example was captured from a program which has a dynamic batch size and
// tiled inner dim on one of the arguments, causing a symbolic relationship on
// the second dimension.
// CHECK-LABEL: @basic_example
module @basic_example {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.assume_strict_symbolic_shapes} {
// CHECK-DAG: %[[ARG1_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG0_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK-DAG: %[[POS0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[POS1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %1, %[[POS0]] :
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %1, %[[POS1]] :
// CHECK: %[[ARG0_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
// CHECK: %[[ARG0_DIM0_RANGE:.*]] = util.assume.range %[[ARG0_DIM0_NARROW]] in [1, 1024] : index
// CHECK: %[[ARG0_DIM1_NARROW:.*]] = util.assume.narrow %[[DIM1]] : index to i32
// CHECK: %[[ARG0_DIM1_RANGE:.*]] = util.assume.range %[[ARG0_DIM1_NARROW]] in [1, 1024] : index
// CHECK: %[[ARG0_TIE:.*]] = flow.tensor.tie_shape %[[ARG0_ANCHOR]] : tensor<?x?xf32>{%[[ARG0_DIM0_RANGE]], %[[ARG0_DIM1_RANGE]]}
// CHECK: %[[ARG0_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG0_TIE]]
// CHECK: %[[ARG1_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
// CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.range %[[ARG1_DIM0_NARROW]] in [1, 1024]
// CHECK: %[[MULTIPLIER0:.*]] = arith.constant 2 : index
// CHECK: %[[ARG1_DIM1:.*]] = arith.muli %[[DIM1]], %[[MULTIPLIER0]]
// CHECK: %[[ARG1_DIM1_NARROW:.*]] = util.assume.narrow %[[ARG1_DIM1]] : index to i32
// CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.range %[[ARG1_DIM1_NARROW]] in [2, 2048] : index
// CHECK: %[[ARG1_DIM1_DIV:.*]] = util.assume.divisible %[[ARG1_DIM1_RANGE]] by 2
// CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor<?x?xf32>{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_DIV]]}
// CHECK: %[[ARG1_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG1_TIE]]
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
%2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
%int1_0 = torch.constant.int 1
%5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
return %5 : !torch.vtensor<[?,?],f32>
}
}

// -----
// This example was captured from a torch program that used a symbol that did
// not correspond to any dimension (being used in an expression as part of
// distinct dimensions). This exercises a special case in the pass for deferring
// to runtime resolution of the dim.
// We just verify that the vital information has been captured.
// CHECK-LABEL: @unbacked_symbol
module @unbacked_symbol {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: util.assume.narrow
// CHECK: util.assume.range{{.*}} [1, 1024]
// CHECK: util.assume.narrow
// CHECK: util.assume.range{{.*}} [2, 2048]
// CHECK: util.assume.divisible{{.*}} by 2
// CHECK: tie_shape
// CHECK: util.assume.narrow
// CHECK: util.assume.range{{.*}} [1, 1024]
// CHECK: util.assume.narrow
// CHECK: util.assume.range{{.*}} [4, 4096]
// CHECK: util.assume.divisible{{.*}} by 4
// CHECK: tie_shape
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "2*s4" {min_val = 0, max_val = 2048} : !torch.int
%2 = torch.symbolic_int "4*s4" {min_val = 0, max_val = 4096} : !torch.int
%3 = torch.symbolic_int "s4" {min_val = 2, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %arg1, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%4 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.repeat %arg0, %4 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %5, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
%int1_0 = torch.constant.int 1
%6 = torch.aten.add.Tensor %5, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %6, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
return %6 : !torch.vtensor<[?,?],f32>
}
}

// -----
// CHECK-LABEL: @all_bindings_dropped
module @all_bindings_dropped {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-NOT: torch.symbolic_int
// CHECK-NOT: torch.bind_symbolic_shape
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
%2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
%int1_0 = torch.constant.int 1
%5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
return %5 : !torch.vtensor<[?,?],f32>
}
}

// -----
// CHECK-LABEL: @add_expr
module @add_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: addi
// CHECK-NOT: divisible
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 + 2)> : !torch.vtensor<[?,?],f32>
return
}
}

// -----
// CHECK-LABEL: @mod_expr
module @mod_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: remui
// CHECK-NOT: divisible
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 mod 2)> : !torch.vtensor<[?,?],f32>
return
}
}

// -----
// CHECK-LABEL: @floordiv_expr
module @floordiv_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: divui
// CHECK-NOT: divisible
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 floordiv 2)> : !torch.vtensor<[?,?],f32>
return
}
}

// -----
// Verifies that unsupported dim expressions warn (and do not assert).
// CHECK-LABEL: @unsupported_non_symbolic
module @unsupported_non_symbolic {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
// expected-warning@+1 {{Symbolic shape expression not supported: d0}}
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<(d0)[s0, s1] -> (s0, s1 + d0)> : !torch.vtensor<[?,?],f32>
return
}
}

// -----
// Torch uses high values to signal unbounded ranges. Ensure they are
// suppressed.
// CHECK-LABEL: @torch_unbounded_max_range
module @torch_unbounded_max_range {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK-NOT: util.assume.range
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 4611686018427387903} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 9223372036854775806} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 10)> : !torch.vtensor<[?,?],f32>
return
}
}
74 changes: 74 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,80 @@ def OpGroupCompilerHintOps : OpDocGroup {

let opDocGroup = OpGroupCompilerHintOps in {

def Util_AssumeDivisibleOp :
Util_PureOp<"assume.divisible", [SameOperandsAndResultType]> {
let summary = "Memorializes knowledge that an index/integer value is divisible by some constant.";

let arguments = (ins
Util_Range:$operand,
Util_IndexAttr:$divisor
);
let results = (outs
Util_Range:$result
);
let assemblyFormat = [{
$operand `by` $divisor attr-dict `:` type($operand)
}];
let builders = [
OpBuilder<(ins
"Value":$operand,
"uint64_t":$divisor
),
[{
IntegerAttr divisorAttr = $_builder.getIntegerAttr(
$_builder.getIndexType(), divisor);
build($_builder, $_state, operand.getType(), operand, divisorAttr);
}]>,
];
}

def Util_AssumeNarrowOp :
Util_PureOp<"assume.narrow", [SameOperandsAndResultType]> {
let summary = "Memorializes knowledge that an index/integer value can be narrowed to a type.";

let arguments = (ins
Util_Range:$operand,
TypeAttr:$narrow_type
);
let results = (outs
Util_Range:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` $narrow_type
}];
}

def Util_AssumeRangeOp :
Util_PureOp<"assume.range", [SameOperandsAndResultType]> {
let summary = "Memorializes knowledge that an index/integer value is always within some range.";

let arguments = (ins
Util_Range:$operand,
Util_IndexAttr:$min_value,
Util_IndexAttr:$max_value
);
let results = (outs
Util_Range:$result
);
let assemblyFormat = [{
$operand `in` ` ` `[` $min_value `,` $max_value `]` `:` type($operand) attr-dict
}];
let builders = [
OpBuilder<(ins
"Value":$operand,
"uint64_t":$minValue,
"uint64_t":$maxValue
),
[{
IntegerAttr minAttr = $_builder.getIntegerAttr(
$_builder.getIndexType(), minValue);
IntegerAttr maxAttr = $_builder.getIntegerAttr(
$_builder.getIndexType(), maxValue);
build($_builder, $_state, operand.getType(), operand, minAttr, maxAttr);
}]>,
];
}

def Util_OptimizationBarrierOp : Util_Op<"optimization_barrier", [
SameOperandsAndResultType,
]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,20 @@ class DropCompilerHintsPass
void runOnOperation() override {
// We can't use patterns and applyPatternsAndFoldGreedily because that
// automatically does canonicalization.
getOperation()->walk([&](IREE::Util::OptimizationBarrierOp op) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
getOperation()->walk([&](Operation *genericOp) {
if (auto op = dyn_cast<IREE::Util::OptimizationBarrierOp>(genericOp)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto op = dyn_cast<IREE::Util::AssumeDivisibleOp>(genericOp)) {
op.replaceAllUsesWith({op.getOperand()});
op.erase();
} else if (auto op = dyn_cast<IREE::Util::AssumeRangeOp>(genericOp)) {
op.replaceAllUsesWith({op.getOperand()});
op.erase();
} else if (auto op = dyn_cast<IREE::Util::AssumeNarrowOp>(genericOp)) {
op.replaceAllUsesWith({op.getOperand()});
op.erase();
}
});
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,33 @@ module @deeply_nested {
}
}
}

// -----

// CHECK-LABEL: @assume.divisible
util.func @assume.divisible() -> i32 {
// CHECK-NOT: util.assume.divisible
%c1 = arith.constant 12 : i32
%0 = util.assume.divisible %c1 by 2 : i32
util.return %0 : i32
}

// -----

// CHECK-LABEL: @assume.narrow
util.func @assume.narrow() -> i32 {
// CHECK-NOT: util.assume.narrow
%c1 = arith.constant 12 : i32
%0 = util.assume.narrow %c1 : i32 to i8
util.return %0 : i32
}

// -----

// CHECK-LABEL: @assume.range
util.func @assume.range() -> i32 {
// CHECK-NOT: util.assume.range
%c1 = arith.constant 12 : i32
%0 = util.assume.range %c1 in [2, 20] : i32
util.return %0 : i32
}

0 comments on commit 462ecb6

Please sign in to comment.