Skip to content

Commit

Permalink
Fix building doc (#912)
Browse files Browse the repository at this point in the history
* Fix building doc

* Fix flake8
  • Loading branch information
pkufool authored Jan 29, 2022
1 parent 779a9bd commit 47c4b75
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
14 changes: 7 additions & 7 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
dataclasses
graphviz
recommonmark
sphinx
sphinx-autodoc-typehints
sphinx_rtd_theme
sphinxcontrib-bibtex
dataclasses==0.6
graphviz==0.19.1
recommonmark==0.7.1
sphinx==4.3.2
sphinx-autodoc-typehints==1.12.0
sphinx_rtd_theme==1.0.0
sphinxcontrib-bibtex==2.4.1
torch>=1.6.0
13 changes: 7 additions & 6 deletions k2/python/k2/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def mutual_information_recursion(
return_grad:
Whether to return grads of ``px`` and ``py``, this grad standing for the
occupation probability is the output of the backward with a
``fake gradient`` input (all ones) This is useful to implement the
pruned version of rnnt loss.
``fake gradient`` the ``fake gradient`` is the same as the gradient
you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``.
This is useful to implement the pruned version of rnnt loss.
Returns:
Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual
Expand All @@ -160,8 +161,8 @@ def mutual_information_recursion(
where we handle edge cases by treating quantities with negative indexes
as **-infinity**. The extension to cases where the boundaries are
specified should be obvious; it just works on shorter sequences with offsets
into ``px`` and ``py``.
specified should be obvious; it just works on shorter sequences with
offsets into ``px`` and ``py``.
"""
assert px.ndim == 3
B, S, T1 = px.shape
Expand All @@ -179,10 +180,10 @@ def mutual_information_recursion(
assert px.is_contiguous()
assert py.is_contiguous()

m, px_grad, py_grad = MutualInformationRecursionFunction.apply(
scores, px_grad, py_grad = MutualInformationRecursionFunction.apply(
px, py, boundary, return_grad
)
return (m, (px_grad, py_grad)) if return_grad else m
return (scores, (px_grad, py_grad)) if return_grad else scores


def _inner_product(a: Tensor, b: Tensor) -> Tensor:
Expand Down

0 comments on commit 47c4b75

Please sign in to comment.