Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update scaled_dot_product_attention to work with >6 inputs in latest torch version #2021

Merged

Conversation

ZachNagengast
Copy link
Contributor

@ZachNagengast ZachNagengast commented Oct 20, 2023

In torch 2.1, scaled_dot_product_attention now has 7 inputs, whereas 2.0 has only 6.

2.0 docs: https://pytorch.org/docs/2.0/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
2.1 docs: https://pytorch.org/docs/2.1/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention

This PR will allow for both interfaces to work, and simply ignore the new one. This issue arises for any conversion that uses this op on Torch 2.1:

ValueError: node hidden_states.11 (scaled_dot_product_attention) got 7 input(s), expected [6]

@TobyRoseman
Copy link
Collaborator

Thanks for the review @ZachNagengast. I actually have a fix for this issue locally that I was planning to put up for a PR soon. I was waiting to bundle it with a few other small changes.

We should raise an error if the 7th value (i.e. the scale) is not None. We should also raise an error if the number of parameters is more than 7. Feel free to make these changes, or I can put up my fix.

@ZachNagengast
Copy link
Contributor Author

Yea no problem, I'll add it here shortly.

@ZachNagengast
Copy link
Contributor Author

@TobyRoseman This should be all set, although I'm not sure how to trigger CI here.

@TobyRoseman
Copy link
Collaborator

Thanks @ZachNagengast for the changes. It looks good to me.

I have to kick off the CI: https://gitlab.com/coremltools1/coremltools/-/pipelines/1046916275

@Zulqurnain24
Copy link

Thanks for the review @ZachNagengast. I actually have a fix for this issue locally that I was planning to put up for a PR soon. I was waiting to bundle it with a few other small changes.

We should raise an error if the 7th value (i.e. the scale) is not None. We should also raise an error if the number of parameters is more than 7. Feel free to make these changes, or I can put up my fix.

The input issue is now fixed but I am getting this liuliu/swift-diffusion#48 due to AttributeError: 'Namespace' object has no attribute 'merge_chunks_in_pipeline_model'

@ZachNagengast
Copy link
Contributor Author

@TobyRoseman Looks like the CI passed, anything else needed?

@Zulqurnain24 This is because chunk_mlprogram.py has no default value for the argument merge_chunks_in_pipeline_model, and it appears to be missing from wherever it was called in your script https://github.com/apple/ml-stable-diffusion/blob/main/python_coreml_stable_diffusion/chunk_mlprogram.py#L317 you just need to set it to a default before calling the chunk script https://github.com/apple/ml-stable-diffusion/blob/bea04420b5958935a975d3b7aeb071bbbaa9a097/python_coreml_stable_diffusion/torch2coreml.py#L908

@TobyRoseman
Copy link
Collaborator

@ZachNagengast - nothing else is needed. I will merge it after the release is finished.

@TobyRoseman TobyRoseman self-assigned this Oct 27, 2023
jakesabathia2
jakesabathia2 previously approved these changes Oct 27, 2023
@TobyRoseman
Copy link
Collaborator

@ZachNagengast - Our release has finished. I would merge this now, but there is a conflict. Can you fix the conflict? Then I'll kick off another CI run.

@ZachNagengast
Copy link
Contributor Author

Ok great, I'll merge in main shortly

@ZachNagengast
Copy link
Contributor Author

@TobyRoseman All set, there was code here that was doing similar but didn't handle the error if scale was non-null, and only looked for a minimum of 3 params instead of strictly between 6-7, hopefully this implementation is the preferred one.

@TobyRoseman
Copy link
Collaborator

Thanks @ZachNagengast - the change looks good.

Update CI: https://gitlab.com/coremltools1/coremltools/-/pipelines/1059479892

@ZachNagengast
Copy link
Contributor Author

@TobyRoseman Unclear why the previous ci failed, but I just updated from main again, perhaps you can rerun

@TobyRoseman
Copy link
Collaborator

@ZachNagengast - the CI failures look related to your change. I don't think updating from main is going to fix it.

FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('mlprogram', 'fp16')-rank=2]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('mlprogram', 'fp16')-rank=3]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('mlprogram', 'fp16')-rank=4]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('mlprogram', 'fp16')-rank=5]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')-rank=2]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')-rank=3]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')-rank=4]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_different_input_ranks_no_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')-rank=5]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_attn_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('mlprogram', 'fp16')-seq_lengths=(5, 5)-bool_mask=False]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_attn_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')-seq_lengths=(5, 5)-bool_mask=False]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_attn_mask[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')-seq_lengths=(7, 5)-bool_mask=False]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_toy_xformer_with_sdpa[compute_unit=ComputeUnit.CPU_ONLY-backend=('mlprogram', 'fp16')-mask_as_input=True]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_toy_xformer_with_sdpa[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')-mask_as_input=True]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestScaledDotProductAttention::test_dropout_early_error_out
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestTransformer::test_transformer_encoder[compute_unit=ComputeUnit.CPU_ONLY-backend=('mlprogram', 'fp16')]
FAILED ../../envs/coremltools-py3.9/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestTransformer::test_transformer_encoder[compute_unit=ComputeUnit.CPU_ONLY-backend=('neuralnetwork', 'fp32')]

I suggest you try to reproduce these failures locally. Or let me know if you'd like me to try my previously mentioned fix.

@ZachNagengast
Copy link
Contributor Author

@TobyRoseman I see the issue, there were changes that came from the merge that I didn't notice originally. Updated to restore the pre-merge version.

@TobyRoseman
Copy link
Collaborator

The current diff makes it very difficult to understand what changes you have actually made. Please do a rebase squash on top of current main, i.e. please update this PR so there is just one commit and its parent is the tip of main.

Handle new scaled param in scaled_dot_product_attention

Fix param name

Update docstring and error type

Update format

Restore original scaled_dot_product_attention

Fix is_causal for scaled dot product attn
@ZachNagengast ZachNagengast force-pushed the fix-scaled_dot_product_attention-inputs branch from 571e162 to 4fa16a0 Compare November 9, 2023 20:12
@ZachNagengast
Copy link
Contributor Author

@TobyRoseman Alright I simplified the whole thing a bit, all that needed changing was a little change in the original implementation where is_causal was being set to is_causal.val, so I've fixed that and rebased.

@TobyRoseman
Copy link
Collaborator

@TobyRoseman Alright I simplified the whole thing a bit, all that needed changing was a little change in the original implementation where is_causal was being set to is_causal.val, so I've fixed that and rebased.

Thank you. The diff is much clearer.

dropout = 0.0 if len(inputs) < 5 else inputs[4]
is_causal = False if len(inputs) < 6 else inputs[5].val
if attn_mask is not None and is_causal:
inputs = _get_inputs(context, node, expected=[6, 7])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the current (less strict requirement) of min_expected=3 rather than expected=[6, 7]. We should keep the logic from the above deleted code that set those to default values if they are not present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@TobyRoseman
Copy link
Collaborator

@ZachNagengast
Copy link
Contributor Author

@TobyRoseman Missed one of the is_causal changes from previous causing this error:

>       if attn_mask is not None and is_causal.val:
E       AttributeError: 'bool' object has no attribute 'val'

Fixed now.

@TobyRoseman
Copy link
Collaborator

@TobyRoseman
Copy link
Collaborator

CI is green. Thanks for the submission @ZachNagengast.

@TobyRoseman TobyRoseman merged commit e72db64 into apple:main Nov 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants