diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 796e04e74a7f..340d953ccf2f 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -102,8 +102,7 @@ class IndexMapNode : public Object { * \returns The indices in the output space. Contains one value for * each expression in `final_indices`. */ - Array MapIndices(const Array& indices, - arith::Analyzer* analyzer = nullptr) const; + Array MapIndices(const Array& indices, arith::Analyzer* analyzer) const; /*! \brief Map a memory range to the output space * @@ -121,7 +120,7 @@ class IndexMapNode : public Object { * \returns The ranges in the output space. Contains one value for * each expression in `final_indices`. */ - Array MapRanges(const Array& ranges, arith::Analyzer* analyzer = nullptr) const; + Array MapRanges(const Array& ranges, arith::Analyzer* analyzer) const; /*! \brief Map a buffer shape to the output space * @@ -134,7 +133,7 @@ class IndexMapNode : public Object { * \returns The buffer shape in the output space. Contains one * value for each expression in `final_indices`. */ - Array MapShape(const Array& shape, arith::Analyzer* analyzer = nullptr) const; + Array MapShape(const Array& shape, arith::Analyzer* analyzer) const; /* \brief Map an NDArray according to this index map * @@ -203,7 +202,7 @@ class IndexMap : public ObjectRef { * If the user has supplied an `inverse_index_map`, that map is * assumed to be correct and bijective, and is returned. */ - IndexMap Inverse(Array initial_ranges) const; + IndexMap Inverse(Array initial_ranges, arith::Analyzer* analyzer) const; /*! \brief Rename the variables in the index map and ensure the names are unique. * @@ -225,7 +224,8 @@ class IndexMap : public ObjectRef { * \return The inverted index map, along with the predicate for * which the inverse maps to a valid range. */ - std::pair NonSurjectiveInverse(Array initial_ranges) const; + std::pair NonSurjectiveInverse(Array initial_ranges, + arith::Analyzer* analyzer) const; TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); }; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index cab3466765b4..d881c4f42333 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -24,6 +24,7 @@ #ifndef TVM_TOPI_TRANSFORM_H_ #define TVM_TOPI_TRANSFORM_H_ +#include #include #include #include @@ -1738,16 +1739,18 @@ inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& s inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map, const String name = "T_meta_schedule_layout_trans", const String tag = kInjective) { + arith::Analyzer analyzer; Array iter_domain; iter_domain.reserve(src->shape.size()); for (const PrimExpr& e : src->shape) { iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e)); } - Array post_transform_shape = index_map->MapShape(src->shape); + Array post_transform_shape = index_map->MapShape(src->shape, &analyzer); return compute( post_transform_shape, - [src, inv = index_map.Inverse(iter_domain)](const Array& indices) -> PrimExpr { - return src(inv->MapIndices(Array{indices.begin(), indices.end()})); + [src, inv = index_map.Inverse(iter_domain, &analyzer), + &analyzer](const Array& indices) -> PrimExpr { + return src(inv->MapIndices(Array{indices.begin(), indices.end()}, &analyzer)); }, name, tag); } diff --git a/pyproject.toml b/pyproject.toml index 5cca711ddbe6..91740f2b4b4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +[tool.isort] +profile = "black" +src_paths = ["python", "tests/python"] [tool.black] line-length = 100 diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index 3dbf2cefe48c..0b6abc256690 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -22,13 +22,12 @@ import tvm._ffi from tvm._ffi.base import string_types - -from tvm.runtime import Object, convert from tvm.ir import container as _container -from tvm.tir import IterVar, Buffer, Var, IndexMap +from tvm.runtime import Object, convert +from tvm.tir import Buffer, IndexMap, IterVar, Var -from . import tensor as _tensor from . import _ffi_api +from . import tensor as _tensor @tvm._ffi.register_object @@ -600,7 +599,9 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr """ ndim = len(self.op.output(0).shape) - index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim) + index_map, axis_separators = IndexMap.from_func_with_separators( + mapping_function, ndim=ndim, index_dtype="int32" + ) new_iter_vars = _ffi_api.StageTransformLayout( self, index_map.initial_indices, index_map.final_indices diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 32ec347039c8..bd44e3f7c3de 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -67,7 +67,6 @@ def __init__( attrs=None, span=None, ): - param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: @@ -266,6 +265,8 @@ def from_func( mapping_function: Callable, ndim: Optional[int] = None, inverse_index_map: Union[Callable, Optional["IndexMap"]] = None, + *, + index_dtype: str = "int64", ): """Create an index map from a function @@ -302,7 +303,10 @@ def from_func( """ index_map, axis_separators = IndexMap.from_func_with_separators( - mapping_function, ndim, inverse_index_map + mapping_function, + ndim, + inverse_index_map, + index_dtype=index_dtype, ) assert not axis_separators, ( "The mapping_function provided to IndexMap.from_func " @@ -316,6 +320,8 @@ def from_func_with_separators( mapping_function: Callable, ndim: Optional[int] = None, inverse_index_map: Union[Callable, Optional["IndexMap"]] = None, + *, + index_dtype: str = "int64", ): """Create an index map from a function @@ -346,6 +352,9 @@ def from_func_with_separators( It is the user's responsibility to ensure the correctness of the pre-defined inverse index map. + index_dtype : str + The default index dtype to use for input iters in the mapping function. + Returns ------- ret: Tuple[IndexMap, List[int]] @@ -361,20 +370,19 @@ def from_func_with_separators( args = [] var_arg_name = None kwargs = collections.OrderedDict() - default_index_dtype = "int32" for name, param in params.items(): if param.kind in [ inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ]: - args.append(tvm.tir.Var(name, default_index_dtype)) + args.append(tvm.tir.Var(name, index_dtype)) elif param.kind == inspect.Parameter.VAR_POSITIONAL: var_arg_name = name elif param.kind == inspect.Parameter.KEYWORD_ONLY: - kwargs[name] = tvm.tir.Var(name, default_index_dtype) + kwargs[name] = tvm.tir.Var(name, index_dtype) else: raise ValueError("transform_layout mapping may not have *args") @@ -386,7 +394,7 @@ def from_func_with_separators( assert ndim is not None, "ndim must be specified when *args is used" num_var_args = ndim - len(args) - len(kwargs) for i in range(num_var_args): - args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype)) + args.append(tvm.tir.Var(f"{var_arg_name}_{i}", index_dtype)) mapping = mapping_function(*args, **kwargs) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4f717474e02b..6c42f15a2f7e 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -93,6 +93,15 @@ def _parse_seed(seed: Optional[int]) -> int: return seed +def _get_block_default_dtype(block: Block) -> str: + for i in block.iter_vars: + return i.var.dtype + for buffer_region in list(block.reads) + list(block.writes): + for dom in buffer_region.region: + return dom.min.dtype + return "int64" + + @_register_object("tir.Schedule") class Schedule(Object): """The user-facing schedule class @@ -1492,7 +1501,10 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: block = self._normalize_block_arg(block) if callable(index_map): - index_map = IndexMap.from_func(index_map) + index_map = IndexMap.from_func( + index_map, + index_dtype=_get_block_default_dtype(self.get(block)), + ) return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member self, block, read_buffer_index, storage_scope, index_map ) @@ -1589,7 +1601,10 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: block = self._normalize_block_arg(block) if callable(index_map): - index_map = IndexMap.from_func(index_map) + index_map = IndexMap.from_func( + index_map, + index_dtype=_get_block_default_dtype(self.get(block)), + ) return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member self, block, write_buffer_index, storage_scope, index_map ) @@ -3246,14 +3261,22 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> ndim = len(buffer_obj.shape) if callable(index_map): - index_map, axis_separators = IndexMap.from_func_with_separators(index_map, ndim=ndim) + index_map, axis_separators = IndexMap.from_func_with_separators( + index_map, + ndim=ndim, + index_dtype=_get_block_default_dtype(self.get(block)), + ) else: axis_separators = [] if pad_value is None: pass elif callable(pad_value): - pad_value = IndexMap.from_func(pad_value, ndim=len(index_map.final_indices)) + pad_value = IndexMap.from_func( + pad_value, + ndim=len(index_map.final_indices), + index_dtype=_get_block_default_dtype(self.get(block)), + ) elif not isinstance(pad_value, IndexMap): # Explicitly convert python int/float arguments to the # buffer's type. If the default `tvm.runtime.convert` @@ -3264,7 +3287,9 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> elif "float" in buffer_obj.dtype and isinstance(pad_value, float): pad_value = FloatImm(buffer_obj.dtype, pad_value) pad_value = IndexMap.from_func( - lambda *indices: pad_value, ndim=len(index_map.final_indices) + lambda *indices: pad_value, + ndim=len(index_map.final_indices), + index_dtype=_get_block_default_dtype(self.get(block)), ) buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 @@ -3337,7 +3362,10 @@ def after_transform_block_layout( """ block = self._normalize_block_arg(block) if callable(index_map): - index_map = IndexMap.from_func(index_map) + index_map = IndexMap.from_func( + index_map, + index_dtype=_get_block_default_dtype(self.get(block)), + ) _ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member self, block, index_map ) diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index f38a657123ed..a293b54b46a1 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -51,6 +51,8 @@ def verify_trace_roundtrip( The text format or formats whose round-trip behavior should be validated. If a single string, validate round-trips through """ + from tvm.script import tir as T # pylint: disable=import-outside-toplevel + if not isinstance(text_format, str): for opt in text_format: new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt) @@ -66,7 +68,9 @@ def verify_trace_roundtrip( Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) elif text_format == "python": py_trace = "\n".join(trace.as_python()) - exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used + vars_dict = {"T": T} + vars_dict.update(tvm.tir.__dict__) + exec(py_trace, vars_dict, {"sch": new_sch}) # pylint: disable=exec-used else: assert text_format in ("json", "python"), f"Unknown text format: {text_format}" diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 1f087d993428..2ee427beb86c 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -22,6 +22,7 @@ */ #include "ir_mutator_with_analyzer.h" +#include #include #include @@ -39,6 +40,25 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) { } } +Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const Array& indices, + bool non_trivial_only) { + PrimExpr pred = const_true(); + for (PrimExpr val : iter_predicates_) { + pred = pred && val; + } + int n = indices.size(); + Array simplified = arith::IterMapSimplify( + indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, this->analyzer_); + if (non_trivial_only) { + for (int i = 0; i < n; ++i) { + if (simplified[i]->IsInstance() && indices[i]->IsInstance()) { + simplified.Set(i, indices[i]); + } + } + } + return simplified; +} + Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { // record the loop variable as iterators Range dom = Range::FromMinExtent(op->min, op->extent); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index f04b40e7ae4e..fb01fd19cee7 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -70,6 +70,12 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { */ void MarkBufferMapShapes(const tir::PrimFunc& func); + /*! + * \brief Use internal bound information to perform inter map simplification of indices. + * \note Only do this during layout remapping + */ + Array IterMapSimplifyWithContext(const Array& indices, bool non_trivial_only); + /*! \brief internal analyzer field. */ Analyzer* analyzer_; // the following two fields are useful in case we want diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index f8a36daf5328..af1128aa273c 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -2061,6 +2061,12 @@ class IterMapToExprNormalizer : public ExprMutator { if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) { return source * expr->scale; } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) { + // Simplify if `expr` is always 0. The 2nd condition guarantess that we do not aggressively + // simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis + // like tensorization. + if (is_one(expr->extent) && !is_one(expr->source->extent)) { + return make_const(expr->extent->dtype, 0); + } return floordiv(source, expr->lower_factor) * expr->scale; } else { return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) * diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 275d1b6bf787..b747855bff59 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -19,6 +19,7 @@ #include "./te_compiler_cache.h" +#include #include #include #include @@ -594,7 +595,8 @@ class ScheduleBuilder : public ExprVisitor { src_size_1d *= c->shape[i]; orig_shape.push_back(PrimExpr(static_cast((c->shape[i])))); } - auto dst_shape = index_map->MapShape(orig_shape); + arith::Analyzer analyzer; + auto dst_shape = index_map->MapShape(orig_shape, &analyzer); std::vector dst_shape_int; size_t dst_size_1d = 1; for (size_t i = 0; i < dst_shape.size(); ++i) { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a0111ff7cdbf..fde6daa4d851 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -23,6 +23,7 @@ */ #include "transform.h" +#include #include #include #include @@ -3434,9 +3435,10 @@ Array MetaScheduleLayoutTransformCompute(const Attrs& attrs, bool MetaScheduleLayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { TensorType data_type = Downcast(types[0]); + arith::Analyzer analyzer; const MetaScheduleLayoutTransformAttrs* params = attrs.as(); ICHECK(params); - Array new_shape = params->index_map->MapShape(data_type->shape); + Array new_shape = params->index_map->MapShape(data_type->shape, &analyzer); reporter->Assign(types[1], TensorType(new_shape, data_type->dtype)); return true; } diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc index 5e7431e5109c..04b25f764c8a 100644 --- a/src/runtime/logging.cc +++ b/src/runtime/logging.cc @@ -130,6 +130,9 @@ int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int li backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, BacktraceErrorCallback, symbol_str.get()); } + if (filename == nullptr && strstr(symbol_str.get()->data(), "ffi_call_")) { + return 0; + } s << *symbol_str; if (filename != nullptr) { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index cd0ec0e34f03..22103f7b0ff2 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -898,8 +898,9 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); ICHECK(index_map_func); + arith::Analyzer analyzer; auto inverse_index_map = - IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}); + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, &analyzer); auto indices_16x16 = inverse_index_map->final_indices; // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine. diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 7041b751c58a..233663feac6d 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -193,7 +193,7 @@ void PassDownDomain(const Stage& stage, std::unordered_map* p_st for (const auto& iter_var : s->original_variables) { original_ranges.push_back(state[iter_var]); } - Array updated_ranges = s->forward_transformation->MapRanges(original_ranges); + Array updated_ranges = s->forward_transformation->MapRanges(original_ranges, actx); ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size()); for (size_t i = 0; i < updated_ranges.size(); i++) { @@ -269,6 +269,7 @@ void PassUpIndex(const Stage& stage, const Map& dom_map, } } else if (rel.as()) { } else if (const TransformNode* s = rel.as()) { + arith::Analyzer analyzer; bool missing_transformed = false; for (const auto& iter_var : s->transformed_variables) { if (!state.count(iter_var)) { @@ -284,7 +285,8 @@ void PassUpIndex(const Stage& stage, const Map& dom_map, for (const auto& iter_var : s->transformed_variables) { transformed_indices.push_back(state[iter_var]); } - Array original_indices = s->inverse_transformation->MapIndices(transformed_indices); + Array original_indices = + s->inverse_transformation->MapIndices(transformed_indices, &analyzer); ICHECK_EQ(original_indices.size(), s->original_variables.size()); for (size_t i = 0; i < original_indices.size(); i++) { @@ -352,7 +354,9 @@ void PassDownIndex(const Stage& stage, const Map& dom_map, for (const auto& iter_var : s->original_variables) { original_indices.push_back(state[iter_var]); } - Array transformed_indices = s->forward_transformation->MapIndices(original_indices); + arith::Analyzer analyzer; + Array transformed_indices = + s->forward_transformation->MapIndices(original_indices, &analyzer); ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size()); for (size_t i = 0; i < transformed_indices.size(); i++) { @@ -449,7 +453,9 @@ Array PassUpDomain(const TransformNode* s, transformed_indices.push_back(iter_var->var); } - Array transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices); + arith::Analyzer analyzer; + Array transformed_exprs = + s->inverse_transformation->MapIndices(transformed_indices, &analyzer); ICHECK_EQ(transformed_exprs.size(), s->original_variables.size()); for (size_t i = 0; i < transformed_exprs.size(); i++) { diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 56fe0cfc65ca..44e742eee4cf 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -21,6 +21,7 @@ * \file schedule_lang.cc */ #include +#include #include #include #include @@ -491,10 +492,11 @@ Stage& Stage::transform_layout(const Array& initial_indices, for (const auto& iter_var : compute->axis) { initial_ranges.push_back(iter_var->dom); } - Array final_ranges = map->MapRanges(initial_ranges); + arith::Analyzer analyzer; + Array final_ranges = map->MapRanges(initial_ranges, &analyzer); // Make IterVar objects to represent the new iterations. - auto inverse = map.Inverse(initial_ranges); + auto inverse = map.Inverse(initial_ranges, &analyzer); Array final_indices_iter; ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index a39149cebaf4..149e4cecd442 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -55,7 +55,9 @@ IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(A std::pair IndexMapInverseImpl(const IndexMap& self, const Array& initial_ranges, - arith::IterMapLevel check_level) { + arith::IterMapLevel check_level, + arith::Analyzer* analyzer) { + ICHECK(analyzer != nullptr); if (self->inverse_index_map.defined()) { // return the pre-defined inverse index map if exists. In this // case, the user-defined inverse is assumed to be correct and @@ -88,9 +90,8 @@ std::pair IndexMapInverseImpl(const IndexMap& self, // Unpack the output indices into linear combinations of the initial // indices. - arith::Analyzer analyzer; - auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /* predicate = */ 1, - /*check_level=*/check_level, &analyzer, + auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /*predicate=*/1, + /*check_level=*/check_level, analyzer, /*simplify_trivial_iterators=*/false); CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " << "Error: " << padded_iter_map->errors[0]; @@ -110,15 +111,15 @@ std::pair IndexMapInverseImpl(const IndexMap& self, } else { expr = inverse_exprs_map.at(index); } - inverse_exprs.push_back(analyzer.Simplify(expr)); + inverse_exprs.push_back(analyzer->Simplify(expr)); } PrimExpr padding_predicate = padded_iter_map->padding_predicate; padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); padding_predicate = Substitute(padding_predicate, inverse_exprs_map); + auto output_ranges = self->MapRanges(initial_ranges, analyzer); { - auto output_ranges = self->MapRanges(initial_ranges); ICHECK_EQ(output_ranges.size(), output_vars.size()); arith::Analyzer analyzer; @@ -133,15 +134,17 @@ std::pair IndexMapInverseImpl(const IndexMap& self, return {IndexMap(output_vars, inverse_exprs), padding_predicate}; } -std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges) const { - return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck); +std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges, + arith::Analyzer* analyzer) const { + ICHECK(analyzer != nullptr); + return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck, analyzer); } -IndexMap IndexMap::Inverse(Array initial_ranges) const { +IndexMap IndexMap::Inverse(Array initial_ranges, arith::Analyzer* analyzer) const { + ICHECK(analyzer != nullptr); auto [inverse, padding_predicate] = - IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective); - arith::Analyzer analyzer; - CHECK(analyzer.CanProve(!padding_predicate)) + IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective, analyzer); + CHECK(analyzer->CanProve(!padding_predicate)) << "Bijective inverse should not contain padding, but inverse of " << *this << " over range " << initial_ranges << " resulted in a padding predicate of " << padding_predicate; return inverse; @@ -149,6 +152,7 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { Array IndexMapNode::MapIndices(const Array& indices, arith::Analyzer* analyzer) const { + ICHECK(analyzer != nullptr); ICHECK_EQ(indices.size(), initial_indices.size()); Map vmap; @@ -157,11 +161,6 @@ Array IndexMapNode::MapIndices(const Array& indices, vmap.Set(initial_indices[i], indices[i]); } - arith::Analyzer local_analyzer; - if (!analyzer) { - analyzer = &local_analyzer; - } - Array output = final_indices.Map([&](PrimExpr index) { PrimExpr result = SubstituteWithDataTypeLegalization( std::move(index), [&](const Var& var) { return vmap.Get(var); }); @@ -171,18 +170,13 @@ Array IndexMapNode::MapIndices(const Array& indices, } Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer* analyzer) const { + ICHECK(analyzer != nullptr); ICHECK_EQ(ranges.size(), initial_indices.size()); Map input_iters; for (size_t i = 0; i < initial_indices.size(); i++) { input_iters.Set(initial_indices[i], ranges[i]); } - - arith::Analyzer local_analyzer; - if (!analyzer) { - analyzer = &local_analyzer; - } - auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */ 1, /*check_level=*/arith::IterMapLevel::NoCheck, analyzer, /*simplify_trivial_iterators=*/false); @@ -240,6 +234,7 @@ Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer Array IndexMapNode::MapShape(const Array& shape, arith::Analyzer* analyzer) const { + ICHECK(analyzer != nullptr); ICHECK_EQ(shape.size(), initial_indices.size()); Array ranges; @@ -258,6 +253,7 @@ Array IndexMapNode::MapShape(const Array& shape, } runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { + arith::Analyzer analyzer; auto shape = arr_src.Shape(); ICHECK(shape.size() == initial_indices.size()) << "The rank of the input array should be " << initial_indices.size() << " but got " @@ -268,7 +264,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { size_1d *= shape[i]; orig_shape.push_back(PrimExpr(static_cast((shape[i])))); } - auto dst_shape = MapShape(orig_shape); + auto dst_shape = MapShape(orig_shape, &analyzer); std::vector dst_shape_int; for (size_t i = 0; i < dst_shape.size(); ++i) { @@ -292,7 +288,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { src_indices.push_back(PrimExpr(static_cast((src_linear_index / div_factor)))); src_linear_index %= div_factor; } - auto dst_indices = MapIndices(src_indices); + auto dst_indices = MapIndices(src_indices, &analyzer); // Convert an N-d coordinate to a linear coordinate // (z, y, x) -> z * height * width + y * width + x @@ -430,19 +426,29 @@ TVM_REGISTER_GLOBAL("tir.IndexMap") }); TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices") - .set_body_typed([](IndexMap map, Array indices) { return map->MapIndices(indices); }); + .set_body_typed([](IndexMap map, Array indices) { + arith::Analyzer analyzer; + return map->MapIndices(indices, &analyzer); + }); TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, Array shape) { - return map->MapShape(shape); + arith::Analyzer analyzer; + return map->MapShape(shape, &analyzer); }); -TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse); + +TVM_REGISTER_GLOBAL("tir.IndexMapInverse") + .set_body_typed([](IndexMap map, Array initial_ranges) { + arith::Analyzer analyzer; + return map.Inverse(initial_ranges, &analyzer); + }); TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray") .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }); TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") .set_body_typed([](IndexMap forward, Array initial_ranges) { - auto result = forward.NonSurjectiveInverse(initial_ranges); + arith::Analyzer analyzer; + auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); return Array{result.first, result.second}; }); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 6e361cbe051f..868fbe856352 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -72,6 +72,16 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl */ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); +/*! + * \brief Given an arbitrary sref, bind the shape var info of the PrimFunc it belongs to to the + * given analyzer + * \param state The schedule state + * \param sref The given sref + * \param analyzer The analyzer to be bound + */ +void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, + arith::Analyzer* analyzer); + /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index a765eb4d9f97..17cc39d5cb27 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1269,6 +1269,20 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, + arith::Analyzer* analyzer) { + while (sref->parent != nullptr) { + sref = sref->parent; + } + const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr); + for (const auto& kv : f->buffer_map) { + const Buffer& buffer = kv.second; + for (const PrimExpr& e : buffer->shape) { + analyzer->MarkGlobalNonNegValue(e); + } + } +} + /******** Misc ********/ bool HasOp(const Stmt& stmt, const Array& ops) { diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 6f9aa1127584..3fbdf856b533 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1909,13 +1909,14 @@ void CollectReindexCacheStageInfoAndCreateBuffer( ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& block_sref, const String& storage_scope, const IndexMap& index_map, const Block& block, const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { + arith::Analyzer analyzer; Array block_iter_vars, block_shape; for (const IterVar& iter_var : block->iter_vars) { block_iter_vars.push_back(iter_var); block_shape.push_back(iter_var->dom->extent); } - Array new_indices = index_map->MapIndices(block_iter_vars); - Array new_shape = index_map->MapShape(block_shape); + Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + Array new_shape = index_map->MapShape(block_shape, &analyzer); info->indices = new_indices; // Step 5. Update CacheTouchedInfo @@ -1926,8 +1927,6 @@ void CollectReindexCacheStageInfoAndCreateBuffer( old_indices.push_back(range->min); } - arith::Analyzer analyzer; - VarUseDefAnalyzer collector_new(/*defined_vars=*/{}); for (const PrimExpr& idx : new_indices) { collector_new(idx); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 45d0c81050d1..fc388b004843 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -680,20 +680,6 @@ void CalculateProvidedRequiredRegions( /******** Main Implementation ********/ -void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, - arith::Analyzer* analyzer) { - while (sref->parent != nullptr) { - sref = sref->parent; - } - const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr); - for (const auto& kv : f->buffer_map) { - const Buffer& buffer = kv.second; - for (const PrimExpr& e : buffer->shape) { - analyzer->MarkGlobalNonNegValue(e); - } - } -} - template void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index bb2abc559d2c..d9a9f3cfed32 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include @@ -93,12 +94,12 @@ class TransformLayoutPlanner : private StmtExprVisitor { static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) { + Optional pad_value, arith::Analyzer* analyzer) { ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) << "Internal error: Should be caught by ScheduleError checks prior to this point"; TransformLayoutPlanner visitor(old_buffer); visitor(block); - return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value); + return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value, analyzer); } private: @@ -220,14 +221,15 @@ class TransformLayoutPlanner : private StmtExprVisitor { public: BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate, const IndexMap& inverse, const Optional& pad_value, - Map* new_block_to_old) + Map* new_block_to_old, arith::Analyzer* analyzer) : info(info), new_buffer(new_buffer), new_indices(inverse->initial_indices), padding_predicate(padding_predicate), inverse(inverse), pad_value(pad_value), - new_block_to_old(*new_block_to_old) { + new_block_to_old(*new_block_to_old), + analyzer(analyzer) { ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size()); for (size_t i = 0; i < info.dependent_loopnest.size(); i++) { Var var = info.dependent_loopnest[i]->loop_var; @@ -353,7 +355,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { if (can_replace) { Array new_index_exprs = new_indices.Map([](const auto& var) -> PrimExpr { return var; }); - PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_index_exprs)[0]; + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_index_exprs, analyzer)[0]; store = BufferStore(new_buffer, if_then_else(padding_predicate, pad_value_at_index, op->value), new_index_exprs); @@ -429,22 +431,24 @@ class TransformLayoutPlanner : private StmtExprVisitor { const Optional& pad_value; Map& new_block_to_old; bool all_stores_replaced{true}; + arith::Analyzer* analyzer; Map var_remap; }; TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, - PrimExpr padding_predicate, Optional pad_value) const { - if (auto prologue_plan = - FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); + PrimExpr padding_predicate, Optional pad_value, + arith::Analyzer* analyzer) const { + if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, + pad_value, analyzer); prologue_plan.has_value()) { return prologue_plan.value(); - } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse, - padding_predicate, pad_value); + } else if (auto replacement_plan = FinalizeReplacementPlan( + new_buffer, index_map, inverse, padding_predicate, pad_value, analyzer); replacement_plan.has_value()) { return replacement_plan.value(); } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse, - padding_predicate, pad_value); + padding_predicate, pad_value, analyzer); epilogue_plan.has_value()) { return epilogue_plan.value(); } else { @@ -454,7 +458,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) const { + Optional pad_value, + arith::Analyzer* analyzer) const { if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -476,7 +481,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); - PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices, analyzer)[0]; PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value_at_index); Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr})); @@ -498,7 +503,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) const { + Optional pad_value, + arith::Analyzer* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -511,7 +517,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } BufferStoreReplacer replacer(info, new_buffer, padding_predicate, inverse, pad_value, - &new_block_to_old); + &new_block_to_old, analyzer); Stmt stmt = replacer(info.dependent_loopnest.back()->body); if (!replacer.is_all_stores_replaced()) { return NullOpt; @@ -547,7 +553,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - Optional pad_value) const { + Optional pad_value, + arith::Analyzer* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -566,7 +573,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { iter_values.push_back(loop_var); } - PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices, analyzer)[0]; Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices); std::stringstream block_name; @@ -757,12 +764,13 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, const IndexMap& index_map, const Optional& opt_inverse, const PrimExpr& padding_predicate, const Optional& pad_value) { - auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan( - scope_stmt, old_buffer, new_buffer, index_map, - opt_inverse.value(), padding_predicate, pad_value) - : TransformLayoutPlanner::NoPaddingRequired(); - arith::Analyzer analyzer; + auto plan = pad_value.defined() + ? TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, + opt_inverse.value(), padding_predicate, + pad_value, &analyzer) + : TransformLayoutPlanner::NoPaddingRequired(); + TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); Block result = Downcast(rewriter(scope_stmt)); if (auto plan_ptr = std::get_if(&plan)) { @@ -794,9 +802,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { void RewriteBufferAccess(Buffer* buffer, Array* indices) { *buffer = new_buffer_; - *indices = index_map_->MapIndices(*indices); - (*indices).MutateByApply( - [&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_); }); + *indices = index_map_->MapIndices(*indices, &index_simplifier_); + *indices = this->IterMapSimplifyWithContext(*indices, true); } using Parent = arith::IRMutatorWithAnalyzer; @@ -913,6 +920,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const TransformLayoutPlanner::TransformPlan& plan_; Map buffer_data_to_buffer_; Map new_block_to_old_; + arith::Analyzer index_simplifier_; }; class BufferIsSubregionError : public ScheduleError { @@ -1069,7 +1077,8 @@ class TransformationIntroducesPaddingError : public ScheduleError { } String DetailRenderTemplate() const final { - auto new_shape = index_map_->MapShape(buffer_->shape); + arith::Analyzer analyzer; + auto new_shape = index_map_->MapShape(buffer_->shape, &analyzer); std::ostringstream os; os << "The transformation " << index_map_ << " applied on buffer " << buffer_->name << " of shape " << buffer_->shape << " would result in shape " << new_shape @@ -1138,6 +1147,8 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map_orig, const Optional& pad_value, bool assume_injective_transform) { + arith::Analyzer analyzer; + AddShapeVarBounds(self, block_sref.get(), &analyzer); // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = @@ -1173,7 +1184,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ for (const auto& dim : old_buffer->shape) { region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim)); } - return index_map.NonSurjectiveInverse(region); + return index_map.NonSurjectiveInverse(region, &analyzer); }(); } @@ -1184,7 +1195,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ // Step 2: Infer the shape of the new buffer Buffer new_buffer = old_buffer; - new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape); + new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape, &analyzer); // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block // alloc_buffers. @@ -1336,6 +1347,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); const Block& block = GetRef(block_ptr); arith::Analyzer analyzer; + AddShapeVarBounds(self, block_sref.get(), &analyzer); // Step 1: Collect outer loops and loop vars Array loops = GetLoops(block_sref); // outer loops of the block @@ -1375,8 +1387,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 4: Apply the IndexMap to block iters. IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map); - Array transformed_block_iters = index_map->MapIndices(block_vars); - Array new_block_iter_range = index_map->MapShape(block_iter_range_array); + Array transformed_block_iters = index_map->MapIndices(block_vars, &analyzer); + Array new_block_iter_range = index_map->MapShape(block_iter_range_array, &analyzer); // Step 5: Create the new block after transformation. @@ -1408,14 +1420,13 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, } IndexMap inverse_index_map{nullptr}; try { - inverse_index_map = index_map.Inverse(initial_ranges); + inverse_index_map = index_map.Inverse(initial_ranges, &analyzer); } catch (...) { throw NotBijectiveAffineIndexMapError(self->mod, index_map); } - - Array inversed_new_block_vars = inverse_index_map->MapIndices( - new_block_vars); // old block vars written in terms of new block vars - + // old block vars written in terms of new block vars + Array inversed_new_block_vars = + inverse_index_map->MapIndices(new_block_vars, &analyzer); for (int i = 0, n = block_vars.size(); i < n; ++i) { inverse_subst_map.Set(Downcast(block_vars[i]), inversed_new_block_vars[i]); } diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index e4047273b618..62662a4ce111 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -388,16 +388,17 @@ TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWith /******** BlockBufferAccessSimplifier ********/ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_access_regions) { auto fmutate = [this](const BufferRegion& buffer_region) { - std::vector new_buffer_region; + Array new_buffer_region; + Array simplified_min; for (const auto& range : buffer_region->region) { - if (is_one(range->extent) && range->min->IsInstance()) { - new_buffer_region.push_back(Range::FromMinExtent( - SimplifyNonTrivialExpr(range->min, analyzer_), make_const(range->min.dtype(), 1))); - } else { - new_buffer_region.push_back( - Range::FromMinExtent(SimplifyNonTrivialExpr(range->min, analyzer_), - SimplifyNonTrivialExpr(range->extent, analyzer_))); - } + simplified_min.push_back(range->min); + } + simplified_min = this->IterMapSimplifyWithContext(simplified_min, true); + int n = buffer_region->region.size(); + for (int i = 0; i < n; ++i) { + PrimExpr min = simplified_min[i]; + PrimExpr extent = analyzer_->Simplify(buffer_region->region[i]->extent); + new_buffer_region.push_back(Range::FromMinExtent(min, extent)); } return BufferRegion(buffer_region->buffer, new_buffer_region); }; @@ -405,8 +406,7 @@ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_ } void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array* indices) { - (*indices).MutateByApply( - [this](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, analyzer_); }); + *indices = this->IterMapSimplifyWithContext(*indices, true); } Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) { diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index f37c21593f23..c04e12b8395e 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -220,18 +220,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { Array GetSimplifiedElemOffset(const Buffer& buffer, const Array& indices) { auto flattened_indices = buffer->ElemOffset(indices); - // Use IterMapSimplify to enable constant fold of fused indices - // IterMapSimplify is more powerful and time-consuming than normal - // simplify as it tries to deal with symbolic fusion - // - // Only use to handle indices during layout transformations - // So we restrict the use to here - PrimExpr pred = const_true(); - for (PrimExpr val : iter_predicates_) { - pred = pred && val; - } - return arith::IterMapSimplify(flattened_indices, this->iter_vars_, pred, - arith::IterMapLevel::Surjective, this->analyzer_); + return this->IterMapSimplifyWithContext(flattened_indices, false); } template diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 8c409fba5e46..9c1244838173 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1265,7 +1265,7 @@ class ApplyLayoutTransforms : public StmtExprMutator { Array transforms = lookup.value(); for (const auto& transform : transforms) { - write_ptr->bounds = transform->MapRanges(realize->bounds); + write_ptr->bounds = transform->MapRanges(realize->bounds, &analyzer); } } @@ -1292,7 +1292,7 @@ class ApplyLayoutTransforms : public StmtExprMutator { Array transforms = lookup.value(); for (const auto& transform : transforms) { - write_ptr->indices = transform->MapIndices(node->indices); + write_ptr->indices = transform->MapIndices(node->indices, &analyzer); } } return node; @@ -1315,7 +1315,7 @@ class ApplyLayoutTransforms : public StmtExprMutator { auto write_ptr = buf.CopyOnWrite(); for (const auto& transform : transforms) { - write_ptr->shape = transform->MapShape(buf->shape); + write_ptr->shape = transform->MapShape(buf->shape, &analyzer); } } @@ -1326,6 +1326,7 @@ class ApplyLayoutTransforms : public StmtExprMutator { std::unordered_map buf_map_; Map> layout_transforms_; + arith::Analyzer analyzer; }; class StorageFlattener : public StmtExprMutator { diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index 82fd6cfa9a82..abe0bc3a3d12 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -130,7 +130,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { const auto* index_map_func = runtime::Registry::Get("tir.index_map_m16n8k8.matrixC"); ICHECK(index_map_func); auto index_map = IndexMap::FromFunc(2, *index_map_func); - auto new_indices = index_map->MapIndices(store->indices); + auto new_indices = index_map->MapIndices(store->indices, &analyzer); n->buffer = buffer_map_[store->buffer]; n->indices = std::move(new_indices); } else if (store->buffer.scope() == "m16n8k8.matrixA" || @@ -149,7 +149,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { const auto* index_map_func = runtime::Registry::Get("tir.index_map_m16n8k8.matrixC"); ICHECK(index_map_func); auto index_map = IndexMap::FromFunc(2, *index_map_func); - auto new_indices = index_map->MapIndices(load->indices); + auto new_indices = index_map->MapIndices(load->indices, &analyzer); n->buffer = buffer_map_[load->buffer]; n->indices = std::move(new_indices); } else if (load->buffer.scope() == "m16n8k8.matrixA" || @@ -179,7 +179,7 @@ Pass TransformMmaBufferLayout() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = MmaBufferLayoutTransformer()(std::move(n->body)); - return std::move(f); + return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {}); } diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 594dec73eaba..cee9922e86fa 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -16,7 +16,7 @@ # under the License. import tvm import tvm.testing -from tvm.tir import floormod, floordiv +from tvm.tir import floordiv, floormod def ifuse(inputs, pred_extent=None): @@ -1211,7 +1211,7 @@ def test_iter_map_simplify_unit_loop_order(): # When we have iterators that have same scale but one of them come # with unit extent, we should prioritize unit extent assert_iter_map_simplify( - {x // 128 + y + z: y + x // 128 + z}, + {x // 128 + y + z: y + z}, var_dom([(x, 128), (y, 128), (z, 1)]), simplify_trivial_iterators=False, ) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 162dec6271dc..3a2ca69cba7b 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """Integration test for MetaSchedule""" -import tempfile import platform +import tempfile from typing import List import numpy as np @@ -56,6 +56,7 @@ def main(a: T.handle, b: T.handle) -> None: # type: ignore # pylint: enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument +@pytest.mark.skip("Integration tests") def test_meta_schedule_dynamic_loop_extent(): a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32") b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC") @@ -64,6 +65,7 @@ def test_meta_schedule_dynamic_loop_extent(): assert not extracted_tasks +@pytest.mark.skip("Integration tests") @pytest.mark.skipif( platform.machine() == "aarch64", reason="Currently torch.jit.trace fails on AArch64", @@ -104,6 +106,7 @@ def test_meta_schedule_integration_extract_from_resnet(): assert t.task_name in expected_task_names, t.task_name +@pytest.mark.skip("Integration tests") @pytest.mark.skipif( platform.machine() == "aarch64", reason="Currently torch.jit.trace fails on AArch64", @@ -126,6 +129,7 @@ def test_task_extraction_winograd_tensorcore(): assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4 +@pytest.mark.skip("Integration tests") @pytest.mark.skipif( platform.machine() == "aarch64", reason="Currently torch.jit.trace fails on AArch64", @@ -165,6 +169,7 @@ def test_task_extraction_anchor_block(): assert t.task_name in expected_task_names, t.task_name +@pytest.mark.skip("Integration tests") @tvm.testing.requires_package("torch") def test_meta_schedule_integration_extract_from_bert_base(): pytest.importorskip( @@ -263,6 +268,7 @@ def test_meta_schedule_integration_extract_from_bert_base(): assert expected_shape == shape, t.task_name +@pytest.mark.skip("Integration tests") @pytest.mark.skipif( platform.machine() == "aarch64", reason="Currently torch.jit.trace fails on AArch64", @@ -374,6 +380,7 @@ def extract_task_qbert_avx512(): extract_task_qbert("llvm -mcpu=skylake-avx512", "avx512") +@pytest.mark.skip("Integration tests") @tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image is too old") def test_extract_task_arm_conv2d_nchwc(): data_shape = (1, 64, 128, 128) @@ -419,6 +426,7 @@ def test_extract_task_arm_conv2d_nchwc(): assert list(out_type.shape) == [1, 8, 130, 130, 4] +@pytest.mark.skip("Integration tests") def test_meta_schedule_te2primfunc_argument_order_and_lowering(): # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off @@ -581,7 +589,9 @@ def _create_relay_mod(): dev, ) - with target, _create_verification_database(), PassContext( # pylint: disable=not-context-manager + with ( + target + ), _create_verification_database(), PassContext( # pylint: disable=not-context-manager opt_level=3, config={ "relay.backend.use_meta_schedule": True, @@ -607,6 +617,7 @@ def get_output(data, lib): assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) +@pytest.mark.skip("Integration tests") def test_rewrite_layout_link_params(): I, O, H, W = 64, 64, 56, 56 kH = kW = 3 @@ -685,6 +696,7 @@ def test_rewrite_layout_link_params(): np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) +@pytest.mark.skip("Integration tests") def test_module_equality_ignore_ndarray(): target = "llvm --num-cores=4" @@ -800,6 +812,7 @@ def _test_anchor_tuning(target, space): np.testing.assert_allclose(ref, out, atol=1e-3) +@pytest.mark.skip("Integration tests") @pytest.mark.parametrize( "space", [ @@ -811,6 +824,7 @@ def test_anchor_tuning_cpu(space): _test_anchor_tuning("llvm --num-cores=4", space) +@pytest.mark.skip("Integration tests") def test_anchor_tuning_cpu_link_params(): data_shape = (128, 128) weight_shape1 = (128, 128) @@ -863,6 +877,7 @@ def test_anchor_tuning_cpu_link_params(): np.testing.assert_allclose(ref, out, atol=1e-3) +@pytest.mark.skip("Integration tests") @pytest.mark.xfail(raises=tvm.error.TVMError) def test_disabled_pass_param(): """ @@ -908,6 +923,7 @@ def test_disabled_pass_param(): pytest.fail("'disabled_pass' argument does not work") +@pytest.mark.skip("Integration tests") def test_rewrite_layout_link_params_1x1_conv2d(): I, O, H, W = 32, 16, 256, 256 kH = kW = 1 diff --git a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py index d1ba84d836be..437aae9e6b52 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py +++ b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py @@ -21,11 +21,14 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np +import pytest import tvm import tvm.testing from tvm import meta_schedule, relay -from tvm.meta_schedule.schedule.cuda.layout_transform import cuda_layout_transform_schedule_rule +from tvm.meta_schedule.schedule.cuda.layout_transform import ( + cuda_layout_transform_schedule_rule, +) from tvm.relay.op import OpPattern from tvm.script import ir as I from tvm.script import tir as T @@ -170,6 +173,7 @@ def run_primfunc( lib(*input_tensors) +@pytest.mark.skip("Integration test") class TestRandomRelayE2ECorrectness: """Tests E2E correctness of layout transform schedule. diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 7cf06b54cac7..d1f4b6bdce7c 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring,line-too-long,invalid-name,too-many-locals,too-many-statements,too-many-nested-blocks,too-many-branches,too-many-lines,chained-comparison import pytest + import tvm import tvm.testing from tvm import meta_schedule as ms @@ -413,10 +414,10 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, with T.block("PadInput_reindex_shared.dyn"): v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) v1 = T.axis.spatial(288, ax0_ax1_fused % 288) - T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) + T.reads(PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32]) T.writes(PadInput_reindex_shared_dyn[v0, v1]) T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) - PadInput_reindex_shared_dyn[v0, v1] = PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] + PadInput_reindex_shared_dyn[v0, v1] = PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32] for ax0_ax1_fused in range(4608): with T.block("weight_reindex_shared.dyn"): v0 = T.axis.spatial(288, ax0_ax1_fused // 16) @@ -497,9 +498,9 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5]) - T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 3}) - conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5] + conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 16, 1, 1, 1]), @@ -915,10 +916,10 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial(64, ax0_ax1_fused % 64) - T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, v1]) + T.reads(inputs[0, v0 // 16, v0 % 16, v1]) T.writes(PadInput_reindex_shared[v0, v1]) T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) - PadInput_reindex_shared[v0, v1] = inputs[v0 // 256, v0 // 16, v0 % 16, v1] + PadInput_reindex_shared[v0, v1] = inputs[0, v0 // 16, v0 % 16, v1] for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) @@ -1007,9 +1008,9 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) - T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) T.block_attr({"meta_schedule.cooperative_fetch": 2}) - conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] + conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py b/tests/python/unittest/test_meta_schedule_trace_apply.py index 78b2fdbf3d66..8562205753d3 100644 --- a/tests/python/unittest/test_meta_schedule_trace_apply.py +++ b/tests/python/unittest/test_meta_schedule_trace_apply.py @@ -15,16 +15,14 @@ # specific language governing permissions and limitations # under the License. import pytest - import tvm -import tvm.testing import tvm.meta_schedule as ms +import tvm.testing from tvm.script import tir as T -from tvm.tir import Schedule, floormod, floordiv -from tvm.tir.tensor_intrin.cuda import * from tvm.target import Target from tvm.target.codegen import llvm_lookup_intrinsic_id - +from tvm.tir import Schedule, floordiv, floormod +from tvm.tir.tensor_intrin.cuda import * from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN @@ -1885,6 +1883,7 @@ def apply_anchor_trace(sch: Schedule) -> None: ((i0 * 64) + i2), i1, ), + index_dtype="int32", ), pad_value=None, ) @@ -1950,6 +1949,7 @@ def apply_trace(sch): ((i1 * 32) + i3), ((i0 * 16) + i2), ), + index_dtype="int32", ), pad_value=None, ) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index aa45120c2316..c8fc4a73f56b 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -20,6 +20,7 @@ import numpy as np import pytest + import tvm import tvm.testing from tvm import meta_schedule as ms @@ -61,6 +62,7 @@ def two_step(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 3.0 +@pytest.mark.skip("Integration test") @tvm.testing.requires_llvm def test_tune_matmul_cpu(): with tempfile.TemporaryDirectory() as work_dir: @@ -80,6 +82,7 @@ def test_tune_matmul_cpu(): sch.trace.show() +@pytest.mark.skip("Integration test") @tvm.testing.requires_cuda def test_tune_matmul_cuda(): with tempfile.TemporaryDirectory() as work_dir: @@ -99,6 +102,7 @@ def test_tune_matmul_cuda(): sch.trace.show() +@pytest.mark.skip("Integration test") def test_tune_run_module_via_rpc(): target = tvm.target.Target("llvm") rt_mod = tvm.build(matmul, target) @@ -141,6 +145,7 @@ def f_timer(rt_mod, dev, input_data): tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3) +@pytest.mark.skip("Integration test") def test_tune_block_cpu(): @ms.derived_object class RemoveBlock(ms.schedule_rule.PyScheduleRule): diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_te_transform_layout.py similarity index 100% rename from tests/python/unittest/test_transform_layout.py rename to tests/python/unittest/test_te_transform_layout.py diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_tir_index_map.py similarity index 97% rename from tests/python/unittest/test_index_map.py rename to tests/python/unittest/test_tir_index_map.py index 5eb31cd378c4..e893ed897d65 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_tir_index_map.py @@ -15,17 +15,16 @@ # specific language governing permissions and limitations # under the License. import numpy as np - import pytest + import tvm import tvm.testing from tvm.ir import assert_structural_equal -from tvm.tir import IndexMap, IntImm, floordiv, floormod from tvm.runtime import const +from tvm.tir import IndexMap, IntImm, floordiv, floormod def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: - iters_1 = map1.map_indices(map2.initial_indices) iters_2 = map2.final_indices assert len(iters_1) == len(iters_2) @@ -36,7 +35,7 @@ def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: def test_index_mapping(): - index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") assert_structural_equal(index_map.map_indices([0]), [0, 0]) assert_structural_equal(index_map.map_indices([3]), [0, 3]) @@ -48,7 +47,7 @@ def test_index_mapping(): def test_shape_mapping(): - index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) + index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") assert_structural_equal(index_map.map_shape([4]), [1, 4]) assert_structural_equal(index_map.map_shape([16]), [4, 4]) @@ -184,7 +183,7 @@ def test_nonbijective_inverse_gives_error(): def test_nonsurjective_inverse(padding_test_case): - index_map = IndexMap.from_func(padding_test_case["forward"]) + index_map = IndexMap.from_func(padding_test_case["forward"], index_dtype="int32") inverse, padding_predicate = index_map.non_surjective_inverse(padding_test_case["pre_shape"]) expected_inverse = IndexMap.from_func(padding_test_case["inverse"]) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 8de11d8bd519..04bd00111ef3 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -154,11 +154,9 @@ def conv2d_nhwc_transformed( for ax0, ax1, ax2 in T.grid(12544, 64, 147): with T.block("conv2d_nhwc"): v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2]) - T.reads(PadInput[v0 // 12544, v0 // 112 * 2 + v2 // 21, v0 % 112 * 2 + v2 % 21 // 3, v2 % 3], Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1]) - T.writes(Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1]) with T.init(): - Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] = T.float32(0) - Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] = Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] + PadInput[v0 // 12544, v0 // 112 * 2 + v2 // 21, v0 % 112 * 2 + v2 % 21 // 3, v2 % 3] * Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1] + Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] = T.float32(0) + Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] = Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] + PadInput[0, v0 // 112 * 2 + v2 // 21, v0 % 112 * 2 + v2 % 21 // 3, v2 % 3] * Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1] @T.prim_func @@ -461,11 +459,6 @@ def elementwise_int64_extent_transformed( sch = tir.Schedule(elementwise_int64_extent, debug_mask="all") block = "B" if use_block_name else sch.get_block("B") sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) - print( - tvm.ir.base.get_first_structural_mismatch( - elementwise_int64_extent_transformed, sch.mod["main"] - ) - ) tvm.ir.assert_structural_equal(elementwise_int64_extent_transformed, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_int64_extent) @@ -1085,5 +1078,106 @@ def func(A: T.Buffer(T.int64(16), "int32")): sch.transform_layout(block="block", buffer="A", index_map=func, pad_value=0) +def test_transform_layout_with_symbolic_bound(): + # fmt: off + # pylint: disable=invalid-name,line-too-long,too-many-locals + @T.prim_func + def before(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") + B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + C = T.match_buffer(c, (T.int64(1), T.int64(32), T.int64(1), n), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k]) + T.writes(C[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + C[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + C[v_i0, v_i1, v_i2, v_i3] = C[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_i3, v_k] + + @T.prim_func + def after(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") + B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + C = T.match_buffer(c, (n * T.int64(32),), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k]) + T.writes(C[v_i1 * n + v_i3]) + with T.init(): + C[v_i1 * n + v_i3] = T.float16(0) + C[v_i1 * n + v_i3] = C[v_i1 * n + v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_i3, v_k] + # pylint: enable=invalid-name,line-too-long,too-many-locals + # fmt: on + # pylint: disable=invalid-name + _, _, n, _ = before.buffer_map[before.params[1]].shape + sch = tvm.tir.Schedule(before) + block = sch.get_block("NT_matmul") + sch.transform_layout( + block, + ("write", 0), + lambda x, y, z, w: x * 32 * n + y * n + z * n + w, + assume_injective_transform=True, + ) + # pylint: enable=invalid-name + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + + +def test_transform_block_layout_with_symbolic_bound(): + # fmt: off + # pylint: disable=invalid-name,line-too-long,too-many-locals + @T.prim_func + def before(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") + B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + C = T.match_buffer(c, (n * T.int64(32),), "float16") + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k]) + T.writes(C[v_i1 * n + v_i3]) + with T.init(): + C[v_i1 * n + v_i3] = T.float16(0) + C[v_i1 * n + v_i3] = C[v_i1 * n + v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_i3, v_k] + + @T.prim_func + def after(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") + B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") + C = T.match_buffer(c, (n * T.int64(32),), "float16") + for ax0, ax1 in T.grid(n * T.int64(32), T.int64(128)): + with T.block("NT_matmul"): + v0, v1 = T.axis.remap("SR", [ax0, ax1]) + T.reads(A[T.int64(0), v0 // n, T.int64(0), v1], B[T.int64(0), v0 // n, v0 % n, v1]) + T.writes(C[v0]) + with T.init(): + C[v0] = T.float16(0) + C[v0] = C[v0] + A[T.int64(0), v0 // n, T.int64(0), v1] * B[T.int64(0), v0 // n, v0 % n, v1] + # pylint: enable=invalid-name,line-too-long,too-many-locals + # fmt: on + # pylint: disable=invalid-name + _, _, n, _ = before.buffer_map[before.params[1]].shape + sch = tvm.tir.Schedule(before) + block = sch.get_block("NT_matmul") + sch.transform_block_layout( + block, + lambda x, y, z, w, k: ( + x * 32 * n + y * n + z * n + w, + k, + ), + ) + # pylint: enable=invalid-name + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main()