From 462ecb691a523a1b329cd14707fdcbed1d85f116 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 1 Oct 2024 20:19:07 -0700 Subject: [PATCH] [torch] Materialize all derivable bounds and divisor information in the IR. (#18646) * Adds new util ops: util.assume.divisible, util.assume.narrow, util.assume.range * Adds new pass torch-iree-bind-symbolic-shapes which will lower torch.bind_symbolic_shape ops if present in the IR (these are currently suppressed in the frontend with a flag, so adding this pass unconditionally is a no-op) * Canonicalizes all dynamics dims so that equal-dims are represented program wide with the same SSA value and related-dims are derived from the same root SSA values. * Followon steps will clone the assume annotations into dispatches so that codegen can make decisions based on the knowledge --------- Signed-off-by: Stella Laurenzo Co-authored-by: Ben Vanik --- .../InputConversion/BindSymbolicShapes.cpp | 472 ++++++++++++++++++ .../Torch/InputConversion/CMakeLists.txt | 1 + .../input/Torch/InputConversion/Passes.cpp | 4 + .../input/Torch/InputConversion/Passes.td | 5 + .../Torch/InputConversion/test/CMakeLists.txt | 1 + .../test/bind_symbolic_shapes.mlir | 178 +++++++ .../iree/compiler/Dialect/Util/IR/UtilOps.td | 74 +++ .../Util/Transforms/DropCompilerHints.cpp | 17 +- .../Transforms/test/drop_compiler_hints.mlir | 30 ++ 9 files changed, 779 insertions(+), 3 deletions(-) create mode 100644 compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp create mode 100644 compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp new file mode 100644 index 000000000000..b37ab37ea410 --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp @@ -0,0 +1,472 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +#include + +namespace Torch = mlir::torch::Torch; +namespace TorchConversion = mlir::torch::TorchConversion; + +namespace mlir::iree_compiler::TorchInput { + +#define GEN_PASS_DEF_BINDSYMBOLICSHAPESPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" + +namespace { + +Type getNarrowestType(Builder &builder, + std::optional> minMaxBounds) { + if (!minMaxBounds) + return {}; + + auto maxBound = minMaxBounds->second; + if (maxBound <= std::numeric_limits::max()) + return builder.getIntegerType(32); + else + return builder.getIntegerType(64); +} + +// Torch "binds" symbolic shape information to all tensors in the program +// which are not static. It does this by emitting side-effecting +// torch.bind_symbolic_shape ops which are backed by torch.symbolic_int ops +// which match 1:1 to terminal symbols in the Torch program. +// +// This is a somewhat different representation than we need in order to be +// usable within IREE: +// +// 1. We only want shape information and assertion at the boundaries where +// they can come from runtime values of unknown lineage. +// 2. IREE operates in terms of index values and "binding" them to tensors +// so that later dim lookups are memoized. +// 3. IREE's value analyses operate on real index SSA values, not "symbolic" +// values that only exist in the ether. +// +// These constraints can only be met if we assume that all Torch symbols are +// "backed" by a dimension or argument, so just a free-floating relational +// symbol. Such "backed" symbols are the most dominant form of Torch programs, +// but it is possible to create them such that symbols do not relate to any +// one dimension (although this typically does not happen naturally at +// program boundaries). In this pass we assume that any such relational +// symbols are not actionable by us, and we therefore drop them. It is possible +// for the frontend or user to fix this situation, and we therefore assume +// that anyone who cares will have done so. These cases are emitted as warnings +// in this pass because they signal potential missed optimization opportunties +// that we would like to know about. +// +// The approach we use from here will roughly map a torch.bind_symbolic_shape +// op to a flow.tensor.tie_shape op, preserving only the needed dynamic +// dimensions. Dimensions will be derived from util ops which annotate +// constraints and relationships. +// +// All other bind_symbolic_shape ops will be dropped. +class BindSymbolicShapesPass final + : public impl::BindSymbolicShapesPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + bool isEligibleBinding(Torch::BindSymbolicShapeOp bindOp) { + auto operand = bindOp.getOperand(); + // Torch programs are single block and use structured control flow, so + // presume this is an entrypoint. + if (llvm::isa(operand)) + return true; + + // Mutable tensors can exist at the boundary and must be "copied" to a + // vtensor prior to use. Therefore, we anchor on the point of copy. + if (operand.getDefiningOp()) + return true; + + return false; + } + + struct SymbolInfo { + SymbolInfo(Torch::SymbolicIntOp symbolDefOp) : symbolDefOp(symbolDefOp) { + auto minVal = symbolDefOp.getMinValAttr(); + auto maxVal = symbolDefOp.getMaxValAttr(); + if (minVal && maxVal) { + uint64_t minValInt = minVal.getValue().getZExtValue(); + uint64_t maxValInt = maxVal.getValue().getZExtValue(); + // Note that torch represents open ranges in strange ways with various + // magic numbers in the high range of the uint64_t type. We somewhat + // arbitrarily say that anything over a fourth of the uint64_t + // range (which is half of the positive int64_t range, should these have + // originated as signed quantities), is a ridiculously large number not + // suitable as a shape dimension, and we drop the hint. + if (maxValInt >= minValInt && + maxValInt < std::numeric_limits::max() / 4) { + // Note that in Torch, min values are "weird" because they encode + // some special cases about broadcast behavior. Here we just discard + // them, but in the future, there may be more to derive here. + minMaxBounds = std::make_pair(1, maxValInt); + } + } + } + + // Gets the canonical dim for this symbol, returning {} if there + // is no canonical dim. + Value getCanonicalDimValue(OpBuilder &builder) { + if (canonicalDimValue) + return canonicalDimValue; + if (equalityDimInfos.empty()) + return {}; + canonicalDimValue = getEqualityDimValue(builder, 0); + return canonicalDimValue; + } + + // Gets the dim value for one of the entries in equalityDimInfos, + // materializing an op if needed. + Value getEqualityDimValue(OpBuilder &builder, unsigned index) { + auto [producer, position] = equalityDimInfos[index]; + // Scrunch all dim ops up as far as they will go so that they can be + // shared among any legal consumers. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfterValue(producer); + Value dimValue = + builder.create(producer.getLoc(), producer, position); + return dimValue; + } + + Operation *symbolDefOp; + + // If the symbol carries min/max bounds, note them here. + std::optional> minMaxBounds; + + // All dimensions that should be considered equal by {producer_tensor, + // position}. When materializing shape expressions, we always use the + // first from this list so that simple SSA equality can be used across + // the graph. + SmallVector> equalityDimInfos; + + Value canonicalDimValue; + }; + + struct TensorBinding { + Operation *bindOp; + + // Symbol ops that that bind to symbols of the affine map. + llvm::SmallVector symbols; + + // The value (tensor) this binding annotates. + Value annotatesValue; + + // Torch type of the annotated tensor. + Torch::ValueTensorType torchType; + + // Corresponding builtin tensor type. + RankedTensorType builtinTensorType; + + // The affine map representing the dimensions. + AffineMap shapeMap; + + // When prepared, we convert from the torch type to builtin and back. This + // is the back value. Our work gets done feeding into this. + TorchConversion::FromBuiltinTensorOp rewrittenTorchOp; + + // Anchor op for building IR on native types. + Operation *anchorOp = nullptr; + + // All dim materializations we were able to make. If all are defined once + // processing is complete, then we can tie the shape. This will be fully + // populated after the associateEqualityDims phase, and subsequent + // materializations should take the first value so that all related shapes + // anchor the same. + llvm::SmallVector materializedDims; + + // Perform IR preparation for any bindings we may want to preserve. + void prepare() { + OpBuilder builder(bindOp); + TorchConversion::ToBuiltinTensorOp builtinConversion; + { + // Scrunch all ToBuiltinTensor ops as high up as they can go. We'll + // hang tensor.dim ops off of these across all dependent bindings so + // we need to make sure that it is always topologically legal. The + // easiest way to do this is to put common dependencies like this + // as far up as they will go, which means that each binding op (which + // is already guaranteed to be topologically legal) stays so. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfterValue(annotatesValue); + builtinConversion = builder.create( + bindOp->getLoc(), builtinTensorType, annotatesValue); + } + rewrittenTorchOp = builder.create( + bindOp->getLoc(), torchType, builtinConversion.getResult()); + annotatesValue.replaceAllUsesExcept(rewrittenTorchOp.getResult(), + builtinConversion); + annotatesValue = builtinConversion.getResult(); + anchorOp = rewrittenTorchOp; + + materializedDims.resize(builtinTensorType.getRank()); + } + + std::optional> + evaluateExprBounds(AffineExpr expr, + llvm::DenseMap &symbolInfos) { + if (!expr.isSymbolicOrConstant()) + return {}; + llvm::SmallVector> lowerBounds; + llvm::SmallVector> upperBounds; + lowerBounds.reserve(symbols.size()); + upperBounds.reserve(symbols.size()); + for (auto [pos, symbolValue] : llvm::enumerate(symbols)) { + const SymbolInfo &symbolInfo = symbolInfos.at(symbolValue); + if (!symbolInfo.minMaxBounds) { + lowerBounds.push_back({}); + upperBounds.push_back({}); + } else { + lowerBounds.push_back(symbolInfo.minMaxBounds->first); + upperBounds.push_back(symbolInfo.minMaxBounds->second); + } + } + + auto upperBound = getBoundForAffineExpr( + expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds, + upperBounds, /*isUpper=*/true); + if (!upperBound) + return {}; + + auto lowerBound = getBoundForAffineExpr( + expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds, + upperBounds, /*isUpper=*/false); + if (!lowerBound) + return {}; + + return std::make_pair(*lowerBound, *upperBound); + } + + // For any dims in the shapeMap that are terminal, set up the root + // bindings. + void associateEqualityDims(llvm::DenseMap &symbolInfos) { + OpBuilder builder(anchorOp); + for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) { + if (expr.getKind() != AffineExprKind::SymbolId) + continue; + auto symbolPos = llvm::cast(expr).getPosition(); + Value symbol = symbols[symbolPos]; + auto symbolInfoIt = symbolInfos.find(symbol); + assert(symbolInfoIt != symbolInfos.end() && + "No symbol info for symbol"); + auto &symbolInfo = symbolInfoIt->second; + symbolInfo.equalityDimInfos.emplace_back(annotatesValue, index); + } + } + + Value materializeDimExpr(Location loc, OpBuilder &builder, + AffineExpr genericExpr, + llvm::DenseMap &symbolInfos) { + if (auto binaryExpr = llvm::dyn_cast(genericExpr)) { + auto lhs = + materializeDimExpr(loc, builder, binaryExpr.getLHS(), symbolInfos); + if (!lhs) + return {}; + auto rhs = + materializeDimExpr(loc, builder, binaryExpr.getRHS(), symbolInfos); + if (!rhs) + return {}; + + switch (binaryExpr.getKind()) { + case AffineExprKind::Add: + return builder.create(loc, lhs, rhs); + case AffineExprKind::Mul: + return builder.create(loc, lhs, rhs); + case AffineExprKind::Mod: + return builder.create(loc, lhs, rhs); + case AffineExprKind::FloorDiv: + return builder.create(loc, lhs, rhs); + case AffineExprKind::CeilDiv: + return builder.create(loc, lhs, rhs); + default: + break; + } + } + + switch (genericExpr.getKind()) { + case AffineExprKind::Constant: + return builder.create( + loc, builder.getIndexAttr( + llvm::cast(genericExpr).getValue())); + case AffineExprKind::DimId: + // Unsupported. + break; + case AffineExprKind::SymbolId: { + auto symExpr = llvm::cast(genericExpr); + auto pos = symExpr.getPosition(); + if (pos >= symbols.size()) + break; + Value symbolValue = symbols[pos]; + auto foundIt = symbolInfos.find(symbolValue); + if (foundIt == symbolInfos.end()) + break; + SymbolInfo &info = foundIt->second; + return info.getCanonicalDimValue(builder); // May legally return {} + } + default: + break; + } + + std::string s; + llvm::raw_string_ostream os(s); + genericExpr.print(os); + emitWarning(loc) << "Symbolic shape expression not supported: " << s + << " (falling back to runtime symbol resolution)"; + return {}; + } + + void materializeDims(llvm::DenseMap &symbolInfos) { + OpBuilder builder(anchorOp); + for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) { + if (!builtinTensorType.isDynamicDim(index)) + continue; + + Value dimValue = + materializeDimExpr(anchorOp->getLoc(), builder, expr, symbolInfos); + if (!dimValue) { + // Certain classes of symbolic expressions may not terminate on + // distinct dimensions (i.e. `s0 * 4` with no symbol that corresponds) + // to `s0`. In this case, we just do runtime resolution of the symbol. + dimValue = builder.create(bindOp->getLoc(), + annotatesValue, index); + } + + // Add optimization assumptions if the divisor or bounds are known. + int64_t divisor = expr.getLargestKnownDivisor(); + auto bounds = evaluateExprBounds(expr, symbolInfos); + if (divisor != 1 || bounds) { + Type narrowType = getNarrowestType(builder, bounds); + if (narrowType) { + dimValue = builder.create( + bindOp->getLoc(), dimValue, TypeAttr::get(narrowType)); + } + if (bounds) { + dimValue = builder.create( + bindOp->getLoc(), dimValue, bounds->first, bounds->second); + } + if (divisor != 1) { + dimValue = builder.create( + bindOp->getLoc(), dimValue, divisor); + } + } + + materializedDims[index] = dimValue; + } + } + + void tieShape(llvm::DenseMap &symbolInfos) { + llvm::SmallVector dynamicDims; + dynamicDims.reserve(materializedDims.size()); + for (size_t pos = 0; pos < materializedDims.size(); ++pos) { + if (builtinTensorType.isDynamicDim(pos)) { + Value dimValue = materializedDims[pos]; + if (!dimValue) { + emitWarning(bindOp->getLoc()) + << "Discarding symbolic shape information from PyTorch: Not " + << "all symbols resolved to a known dim value (first missing " + << "at position " << pos << ")"; + return; + } + + dynamicDims.push_back(dimValue); + } + } + + OpBuilder builder(anchorOp); + Value tieShape = builder.create( + bindOp->getLoc(), builtinTensorType, annotatesValue, dynamicDims); + rewrittenTorchOp.setOperand(tieShape); + } + }; + + void runOnOperation() override { + ConversionTarget target(getContext()); + TypeConverter typeConverter; + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + llvm::SmallVector cleanupOpList; + llvm::SmallVector bindings; + // Mapping of SSA value for a torch.symbolic_int (or related op) to its + // info. + llvm::DenseMap symbolInfos; + + // Walk the ops we care about and stash for analysis. + getOperation()->walk([&](Operation *childOp) { + if (auto symbolOp = llvm::dyn_cast(childOp)) { + cleanupOpList.push_back(symbolOp); + symbolInfos.insert_or_assign(symbolOp.getResult(), + SymbolInfo(symbolOp)); + } else if (auto bindOp = + llvm::dyn_cast(childOp)) { + cleanupOpList.push_back(bindOp); + if (!isEligibleBinding(bindOp)) + return; + auto torchType = + llvm::cast(bindOp.getOperand().getType()); + auto builtinType = llvm::dyn_cast_or_null( + typeConverter.convertType(torchType)); + if (!builtinType) { + emitError(childOp->getLoc()) + << "cannot convert torch type to builtin: " << torchType; + return signalPassFailure(); + } + bindings.push_back(TensorBinding{ + /*bindOp=*/childOp, + /*symbols=*/bindOp.getShapeSymbols(), + /*annotatesValue=*/bindOp.getOperand(), + /*torchType=*/torchType, + /*builtinType=*/builtinType, + /*shapeMap=*/bindOp.getShapeExpressions().getAffineMap()}); + } + }); + + // For every tensor value of interest, convert to a builtin tensor type and + // back, RAUW'ing the result. This will meet the eventual final conversion + // with additional graph forking. + for (auto &binding : bindings) { + binding.prepare(); + } + + // Find all associations to a single symbol and set up the roots. + for (auto &binding : bindings) { + binding.associateEqualityDims(symbolInfos); + } + + // Materialize all dimension expressions and constraints. + for (auto &binding : bindings) { + binding.materializeDims(symbolInfos); + } + + // Now that all is known, insert tie shape. + for (auto &binding : bindings) { + binding.tieShape(symbolInfos); + } + + // Erase all found ops. + for (auto *op : llvm::reverse(cleanupOpList)) { + op->erase(); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt index 1db408527651..4e4878482b4a 100644 --- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt @@ -34,6 +34,7 @@ iree_cc_library( HDRS "Passes.h" SRCS + "BindSymbolicShapes.cpp" "BitCastQuantTensor.cpp" "ConvertTMTensorToLinalgExt.cpp" "FuncConversion.cpp" diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index 00ab1a444854..a0682b5c6e41 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -36,6 +36,10 @@ void createTorchToIREEPipeline( // model) and those constants get somewhat obscured by TorchToArith. llvm::ArrayRef emptyArrayRef; + // Dynamic shape bindings add a lot of structure to the IR which we prefer to + // leverage and eliminate prior to any other activity, so do this first. + pm.addNestedPass(createBindSymbolicShapesPass()); + if (options.strictSymbolicShapes) { pm.addNestedPass(createSetStrictSymbolicShapesPass()); // Run canonicalization in case any previously non-strict dynamic code can diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td index 91b7792d08b3..251cbb219ecf 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/InputConversion/Passes.td @@ -9,6 +9,11 @@ include "mlir/Pass/PassBase.td" +def BindSymbolicShapesPass : + InterfacePass<"torch-iree-bind-symbolic-shapes", "mlir::FunctionOpInterface"> { + let summary = "Process torch dynamic shape bindings into IREE analyzable forms"; +} + def BitCastQuantTensorPass : InterfacePass<"torch-iree-bitcast-quant-tensor", "mlir::FunctionOpInterface"> { let summary = "Bitcasts i8 packed tensors of sub-byte types to the actual bit width"; diff --git a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt index 6f86276c3733..cabc6b29e754 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt @@ -6,6 +6,7 @@ iree_lit_test_suite( "assume_strict_symbols.mlir" "auto_input_conversion.mlir" "attention.mlir" + "bind_symbolic_shapes.mlir" "bitcast_quant_tensor.mlir" "func_conversion.mlir" "func_conversion_invalid.mlir" diff --git a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir new file mode 100644 index 000000000000..e3d6061fa222 --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir @@ -0,0 +1,178 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(torch-iree-bind-symbolic-shapes))" --split-input-file --verify-diagnostics %s | FileCheck %s + +// This example was captured from a program which has a dynamic batch size and +// tiled inner dim on one of the arguments, causing a symbolic relationship on +// the second dimension. +// CHECK-LABEL: @basic_example +module @basic_example { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.assume_strict_symbolic_shapes} { + // CHECK-DAG: %[[ARG1_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?],f32> -> tensor + // CHECK-DAG: %[[ARG0_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor + // CHECK-DAG: %[[POS0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[POS1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %1, %[[POS0]] : + // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %1, %[[POS1]] : + // CHECK: %[[ARG0_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32 + // CHECK: %[[ARG0_DIM0_RANGE:.*]] = util.assume.range %[[ARG0_DIM0_NARROW]] in [1, 1024] : index + // CHECK: %[[ARG0_DIM1_NARROW:.*]] = util.assume.narrow %[[DIM1]] : index to i32 + // CHECK: %[[ARG0_DIM1_RANGE:.*]] = util.assume.range %[[ARG0_DIM1_NARROW]] in [1, 1024] : index + // CHECK: %[[ARG0_TIE:.*]] = flow.tensor.tie_shape %[[ARG0_ANCHOR]] : tensor{%[[ARG0_DIM0_RANGE]], %[[ARG0_DIM1_RANGE]]} + // CHECK: %[[ARG0_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG0_TIE]] + // CHECK: %[[ARG1_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32 + // CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.range %[[ARG1_DIM0_NARROW]] in [1, 1024] + // CHECK: %[[MULTIPLIER0:.*]] = arith.constant 2 : index + // CHECK: %[[ARG1_DIM1:.*]] = arith.muli %[[DIM1]], %[[MULTIPLIER0]] + // CHECK: %[[ARG1_DIM1_NARROW:.*]] = util.assume.narrow %[[ARG1_DIM1]] : index to i32 + // CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.range %[[ARG1_DIM1_NARROW]] in [2, 2048] : index + // CHECK: %[[ARG1_DIM1_DIV:.*]] = util.assume.divisible %[[ARG1_DIM1_RANGE]] by 2 + // CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_DIV]]} + // CHECK: %[[ARG1_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG1_TIE]] + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int + %2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32> + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32> + %int1_0 = torch.constant.int 1 + %5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32> + return %5 : !torch.vtensor<[?,?],f32> + } +} + +// ----- +// This example was captured from a torch program that used a symbol that did +// not correspond to any dimension (being used in an expression as part of +// distinct dimensions). This exercises a special case in the pass for deferring +// to runtime resolution of the dim. +// We just verify that the vital information has been captured. +// CHECK-LABEL: @unbacked_symbol +module @unbacked_symbol { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + // CHECK: util.assume.narrow + // CHECK: util.assume.range{{.*}} [1, 1024] + // CHECK: util.assume.narrow + // CHECK: util.assume.range{{.*}} [2, 2048] + // CHECK: util.assume.divisible{{.*}} by 2 + // CHECK: tie_shape + // CHECK: util.assume.narrow + // CHECK: util.assume.range{{.*}} [1, 1024] + // CHECK: util.assume.narrow + // CHECK: util.assume.range{{.*}} [4, 4096] + // CHECK: util.assume.divisible{{.*}} by 4 + // CHECK: tie_shape + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int + %1 = torch.symbolic_int "2*s4" {min_val = 0, max_val = 2048} : !torch.int + %2 = torch.symbolic_int "4*s4" {min_val = 0, max_val = 4096} : !torch.int + %3 = torch.symbolic_int "s4" {min_val = 2, max_val = 1024} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %arg1, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32> + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %4 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.repeat %arg0, %4 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %5, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32> + %int1_0 = torch.constant.int 1 + %6 = torch.aten.add.Tensor %5, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %6, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32> + return %6 : !torch.vtensor<[?,?],f32> + } +} + +// ----- +// CHECK-LABEL: @all_bindings_dropped +module @all_bindings_dropped { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + // CHECK-NOT: torch.symbolic_int + // CHECK-NOT: torch.bind_symbolic_shape + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int + %2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32> + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32> + %int1_0 = torch.constant.int 1 + %5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32> + return %5 : !torch.vtensor<[?,?],f32> + } +} + +// ----- +// CHECK-LABEL: @add_expr +module @add_expr { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { + // CHECK: addi + // CHECK-NOT: divisible + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 + 2)> : !torch.vtensor<[?,?],f32> + return + } +} + +// ----- +// CHECK-LABEL: @mod_expr +module @mod_expr { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { + // CHECK: remui + // CHECK-NOT: divisible + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 mod 2)> : !torch.vtensor<[?,?],f32> + return + } +} + +// ----- +// CHECK-LABEL: @floordiv_expr +module @floordiv_expr { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { + // CHECK: divui + // CHECK-NOT: divisible + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 floordiv 2)> : !torch.vtensor<[?,?],f32> + return + } +} + +// ----- +// Verifies that unsupported dim expressions warn (and do not assert). +// CHECK-LABEL: @unsupported_non_symbolic +module @unsupported_non_symbolic { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + // expected-warning@+1 {{Symbolic shape expression not supported: d0}} + torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<(d0)[s0, s1] -> (s0, s1 + d0)> : !torch.vtensor<[?,?],f32> + return + } +} + +// ----- +// Torch uses high values to signal unbounded ranges. Ensure they are +// suppressed. +// CHECK-LABEL: @torch_unbounded_max_range +module @torch_unbounded_max_range { + func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { + // CHECK-NOT: util.assume.range + %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 4611686018427387903} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 9223372036854775806} : !torch.int + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 10)> : !torch.vtensor<[?,?],f32> + return + } +} diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index b6466d61d758..881d8d652edb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -458,6 +458,80 @@ def OpGroupCompilerHintOps : OpDocGroup { let opDocGroup = OpGroupCompilerHintOps in { +def Util_AssumeDivisibleOp : + Util_PureOp<"assume.divisible", [SameOperandsAndResultType]> { + let summary = "Memorializes knowledge that an index/integer value is divisible by some constant."; + + let arguments = (ins + Util_Range:$operand, + Util_IndexAttr:$divisor + ); + let results = (outs + Util_Range:$result + ); + let assemblyFormat = [{ + $operand `by` $divisor attr-dict `:` type($operand) + }]; + let builders = [ + OpBuilder<(ins + "Value":$operand, + "uint64_t":$divisor + ), + [{ + IntegerAttr divisorAttr = $_builder.getIntegerAttr( + $_builder.getIndexType(), divisor); + build($_builder, $_state, operand.getType(), operand, divisorAttr); + }]>, + ]; +} + +def Util_AssumeNarrowOp : + Util_PureOp<"assume.narrow", [SameOperandsAndResultType]> { + let summary = "Memorializes knowledge that an index/integer value can be narrowed to a type."; + + let arguments = (ins + Util_Range:$operand, + TypeAttr:$narrow_type + ); + let results = (outs + Util_Range:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` $narrow_type + }]; +} + +def Util_AssumeRangeOp : + Util_PureOp<"assume.range", [SameOperandsAndResultType]> { + let summary = "Memorializes knowledge that an index/integer value is always within some range."; + + let arguments = (ins + Util_Range:$operand, + Util_IndexAttr:$min_value, + Util_IndexAttr:$max_value + ); + let results = (outs + Util_Range:$result + ); + let assemblyFormat = [{ + $operand `in` ` ` `[` $min_value `,` $max_value `]` `:` type($operand) attr-dict + }]; + let builders = [ + OpBuilder<(ins + "Value":$operand, + "uint64_t":$minValue, + "uint64_t":$maxValue + ), + [{ + IntegerAttr minAttr = $_builder.getIntegerAttr( + $_builder.getIndexType(), minValue); + IntegerAttr maxAttr = $_builder.getIntegerAttr( + $_builder.getIndexType(), maxValue); + build($_builder, $_state, operand.getType(), operand, minAttr, maxAttr); + }]>, + ]; +} + def Util_OptimizationBarrierOp : Util_Op<"optimization_barrier", [ SameOperandsAndResultType, ]> { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp index ff7cefd90b0e..a6f072c1197f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp @@ -19,9 +19,20 @@ class DropCompilerHintsPass void runOnOperation() override { // We can't use patterns and applyPatternsAndFoldGreedily because that // automatically does canonicalization. - getOperation()->walk([&](IREE::Util::OptimizationBarrierOp op) { - op.replaceAllUsesWith(op.getOperands()); - op.erase(); + getOperation()->walk([&](Operation *genericOp) { + if (auto op = dyn_cast(genericOp)) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + } else if (auto op = dyn_cast(genericOp)) { + op.replaceAllUsesWith({op.getOperand()}); + op.erase(); + } else if (auto op = dyn_cast(genericOp)) { + op.replaceAllUsesWith({op.getOperand()}); + op.erase(); + } else if (auto op = dyn_cast(genericOp)) { + op.replaceAllUsesWith({op.getOperand()}); + op.erase(); + } }); } }; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir index 717d2bfc106e..c0db60a6538f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir @@ -73,3 +73,33 @@ module @deeply_nested { } } } + +// ----- + +// CHECK-LABEL: @assume.divisible +util.func @assume.divisible() -> i32 { + // CHECK-NOT: util.assume.divisible + %c1 = arith.constant 12 : i32 + %0 = util.assume.divisible %c1 by 2 : i32 + util.return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @assume.narrow +util.func @assume.narrow() -> i32 { + // CHECK-NOT: util.assume.narrow + %c1 = arith.constant 12 : i32 + %0 = util.assume.narrow %c1 : i32 to i8 + util.return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @assume.range +util.func @assume.range() -> i32 { + // CHECK-NOT: util.assume.range + %c1 = arith.constant 12 : i32 + %0 = util.assume.range %c1 in [2, 20] : i32 + util.return %0 : i32 +}