-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax] Allow out_sinfo
to be omitted from R.call_tir
#17216
Conversation
Prior to this commit, the Relax type produced by calling a TIR PrimFunc needed to be explicitly specified using the `out_sinfo` argument. These output shapes are required in order to allocate output tensors during the `CallTIRRewrite` lowering pass. However, specifying them explicitly, especially in hand-written functions, duplicates information that is already present in the `PrimFunc` signature, and introduces the potential for inconsistencies. This commit updates the `MakeCallTIR` function to infer `out_sinfo` if not explicitly specified. This inference uses the number of relax arguments to identify output parameters in the signature of the `PrimFunc`, which then become the return values from `R.call_tir`. Currently, this inference of `out_sinfo` occurs when constructing the `relax::Call` object, after which the `out_sinfo` is always present in the Relax IR.
@@ -331,8 +331,133 @@ RELAY_REGISTER_OP("relax.call_tir") | |||
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR) | |||
.set_attr<Bool>("FPurity", Bool(true)); | |||
|
|||
Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list, | |||
static Array<TensorStructInfo> InferCallTIROutputStructInfo(Expr func, Tuple args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to note is that this is not always possible to do such inference. Since it is possible to have tir functions like reshape, where the output shape is being explicitly specified via the destination. For the particular low-level call_tir op. I think it is safer to always ask for the sinfo, then explicitly checks the consistency to avoid error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely agreed that we should check for consistency after generating the IR, and that's something I want to add to the well-formed checker as well. This specific PR would be to avoid inconsistency while generating the IR.
(And if we can't infer the output shape, then the output shape must still be be explicitly provided.)
Just want to note that it is not always possible to do such inference. class IRModule:
@T.prim_func
def reshape(A : Buffer((2, 4)), B: Buffer((n, m)):
def main(A: Buffer((2, 4))):
lv0 = R.call_tir(reshape, [A], R.Tensor((1, 8))) For example, the above code is a valid tir call, but needs the output sinfo to be explicitly specified. Because we have such cases, and |
That's a good point, and I agree that we should always be able to explicitly specify the output struct info, as output tensor shapes in TIR may define symbolic shapes. However, I don't think it should a required argument. I've added a new test case, based on your example with Does the udpated check/error messages address your concerns for this PR? |
I think this is mainly a design consideration here on what do we view the intended use of CreateCallTIR, in terms of different expectations we have on caller of the function. I can see some merits on auto deduction or call for explicitness Given |
While I think this would be an interesting point to discuss, I don't think it's relevant to this specific change. This PR keeps the exact same This change is solely in the front-end, for cases where an |
Thanks for pointing out the frontend case, I still think being explicit is helpful and aims for a consistency check with good error messages. Having such explicit argument makes the "intent" clear, with the explicit sinfo, we can write down the semantics in a clear fashion def call_tir(func, args, out_sinfo):
out = alloc_outputs(out_sinfo)
func(*args, unpack_outputs(out))
return out omitting the out_sinfo, while indeed ok in some cases, was not always derivable, and the intent was less clear. I know the arguments can go another way to reduce the amount users type. In this particular case, having good well form check about consistency would help a lot toward that direction |
Good point on the semantics. This change would add an additional step to the user-facing semantics of def call_tir(func, args, out_sinfo):
if out_sinfo is None:
out_sinfo = infer_out_sinfo(func, args) # may throw
out = alloc_outputs(out_sinfo)
func(*args, unpack_outputs(out))
return out I suppose that I'm getting stuck on is the "intent" part. While there are exceptions, in the majority of cases, there's one and only one correct value for
Agreed. I think for now, let's put this PR on hold, and I'll update the well-formed checker to verify consistent between the |
closed in favor of #17285 |
Prior to this commit, the Relax type produced by calling a TIR PrimFunc needed to be explicitly specified using the
out_sinfo
argument. These output shapes are required in order to allocate output tensors during theCallTIRRewrite
lowering pass. However, specifying them explicitly, especially in hand-written functions, duplicates information that is already present in thePrimFunc
signature, and introduces the potential for inconsistencies.This commit updates the
MakeCallTIR
function to inferout_sinfo
if not explicitly specified. This inference uses the number of relax arguments to identify output parameters in the signature of thePrimFunc
, which then become the return values fromR.call_tir
. Currently, this inference ofout_sinfo
occurs when constructing therelax::Call
object, after which theout_sinfo
is always present in the Relax IR.