Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Allow out_sinfo to be omitted from R.call_tir #17216

Closed
wants to merge 3 commits into from

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, the Relax type produced by calling a TIR PrimFunc needed to be explicitly specified using the out_sinfo argument. These output shapes are required in order to allocate output tensors during the CallTIRRewrite lowering pass. However, specifying them explicitly, especially in hand-written functions, duplicates information that is already present in the PrimFunc signature, and introduces the potential for inconsistencies.

This commit updates the MakeCallTIR function to infer out_sinfo if not explicitly specified. This inference uses the number of relax arguments to identify output parameters in the signature of the PrimFunc, which then become the return values from R.call_tir. Currently, this inference of out_sinfo occurs when constructing the relax::Call object, after which the out_sinfo is always present in the Relax IR.

Prior to this commit, the Relax type produced by calling a TIR
PrimFunc needed to be explicitly specified using the `out_sinfo`
argument.  These output shapes are required in order to allocate
output tensors during the `CallTIRRewrite` lowering pass.  However,
specifying them explicitly, especially in hand-written functions,
duplicates information that is already present in the `PrimFunc`
signature, and introduces the potential for inconsistencies.

This commit updates the `MakeCallTIR` function to infer `out_sinfo` if
not explicitly specified.  This inference uses the number of relax
arguments to identify output parameters in the signature of the
`PrimFunc`, which then become the return values from `R.call_tir`.
Currently, this inference of `out_sinfo` occurs when constructing the
`relax::Call` object, after which the `out_sinfo` is always present in
the Relax IR.
@@ -331,8 +331,133 @@ RELAY_REGISTER_OP("relax.call_tir")
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
static Array<TensorStructInfo> InferCallTIROutputStructInfo(Expr func, Tuple args,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that this is not always possible to do such inference. Since it is possible to have tir functions like reshape, where the output shape is being explicitly specified via the destination. For the particular low-level call_tir op. I think it is safer to always ask for the sinfo, then explicitly checks the consistency to avoid error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely agreed that we should check for consistency after generating the IR, and that's something I want to add to the well-formed checker as well. This specific PR would be to avoid inconsistency while generating the IR.

(And if we can't infer the output shape, then the output shape must still be be explicitly provided.)

@tqchen
Copy link
Member

tqchen commented Jul 30, 2024

Just want to note that it is not always possible to do such inference.

class IRModule:
    @T.prim_func
    def reshape(A : Buffer((2, 4)), B: Buffer((n, m)):
    
    def main(A: Buffer((2, 4))):
         lv0 = R.call_tir(reshape, [A], R.Tensor((1, 8)))

For example, the above code is a valid tir call, but needs the output sinfo to be explicitly specified. Because we have such cases, and call_tir is a lower level function, it is safer to always ask for sinfo, but checks its consistency with the corresponding prim_func signature if needed

@Lunderberg
Copy link
Contributor Author

For example, the above code is a valid tir call, but needs the output sinfo to be explicitly specified. Because we have such cases, and call_tir is a lower level function, it is safer to always ask for sinfo, but checks its consistency with the corresponding prim_func signature if needed

That's a good point, and I agree that we should always be able to explicitly specify the output struct info, as output tensor shapes in TIR may define symbolic shapes. However, I don't think it should a required argument.

I've added a new test case, based on your example with reshape, to validate the behavior when the output shape cannot be inferred. While the initial implementation did identify this failure and throw an error, the error message wasn't ideal. I've added an earlier check for non-inferable output shapes, so that the error message can direct the user to provide the out_sinfo field.

Does the udpated check/error messages address your concerns for this PR?

@tqchen
Copy link
Member

tqchen commented Jul 31, 2024

I think this is mainly a design consideration here on what do we view the intended use of CreateCallTIR, in terms of different expectations we have on caller of the function. I can see some merits on auto deduction or call for explicitness

Given call_tir is lower level, having "less automation" here during pass and have explicitly checking would ensure correctness while indeed asking pass writers to do a bit more. It is like explicitly annotating types when writing c++ code versus writing auto. I think encouraging pass writers to explicitly think about the DPS pattern and always provide the return argument helps to reduce uncertainty here. While I can indeed see some merits of automated decusion, given it is not always possible, I still prefer we have the explicitness and provide good amount of consistency checking

@Lunderberg
Copy link
Contributor Author

I think encouraging pass writers to explicitly think about the DPS pattern and always provide the return argument helps to reduce uncertainty here.

While I think this would be an interesting point to discuss, I don't think it's relevant to this specific change. This PR keeps the exact same out_sinfo in the C++ IR types, and still requires pass writers to explicitly provide the output info. The MakeCallTIR function is not exposed to the back-end C++ API, only through the front-end Python API.

This change is solely in the front-end, for cases where an IRModule is being hand-written. I'd like to make that use-case less error-prone.

@tqchen
Copy link
Member

tqchen commented Jul 31, 2024

Thanks for pointing out the frontend case, I still think being explicit is helpful and aims for a consistency check with good error messages. Having such explicit argument makes the "intent" clear, with the explicit sinfo, we can write down the semantics in a clear fashion

def call_tir(func, args, out_sinfo):
     out = alloc_outputs(out_sinfo)
     func(*args, unpack_outputs(out))
     return out

omitting the out_sinfo, while indeed ok in some cases, was not always derivable, and the intent was less clear. I know the arguments can go another way to reduce the amount users type. In this particular case, having good well form check about consistency would help a lot toward that direction

@Lunderberg
Copy link
Contributor Author

Having such explicit argument makes the "intent" clear, with the explicit sinfo, we can write down the semantics in a clear fashion

Good point on the semantics. This change would add an additional step to the user-facing semantics of R.call_tir.

def call_tir(func, args, out_sinfo):
    if out_sinfo is None:
        out_sinfo = infer_out_sinfo(func, args) # may throw
        
    out = alloc_outputs(out_sinfo)
    func(*args, unpack_outputs(out))
    return out

I suppose that I'm getting stuck on is the "intent" part. While there are exceptions, in the majority of cases, there's one and only one correct value for out_sinfo. Since the user doesn't have any choice in it, we can't infer any intention from the user about it. On the other hand, if the user has the option of omitting the out_sinfo, then we could distinguish between the intent of "use whichever output is valid" (e.g. R.call_tir(unary_abs, [x])) and "verify and use the output I expect" (e.g. R.call_tir(unary_abs, [x], R.Tensor([16],'float16'))).

In this particular case, having good well form check about consistency would help a lot toward that direction

Agreed. I think for now, let's put this PR on hold, and I'll update the well-formed checker to verify consistent between the R.call_tir callee and the input/output arguments. (Since that's a change that we both agree on, and covers many of the same error modes.)

@tqchen
Copy link
Member

tqchen commented Sep 23, 2024

closed in favor of #17285

@tqchen tqchen closed this Sep 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants