From 35d6a1b9d27f1128bd00edef541be0d1f9f61dd9 Mon Sep 17 00:00:00 2001 From: albert qing <2628869@qq.com> Date: Wed, 16 Oct 2024 10:50:32 +0800 Subject: [PATCH] [TIR][Schedule] Add annotate_buffer_access primitive (#17423) Co-authored-by: qsqqsqqsq-intellif --- include/tvm/tir/schedule/schedule.h | 11 + include/tvm/tir/stmt.h | 10 + python/tvm/tir/schedule/schedule.py | 136 +++++++ src/tir/schedule/concrete_schedule.cc | 10 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 10 + .../primitive/annotate_buffer_access.cc | 167 +++++++++ src/tir/schedule/schedule.cc | 7 + src/tir/schedule/traced_schedule.cc | 12 + src/tir/schedule/traced_schedule.h | 2 + src/tir/transforms/compact_buffer_region.cc | 43 ++- ...est_tir_schedule_annotate_buffer_access.py | 332 ++++++++++++++++++ 12 files changed, 736 insertions(+), 6 deletions(-) create mode 100644 src/tir/schedule/primitive/annotate_buffer_access.cc create mode 100644 tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 092bd52d5634..e4b13888f948 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -834,6 +834,17 @@ class ScheduleNode : public runtime::Object { */ virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; + /*! + * \brief Annotate the buffer access of a block + * \param block_rv The block to be annotated + * \param buffer_index The index of the buffer in block's read or write region + * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \param index_map The index map that defines the new read or write region + */ + virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index c77254ed34cb..38289af463d5 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1664,6 +1664,16 @@ constexpr const char* warp_execution = "warp_execution"; /*! \brief Mark that a block is disallowed in auto inline. */ constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule"; +/*! \brief Mark that a block has an explicitly specified read region. + * This is used to override the default read region inference in TIR. + */ +constexpr const char* explicit_read_region = "explicit_read_region"; + +/*! \brief Mark that a block has an explicitly specified write region. + * This is used to override the default write region inference in TIR. + */ +constexpr const char* explicit_write_region = "explicit_write_region"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index be88e234634f..17c256be3538 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -3907,3 +3907,139 @@ def unsafe_hide_buffer_access( buf_type, buf_index_array, ) + + @type_checked + def annotate_buffer_access( + self, block: BlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable + ) -> None: + """Annotate the read or write region of a block + + Parameters + ---------- + block : BlockRV + The block to be annotated + buffer_index : int + The index of the buffer in block's read or write region + buf_type : str + The buffer type: "read" or "write" + gen_new_ranges : Callable + A function that takes the block's iter_vars and returns a + Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], ...] + which defines the new read or write region for the buffer. + Each element in the tuple can be: + - A single PrimExpr representing the iter_var itself + - A tuple of two PrimExprs representing the range (begin, end) + + Examples + -------- + Annotate a 2D read region for a buffer. + Before annotate_buffer_access, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_annotate_buffer_access( + A: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do annotate_buffer_access: + + .. code-block:: python + + sch = tir.Schedule(before_annotate_buffer_access) + block = sch.get_block("B") + sch.annotate_buffer_access(block, 0, "read", + lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1))) + print(sch.mod["main"].script()) + + After applying annotate_buffer_access, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_annotate_buffer_access( + A: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1:vi + 1, vj - 1:vj + 1]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": 0}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + This annotates the read region for buffer A (index 0) in block "B" to be + [vi-1:vi+1, vj-1:vj+1] for each (vi, vj) in the block's iteration domain. + + Note + ---- + This function allows manual specification of read or write regions, which + can be useful in cases where the compiler cannot accurately infer the + access pattern, such as complex data-dependent accesses. + It overrides the automatically inferred region for the specified buffer. + The function adds an annotation to the block, indicating that an explicit + region has been provided for the buffer at the given index. This annotation + is used in the CompactBufferAllocation pass to respect the manually specified + region instead of relying on automatic inference. + + Caution should be exercised when using this function, as incorrect annotations + may lead to incorrect code generation or runtime errors. It's crucial to + ensure that the specified region covers all actual reads or writes performed + by the block for the given buffer. + + """ + block_obj = self.get(block) + iter_vars = [x.var for x in block_obj.iter_vars] + new_ranges_spec = gen_new_ranges(*iter_vars) + if len(iter_vars) != len(new_ranges_spec): + raise ValueError( + f"Number of iter_vars ({len(iter_vars)}) must match " + f"number of new_ranges_spec ({len(new_ranges_spec)})" + ) + + result = [] + for rng in new_ranges_spec: + if isinstance(rng, (tuple, list)): + if len(rng) != 2: + raise ValueError( + "Tuple must have exactly 2 elements to represent (begin, end)." + ) + result.extend(rng) + elif isinstance(rng, PrimExpr): + result.extend([rng, rng + 1]) # Single point represented as (rng, rng + 1) + else: + raise TypeError(f"Expected PrimExpr or tuple of PrimExpr, got {type(rng)}") + + # Create index_map using IndexMap constructor + index_map = IndexMap( + initial_indices=iter_vars, + final_indices=result, + inverse_index_map=None, + ) + + if buf_type == "read": + buffer_index_type = 0 + elif buf_type == "write": + buffer_index_type = 1 + else: + raise ValueError(f"Invalid buf_type: {buf_type}. Expected 'read' or 'write'.") + + return _ffi_api.ScheduleAnnotateBufferAccess( + self, block, buffer_index, buffer_index_type, index_map + ) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 73b5ff3fafd4..f6cb1f05ef6e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -1059,5 +1059,15 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const this->state_->DebugVerify(); } +void ConcreteScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, + index_map); + TVM_TIR_SCHEDULE_END("annotate-buffer-access", this->error_render_level_); + this->state_->DebugVerify(); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 092bcf0c79f9..b8ad56d2ab56 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -183,6 +183,8 @@ class ConcreteScheduleNode : public ScheduleNode { void EnterPostproc() override {} void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, const Array& buf_index_array) override; + void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) override; protected: /******** Utility functions ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fd1349e4a3ec..cf1ac957c89f 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -718,6 +718,16 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, const String& buf_type, const Array& buf_index_array); +/*! + * \brief Annotate the read or write region of a specific buffer in a block + * \param self The state of the schedule + * \param block_sref The sref of the block to be annotated + * \param buffer_index The index of the buffer in block's read or write region + * \param buffer_index_type The type of the buffer index, kRead or kWrite + * \param index_map The IndexMap that defines the new read or write region for the buffer + */ +TVM_DLL void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/annotate_buffer_access.cc b/src/tir/schedule/primitive/annotate_buffer_access.cc new file mode 100644 index 000000000000..2c5976b035dd --- /dev/null +++ b/src/tir/schedule/primitive/annotate_buffer_access.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class AnnotateRegionRewriter : public StmtExprMutator { + public: + AnnotateRegionRewriter(Buffer buffer, int buffer_index, BufferRegion new_region, + BufferIndexType buffer_index_type) + : buffer_(buffer), + buffer_index_(buffer_index), + new_region_(new_region), + buffer_index_type_(buffer_index_type) {} + + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + Array regions = + buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads; + ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative"; + ICHECK_LT(buffer_index_, static_cast(regions.size())) << "Buffer index out of range"; + regions.Set(buffer_index_, new_region_); + + ObjectPtr n = CopyOnWrite(block.get()); + if (buffer_index_type_ == BufferIndexType::kWrite) { + n->writes = std::move(regions); + } else { + n->reads = std::move(regions); + } + + // Annotate the block with explicit_read_region or explicit_write_region + Map new_annotations = n->annotations; + String annotation_key = buffer_index_type_ == BufferIndexType::kWrite + ? attr::explicit_write_region + : attr::explicit_read_region; + if (new_annotations.count(annotation_key)) { + Array buffer_indices = Downcast>(new_annotations[annotation_key]); + bool found = false; + for (const Integer& index : buffer_indices) { + if (index->value == buffer_index_) { + found = true; + break; + } + } + if (!found) { + buffer_indices.push_back(Integer(buffer_index_)); + new_annotations.Set(annotation_key, buffer_indices); + } + } else { + new_annotations.Set(annotation_key, Array{Integer(buffer_index_)}); + } + n->annotations = std::move(new_annotations); + + return Block(n); + } + + private: + Buffer buffer_; + int buffer_index_; + BufferRegion new_region_; + BufferIndexType buffer_index_type_; +}; + +void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, buffer_index_type); + + arith::Analyzer analyzer; + Array block_iter_vars; + for (const IterVar& iter_var : block->iter_vars) { + block_iter_vars.push_back(iter_var->var); + } + Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; + Array new_ranges; + for (size_t i = 0; i < new_indices.size(); i += 2) { + // (begin, end) represents a region + new_ranges.push_back(Range::FromMinExtent( + new_indices[i], analyzer.Simplify(new_indices[i + 1] - new_indices[i]))); + } + + BufferRegion new_region(buffer, new_ranges); + + AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type); + Stmt new_stmt = mutator(GetRef(block_sref->stmt)); + + self->Replace(block_sref, new_stmt, {{GetRef(block), Downcast(new_stmt)}}); +} + +struct AnnotateBufferAccessTraits : public UnpackedInstTraits { + static constexpr const char* kName = "AnnotateBufferAccess"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 4; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, + Integer buffer_index_type, IndexMap index_map) { + return sch->AnnotateBufferAccess(block, buffer_index->value, + static_cast(buffer_index_type->value), + index_map); + } + + static String IndexMap2GenNewRangesLambda(const IndexMap& index_map) { + std::ostringstream oss; + oss << "lambda "; + for (size_t i = 0; i < index_map->initial_indices.size(); ++i) { + if (i != 0) oss << ", "; + oss << index_map->initial_indices[i]; + } + oss << ": ["; + for (size_t i = 0; i < index_map->final_indices.size(); i += 2) { + if (i != 0) oss << ", "; + if (index_map->final_indices[i].same_as(index_map->final_indices[i + 1])) { + oss << index_map->final_indices[i]; + } else { + oss << "(" << index_map->final_indices[i] << ", " << index_map->final_indices[i + 1] << ")"; + } + } + oss << "]"; + return String(oss.str()); + } + + static String UnpackedAsPython(Array outputs, String block, Integer buffer_index, + Integer buffer_index_type, IndexMap index_map) { + PythonAPICall py("annotate_buffer_access"); + py.Input("block", block); + py.Input("buffer_index", buffer_index->value); + + std::ostringstream os; + os << "\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\""; + py.Input("buf_type", os.str()); + + py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map)); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(AnnotateBufferAccessTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 44f9b8f42c68..2c3661d17ecc 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -310,6 +310,13 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") .set_body_method(&ScheduleNode::UnsafeHideBufferAccess); +/******** (FFI) Annotate buffer access ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") + .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, + int buffer_index_type, const IndexMap& index_map) { + return self->AnnotateBufferAccess(block_rv, buffer_index, + static_cast(buffer_index_type), index_map); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 1611109d7735..d790f21e671a 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -769,5 +769,17 @@ void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const S /*outputs=*/{})); } +void TracedScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) { + ConcreteScheduleNode::AnnotateBufferAccess(block_rv, buffer_index, buffer_index_type, index_map); + static const InstructionKind& kind = InstructionKind::Get("AnnotateBufferAccess"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, Integer(buffer_index), Integer(buffer_index_type), index_map}, + /*attrs=*/{}, + /*outputs=*/{})); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 78629e84f039..1c21c3e2c894 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -142,6 +142,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { void EnterPostproc() final; void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, const Array& buf_index_array) final; + void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) final; }; } // namespace tir diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index f562a057e595..7385af49528b 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -136,7 +136,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + auto explicit_it = explicit_access_annotations_.find(op->buffer); + if (explicit_it != explicit_access_annotations_.end()) { + VisitBufferAccess(explicit_it->second); + } else { + VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + } StmtExprVisitor::VisitExpr_(op); } @@ -235,17 +240,38 @@ class BufferAccessRegionCollector : public StmtExprVisitor { auto& regions = access_annotations_[p.first]; p.second.swap(regions); } - // Step 2. Record relax position of ancestor_loops_ + + // Step 2. Record explicit read/write region annotations + auto record_explicit_region = [&](const String& attr_key, BufferIndexType index_type) { + auto it = op->annotations.find(attr_key); + if (it != op->annotations.end()) { + Array buffer_indices = Downcast>((*it).second); + for (const auto& index : buffer_indices) { + int buffer_index = index->value; + if (buffer_index >= 0 && buffer_index < static_cast(op->reads.size())) { + const BufferRegion& explicit_region = index_type == BufferIndexType::kRead + ? op->reads[buffer_index] + : op->writes[buffer_index]; + explicit_access_annotations_[explicit_region->buffer] = explicit_region; + } + } + } + }; + + record_explicit_region(attr::explicit_read_region, BufferIndexType::kRead); + record_explicit_region(attr::explicit_write_region, BufferIndexType::kWrite); + + // Step 3. Record relax position of ancestor_loops_ for (const Buffer& buffer : op->alloc_buffers) { VisitBufferDef(buffer->data); } - // Step 3. Visit match buffers + // Step 4. Visit match buffers for (const MatchBufferRegion& region : op->match_buffers) { VisitBufferAccess(region->source); } - // Step 4. Visit block body recursively + // Step 5. Visit block body recursively StmtExprVisitor::VisitStmt_(op); - // Step 5. Recover read/write region annotations + // Step 6. Recover read/write region annotations for (auto& p : cur_access_annotations) { auto& regions = access_annotations_[p.first]; if (p.second.empty()) { @@ -254,7 +280,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { regions.swap(p.second); } } - // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. + // Step 7. Clear explicit access annotations + explicit_access_annotations_.clear(); + // Step 8. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. for (const Buffer& buffer : op->alloc_buffers) { ICHECK_EQ(var2buffer_[buffer->data].size(), 1) << "Block allocation buffer shoud not be alised"; @@ -489,6 +517,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief The map from Buffer to it's access regions annotated by current block. */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> access_annotations_; + /*! \brief The map from Buffer to its explicit access region annotated by the block. */ + std::unordered_map + explicit_access_annotations_; }; /*! \brief The storage alignment for a dimension */ diff --git a/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py new file mode 100644 index 000000000000..cc09a807dcac --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) + + +def test_annotate_read_buffer_access(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1 : vi - 1 + 2, vj - 1 : vj - 1 + 2]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1)) + ) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_write_buffer_access(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi : vi + 2, vj : vj + 2]) + T.block_attr({"explicit_write_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + sch.annotate_buffer_access(block, 0, "write", lambda vi, vj: ((vi, vi + 2), (vj, vj + 2))) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_buffer_access_for_resize(): + # fmt: off + @T.prim_func + def resize_before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): + for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, 0:32, 0:32]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)])) + + @T.prim_func + def resize_expected(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): + for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 + 3, v_i3 * 2 - 3:v_i3 * 2 + 3]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"explicit_read_region": [0]}) + resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)])) + # fmt: on + sch = tir.Schedule(resize_before, debug_mask="all") + block = sch.get_block("resize") + sch.annotate_buffer_access( + block, + 0, + "read", + gen_new_ranges=lambda v_i0, v_i1, v_i2, v_i3: [ + v_i0, + v_i1, + (v_i2 * 2 - 3, v_i2 * 2 + 3), + (v_i3 * 2 - 3, v_i3 * 2 + 3), + ], + ) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], resize_expected) + verify_trace_roundtrip(sch=sch, mod=resize_before) + + +def test_annotate_buffer_access_read_and_write(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1 : vi + 2, vj - 1 : vj + 2]) + T.writes(B[vi : vi + 2, vj : vj + 2]) + T.block_attr({"explicit_read_region": [0], "explicit_write_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) + ) + + sch.annotate_buffer_access(block, 0, "write", lambda vi, vj: ((vi, vi + 2), (vj, vj + 2))) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_double_annotate_buffer_access_read(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 2 : vi + 3, vj - 2 : vj + 3]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) + ) + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 2, vi + 3), (vj - 2, vj + 3)) + ) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_buffer_access_with_compute_at_for_resize(): + # fmt: off + @T.prim_func + def before(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer([1, 3, 200, 200], dtype="float32") + for ax0, ax1, ax2, ax3 in T.grid(1, 3, 200, 200): + with T.block("cache"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 3, 100, 100): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(v_i2 * 2 + 0.5)), T.Cast("int32", T.floor(v_i3 * 2 + 0.5))] + + @T.prim_func + def after(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer((1, 3, 200, 200)) + for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): + for ax0, ax1 in T.grid(24, 24): + with T.block("cache"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(3, i1) + v2 = T.axis.spatial(200, i2_0 * 20 - 3 + ax0) + v3 = T.axis.spatial(200, i3_0 * 20 - 3 + ax1) + T.where(3 <= i2_0 * 20 + ax0 and i2_0 * 20 + ax0 < 203 and 3 <= i3_0 * 20 + ax1 and i3_0 * 20 + ax1 < 203) + T.reads(x[v0, v1, v2, v3]) + T.writes(x_global[v0, v1, v2, v3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i2_1, i3_1 in T.grid(10, 10): + with T.block("resize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) + v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) + T.reads(x_global[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 - 3 + 6, v_i3 * 2 - 3:v_i3 * 2 - 3 + 6]) + T.writes(y[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"explicit_read_region": [0]}) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))] + + @T.prim_func + def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer((1, 3, 200, 200)) + for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): + for ax0, ax1 in T.grid(200, 200): + with T.block("cache"): + v0 = T.axis.spatial(1, 0) + v1, v2, v3 = T.axis.remap("SSS", [i1, ax0, ax1]) + T.reads(x[v0, v1, v2, v3]) + T.writes(x_global[v0, v1, v2, v3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i2_1, i3_1 in T.grid(10, 10): + with T.block("resize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) + v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) + T.reads(x_global[v_i0, v_i1, 0:200, 0:200]) + T.writes(y[v_i0, v_i1, v_i2, v_i3]) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))] + # fmt: on + + # Schedule with annotate_buffer_access + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("resize") + cache_block = sch.get_block("cache") + + # Annotate buffer access + sch.annotate_buffer_access( + block, + 0, + "read", + lambda vn, vc, vh, vw: (vn, vc, (vh * 2 - 3, vh * 2 + 3), (vw * 2 - 3, vw * 2 + 3)), + ) + + h, w = sch.get_loops(block)[-2:] + ho, hi = sch.split(h, factors=[10, 10]) + wo, wi = sch.split(w, factors=[10, 10]) + sch.reorder(ho, wo, hi, wi) + sch.compute_at(cache_block, wo) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], after) + verify_trace_roundtrip(sch=sch, mod=before) + + # Schedule without annotate_buffer_access + sch_without_annotate = tir.Schedule(before, debug_mask="all") + block_without_annotate = sch_without_annotate.get_block("resize") + cache_block_without_annotate = sch_without_annotate.get_block("cache") + + h, w = sch_without_annotate.get_loops(block_without_annotate)[-2:] + ho, hi = sch_without_annotate.split(h, factors=[10, 10]) + wo, wi = sch_without_annotate.split(w, factors=[10, 10]) + sch_without_annotate.reorder(ho, wo, hi, wi) + sch_without_annotate.compute_at(cache_block_without_annotate, wo) + + assert_structural_equal_ignore_global_symbol( + sch_without_annotate.mod["main"], after_without_annotate_buffer_access + ) + + +if __name__ == "__main__": + tvm.testing.main()