Skip to content

Commit

Permalink
[Relax][MetaSchedule] Support CPU weight prepack (#17445)
Browse files Browse the repository at this point in the history
This PR adds support for CPU weight prepacking. To be specific, this PR
adds a new pass `AttachAttrLayoutFreeBuffers` to attach layout free buffers
to the weight parameters, so that we can leverage MetaSchedule to optimize
the prepacking process.

After the pass and tuning, we introduce a new pass `SplitLayoutRewritePreproc`
to split the layout rewrite pass into multiple functions, so that we can lift
the parameters transform pass function with existing pass.
  • Loading branch information
Hzfengsy authored Oct 16, 2024
1 parent c6a5b78 commit 8025041
Show file tree
Hide file tree
Showing 11 changed files with 1,083 additions and 3 deletions.
21 changes: 21 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,27 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_war
*/
TVM_DLL Pass RealizeVDevice();

/*!
* \brief Attach layout free buffers to the tir::PrimFunc.
*
* This pass is used to attach layout free buffers to the tir::PrimFunc according to
* the function usage in the relax function. Currently, the layout free buffers are the model
* weights and relax constants.
*
* \note We recommend applying CanonicalizeBindings before this pass.
* \return The Pass.
*/
TVM_DLL Pass AttachAttrLayoutFreeBuffers();

/*!
* \brief Split the layout rewrite preproc block to a separate tir::PrimFunc.
*
* This pass is used in the prepack weight after meta_schedule tuning.
*
* \return The Pass.
*/
TVM_DLL Pass SplitLayoutRewritePreproc();

/*!
* \brief Lift transformation of the parameters of a function.
*
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .modules import (
GELU,
Conv1D,
Conv2D,
Conv3D,
ConvTranspose1D,
Embedding,
GroupNorm,
Expand Down
50 changes: 49 additions & 1 deletion python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def static_shape_tuning_pipeline(
total_trials: int,
target: Union[str, tvm.target.Target],
work_dir: str = "tuning_logs",
cpu_weight_prepack: bool = False,
):
"""Tune the static shape model and store the log to database.
Expand All @@ -122,18 +123,65 @@ def static_shape_tuning_pipeline(
work_dir : str
The directory to store the tuning logs.
cpu_weight_prepack : bool
Whether to enable the cpu weight prepack feature.
Note
----
`cpu_weight_prepack` is expected to be `True` when running on CPU for
better performance. However, it requires an explicit layout transformation
step by calling the corresponding vm function, which changes the interface
of deployment. So we disable it by default. Here is an example to enable it:
.. code-block:: python
mod = relax.pipeline.static_shape_tuning_pipeline(
total_trials=1000,
target="llvm -num-cores 16",
work_dir="tuning_logs",
cpu_weight_prepack=True,
)(mod)
ex = relax.build(mod, target=target)
vm = relax.VirtualMachine(ex, device=tvm.cpu())
# Transform the params using the vm function
# the name should be f"{func_name}_transform_params"
params = vm["main_transform_params"](params["main"])
input_data = tvm.nd.array(np.random.randn(1, 3, 224, 224).astype("float32"))
out = vm["main"](input_data, *params).numpy()
"""

@tvm.transform.module_pass(opt_level=0)
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
if cpu_weight_prepack:
pre_tuning_layout_rewrite = [transform.AttachAttrLayoutFreeBuffers()]
post_tuning_layout_rewrite = [
transform.SplitLayoutRewritePreproc(),
transform.LiftTransformParams(),
transform.FoldConstant(),
]
else:
pre_tuning_layout_rewrite = []
post_tuning_layout_rewrite = []

with tvm.target.Target(target):
mod = tvm.transform.Sequential(
[
transform.DecomposeOpsForInference(),
transform.CanonicalizeBindings(),
zero_pipeline(),
transform.MetaScheduleTuneIRMod({}, work_dir, total_trials),
*pre_tuning_layout_rewrite,
# Skip tuning if total_trials is 0
(
transform.MetaScheduleTuneIRMod({}, work_dir, total_trials)
if total_trials > 0
else tvm.transform.Sequential([])
),
transform.MetaScheduleApplyDatabase(work_dir),
*post_tuning_layout_rewrite,
]
)(mod)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AllocateWorkspace,
AlterOpImpl,
AnnotateTIROpPattern,
AttachAttrLayoutFreeBuffers,
AttachGlobalSymbol,
BindParams,
BindSymbolicVars,
Expand Down Expand Up @@ -73,6 +74,7 @@
RewriteDataflowReshape,
RunCodegen,
SplitCallTIRByPattern,
SplitLayoutRewritePreproc,
StaticPlanBlockMemory,
ToMixedPrecision,
ToNonDataflow,
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,35 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
return _ffi_api.MergeCompositeFunctions() # type: ignore


def AttachAttrLayoutFreeBuffers() -> tvm.ir.transform.Pass:
"""Attach layout free buffers to the tir::PrimFunc.
This pass is used to attach layout free buffers to the tir::PrimFunc according to
the function usage in the relax function. Currently, the layout free buffers are the model
weights and relax constants.
Note that we recommend applying CanonicalizeBindings before this pass.
Returns
-------
ret : tvm.transform.Pass
The registered pass for attaching layout free buffers.
"""
return _ffi_api.AttachAttrLayoutFreeBuffers() # type: ignore


def SplitLayoutRewritePreproc() -> tvm.ir.transform.Pass:
"""Split the TIR layout rewrite into multiple TIR functions.
This pass is used in the prepack weight after meta_schedule tuning.
Returns
-------
ret : tvm.transform.Pass
The registered pass for splitting TIR layout rewrite.
"""
return _ffi_api.SplitLayoutRewritePreproc() # type: ignore


def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) -> tvm.ir.transform.Pass:
"""Lift transformation of the parameters of a function.
Expand Down
8 changes: 7 additions & 1 deletion src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,13 @@ class RewriteLayoutNode : public PostprocNode {
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); }
bool Apply(const tir::Schedule& sch) final {
try {
return tir::RewriteLayout(sch);
} catch (const std::runtime_error& e) {
return false;
}
}

Postproc Clone() const {
ObjectPtr<RewriteLayoutNode> n = make_object<RewriteLayoutNode>(*this);
Expand Down
113 changes: 113 additions & 0 deletions src/relax/transform/attach_attr_layout_free_buffers.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/attach_attr_layout_free_buffers.cc
* \brief Attach layout_free_buffers for layout-free buffers.
*/

#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace relax {

class AttrAttacher : public ExprMutator {
public:
static IRModule Transform(const IRModule& mod) {
AttrAttacher mutator(mod);
for (auto [gvar, func] : mod->functions) {
if (func->IsInstance<relax::FunctionNode>()) {
// clear the layout_free_exprs_ for each function
mutator.layout_free_exprs_.clear();
mutator.builder_->UpdateFunction(gvar, Downcast<BaseFunc>(mutator.VisitExpr(func)));
}
}
return mutator.builder_->GetContextIRModule();
}

private:
explicit AttrAttacher(IRModule mod) : ExprMutator(mod), mod_(mod) {}

using ExprMutator::VisitExpr_;
Expr VisitExpr_(const FunctionNode* op) final {
if (auto opt_num_input = op->attrs.GetAttr<Integer>(attr::kNumInput)) {
ICHECK(layout_free_exprs_.empty()) << "meet a non-global function with num_input attr";
size_t num_input = opt_num_input.value()->value;
for (size_t i = num_input; i < op->params.size(); i++) {
layout_free_exprs_.insert(op->params[i].get());
}
}
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const ConstantNode* op) final {
layout_free_exprs_.insert(op);
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const CallNode* op) final {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
if (call->op != call_tir_op_) {
return call;
}
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
Array<Expr> call_tir_args = Downcast<Tuple>(call->args[1])->fields;
// Compute the layout free buffers
Array<Integer> layout_free_buffers;
for (size_t i = 0; i < call_tir_args.size(); i++) {
if (layout_free_exprs_.count(call_tir_args[i].get())) {
layout_free_buffers.push_back(Integer(i));
}
}
// Attach the layout free buffers to the tir::PrimFunc
tir::PrimFunc func = WithAttr(Downcast<tir::PrimFunc>(mod_->Lookup(gv)), "layout_free_buffers",
layout_free_buffers);
// Renew defs
func = tir::RenewDefs(func);
// Add the updated tir::PrimFunc in the IRModule
// Note the blockbuilder would automatically combine the same tir function
// So we don't need to worry about the duplicate insertion
GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint);
// Create a new call node with the updated tir::PrimFunc
auto n = make_object<CallNode>(*op);
n->args = {new_gv, Tuple(call_tir_args)};
return Call(n);
}

private:
IRModule mod_;
std::unordered_set<const ExprNode*> layout_free_exprs_;
};
namespace transform {

Pass AttachAttrLayoutFreeBuffers() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return AttrAttacher::Transform(mod); };
auto pass = CreateModulePass(pass_func, 0, "_AttachAttrLayoutFreeBuffers", {});
// Apply DeadCodeElimination to remove unused tir::PrimFunc
return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers");
}

TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers")
.set_body_typed(AttachAttrLayoutFreeBuffers);
} // namespace transform
} // namespace relax
} // namespace tvm
Loading

0 comments on commit 8025041

Please sign in to comment.