Skip to content

Commit

Permalink
fix building docs (#933)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool authored Mar 16, 2022
1 parent f4b4247 commit 9a0d72c
Showing 1 changed file with 88 additions and 82 deletions.
170 changes: 88 additions & 82 deletions k2/python/k2/rnnt_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,132 +30,138 @@


class RnntDecodingStream(object):
"""Create a new rnnt decoding stream.
def __init__(self, fsa: Fsa) -> None:
"""Create a new rnnt decoding stream.
Every sequence(wave data) need a decoding stream, this function is expected
to be called when a new sequence comes. We support different decoding graphs
for different streams.
Every sequence(wave data) need a decoding stream, this function is
expected to be called when a new sequence comes. We support different
decoding graphs for different streams.
Args:
graph:
The decoding graph used in this stream.
Args:
graph:
The decoding graph used in this stream.
Returns:
A rnnt decoding stream object, which will be combined into
`RnntDecodingStreams` to do decoding together with other
sequences in parallel.
"""
def __init__(self, fsa: Fsa) -> None:
Returns:
A rnnt decoding stream object, which will be combined into
`RnntDecodingStreams` to do decoding together with other
sequences in parallel.
"""
self.fsa = fsa
self.stream = _k2.create_rnnt_decoding_stream(fsa.arcs)
self.device = fsa.device

"""Return a string representation of this object
For visualization and debug only.
"""
def __str__(self) -> str:
"""Return a string representation of this object
For visualization and debug only.
"""
return f"{self.stream}, device : {self.device}\n"


class RnntDecodingStreams(object):
"""
Combines multiple RnntDecodingStream objects to create a RnntDecodingStreams
object, then all these RnntDecodingStreams can do decoding in parallel.
Args:
src_streams:
A list of RnntDecodingStream object to be combined.
config:
A configuration object which contains decoding parameters like
`vocab-size`, `decoder_history_len`, `beam`, `max_states`,
`max_contexts` etc.
Returns:
Return a RnntDecodingStreams object.
"""
def __init__(
self, src_streams: List[RnntDecodingStream], config: RnntDecodingConfig
) -> None:
"""
Combines multiple RnntDecodingStream objects to create a
RnntDecodingStreams object, then all these RnntDecodingStreams can do
decoding in parallel.
Args:
src_streams:
A list of RnntDecodingStream object to be combined.
config:
A configuration object which contains decoding parameters like
`vocab-size`, `decoder_history_len`, `beam`, `max_states`,
`max_contexts` etc.
Returns:
Return a RnntDecodingStreams object.
"""
assert len(src_streams) > 0
self.num_streams = len(src_streams)
self.src_streams = src_streams
self.device = self.src_streams[0].device
streams = [x.stream for x in self.src_streams]
self.streams = _k2.RnntDecodingStreams(streams, config)

'''Return a string representation of this object
For visualization and debug only.
'''
def __str__(self) -> str:
"""Return a string representation of this object
For visualization and debug only.
"""
s = f"num_streams : {self.num_streams}\n"
for i in range(self.num_streams):
s += f"stream[{i}] : {self.src_streams[i]}"
return s

"""
This function must be called prior to evaluating the joiner network
for a particular frame. It tells the calling code which contexts
it must evaluate the joiner network for.
Returns:
Return a two elements tuple containing a RaggedShape and a tensor.
shape:
A RaggedShape with 2 axes, representing [stream][context].
contexts:
A tensor of shape [tot_contexts][decoder_history_len], where
tot_contexts == shape->TotSize(1) and decoder_history_len comes from
the config, it represents the number of symbols in the context of the
decode network (assumed to be finite). It contains the token ids
into the vocabulary(i.e. `0 <= value < vocab_size`).
"""
def get_contexts(self) -> Tuple[RaggedShape, Tensor]:
"""
This function must be called prior to evaluating the joiner network
for a particular frame. It tells the calling code which contexts
it must evaluate the joiner network for.
Returns:
Return a two elements tuple containing a RaggedShape and a tensor.
shape:
A RaggedShape with 2 axes, representing [stream][context].
contexts:
A tensor of shape [tot_contexts][decoder_history_len], where
tot_contexts == shape->TotSize(1) and decoder_history_len comes from
the config, it represents the number of symbols in the context of
the decode network (assumed to be finite). It contains the token ids
into the vocabulary(i.e. `0 <= value < vocab_size`).
"""
return self.streams.get_contexts()

"""
Advance decoding streams by one frame.
Args:
logprobs:
A tensor of shape [tot_contexts][num_symbols], containing log-probs of
symbols given the contexts output by `get_contexts()`. Will satisfy
logprobs.Dim0() == shape.TotSize(1).
"""
def advance(self, logprobs: Tensor) -> None:
"""
Advance decoding streams by one frame.
Args:
logprobs:
A tensor of shape [tot_contexts][num_symbols], containing log-probs
of symbols given the contexts output by `get_contexts()`. It
satisfies `logprobs.Dim0() == shape.TotSize(1)`, shape is returned
by `get_contexts()`.
"""
self.streams.advance(logprobs)

"""
Terminate the decoding process of current RnntDecodingStreams objects.
It will update the decoding states and store the decoding results currently
got to each of the individual streams.
Note: We can not decode with this object anymore after calling
terminate_and_flush_to_streams().
"""
def terminate_and_flush_to_streams(self) -> None:
"""
Terminate the decoding process of current RnntDecodingStreams objects.
It will update the decoding states and store the decoding results
currently got to each of the individual streams.
Note:
We can not decode with this object anymore after calling
terminate_and_flush_to_streams().
"""
self.streams.terminate_and_flush_to_streams()

"""
Generate the lattice Fsa currently got.
def format_output(self, num_frames: List[int]) -> Fsa:
"""
Generate the lattice Fsa currently got.
Note: The attributes of the generated lattice is a union of the attributes
Note:
The attributes of the generated lattice is a union of the attributes
of all the decoding graphs. For example, a streams contains three
individual stream, each stream has its own decoding graphs, graph[0]
has attributes attr1, attr2; graph[1] has attributes attr1, attr3;
graph[2] has attributes attr3, attr4; then the generated lattice has
attributes attr1, attr2, attr3, attr4.
Args:
num_frames:
A List containing the number of frames we want to gather for each stream
(note: the frames we have ever received for the corresponding stream).
It MUST satisfy `len(num_frames) == self.num_streams`.
Returns:
Return the lattice Fsa with all the attributes propagated. The returned
Fsa has 3 axes with `fsa.dim0==self.num_streams`.
"""
def format_output(self, num_frames: List[int]) -> Fsa:
Args:
num_frames:
A List containing the number of frames we want to gather for each
stream (note: the frames we have ever received for the corresponding
stream). It MUST satisfy `len(num_frames) == self.num_streams`.
Returns:
Return the lattice Fsa with all the attributes propagated.
The returned Fsa has 3 axes with `fsa.dim0==self.num_streams`.
"""
assert len(num_frames) == self.num_streams

ragged_arcs, out_map = self.streams.format_output(num_frames)
Expand Down

0 comments on commit 9a0d72c

Please sign in to comment.