diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 66b147764b..bc7d8a5efb 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -307,6 +307,23 @@ done To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html). +We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.** + +The amp+bf16 training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 0 \ + --use-bf16 1 \ + --exp-dir zipformer/exp_amp_bf16 \ + --causal 0 \ + --full-libri 1 \ + --max-duration 1000 +``` + ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 164cc7bfd0..2a40b8d643 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -297,7 +297,7 @@ def forward(ctx, x: Tensor, dim: int): # (presumably) that op does not support float16, and autocast # is enabled. if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) + ans = ans.to(torch.get_autocast_gpu_dtype()) ctx.save_for_backward(ans) ctx.x_dtype = x.dtype ctx.dim = dim @@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) @@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1379,7 +1379,7 @@ def forward(ctx, x: Tensor) -> Tensor: d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod @@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function): def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1455,7 +1455,7 @@ def forward(ctx, x: Tensor) -> Tensor: d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9b6f4a93aa..9c1c7f5a78 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -521,6 +521,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + add_model_arguments(parser) return parser @@ -1027,7 +1034,9 @@ def save_bad_model(suffix: str = ""): batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,9 +1056,7 @@ def save_bad_model(suffix: str = ""): scaler.update() optimizer.zero_grad() except Exception as e: - logging.info( - f"Caught exception: {e}." - ) + logging.info(f"Caught exception: {e}.") save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -1090,7 +1097,7 @@ def save_bad_model(suffix: str = ""): rank=rank, ) - if batch_idx % 100 == 0 and params.use_fp16: + if batch_idx % 100 == 0 and params.use_autocast: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. @@ -1109,14 +1116,14 @@ def save_bad_model(suffix: str = ""): if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") ) if tb_writer is not None: @@ -1128,7 +1135,7 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: + if params.use_autocast: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) @@ -1204,9 +1211,25 @@ def run(rank, world_size, args): params.ctc_loss_scale = 1.0 else: assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( - params.ctc_loss_scale, params.attention_decoder_loss_scale + params.ctc_loss_scale, + params.attention_decoder_loss_scale, ) + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + logging.info(params) logging.info("About to create model") @@ -1339,7 +1362,7 @@ def remove_short_and_long_utt(c: Cut): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, _ = compute_loss( params=params, model=model,