Skip to content

Commit

Permalink
Use more efficient way to fix boundaries (#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool authored Jan 25, 2022
1 parent d6323d5 commit d3fbb1b
Showing 1 changed file with 25 additions and 60 deletions.
85 changes: 25 additions & 60 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@
from .mutual_information import mutual_information_recursion


def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor:
"""
Insert -inf's into `px` in appropriate places if `boundary` is not
None. If boundary == None and modified == False, px[:,:,-1] will
be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]]
to be -infinity.
Args:
px: a Tensor of of shape [B][S][T+1] (this function is only
called if modified == False, see other docs for `modified`)
px is modified in-place and returned.
boundary: None, or a Tensor of shape [B][3] containing
[s_begin, t_begin, s_end, t_end]; we need only t_end.
"""
if boundary is None:
return px
B, S, T1 = px.shape
boundary = boundary[:, 3].reshape(B, 1, 1).expand(B, S, T1)
return px.scatter_(dim=2, index=boundary, value=float("-inf"))


def get_rnnt_logprobs(
lm: Tensor,
am: Tensor,
Expand Down Expand Up @@ -135,21 +155,6 @@ def get_rnnt_logprobs(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

if boundary is not None:
assert boundary.shape == (B, 4)
mask = (
torch.arange(0, T + 1, device=px_am.device)
.reshape(1, T + 1)
.expand(B, T + 1)
)
mask = mask < boundary[:, 3].reshape(B, 1)
mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1)
px_am = torch.where(
mask,
px_am,
torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device),
)

px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
Expand All @@ -163,6 +168,7 @@ def get_rnnt_logprobs(
py_lm = lm[:, :, termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers

px = fix_for_boundary(px, boundary)
return (px, py)


Expand Down Expand Up @@ -278,21 +284,6 @@ def get_rnnt_logprobs_joint(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

if boundary is not None:
assert boundary.shape == (B, 4)
mask = (
torch.arange(0, T + 1, device=px.device)
.reshape(1, T + 1)
.expand(B, T + 1)
)
mask = mask < boundary[:, 3].reshape(B, 1)
mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1)
px = torch.where(
mask,
px,
torch.tensor(float("-inf"), dtype=px.dtype, device=px.device),
)

px[:, :, :T] -= normalizers[:, :S, :]

py = (
Expand All @@ -302,6 +293,7 @@ def get_rnnt_logprobs_joint(
px = px.contiguous()
py = py.contiguous()

px = fix_for_boundary(px, boundary)
return (px, py)


Expand Down Expand Up @@ -660,21 +652,6 @@ def get_rnnt_logprobs_pruned(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

if boundary is not None:
assert boundary.shape == (B, 4)
mask = (
torch.arange(0, T + 1, device=px.device)
.reshape(1, T + 1)
.expand(B, T + 1)
)
mask = mask < boundary[:, 3].reshape(B, 1)
mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1)
px = torch.where(
mask,
px,
torch.tensor(float("-inf"), dtype=px.dtype, device=px.device),
)

py = joint[:, :, :, termination_symbol] # (B, T, s_range)
py = py - normalizers

Expand All @@ -699,6 +676,8 @@ def get_rnnt_logprobs_pruned(

px = px.contiguous()
py = py.contiguous()

px = fix_for_boundary(px, boundary)
return (px, py)


Expand Down Expand Up @@ -887,21 +866,6 @@ def get_rnnt_logprobs_smoothed(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

if boundary is not None:
assert boundary.shape == (B, 4)
mask = (
torch.arange(0, T + 1, device=px_am.device)
.reshape(1, T + 1)
.expand(B, T + 1)
)
mask = mask < boundary[:, 3].reshape(B, 1)
mask = mask.reshape(B, 1, T + 1).expand(B, S, T + 1)
px_am = torch.where(
mask,
px_am,
torch.tensor(float("-inf"), dtype=px_am.dtype, device=px_am.device),
)

px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
Expand Down Expand Up @@ -945,6 +909,7 @@ def get_rnnt_logprobs_smoothed(
+ py_amonly * am_only_scale
)

px_interp = fix_for_boundary(px_interp, boundary)
return (px_interp, py_interp)


Expand Down

0 comments on commit d3fbb1b

Please sign in to comment.