Skip to content

Commit

Permalink
Add an integer divisibility analysis. (#18727)
Browse files Browse the repository at this point in the history
* Also extends the numeric optimization test to elide arith.remui that
exactly divides (a reasonable optimization but primarily used for
testing in this patch).
* Only implements the analysis for `arith.constant` and
`util.int.assume` in this patch.
* Renames assumption `divisor` to `udiv` to match terminology elsewhere
wrt signed/unsigned analysis.
* The lattice tracks unsigned and signed interpretations separately as
this is needed for propagation through signed ops (but this is not
implemented here).

---------

Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored Oct 9, 2024
1 parent 1b719b3 commit 5270093
Show file tree
Hide file tree
Showing 19 changed files with 391 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ module @basic_example {
// CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.int %[[DIM0]]<umin = 1, umax = 1024>
// CHECK: %[[MULTIPLIER0:.*]] = arith.constant 2 : index
// CHECK: %[[ARG1_DIM1:.*]] = arith.muli %[[DIM1]], %[[MULTIPLIER0]]
// CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.int %[[ARG1_DIM1]]<umin = 2, umax = 2048, divisor = 2> : index
// CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.int %[[ARG1_DIM1]]<umin = 2, umax = 2048, udiv = 2> : index
// CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor<?x?xf32>{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_RANGE]]}
// CHECK: %[[ARG1_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG1_TIE]]
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
Expand Down Expand Up @@ -49,10 +49,10 @@ module @basic_example {
module @unbacked_symbol {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: util.assume.int{{.*}}<umin = 1, umax = 1024>
// CHECK: util.assume.int{{.*}}<umin = 2, umax = 2048, divisor = 2>
// CHECK: util.assume.int{{.*}}<umin = 2, umax = 2048, udiv = 2>
// CHECK: tie_shape
// CHECK: util.assume.int{{.*}}<umin = 1, umax = 1024>
// CHECK: util.assume.int{{.*}}<umin = 4, umax = 4096, divisor = 4>
// CHECK: util.assume.int{{.*}}<umin = 4, umax = 4096, udiv = 4>
// CHECK: tie_shape
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "2*s4" {min_val = 0, max_val = 2048} : !torch.int
Expand Down Expand Up @@ -100,7 +100,7 @@ module @all_bindings_dropped {
module @add_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: addi
// CHECK-NOT: divisor
// CHECK-NOT: udiv
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
Expand All @@ -114,7 +114,7 @@ module @add_expr {
module @mod_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: remui
// CHECK-NOT: divisor
// CHECK-NOT: udiv
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
Expand All @@ -128,7 +128,7 @@ module @mod_expr {
module @floordiv_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: divui
// CHECK-NOT: divisor
// CHECK-NOT: udiv
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ iree_compiler_cc_library(
srcs = [
"Explorer.cpp",
"GlobalTable.cpp",
"IntegerDivisibilityAnalysis.cpp",
"Position.cpp",
],
hdrs = [
"Explorer.h",
"GlobalTable.h",
"IntegerDivisibilityAnalysis.h",
"Position.h",
],
deps = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ iree_cc_library(
HDRS
"Explorer.h"
"GlobalTable.h"
"IntegerDivisibilityAnalysis.h"
"Position.h"
SRCS
"Explorer.cpp"
"GlobalTable.cpp"
"IntegerDivisibilityAnalysis.cpp"
"Position.cpp"
DEPS
LLVMSupport
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h"

#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "iree-util-int-divisibility-analysis"

using llvm::dbgs;

namespace mlir::iree_compiler::IREE::Util {

void IntegerDivisibilityAnalysis::setToEntryState(
IntegerDivisibilityLattice *lattice) {
propagateIfChanged(lattice,
lattice->join(IntegerDivisibility::getMinDivisibility()));
}

LogicalResult IntegerDivisibilityAnalysis::visitOperation(
Operation *op, ArrayRef<const IntegerDivisibilityLattice *> operands,
ArrayRef<IntegerDivisibilityLattice *> results) {
auto inferrable = dyn_cast<InferIntDivisibilityOpInterface>(op);
if (!inferrable) {
setAllToEntryStates(results);
return success();
}

LLVM_DEBUG(dbgs() << "Inferring divisibility for " << *op << "\n");
auto argDivs = llvm::map_to_vector(
operands, [](const IntegerDivisibilityLattice *lattice) {
return lattice->getValue();
});
auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) {
auto result = dyn_cast<OpResult>(v);
if (!result)
return;
assert(llvm::is_contained(op->getResults(), result));

LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n");
IntegerDivisibilityLattice *lattice = results[result.getResultNumber()];
IntegerDivisibility oldDiv = lattice->getValue();

ChangeResult changed = lattice->join(newDiv);

// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
// the dataflow analysis in MLIR doesn't attempt to work out trip counts
// and often can't).
bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
return op->hasTrait<OpTrait::IsTerminator>();
});
if (isYieldedResult && !oldDiv.isUninitialized() &&
!(lattice->getValue() == oldDiv)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
changed |= lattice->join(IntegerDivisibility::getMinDivisibility());
}
propagateIfChanged(lattice, changed);
};

inferrable.inferResultDivisibility(argDivs, joinCallback);
return success();
}

} // namespace mlir::iree_compiler::IREE::Util
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_DIALECT_UTIL_INTEGER_DIVISIBILITY_ANALYSIS_H_
#define IREE_COMPILER_DIALECT_UTIL_INTEGER_DIVISIBILITY_ANALYSIS_H_

#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"

#include <optional>

namespace mlir::iree_compiler::IREE::Util {

class IntegerDivisibilityLattice
: public dataflow::Lattice<IntegerDivisibility> {
public:
using Lattice::Lattice;
};

class IntegerDivisibilityAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<
IntegerDivisibilityLattice> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

// At an entry point, set the lattice to the most pessimistic state,
// indicating that no further reasoning can be done.
void setToEntryState(IntegerDivisibilityLattice *lattice) override;

// Visit an operation, invoking the transfer function.
LogicalResult
visitOperation(Operation *op,
ArrayRef<const IntegerDivisibilityLattice *> operands,
ArrayRef<IntegerDivisibilityLattice *> results) override;
};

} // namespace mlir::iree_compiler::IREE::Util

#endif // IREE_COMPILER_DIALECT_UTIL_INTEGER_DIVISIBILITY_ANALYSIS_H_
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def Util_IntAssumptionAttr : AttrDef<Util_Dialect, "IntAssumption", []
let parameters = (ins
DefaultValuedParameter<"std::optional<uint64_t>", "std::nullopt">:$umin,
DefaultValuedParameter<"std::optional<uint64_t>", "std::nullopt">:$umax,
DefaultValuedParameter<"std::optional<uint64_t>", "std::nullopt">:$divisor
DefaultValuedParameter<"std::optional<uint64_t>", "std::nullopt">:$udiv
);
let assemblyFormat = [{
`<` struct($umin, $umax, $divisor) `>`
`<` struct($umin, $umax, $udiv) `>`
}];
}

Expand Down
26 changes: 26 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,32 @@ def Util_ClosureOpInterface : OpInterface<"ClosureOpInterface"> {
];
}

//===----------------------------------------------------------------------===//
// IREE::Util::InferIntDivisibilityOpInterface
//===----------------------------------------------------------------------===//

def InferIntDivisibilityOpInterface :
OpInterface<"InferIntDivisibilityOpInterface"> {
let cppNamespace = "::mlir::iree_compiler::IREE::Util";

let description = [{
Allows operations to participate in integer divisibility analysis.
}];

let methods = [
InterfaceMethod<
/*desc=*/[{

}],
/*retTy=*/"void",
/*methodName=*/"inferResultDivisibility",
/*args=*/(ins
"::llvm::ArrayRef<::mlir::iree_compiler::IREE::Util::IntegerDivisibility>":$argDivs,
"::mlir::iree_compiler::IREE::Util::SetIntDivisibilityFn":$setResultDivs)
>
];
}

//===----------------------------------------------------------------------===//
// IREE::Util::InitializerOpInterface
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 17 additions & 2 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,11 +1141,12 @@ AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) {
// Gets the unioned divisor for an operand. If there are multiple divisor
// assumptions, the gcd of all of them is returned. If there are no
// divisor assumptions, std::nullopt is returned.
std::optional<uint64_t> AssumeIntOp::getUnionedDivisor(unsigned operandIndex) {
std::optional<uint64_t>
AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) {
auto assumptions = getOperandAssumptions(operandIndex);
std::optional<uint64_t> divisorUnion;
for (auto assumption : assumptions) {
auto divisor = assumption.getDivisor();
auto divisor = assumption.getUdiv();
if (divisor) {
if (divisorUnion)
divisorUnion = std::gcd(*divisor, *divisorUnion);
Expand Down Expand Up @@ -1176,6 +1177,20 @@ void AssumeIntOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
}
}

void AssumeIntOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
SetIntDivisibilityFn setResultDivs) {
for (auto [index, result] : llvm::enumerate(getResults())) {
Type type = result.getType();
if (!isa<IndexType>(type) && !isa<IntegerType>(type))
continue;
auto udiv = getUnionedUnsignedDivisor(index);
if (udiv) {
setResultDivs(result,
ConstantIntDivisibility(/*udiv=*/*udiv, /*sdiv=*/*udiv));
}
}
}

void AssumeIntOp::build(OpBuilder &builder, OperationState &state,
Value singleOperand,
IntAssumptionAttr singleAssumption) {
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def OpGroupCompilerHintOps : OpDocGroup {
let opDocGroup = OpGroupCompilerHintOps in {

def Util_AssumeIntOp : Util_PureOp<"assume.int", [
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface, ["inferResultRanges"]>
]> {
let summary = "memorializes assumptions about index/integer values.";
let description = [{
Expand Down Expand Up @@ -507,7 +508,7 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", [
// Gets the unioned divisor for an operand. If there are multiple divisor
// assumptions, the gcd of all of them is returned. If there are no
// divisor assumptions, std::nullopt is returned.
std::optional<uint64_t> getUnionedDivisor(unsigned operandIndex);
std::optional<uint64_t> getUnionedUnsignedDivisor(unsigned operandIndex);
}];

let hasCustomAssemblyFormat = 1;
Expand Down
90 changes: 90 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/CallInterfaces.h"

#include <numeric>

// clang-format off: must be included after all LLVM/MLIR headers.
#include "iree/compiler/Dialect/Util/IR/UtilEnums.h.inc" // IWYU pragma: keep
// clang-format on
Expand Down Expand Up @@ -155,6 +157,94 @@ void excludeTiedOperandAndResultIndices(
ArrayRef<unsigned> excludedResultIndices,
SmallVector<int64_t> &tiedOperandIndices);

//===----------------------------------------------------------------------===//
// Forward defines for InferIntDivisibilityOpInterface
// See implementations in IntegerDivisibility.h.
//===----------------------------------------------------------------------===//

class ConstantIntDivisibility {
public:
ConstantIntDivisibility() = default;
ConstantIntDivisibility(uint64_t udiv, uint64_t sdiv)
: udivVal(udiv), sdivVal(sdiv) {}

bool operator==(const ConstantIntDivisibility &other) const {
return udivVal == other.udivVal && sdivVal == other.sdivVal;
}

uint64_t udiv() const { return this->udivVal; }
uint64_t sdiv() const { return this->sdivVal; }

// Returns the union (computed separately for signed and unsigned bounds)
// for this range and `other`.
ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const {
return ConstantIntDivisibility(
/*udiv=*/std::gcd(udiv(), other.udiv()),
/*sdiv=*/std::gcd(sdiv(), other.sdiv()));
}

private:
uint64_t udivVal;
uint64_t sdivVal;

friend raw_ostream &operator<<(raw_ostream &os,
const ConstantIntDivisibility &div);
};

inline raw_ostream &operator<<(raw_ostream &os,
const ConstantIntDivisibility &div) {
os << "ConstantIntDivisibility(udiv = " << div.udivVal
<< ", sdiv = " << div.sdivVal << ")";
return os;
}

class IntegerDivisibility {
public:
IntegerDivisibility(ConstantIntDivisibility value)
: value(std::move(value)) {}
IntegerDivisibility(
std::optional<ConstantIntDivisibility> value = std::nullopt)
: value(std::move(value)) {}
// Gets the minimum divisibility of 1 that is used to indicate that the value
// cannot be analyzed further.
static IntegerDivisibility getMinDivisibility() {
return IntegerDivisibility(ConstantIntDivisibility(1, 1));
}

bool isUninitialized() const { return !value.has_value(); }
const ConstantIntDivisibility &getValue() const {
assert(!isUninitialized());
return *value;
}

bool operator==(const IntegerDivisibility &rhs) const {
return value == rhs.value;
}

static IntegerDivisibility join(const IntegerDivisibility &lhs,
const IntegerDivisibility &rhs) {
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue()));
}

void print(raw_ostream &os) const { os << value; }

private:
std::optional<ConstantIntDivisibility> value;
};

inline raw_ostream &operator<<(raw_ostream &os,
const IntegerDivisibility &div) {
div.print(os);
return os;
}

using SetIntDivisibilityFn =
llvm::function_ref<void(Value, const ConstantIntDivisibility &)>;

//===----------------------------------------------------------------------===//
// Shape-aware interface utilities
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 5270093

Please sign in to comment.