diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 9823fac87..ca4ffcf8f 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -23,6 +23,26 @@ from .mutual_information import mutual_information_recursion +def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: + """ + Insert -inf's into `px` in appropriate places if `boundary` is not + None. If boundary == None and modified == False, px[:,:,-1] will + be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]] + to be -infinity. + Args: + px: a Tensor of of shape [B][S][T+1] (this function is only + called if modified == False, see other docs for `modified`) + px is modified in-place and returned. + boundary: None, or a Tensor of shape [B][3] containing + [s_begin, t_begin, s_end, t_end]; we need only t_end. + """ + if boundary is None: + return px + B, S, T1 = px.shape + boundary = boundary[:, 3].reshape(B, 1, 1).expand(B, S, T1) + return px.scatter_(dim=2, index=boundary, value=float("-inf")) + + def get_rnnt_logprobs( lm: Tensor, am: Tensor, @@ -135,21 +155,6 @@ def get_rnnt_logprobs( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px_am.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px_am = torch.where( - mask, - px_am, - torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device), - ) - px_lm = torch.gather( lm[:, :S], dim=2, index=symbols.unsqueeze(-1) ) # [B][S][1] @@ -163,6 +168,7 @@ def get_rnnt_logprobs( py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1] py = py_am + py_lm - normalizers + px = fix_for_boundary(px, boundary) return (px, py) @@ -278,21 +284,6 @@ def get_rnnt_logprobs_joint( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px = torch.where( - mask, - px, - torch.tensor(float("-inf"), dtype=px.dtype, device=px.device), - ) - px[:, :, :T] -= normalizers[:, :S, :] py = ( @@ -302,6 +293,7 @@ def get_rnnt_logprobs_joint( px = px.contiguous() py = py.contiguous() + px = fix_for_boundary(px, boundary) return (px, py) @@ -660,21 +652,6 @@ def get_rnnt_logprobs_pruned( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px = torch.where( - mask, - px, - torch.tensor(float("-inf"), dtype=px.dtype, device=px.device), - ) - py = joint[:, :, :, termination_symbol] # (B, T, s_range) py = py - normalizers @@ -699,6 +676,8 @@ def get_rnnt_logprobs_pruned( px = px.contiguous() py = py.contiguous() + + px = fix_for_boundary(px, boundary) return (px, py) @@ -887,21 +866,6 @@ def get_rnnt_logprobs_smoothed( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - if boundary is not None: - assert boundary.shape == (B, 4) - mask = ( - torch.arange(0, T + 1, device=px_am.device) - .reshape(1, T + 1) - .expand(B, T + 1) - ) - mask = mask < boundary[:, 3].reshape(B, 1) - mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1) - px_am = torch.where( - mask, - px_am, - torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device), - ) - px_lm = torch.gather( lm[:, :S], dim=2, index=symbols.unsqueeze(-1) ) # [B][S][1] @@ -945,6 +909,7 @@ def get_rnnt_logprobs_smoothed( + py_amonly * am_only_scale ) + px_interp = fix_for_boundary(px_interp, boundary) return (px_interp, py_interp)