From 4235bfe38b45e9299876b04e644ecbade32979a1 Mon Sep 17 00:00:00 2001 From: yifeis-nv Date: Wed, 28 Aug 2024 10:57:12 +0800 Subject: [PATCH 1/3] Fix param input order for cudagraph Signed-off-by: yifeis-nv --- transformer_engine/pytorch/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e2642bc360..e8cdb67c6c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -171,8 +171,8 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): + for i in range(num_microbatches): + for c in callables: per_callable_module_params.append( tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () ) From f876a235d7c28dfa4d75218e41069eb2d958f043 Mon Sep 17 00:00:00 2001 From: yifeis-nv Date: Fri, 6 Sep 2024 15:13:16 +0800 Subject: [PATCH 2/3] Addtional Support for Muti-Chunks Signed-off-by: yifeis-nv --- transformer_engine/pytorch/graph.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e8cdb67c6c..3e818419ed 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -171,11 +171,12 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for i in range(num_microbatches): - for c in callables: - per_callable_module_params.append( - tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - ) + for m_chunk in range(num_model_chunks): + for idx in range(num_microbatches): + for l_no in range(num_layers): + per_callable_module_params.append( + tuple(callables[m_chunk*num_layers + l_no].parameters()) if isinstance(c, torch.nn.Module) else () + ) assert len(per_callable_module_params) == len(flatten_sample_args) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] From 2ba9465aa8227da85d1b9e097bb8a0eafdfb5404 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 07:13:46 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 3e818419ed..7320122802 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -175,7 +175,9 @@ def _make_graphed_callables( for idx in range(num_microbatches): for l_no in range(num_layers): per_callable_module_params.append( - tuple(callables[m_chunk*num_layers + l_no].parameters()) if isinstance(c, torch.nn.Module) else () + tuple(callables[m_chunk * num_layers + l_no].parameters()) + if isinstance(c, torch.nn.Module) + else () ) assert len(per_callable_module_params) == len(flatten_sample_args) per_callable_static_input_surfaces = [