Skip to content

Commit

Permalink
zipformer BF16 training recipe (#1700)
Browse files Browse the repository at this point in the history
Support Zipformer AMP +BF16 training
  • Loading branch information
marcoyang1998 authored Aug 23, 2024
1 parent 3b434fe commit a6c02a4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 17 deletions.
17 changes: 17 additions & 0 deletions egs/librispeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,23 @@ done

To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).

We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**

The amp+bf16 training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 0 \
--use-bf16 1 \
--exp-dir zipformer/exp_amp_bf16 \
--causal 0 \
--full-libri 1 \
--max-duration 1000
```

##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M

The tensorboard log can be found at
Expand Down
12 changes: 6 additions & 6 deletions egs/librispeech/ASR/zipformer/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def forward(ctx, x: Tensor, dim: int):
# (presumably) that op does not support float16, and autocast
# is enabled.
if torch.is_autocast_enabled():
ans = ans.to(torch.float16)
ans = ans.to(torch.get_autocast_gpu_dtype())
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
Expand Down Expand Up @@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)

s = torch.sigmoid(x - 1.0)
Expand Down Expand Up @@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)

zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
Expand Down Expand Up @@ -1379,7 +1379,7 @@ def forward(ctx, x: Tensor) -> Tensor:
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y

@staticmethod
Expand Down Expand Up @@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad

if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)

zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
Expand Down Expand Up @@ -1455,7 +1455,7 @@ def forward(ctx, x: Tensor) -> Tensor:
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y

@staticmethod
Expand Down
47 changes: 36 additions & 11 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,13 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--use-bf16",
type=str2bool,
default=False,
help="Whether to use bf16 in AMP.",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -1027,7 +1034,9 @@ def save_bad_model(suffix: str = ""):
batch_size = len(batch["supervisions"]["text"])

try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, loss_info = compute_loss(
params=params,
model=model,
Expand All @@ -1047,9 +1056,7 @@ def save_bad_model(suffix: str = ""):
scaler.update()
optimizer.zero_grad()
except Exception as e:
logging.info(
f"Caught exception: {e}."
)
logging.info(f"Caught exception: {e}.")
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
Expand Down Expand Up @@ -1090,7 +1097,7 @@ def save_bad_model(suffix: str = ""):
rank=rank,
)

if batch_idx % 100 == 0 and params.use_fp16:
if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
Expand All @@ -1109,14 +1116,14 @@ def save_bad_model(suffix: str = ""):

if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0

logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
)

if tb_writer is not None:
Expand All @@ -1128,7 +1135,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
if params.use_autocast:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
Expand Down Expand Up @@ -1204,9 +1211,25 @@ def run(rank, world_size, args):
params.ctc_loss_scale = 1.0
else:
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
params.ctc_loss_scale, params.attention_decoder_loss_scale
params.ctc_loss_scale,
params.attention_decoder_loss_scale,
)

if params.use_bf16: # amp + bf16
assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!"
assert not params.use_fp16, "You can only use either fp16 or bf16"
params.dtype = torch.bfloat16
params.use_autocast = True
elif params.use_fp16: # amp + fp16
params.dtype = torch.float16
params.use_autocast = True
else: # fp32
params.dtype = torch.float32
params.use_autocast = False

logging.info(f"Using dtype={params.dtype}")
logging.info(f"Use AMP={params.use_autocast}")

logging.info(params)

logging.info("About to create model")
Expand Down Expand Up @@ -1339,7 +1362,7 @@ def remove_short_and_long_utt(c: Cut):
params=params,
)

scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
Expand Down Expand Up @@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, _ = compute_loss(
params=params,
model=model,
Expand Down

0 comments on commit a6c02a4

Please sign in to comment.