| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813 |
- from abc import ABC, abstractmethod
- from typing import List, Optional, Tuple
- import torch
- from torchaudio.models import Emformer
- __all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
- class _TimeReduction(torch.nn.Module):
- r"""Coalesces frames along time dimension into a
- fewer number of frames with higher feature dimensionality.
- Args:
- stride (int): number of frames to merge for each output frame.
- """
- def __init__(self, stride: int) -> None:
- super().__init__()
- self.stride = stride
- def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""Forward pass.
- B: batch size;
- T: maximum input sequence length in batch;
- D: feature dimension of each input sequence frame.
- Args:
- input (torch.Tensor): input sequences, with shape `(B, T, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``input``.
- Returns:
- (torch.Tensor, torch.Tensor):
- torch.Tensor
- output sequences, with shape
- `(B, T // stride, D * stride)`
- torch.Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in output sequences.
- """
- B, T, D = input.shape
- num_frames = T - (T % self.stride)
- input = input[:, :num_frames, :]
- lengths = lengths.div(self.stride, rounding_mode="trunc")
- T_max = num_frames // self.stride
- output = input.reshape(B, T_max, D * self.stride)
- output = output.contiguous()
- return output, lengths
- class _CustomLSTM(torch.nn.Module):
- r"""Custom long-short-term memory (LSTM) block that applies layer normalization
- to internal nodes.
- Args:
- input_dim (int): input dimension.
- hidden_dim (int): hidden dimension.
- layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
- layer_norm_epsilon (float, optional): value of epsilon to use in
- layer normalization layers (Default: 1e-5)
- """
- def __init__(
- self,
- input_dim: int,
- hidden_dim: int,
- layer_norm: bool = False,
- layer_norm_epsilon: float = 1e-5,
- ) -> None:
- super().__init__()
- self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
- self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
- if layer_norm:
- self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
- self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
- else:
- self.c_norm = torch.nn.Identity()
- self.g_norm = torch.nn.Identity()
- self.hidden_dim = hidden_dim
- def forward(
- self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
- r"""Forward pass.
- B: batch size;
- T: maximum sequence length in batch;
- D: feature dimension of each input sequence element.
- Args:
- input (torch.Tensor): with shape `(T, B, D)`.
- state (List[torch.Tensor] or None): list of tensors
- representing internal state generated in preceding invocation
- of ``forward``.
- Returns:
- (torch.Tensor, List[torch.Tensor]):
- torch.Tensor
- output, with shape `(T, B, hidden_dim)`.
- List[torch.Tensor]
- list of tensors representing internal state generated
- in current invocation of ``forward``.
- """
- if state is None:
- B = input.size(1)
- h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
- c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
- else:
- h, c = state
- gated_input = self.x2g(input)
- outputs = []
- for gates in gated_input.unbind(0):
- gates = gates + self.p2g(h)
- gates = self.g_norm(gates)
- input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
- input_gate = input_gate.sigmoid()
- forget_gate = forget_gate.sigmoid()
- cell_gate = cell_gate.tanh()
- output_gate = output_gate.sigmoid()
- c = forget_gate * c + input_gate * cell_gate
- c = self.c_norm(c)
- h = output_gate * c.tanh()
- outputs.append(h)
- output = torch.stack(outputs, dim=0)
- state = [h, c]
- return output, state
- class _Transcriber(ABC):
- @abstractmethod
- def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- pass
- @abstractmethod
- def infer(
- self,
- input: torch.Tensor,
- lengths: torch.Tensor,
- states: Optional[List[List[torch.Tensor]]],
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
- pass
- class _EmformerEncoder(torch.nn.Module, _Transcriber):
- r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
- Args:
- input_dim (int): feature dimension of each input sequence element.
- output_dim (int): feature dimension of each output sequence element.
- segment_length (int): length of input segment expressed as number of frames.
- right_context_length (int): length of right context expressed as number of frames.
- time_reduction_input_dim (int): dimension to scale each element in input sequences to
- prior to applying time reduction block.
- time_reduction_stride (int): factor by which to reduce length of input sequence.
- transformer_num_heads (int): number of attention heads in each Emformer layer.
- transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
- transformer_num_layers (int): number of Emformer layers to instantiate.
- transformer_left_context_length (int): length of left context.
- transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
- transformer_activation (str, optional): activation function to use in each Emformer layer's
- feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
- transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
- transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
- strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
- transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
- """
- def __init__(
- self,
- *,
- input_dim: int,
- output_dim: int,
- segment_length: int,
- right_context_length: int,
- time_reduction_input_dim: int,
- time_reduction_stride: int,
- transformer_num_heads: int,
- transformer_ffn_dim: int,
- transformer_num_layers: int,
- transformer_left_context_length: int,
- transformer_dropout: float = 0.0,
- transformer_activation: str = "relu",
- transformer_max_memory_size: int = 0,
- transformer_weight_init_scale_strategy: str = "depthwise",
- transformer_tanh_on_mem: bool = False,
- ) -> None:
- super().__init__()
- self.input_linear = torch.nn.Linear(
- input_dim,
- time_reduction_input_dim,
- bias=False,
- )
- self.time_reduction = _TimeReduction(time_reduction_stride)
- transformer_input_dim = time_reduction_input_dim * time_reduction_stride
- self.transformer = Emformer(
- transformer_input_dim,
- transformer_num_heads,
- transformer_ffn_dim,
- transformer_num_layers,
- segment_length // time_reduction_stride,
- dropout=transformer_dropout,
- activation=transformer_activation,
- left_context_length=transformer_left_context_length,
- right_context_length=right_context_length // time_reduction_stride,
- max_memory_size=transformer_max_memory_size,
- weight_init_scale_strategy=transformer_weight_init_scale_strategy,
- tanh_on_mem=transformer_tanh_on_mem,
- )
- self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
- self.layer_norm = torch.nn.LayerNorm(output_dim)
- def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""Forward pass for training.
- B: batch size;
- T: maximum input sequence length in batch;
- D: feature dimension of each input sequence frame (input_dim).
- Args:
- input (torch.Tensor): input frame sequences right-padded with right context, with
- shape `(B, T + right context length, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``input``.
- Returns:
- (torch.Tensor, torch.Tensor):
- torch.Tensor
- output frame sequences, with
- shape `(B, T // time_reduction_stride, output_dim)`.
- torch.Tensor
- output input lengths, with shape `(B,)` and i-th element representing
- number of valid elements for i-th batch element in output frame sequences.
- """
- input_linear_out = self.input_linear(input)
- time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
- transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
- output_linear_out = self.output_linear(transformer_out)
- layer_norm_out = self.layer_norm(output_linear_out)
- return layer_norm_out, transformer_lengths
- @torch.jit.export
- def infer(
- self,
- input: torch.Tensor,
- lengths: torch.Tensor,
- states: Optional[List[List[torch.Tensor]]],
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
- r"""Forward pass for inference.
- B: batch size;
- T: maximum input sequence segment length in batch;
- D: feature dimension of each input sequence frame (input_dim).
- Args:
- input (torch.Tensor): input frame sequence segments right-padded with right context, with
- shape `(B, T + right context length, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``input``.
- state (List[List[torch.Tensor]] or None): list of lists of tensors
- representing internal state generated in preceding invocation
- of ``infer``.
- Returns:
- (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
- torch.Tensor
- output frame sequences, with
- shape `(B, T // time_reduction_stride, output_dim)`.
- torch.Tensor
- output input lengths, with shape `(B,)` and i-th element representing
- number of valid elements for i-th batch element in output.
- List[List[torch.Tensor]]
- output states; list of lists of tensors
- representing internal state generated in current invocation
- of ``infer``.
- """
- input_linear_out = self.input_linear(input)
- time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
- (
- transformer_out,
- transformer_lengths,
- transformer_states,
- ) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
- output_linear_out = self.output_linear(transformer_out)
- layer_norm_out = self.layer_norm(output_linear_out)
- return layer_norm_out, transformer_lengths, transformer_states
- class _Predictor(torch.nn.Module):
- r"""Recurrent neural network transducer (RNN-T) prediction network.
- Args:
- num_symbols (int): size of target token lexicon.
- output_dim (int): feature dimension of each output sequence element.
- symbol_embedding_dim (int): dimension of each target token embedding.
- num_lstm_layers (int): number of LSTM layers to instantiate.
- lstm_hidden_dim (int): output dimension of each LSTM layer.
- lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
- for LSTM layers. (Default: ``False``)
- lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
- LSTM layer normalization layers. (Default: 1e-5)
- lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
- """
- def __init__(
- self,
- num_symbols: int,
- output_dim: int,
- symbol_embedding_dim: int,
- num_lstm_layers: int,
- lstm_hidden_dim: int,
- lstm_layer_norm: bool = False,
- lstm_layer_norm_epsilon: float = 1e-5,
- lstm_dropout: float = 0.0,
- ) -> None:
- super().__init__()
- self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
- self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
- self.lstm_layers = torch.nn.ModuleList(
- [
- _CustomLSTM(
- symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
- lstm_hidden_dim,
- layer_norm=lstm_layer_norm,
- layer_norm_epsilon=lstm_layer_norm_epsilon,
- )
- for idx in range(num_lstm_layers)
- ]
- )
- self.dropout = torch.nn.Dropout(p=lstm_dropout)
- self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
- self.output_layer_norm = torch.nn.LayerNorm(output_dim)
- self.lstm_dropout = lstm_dropout
- def forward(
- self,
- input: torch.Tensor,
- lengths: torch.Tensor,
- state: Optional[List[List[torch.Tensor]]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
- r"""Forward pass.
- B: batch size;
- U: maximum sequence length in batch;
- D: feature dimension of each input sequence element.
- Args:
- input (torch.Tensor): target sequences, with shape `(B, U)` and each element
- mapping to a target symbol, i.e. in range `[0, num_symbols)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``input``.
- state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
- representing internal state generated in preceding invocation
- of ``forward``. (Default: ``None``)
- Returns:
- (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
- torch.Tensor
- output encoding sequences, with shape `(B, U, output_dim)`
- torch.Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid elements for i-th batch element in output encoding sequences.
- List[List[torch.Tensor]]
- output states; list of lists of tensors
- representing internal state generated in current invocation of ``forward``.
- """
- input_tb = input.permute(1, 0)
- embedding_out = self.embedding(input_tb)
- input_layer_norm_out = self.input_layer_norm(embedding_out)
- lstm_out = input_layer_norm_out
- state_out: List[List[torch.Tensor]] = []
- for layer_idx, lstm in enumerate(self.lstm_layers):
- lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
- lstm_out = self.dropout(lstm_out)
- state_out.append(lstm_state_out)
- linear_out = self.linear(lstm_out)
- output_layer_norm_out = self.output_layer_norm(linear_out)
- return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
- class _Joiner(torch.nn.Module):
- r"""Recurrent neural network transducer (RNN-T) joint network.
- Args:
- input_dim (int): source and target input dimension.
- output_dim (int): output dimension.
- activation (str, optional): activation function to use in the joiner.
- Must be one of ("relu", "tanh"). (Default: "relu")
- """
- def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
- super().__init__()
- self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
- if activation == "relu":
- self.activation = torch.nn.ReLU()
- elif activation == "tanh":
- self.activation = torch.nn.Tanh()
- else:
- raise ValueError(f"Unsupported activation {activation}")
- def forward(
- self,
- source_encodings: torch.Tensor,
- source_lengths: torch.Tensor,
- target_encodings: torch.Tensor,
- target_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- r"""Forward pass for training.
- B: batch size;
- T: maximum source sequence length in batch;
- U: maximum target sequence length in batch;
- D: dimension of each source and target sequence encoding.
- Args:
- source_encodings (torch.Tensor): source encoding sequences, with
- shape `(B, T, D)`.
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- valid sequence length of i-th batch element in ``source_encodings``.
- target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- valid sequence length of i-th batch element in ``target_encodings``.
- Returns:
- (torch.Tensor, torch.Tensor, torch.Tensor):
- torch.Tensor
- joint network output, with shape `(B, T, U, output_dim)`.
- torch.Tensor
- output source lengths, with shape `(B,)` and i-th element representing
- number of valid elements along dim 1 for i-th batch element in joint network output.
- torch.Tensor
- output target lengths, with shape `(B,)` and i-th element representing
- number of valid elements along dim 2 for i-th batch element in joint network output.
- """
- joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
- activation_out = self.activation(joint_encodings)
- output = self.linear(activation_out)
- return output, source_lengths, target_lengths
- class RNNT(torch.nn.Module):
- r"""torchaudio.models.RNNT()
- Recurrent neural network transducer (RNN-T) model.
- Note:
- To build the model, please use one of the factory functions.
- Args:
- transcriber (torch.nn.Module): transcription network.
- predictor (torch.nn.Module): prediction network.
- joiner (torch.nn.Module): joint network.
- """
- def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
- super().__init__()
- self.transcriber = transcriber
- self.predictor = predictor
- self.joiner = joiner
- def forward(
- self,
- sources: torch.Tensor,
- source_lengths: torch.Tensor,
- targets: torch.Tensor,
- target_lengths: torch.Tensor,
- predictor_state: Optional[List[List[torch.Tensor]]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
- r"""Forward pass for training.
- B: batch size;
- T: maximum source sequence length in batch;
- U: maximum target sequence length in batch;
- D: feature dimension of each source sequence element.
- Args:
- sources (torch.Tensor): source frame sequences right-padded with right context, with
- shape `(B, T, D)`.
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``sources``.
- targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
- mapping to a target symbol.
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``targets``.
- predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
- representing prediction network internal state generated in preceding invocation
- of ``forward``. (Default: ``None``)
- Returns:
- (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
- torch.Tensor
- joint network output, with shape
- `(B, max output source length, max output target length, output_dim (number of target symbols))`.
- torch.Tensor
- output source lengths, with shape `(B,)` and i-th element representing
- number of valid elements along dim 1 for i-th batch element in joint network output.
- torch.Tensor
- output target lengths, with shape `(B,)` and i-th element representing
- number of valid elements along dim 2 for i-th batch element in joint network output.
- List[List[torch.Tensor]]
- output states; list of lists of tensors
- representing prediction network internal state generated in current invocation
- of ``forward``.
- """
- source_encodings, source_lengths = self.transcriber(
- input=sources,
- lengths=source_lengths,
- )
- target_encodings, target_lengths, predictor_state = self.predictor(
- input=targets,
- lengths=target_lengths,
- state=predictor_state,
- )
- output, source_lengths, target_lengths = self.joiner(
- source_encodings=source_encodings,
- source_lengths=source_lengths,
- target_encodings=target_encodings,
- target_lengths=target_lengths,
- )
- return (
- output,
- source_lengths,
- target_lengths,
- predictor_state,
- )
- @torch.jit.export
- def transcribe_streaming(
- self,
- sources: torch.Tensor,
- source_lengths: torch.Tensor,
- state: Optional[List[List[torch.Tensor]]],
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
- r"""Applies transcription network to sources in streaming mode.
- B: batch size;
- T: maximum source sequence segment length in batch;
- D: feature dimension of each source sequence frame.
- Args:
- sources (torch.Tensor): source frame sequence segments right-padded with right context, with
- shape `(B, T + right context length, D)`.
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``sources``.
- state (List[List[torch.Tensor]] or None): list of lists of tensors
- representing transcription network internal state generated in preceding invocation
- of ``transcribe_streaming``.
- Returns:
- (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
- torch.Tensor
- output frame sequences, with
- shape `(B, T // time_reduction_stride, output_dim)`.
- torch.Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid elements for i-th batch element in output.
- List[List[torch.Tensor]]
- output states; list of lists of tensors
- representing transcription network internal state generated in current invocation
- of ``transcribe_streaming``.
- """
- return self.transcriber.infer(sources, source_lengths, state)
- @torch.jit.export
- def transcribe(
- self,
- sources: torch.Tensor,
- source_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""Applies transcription network to sources in non-streaming mode.
- B: batch size;
- T: maximum source sequence length in batch;
- D: feature dimension of each source sequence frame.
- Args:
- sources (torch.Tensor): source frame sequences right-padded with right context, with
- shape `(B, T + right context length, D)`.
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``sources``.
- Returns:
- (torch.Tensor, torch.Tensor):
- torch.Tensor
- output frame sequences, with
- shape `(B, T // time_reduction_stride, output_dim)`.
- torch.Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid elements for i-th batch element in output frame sequences.
- """
- return self.transcriber(sources, source_lengths)
- @torch.jit.export
- def predict(
- self,
- targets: torch.Tensor,
- target_lengths: torch.Tensor,
- state: Optional[List[List[torch.Tensor]]],
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
- r"""Applies prediction network to targets.
- B: batch size;
- U: maximum target sequence length in batch;
- D: feature dimension of each target sequence frame.
- Args:
- targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
- mapping to a target symbol, i.e. in range `[0, num_symbols)`.
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``targets``.
- state (List[List[torch.Tensor]] or None): list of lists of tensors
- representing internal state generated in preceding invocation
- of ``predict``.
- Returns:
- (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
- torch.Tensor
- output frame sequences, with shape `(B, U, output_dim)`.
- torch.Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid elements for i-th batch element in output.
- List[List[torch.Tensor]]
- output states; list of lists of tensors
- representing internal state generated in current invocation of ``predict``.
- """
- return self.predictor(input=targets, lengths=target_lengths, state=state)
- @torch.jit.export
- def join(
- self,
- source_encodings: torch.Tensor,
- source_lengths: torch.Tensor,
- target_encodings: torch.Tensor,
- target_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- r"""Applies joint network to source and target encodings.
- B: batch size;
- T: maximum source sequence length in batch;
- U: maximum target sequence length in batch;
- D: dimension of each source and target sequence encoding.
- Args:
- source_encodings (torch.Tensor): source encoding sequences, with
- shape `(B, T, D)`.
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- valid sequence length of i-th batch element in ``source_encodings``.
- target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- valid sequence length of i-th batch element in ``target_encodings``.
- Returns:
- (torch.Tensor, torch.Tensor, torch.Tensor):
- torch.Tensor
- joint network output, with shape `(B, T, U, output_dim)`.
- torch.Tensor
- output source lengths, with shape `(B,)` and i-th element representing
- number of valid elements along dim 1 for i-th batch element in joint network output.
- torch.Tensor
- output target lengths, with shape `(B,)` and i-th element representing
- number of valid elements along dim 2 for i-th batch element in joint network output.
- """
- output, source_lengths, target_lengths = self.joiner(
- source_encodings=source_encodings,
- source_lengths=source_lengths,
- target_encodings=target_encodings,
- target_lengths=target_lengths,
- )
- return output, source_lengths, target_lengths
- def emformer_rnnt_model(
- *,
- input_dim: int,
- encoding_dim: int,
- num_symbols: int,
- segment_length: int,
- right_context_length: int,
- time_reduction_input_dim: int,
- time_reduction_stride: int,
- transformer_num_heads: int,
- transformer_ffn_dim: int,
- transformer_num_layers: int,
- transformer_dropout: float,
- transformer_activation: str,
- transformer_left_context_length: int,
- transformer_max_memory_size: int,
- transformer_weight_init_scale_strategy: str,
- transformer_tanh_on_mem: bool,
- symbol_embedding_dim: int,
- num_lstm_layers: int,
- lstm_layer_norm: bool,
- lstm_layer_norm_epsilon: float,
- lstm_dropout: float,
- ) -> RNNT:
- r"""Builds Emformer-based recurrent neural network transducer (RNN-T) model.
- Note:
- For non-streaming inference, the expectation is for `transcribe` to be called on input
- sequences right-concatenated with `right_context_length` frames.
- For streaming inference, the expectation is for `transcribe_streaming` to be called
- on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
- frames.
- Args:
- input_dim (int): dimension of input sequence frames passed to transcription network.
- encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
- passed to joint network.
- num_symbols (int): cardinality of set of target tokens.
- segment_length (int): length of input segment expressed as number of frames.
- right_context_length (int): length of right context expressed as number of frames.
- time_reduction_input_dim (int): dimension to scale each element in input sequences to
- prior to applying time reduction block.
- time_reduction_stride (int): factor by which to reduce length of input sequence.
- transformer_num_heads (int): number of attention heads in each Emformer layer.
- transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
- transformer_num_layers (int): number of Emformer layers to instantiate.
- transformer_left_context_length (int): length of left context considered by Emformer.
- transformer_dropout (float): Emformer dropout probability.
- transformer_activation (str): activation function to use in each Emformer layer's
- feedforward network. Must be one of ("relu", "gelu", "silu").
- transformer_max_memory_size (int): maximum number of memory elements to use.
- transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
- strategy. Must be one of ("depthwise", "constant", ``None``).
- transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
- symbol_embedding_dim (int): dimension of each target token embedding.
- num_lstm_layers (int): number of LSTM layers to instantiate.
- lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
- lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
- lstm_dropout (float): LSTM dropout probability.
- Returns:
- RNNT:
- Emformer RNN-T model.
- """
- encoder = _EmformerEncoder(
- input_dim=input_dim,
- output_dim=encoding_dim,
- segment_length=segment_length,
- right_context_length=right_context_length,
- time_reduction_input_dim=time_reduction_input_dim,
- time_reduction_stride=time_reduction_stride,
- transformer_num_heads=transformer_num_heads,
- transformer_ffn_dim=transformer_ffn_dim,
- transformer_num_layers=transformer_num_layers,
- transformer_dropout=transformer_dropout,
- transformer_activation=transformer_activation,
- transformer_left_context_length=transformer_left_context_length,
- transformer_max_memory_size=transformer_max_memory_size,
- transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
- transformer_tanh_on_mem=transformer_tanh_on_mem,
- )
- predictor = _Predictor(
- num_symbols,
- encoding_dim,
- symbol_embedding_dim=symbol_embedding_dim,
- num_lstm_layers=num_lstm_layers,
- lstm_hidden_dim=symbol_embedding_dim,
- lstm_layer_norm=lstm_layer_norm,
- lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
- lstm_dropout=lstm_dropout,
- )
- joiner = _Joiner(encoding_dim, num_symbols)
- return RNNT(encoder, predictor, joiner)
- def emformer_rnnt_base(num_symbols: int) -> RNNT:
- r"""Builds basic version of Emformer RNN-T model.
- Args:
- num_symbols (int): The size of target token lexicon.
- Returns:
- RNNT:
- Emformer RNN-T model.
- """
- return emformer_rnnt_model(
- input_dim=80,
- encoding_dim=1024,
- num_symbols=num_symbols,
- segment_length=16,
- right_context_length=4,
- time_reduction_input_dim=128,
- time_reduction_stride=4,
- transformer_num_heads=8,
- transformer_ffn_dim=2048,
- transformer_num_layers=20,
- transformer_dropout=0.1,
- transformer_activation="gelu",
- transformer_left_context_length=30,
- transformer_max_memory_size=0,
- transformer_weight_init_scale_strategy="depthwise",
- transformer_tanh_on_mem=True,
- symbol_embedding_dim=512,
- num_lstm_layers=3,
- lstm_layer_norm=True,
- lstm_layer_norm_epsilon=1e-3,
- lstm_dropout=0.3,
- )
|