diff --git a/k2/python/k2/rnnt_decode.py b/k2/python/k2/rnnt_decode.py index bd6a4bc66..8d4814f59 100644 --- a/k2/python/k2/rnnt_decode.py +++ b/k2/python/k2/rnnt_decode.py @@ -30,53 +30,54 @@ class RnntDecodingStream(object): - """Create a new rnnt decoding stream. + def __init__(self, fsa: Fsa) -> None: + """Create a new rnnt decoding stream. - Every sequence(wave data) need a decoding stream, this function is expected - to be called when a new sequence comes. We support different decoding graphs - for different streams. + Every sequence(wave data) need a decoding stream, this function is + expected to be called when a new sequence comes. We support different + decoding graphs for different streams. - Args: - graph: - The decoding graph used in this stream. + Args: + graph: + The decoding graph used in this stream. - Returns: - A rnnt decoding stream object, which will be combined into - `RnntDecodingStreams` to do decoding together with other - sequences in parallel. - """ - def __init__(self, fsa: Fsa) -> None: + Returns: + A rnnt decoding stream object, which will be combined into + `RnntDecodingStreams` to do decoding together with other + sequences in parallel. + """ self.fsa = fsa self.stream = _k2.create_rnnt_decoding_stream(fsa.arcs) self.device = fsa.device - """Return a string representation of this object - - For visualization and debug only. - """ def __str__(self) -> str: + """Return a string representation of this object + + For visualization and debug only. + """ return f"{self.stream}, device : {self.device}\n" class RnntDecodingStreams(object): - """ - Combines multiple RnntDecodingStream objects to create a RnntDecodingStreams - object, then all these RnntDecodingStreams can do decoding in parallel. - - Args: - src_streams: - A list of RnntDecodingStream object to be combined. - config: - A configuration object which contains decoding parameters like - `vocab-size`, `decoder_history_len`, `beam`, `max_states`, - `max_contexts` etc. - - Returns: - Return a RnntDecodingStreams object. - """ def __init__( self, src_streams: List[RnntDecodingStream], config: RnntDecodingConfig ) -> None: + """ + Combines multiple RnntDecodingStream objects to create a + RnntDecodingStreams object, then all these RnntDecodingStreams can do + decoding in parallel. + + Args: + src_streams: + A list of RnntDecodingStream object to be combined. + config: + A configuration object which contains decoding parameters like + `vocab-size`, `decoder_history_len`, `beam`, `max_states`, + `max_contexts` etc. + + Returns: + Return a RnntDecodingStreams object. + """ assert len(src_streams) > 0 self.num_streams = len(src_streams) self.src_streams = src_streams @@ -84,78 +85,83 @@ def __init__( streams = [x.stream for x in self.src_streams] self.streams = _k2.RnntDecodingStreams(streams, config) - '''Return a string representation of this object - - For visualization and debug only. - ''' def __str__(self) -> str: + """Return a string representation of this object + + For visualization and debug only. + """ s = f"num_streams : {self.num_streams}\n" for i in range(self.num_streams): s += f"stream[{i}] : {self.src_streams[i]}" return s - """ - This function must be called prior to evaluating the joiner network - for a particular frame. It tells the calling code which contexts - it must evaluate the joiner network for. - - Returns: - Return a two elements tuple containing a RaggedShape and a tensor. - shape: - A RaggedShape with 2 axes, representing [stream][context]. - contexts: - A tensor of shape [tot_contexts][decoder_history_len], where - tot_contexts == shape->TotSize(1) and decoder_history_len comes from - the config, it represents the number of symbols in the context of the - decode network (assumed to be finite). It contains the token ids - into the vocabulary(i.e. `0 <= value < vocab_size`). - """ def get_contexts(self) -> Tuple[RaggedShape, Tensor]: + """ + This function must be called prior to evaluating the joiner network + for a particular frame. It tells the calling code which contexts + it must evaluate the joiner network for. + + Returns: + Return a two elements tuple containing a RaggedShape and a tensor. + + shape: + A RaggedShape with 2 axes, representing [stream][context]. + + contexts: + A tensor of shape [tot_contexts][decoder_history_len], where + tot_contexts == shape->TotSize(1) and decoder_history_len comes from + the config, it represents the number of symbols in the context of + the decode network (assumed to be finite). It contains the token ids + into the vocabulary(i.e. `0 <= value < vocab_size`). + """ return self.streams.get_contexts() - """ - Advance decoding streams by one frame. - - Args: - logprobs: - A tensor of shape [tot_contexts][num_symbols], containing log-probs of - symbols given the contexts output by `get_contexts()`. Will satisfy - logprobs.Dim0() == shape.TotSize(1). - """ def advance(self, logprobs: Tensor) -> None: + """ + Advance decoding streams by one frame. + + Args: + logprobs: + A tensor of shape [tot_contexts][num_symbols], containing log-probs + of symbols given the contexts output by `get_contexts()`. It + satisfies `logprobs.Dim0() == shape.TotSize(1)`, shape is returned + by `get_contexts()`. + """ self.streams.advance(logprobs) - """ - Terminate the decoding process of current RnntDecodingStreams objects. - It will update the decoding states and store the decoding results currently - got to each of the individual streams. - - Note: We can not decode with this object anymore after calling - terminate_and_flush_to_streams(). - """ def terminate_and_flush_to_streams(self) -> None: + """ + Terminate the decoding process of current RnntDecodingStreams objects. + It will update the decoding states and store the decoding results + currently got to each of the individual streams. + + Note: + We can not decode with this object anymore after calling + terminate_and_flush_to_streams(). + """ self.streams.terminate_and_flush_to_streams() - """ - Generate the lattice Fsa currently got. + def format_output(self, num_frames: List[int]) -> Fsa: + """ + Generate the lattice Fsa currently got. - Note: The attributes of the generated lattice is a union of the attributes + Note: + The attributes of the generated lattice is a union of the attributes of all the decoding graphs. For example, a streams contains three individual stream, each stream has its own decoding graphs, graph[0] has attributes attr1, attr2; graph[1] has attributes attr1, attr3; graph[2] has attributes attr3, attr4; then the generated lattice has attributes attr1, attr2, attr3, attr4. - Args: - num_frames: - A List containing the number of frames we want to gather for each stream - (note: the frames we have ever received for the corresponding stream). - It MUST satisfy `len(num_frames) == self.num_streams`. - Returns: - Return the lattice Fsa with all the attributes propagated. The returned - Fsa has 3 axes with `fsa.dim0==self.num_streams`. - """ - def format_output(self, num_frames: List[int]) -> Fsa: + Args: + num_frames: + A List containing the number of frames we want to gather for each + stream (note: the frames we have ever received for the corresponding + stream). It MUST satisfy `len(num_frames) == self.num_streams`. + Returns: + Return the lattice Fsa with all the attributes propagated. + The returned Fsa has 3 axes with `fsa.dim0==self.num_streams`. + """ assert len(num_frames) == self.num_streams ragged_arcs, out_map = self.streams.format_output(num_frames)