Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Allow symbolic bounds in IndexMap analysis #15264

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ class IndexMapNode : public Object {
* \returns The indices in the output space. Contains one value for
* each expression in `final_indices`.
*/
Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
arith::Analyzer* analyzer = nullptr) const;
Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices, arith::Analyzer* analyzer) const;

/*! \brief Map a memory range to the output space
*
Expand All @@ -121,7 +120,7 @@ class IndexMapNode : public Object {
* \returns The ranges in the output space. Contains one value for
* each expression in `final_indices`.
*/
Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer = nullptr) const;
Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer) const;

/*! \brief Map a buffer shape to the output space
*
Expand All @@ -134,7 +133,7 @@ class IndexMapNode : public Object {
* \returns The buffer shape in the output space. Contains one
* value for each expression in `final_indices`.
*/
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* analyzer = nullptr) const;
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* analyzer) const;

/* \brief Map an NDArray according to this index map
*
Expand Down Expand Up @@ -203,7 +202,7 @@ class IndexMap : public ObjectRef {
* If the user has supplied an `inverse_index_map`, that map is
* assumed to be correct and bijective, and is returned.
*/
IndexMap Inverse(Array<Range> initial_ranges) const;
IndexMap Inverse(Array<Range> initial_ranges, arith::Analyzer* analyzer) const;

/*! \brief Rename the variables in the index map and ensure the names are unique.
*
Expand All @@ -225,7 +224,8 @@ class IndexMap : public ObjectRef {
* \return The inverted index map, along with the predicate for
* which the inverse maps to a valid range.
*/
std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges) const;
std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges,
arith::Analyzer* analyzer) const;

TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
};
Expand Down
9 changes: 6 additions & 3 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_TOPI_TRANSFORM_H_
#define TVM_TOPI_TRANSFORM_H_

#include <tvm/arith/analyzer.h>
#include <tvm/te/operation.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/index_map.h>
Expand Down Expand Up @@ -1738,16 +1739,18 @@ inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& s
inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map,
const String name = "T_meta_schedule_layout_trans",
const String tag = kInjective) {
arith::Analyzer analyzer;
Array<Range> iter_domain;
iter_domain.reserve(src->shape.size());
for (const PrimExpr& e : src->shape) {
iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
}
Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape);
Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, &analyzer);
return compute(
post_transform_shape,
[src, inv = index_map.Inverse(iter_domain)](const Array<Var>& indices) -> PrimExpr {
return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}));
[src, inv = index_map.Inverse(iter_domain, &analyzer),
&analyzer](const Array<Var>& indices) -> PrimExpr {
return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}, &analyzer));
},
name, tag);
}
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
[tool.isort]
profile = "black"
src_paths = ["python", "tests/python"]

[tool.black]
line-length = 100
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/te/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@

import tvm._ffi
from tvm._ffi.base import string_types

from tvm.runtime import Object, convert
from tvm.ir import container as _container
from tvm.tir import IterVar, Buffer, Var, IndexMap
from tvm.runtime import Object, convert
from tvm.tir import Buffer, IndexMap, IterVar, Var

from . import tensor as _tensor
from . import _ffi_api
from . import tensor as _tensor


@tvm._ffi.register_object
Expand Down Expand Up @@ -600,7 +599,9 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
"""

ndim = len(self.op.output(0).shape)
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim)
index_map, axis_separators = IndexMap.from_func_with_separators(
mapping_function, ndim=ndim, index_dtype="int32"
)

new_iter_vars = _ffi_api.StageTransformLayout(
self, index_map.initial_indices, index_map.final_indices
Expand Down
20 changes: 14 additions & 6 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
attrs=None,
span=None,
):

param_list = []
buffer_map = {} if buffer_map is None else buffer_map
for x in params:
Expand Down Expand Up @@ -266,6 +265,8 @@ def from_func(
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
*,
index_dtype: str = "int64",
):
"""Create an index map from a function

Expand Down Expand Up @@ -302,7 +303,10 @@ def from_func(

"""
index_map, axis_separators = IndexMap.from_func_with_separators(
mapping_function, ndim, inverse_index_map
mapping_function,
ndim,
inverse_index_map,
index_dtype=index_dtype,
)
assert not axis_separators, (
"The mapping_function provided to IndexMap.from_func "
Expand All @@ -316,6 +320,8 @@ def from_func_with_separators(
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
*,
index_dtype: str = "int64",
):
"""Create an index map from a function

Expand Down Expand Up @@ -346,6 +352,9 @@ def from_func_with_separators(
It is the user's responsibility to ensure the correctness of the pre-defined inverse
index map.

index_dtype : str
The default index dtype to use for input iters in the mapping function.

Returns
-------
ret: Tuple[IndexMap, List[int]]
Expand All @@ -361,20 +370,19 @@ def from_func_with_separators(
args = []
var_arg_name = None
kwargs = collections.OrderedDict()
default_index_dtype = "int32"

for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))
args.append(tvm.tir.Var(name, index_dtype))

elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name

elif param.kind == inspect.Parameter.KEYWORD_ONLY:
kwargs[name] = tvm.tir.Var(name, default_index_dtype)
kwargs[name] = tvm.tir.Var(name, index_dtype)

else:
raise ValueError("transform_layout mapping may not have *args")
Expand All @@ -386,7 +394,7 @@ def from_func_with_separators(
assert ndim is not None, "ndim must be specified when *args is used"
num_var_args = ndim - len(args) - len(kwargs)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype))
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", index_dtype))

mapping = mapping_function(*args, **kwargs)

Expand Down
40 changes: 34 additions & 6 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def _parse_seed(seed: Optional[int]) -> int:
return seed


def _get_block_default_dtype(block: Block) -> str:
for i in block.iter_vars:
return i.var.dtype
for buffer_region in list(block.reads) + list(block.writes):
for dom in buffer_region.region:
return dom.min.dtype
return "int64"


@_register_object("tir.Schedule")
class Schedule(Object):
"""The user-facing schedule class
Expand Down Expand Up @@ -1492,7 +1501,10 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None:
block = self._normalize_block_arg(block)

if callable(index_map):
index_map = IndexMap.from_func(index_map)
index_map = IndexMap.from_func(
index_map,
index_dtype=_get_block_default_dtype(self.get(block)),
)
return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope, index_map
)
Expand Down Expand Up @@ -1589,7 +1601,10 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
block = self._normalize_block_arg(block)

if callable(index_map):
index_map = IndexMap.from_func(index_map)
index_map = IndexMap.from_func(
index_map,
index_dtype=_get_block_default_dtype(self.get(block)),
)
return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope, index_map
)
Expand Down Expand Up @@ -3246,14 +3261,22 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->

ndim = len(buffer_obj.shape)
if callable(index_map):
index_map, axis_separators = IndexMap.from_func_with_separators(index_map, ndim=ndim)
index_map, axis_separators = IndexMap.from_func_with_separators(
index_map,
ndim=ndim,
index_dtype=_get_block_default_dtype(self.get(block)),
)
else:
axis_separators = []

if pad_value is None:
pass
elif callable(pad_value):
pad_value = IndexMap.from_func(pad_value, ndim=len(index_map.final_indices))
pad_value = IndexMap.from_func(
pad_value,
ndim=len(index_map.final_indices),
index_dtype=_get_block_default_dtype(self.get(block)),
)
elif not isinstance(pad_value, IndexMap):
# Explicitly convert python int/float arguments to the
# buffer's type. If the default `tvm.runtime.convert`
Expand All @@ -3264,7 +3287,9 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
elif "float" in buffer_obj.dtype and isinstance(pad_value, float):
pad_value = FloatImm(buffer_obj.dtype, pad_value)
pad_value = IndexMap.from_func(
lambda *indices: pad_value, ndim=len(index_map.final_indices)
lambda *indices: pad_value,
ndim=len(index_map.final_indices),
index_dtype=_get_block_default_dtype(self.get(block)),
)

buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
Expand Down Expand Up @@ -3337,7 +3362,10 @@ def after_transform_block_layout(
"""
block = self._normalize_block_arg(block)
if callable(index_map):
index_map = IndexMap.from_func(index_map)
index_map = IndexMap.from_func(
index_map,
index_dtype=_get_block_default_dtype(self.get(block)),
)
_ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member
self, block, index_map
)
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/tir/schedule/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def verify_trace_roundtrip(
The text format or formats whose round-trip behavior should be
validated. If a single string, validate round-trips through
"""
from tvm.script import tir as T # pylint: disable=import-outside-toplevel

if not isinstance(text_format, str):
for opt in text_format:
new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt)
Expand All @@ -66,7 +68,9 @@ def verify_trace_roundtrip(
Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
elif text_format == "python":
py_trace = "\n".join(trace.as_python())
exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used
vars_dict = {"T": T}
vars_dict.update(tvm.tir.__dict__)
exec(py_trace, vars_dict, {"sch": new_sch}) # pylint: disable=exec-used
else:
assert text_format in ("json", "python"), f"Unknown text format: {text_format}"

Expand Down
20 changes: 20 additions & 0 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include "ir_mutator_with_analyzer.h"

#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>

Expand All @@ -39,6 +40,25 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) {
}
}

Array<PrimExpr> IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const Array<PrimExpr>& indices,
bool non_trivial_only) {
PrimExpr pred = const_true();
for (PrimExpr val : iter_predicates_) {
pred = pred && val;
}
int n = indices.size();
Array<PrimExpr> simplified = arith::IterMapSimplify(
indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, this->analyzer_);
if (non_trivial_only) {
for (int i = 0; i < n; ++i) {
if (simplified[i]->IsInstance<IntImmNode>() && indices[i]->IsInstance<VarNode>()) {
simplified.Set(i, indices[i]);
}
}
}
return simplified;
}

Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
// record the loop variable as iterators
Range dom = Range::FromMinExtent(op->min, op->extent);
Expand Down
6 changes: 6 additions & 0 deletions src/arith/ir_mutator_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
*/
void MarkBufferMapShapes(const tir::PrimFunc& func);

/*!
* \brief Use internal bound information to perform inter map simplification of indices.
* \note Only do this during layout remapping
*/
Array<PrimExpr> IterMapSimplifyWithContext(const Array<PrimExpr>& indices, bool non_trivial_only);

/*! \brief internal analyzer field. */
Analyzer* analyzer_;
// the following two fields are useful in case we want
Expand Down
6 changes: 6 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,12 @@ class IterMapToExprNormalizer : public ExprMutator {
if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
return source * expr->scale;
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
// Simplify if `expr` is always 0. The 2nd condition guarantess that we do not aggressively
// simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis
// like tensorization.
if (is_one(expr->extent) && !is_one(expr->source->extent)) {
return make_const(expr->extent->dtype, 0);
}
return floordiv(source, expr->lower_factor) * expr->scale;
} else {
return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) *
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "./te_compiler_cache.h"

#include <tvm/arith/analyzer.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/name_supply.h>
#include <tvm/ir/type_functor.h>
Expand Down Expand Up @@ -594,7 +595,8 @@ class ScheduleBuilder : public ExprVisitor {
src_size_1d *= c->shape[i];
orig_shape.push_back(PrimExpr(static_cast<int>((c->shape[i]))));
}
auto dst_shape = index_map->MapShape(orig_shape);
arith::Analyzer analyzer;
auto dst_shape = index_map->MapShape(orig_shape, &analyzer);
std::vector<int64_t> dst_shape_int;
size_t dst_size_1d = 1;
for (size_t i = 0; i < dst_shape.size(); ++i) {
Expand Down
Loading
Loading