Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix param input order for cudagraph #1138

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

yifeis-nv
Copy link

@yifeis-nv yifeis-nv commented Aug 27, 2024

Description

I discovered that when I attempt to use cudagraph during Pipeline Parallelism, the gradient becomes incorrect, ultimately leading to a NaN issue. After debugging, I identified a small bug in TE's graph.py.

Fixes # (issue)

Here is the translation of your text into English for your GitHub issue description: Since the make_graphed_callables function in TE implements the backward graph through the torch.autograd.grad function, the weights are also passed into the torch.autograd.grad function through the inputs. This requires that the order of inputs in torch.autograd.grad matches the order in the forward graph; otherwise, it will lead to backward errors.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Modify the input order of weights inside of cudagraph related module

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems plausible. It seems that make_graphed_callables expects sample_args to be ordered first by layer number, then by microbatch, then by model chunk:

per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)

However, I see some of our MLPerf wrappers order by microbatch, then layer number, then model chunk: https://gitlab-master.nvidia.com/dl/mlperf/optimized/-/blob/main/large_language_model/pytorch/custom_callbacks.py#L249-L254
Pinging @ksivaman.

Also, can you sign your commit to pass the DCO check?

@timmoon10 timmoon10 added the bug Something isn't working label Aug 27, 2024
@yifeis-nv
Copy link
Author

This fix seems plausible. It seems that make_graphed_callables expects sample_args to be ordered first by layer number, then by microbatch, then by model chunk:

per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)

However, I see some of our MLPerf wrappers order by microbatch, then layer number, then model chunk: https://gitlab-master.nvidia.com/dl/mlperf/optimized/-/blob/main/large_language_model/pytorch/custom_callbacks.py#L249-L254
Pinging @ksivaman.
Also, can you sign your commit to pass the DCO check?

THX for your reminder! I have signed my commit.
Based on my understanding of the code, the order you referenced from MLPerf does not affect the capture order within the make_graphed_callables function. When performing the capture, it still follows the sequence of first by layer number, then by microbatch, and finally by model chunk. Therefore, the issue described earlier will still occur. I understand that this is why there is a modification in the code to isolate the captures of different microbatches (which will prevent sharing the memory pool and is likely to increase memory overhead):
https://gitlab-master.nvidia.com/dl/mlperf/optimized/-/blob/main/large_language_model/pytorch/custom_callbacks.py#L216-237

@@ -171,8 +171,8 @@ def _make_graphed_callables(
]
else:
per_callable_module_params = []
for c in callables:
for i in range(num_microbatches):
for i in range(num_microbatches):
Copy link
Collaborator

@vasunvidia vasunvidia Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change doesn't appear to fully solve the bug. For example this fix will work only when the number of model chunks (num_model_chunks) is 1. The correct solution will be

for m_chunk in range(num_model_chunks):
    for idx in range(num_microbatches):
        for l_no in range(num_layers):
            per_callable_module_params.append(
                    tuple(callables[m_chunk*num_layers + l_no].parameters()) if isinstance(c, torch.nn.Module) else ()
                )

Can you test if this fix works?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your input! This works for my situation.

@timmoon10
Copy link
Collaborator

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants