Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Easier use of Function.bind_params with relax.op.create.* #17323

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions src/relax/op/op_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,16 @@ Array<Expr> GetCallArgs(const Call& call) {
return args;
}

void CheckNumArguments(const Call& call, const BlockBuilder& ctx) {
void CheckNumArguments(const Call& call) {
Op op = Downcast<Op>(call->op);
int expected_input = op->arguments.size();
if (static_cast<int>(call->args.size()) != expected_input) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << op << " expects " << expected_input << " arguments"
<< ", but was called with " << call->args.size() << " arguments");
LOG(FATAL) << "Operator " << op << " expects " << expected_input << " arguments"
<< ", but was called with " << call->args.size() << " arguments";
}
}

TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) {
TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg) {
Op op = Downcast<Op>(call->op);

ICHECK_EQ(op->arguments.size(), call->args.size())
Expand All @@ -59,24 +58,19 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const
if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
return tensor_sinfo.value();
} else {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << op << " requires argument " << i_arg << " ("
<< op->arguments[i_arg]->name << ") to be a tensor. "
<< "However, the argument " << arg << " is instead of type " << sinfo);
// Unreachable, but [[noreturn]] attribute on virtual function
// `ReportFatal` is insufficient to silence -Wreturn-type, as
// child class might not be [[noreturn]].
return TensorStructInfo();
LOG(FATAL) << "Operator " << op << " requires argument " << i_arg << " ("
<< op->arguments[i_arg]->name << ") to be a tensor. "
<< "However, the argument " << arg << " is instead of type " << sinfo;
}
}

Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) {
CheckNumArguments(call, ctx);
Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call) {
CheckNumArguments(call);

Op op = Downcast<Op>(call->op);
Array<TensorStructInfo> input_tensor_sinfo;
for (size_t i = 0; i < call->args.size(); ++i) {
input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx));
input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i));
}
return input_tensor_sinfo;
}
Expand Down
90 changes: 72 additions & 18 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ namespace relax {

/************ Op input struct info getter ************/

/*!
* \brief Check that the operator has the correct number of arguments
*
* Verify that the number of arguments matches the expected number for
* the operator.
*
* \param call The context Call to the operator.
*/
void CheckNumArguments(const Call& call);

/*!
* \brief Check that the operator has
*
Expand All @@ -54,7 +64,17 @@ namespace relax {
*
* \param ctx The error reporting context.
*/
void CheckNumArguments(const Call& call, const BlockBuilder& ctx);
inline void CheckNumArguments(const Call& call, const BlockBuilder& ctx) {
CheckNumArguments(call);
}

/*!
* \brief Get the tensor struct info of the operator input.
* \param call The context Call to the operator.
* \param i_arg The index of the argument to check
* \return The tensor struct info of the argument
*/
TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg);

/*!
* \brief Get the tensor struct info of the operator input.
Expand All @@ -63,7 +83,19 @@ void CheckNumArguments(const Call& call, const BlockBuilder& ctx);
* \param ctx The error reporting context.
* \return The tensor struct info of the argument
*/
TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx);
inline TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg,
const BlockBuilder& ctx) {
return GetInputTensorStructInfo(call, i_arg);
}

/*!
* \brief Get the tensor struct info of the operator input.
* \param call The context Call to the operator.
* \return The tensor struct info of each input.
* \note This function require every input to be Tensor. The number of call arguments is required
* to match the number of inputs of the op being called.
*/
Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call);

/*!
* \brief Get the tensor struct info of the operator input.
Expand All @@ -73,7 +105,20 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const
* \note This function require every input to be Tensor. The number of call arguments is required
* to match the number of inputs of the op being called.
*/
Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx);
inline Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) {
return GetInputTensorStructInfo(call);
}

/*!
* \brief Get the tensor struct info of the unary operator input.
* \param call The context Call to the operator.
* \return The tensor struct info of the unary operator input.
* \throw Throw exception if the number of input is not one, or the struct info of the input is not
* a tensor struct info.
*/
inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call) {
return GetInputTensorStructInfo(call)[0];
}

/*!
* \brief Get the tensor struct info of the unary operator input.
Expand All @@ -84,7 +129,7 @@ Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBu
* a tensor struct info.
*/
inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) {
return GetInputTensorStructInfo(call, ctx)[0];
return GetUnaryInputTensorStructInfo(call);
}

/*!
Expand All @@ -101,22 +146,19 @@ Array<TensorStructInfo> GetTensorStructInfoFromTuple(const Call& call, const Blo
namespace detail {
/*! \brief Implementation helper for GetArgStructInfo */
template <typename ArgType>
ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, const BlockBuilder& ctx,
size_t index) {
ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, size_t index) {
if (!call->args[index]->struct_info_.defined()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " op should have arguments with defined StructInfo. "
<< "However, args[" << index << "] has undefined struct info.");
LOG(FATAL) << "Operator " << op << " should have arguments with defined StructInfo. "
<< "However, args[" << index << "] has undefined struct info.";
}

auto sinfo = GetStructInfo(call->args[index]);
auto typed_sinfo = sinfo.as<ArgType>();

if (!typed_sinfo.defined()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " requires that args[" << index << "] be a "
<< ArgType::ContainerType::_type_key << ", but was instead " << sinfo
<< " of type " << sinfo->GetTypeKey());
LOG(FATAL) << "Operator " << op << " requires that args[" << index << "] be a "
<< ArgType::ContainerType::_type_key << ", but was instead " << sinfo << " of type "
<< sinfo->GetTypeKey();
}

return typed_sinfo.value();
Expand All @@ -125,9 +167,8 @@ ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, const BlockBuild
/*! \brief Implementation helper for GetArgStructInfo */
template <typename... ArgTypes, size_t... Indices>
std::tuple<ArgTypes...> GetArgStructInfoHelper(const Call& call, const Op& op,
const BlockBuilder& ctx,
std::index_sequence<Indices...>) {
return std::tuple<ArgTypes...>{GetArgStructInfoByIndex<ArgTypes>(call, op, ctx, Indices)...};
return std::tuple<ArgTypes...>{GetArgStructInfoByIndex<ArgTypes>(call, op, Indices)...};
}
} // namespace detail

Expand All @@ -136,12 +177,11 @@ std::tuple<ArgTypes...> GetArgStructInfoHelper(const Call& call, const Op& op,
*
* \tparam ArgTypes The expected types of arguments, in the order they appear.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
* \return The tensor struct infos of tuple input.
* \throw Throw exception if input expression is not a tuple.
*/
template <typename... ArgTypes>
std::tuple<ArgTypes...> GetArgStructInfo(const Call& call, const BlockBuilder& ctx) {
std::tuple<ArgTypes...> GetArgStructInfo(const Call& call) {
Op op = Downcast<Op>(call->op);
size_t n_input = op->arguments.size();

Expand All @@ -154,7 +194,21 @@ std::tuple<ArgTypes...> GetArgStructInfo(const Call& call, const BlockBuilder& c
<< "but GetArgStructInfo was given " << sizeof...(ArgTypes) << " template arguments.";

return detail::GetArgStructInfoHelper<ArgTypes...>(
call, op, ctx, std::make_index_sequence<sizeof...(ArgTypes)>());
call, op, std::make_index_sequence<sizeof...(ArgTypes)>());
}

/*!
* \brief Get all arg struct infos as expected types
*
* \tparam ArgTypes The expected types of arguments, in the order they appear.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
* \return The tensor struct infos of tuple input.
* \throw Throw exception if input expression is not a tuple.
*/
template <typename... ArgTypes>
std::tuple<ArgTypes...> GetArgStructInfo(const Call& call, const BlockBuilder& ctx) {
return GetArgStructInfo<ArgTypes...>(call);
}

/************ Op registration macro ************/
Expand Down
Loading
Loading