diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e2642bc360..e8cdb67c6c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -171,8 +171,8 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): + for i in range(num_microbatches): + for c in callables: per_callable_module_params.append( tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () )