| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- from typing import Optional, Tuple
- import torch
- __all__ = ["Conformer"]
- def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
- batch_size = lengths.shape[0]
- max_length = int(torch.max(lengths).item())
- padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
- batch_size, max_length
- ) >= lengths.unsqueeze(1)
- return padding_mask
- class _ConvolutionModule(torch.nn.Module):
- r"""Conformer convolution module.
- Args:
- input_dim (int): input dimension.
- num_channels (int): number of depthwise convolution layer input channels.
- depthwise_kernel_size (int): kernel size of depthwise convolution layer.
- dropout (float, optional): dropout probability. (Default: 0.0)
- bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
- use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
- """
- def __init__(
- self,
- input_dim: int,
- num_channels: int,
- depthwise_kernel_size: int,
- dropout: float = 0.0,
- bias: bool = False,
- use_group_norm: bool = False,
- ) -> None:
- super().__init__()
- assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
- self.layer_norm = torch.nn.LayerNorm(input_dim)
- self.sequential = torch.nn.Sequential(
- torch.nn.Conv1d(
- input_dim,
- 2 * num_channels,
- 1,
- stride=1,
- padding=0,
- bias=bias,
- ),
- torch.nn.GLU(dim=1),
- torch.nn.Conv1d(
- num_channels,
- num_channels,
- depthwise_kernel_size,
- stride=1,
- padding=(depthwise_kernel_size - 1) // 2,
- groups=num_channels,
- bias=bias,
- ),
- torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
- if use_group_norm
- else torch.nn.BatchNorm1d(num_channels),
- torch.nn.SiLU(),
- torch.nn.Conv1d(
- num_channels,
- input_dim,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=bias,
- ),
- torch.nn.Dropout(dropout),
- )
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- r"""
- Args:
- input (torch.Tensor): with shape `(B, T, D)`.
- Returns:
- torch.Tensor: output, with shape `(B, T, D)`.
- """
- x = self.layer_norm(input)
- x = x.transpose(1, 2)
- x = self.sequential(x)
- return x.transpose(1, 2)
- class _FeedForwardModule(torch.nn.Module):
- r"""Positionwise feed forward layer.
- Args:
- input_dim (int): input dimension.
- hidden_dim (int): hidden dimension.
- dropout (float, optional): dropout probability. (Default: 0.0)
- """
- def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
- super().__init__()
- self.sequential = torch.nn.Sequential(
- torch.nn.LayerNorm(input_dim),
- torch.nn.Linear(input_dim, hidden_dim, bias=True),
- torch.nn.SiLU(),
- torch.nn.Dropout(dropout),
- torch.nn.Linear(hidden_dim, input_dim, bias=True),
- torch.nn.Dropout(dropout),
- )
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- r"""
- Args:
- input (torch.Tensor): with shape `(*, D)`.
- Returns:
- torch.Tensor: output, with shape `(*, D)`.
- """
- return self.sequential(input)
- class ConformerLayer(torch.nn.Module):
- r"""Conformer layer that constitutes Conformer.
- Args:
- input_dim (int): input dimension.
- ffn_dim (int): hidden layer dimension of feedforward network.
- num_attention_heads (int): number of attention heads.
- depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
- dropout (float, optional): dropout probability. (Default: 0.0)
- use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
- in the convolution module. (Default: ``False``)
- convolution_first (bool, optional): apply the convolution module ahead of
- the attention module. (Default: ``False``)
- """
- def __init__(
- self,
- input_dim: int,
- ffn_dim: int,
- num_attention_heads: int,
- depthwise_conv_kernel_size: int,
- dropout: float = 0.0,
- use_group_norm: bool = False,
- convolution_first: bool = False,
- ) -> None:
- super().__init__()
- self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
- self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
- self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
- self.self_attn_dropout = torch.nn.Dropout(dropout)
- self.conv_module = _ConvolutionModule(
- input_dim=input_dim,
- num_channels=input_dim,
- depthwise_kernel_size=depthwise_conv_kernel_size,
- dropout=dropout,
- bias=True,
- use_group_norm=use_group_norm,
- )
- self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
- self.final_layer_norm = torch.nn.LayerNorm(input_dim)
- self.convolution_first = convolution_first
- def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
- residual = input
- input = input.transpose(0, 1)
- input = self.conv_module(input)
- input = input.transpose(0, 1)
- input = residual + input
- return input
- def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
- r"""
- Args:
- input (torch.Tensor): input, with shape `(T, B, D)`.
- key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer.
- Returns:
- torch.Tensor: output, with shape `(T, B, D)`.
- """
- residual = input
- x = self.ffn1(input)
- x = x * 0.5 + residual
- if self.convolution_first:
- x = self._apply_convolution(x)
- residual = x
- x = self.self_attn_layer_norm(x)
- x, _ = self.self_attn(
- query=x,
- key=x,
- value=x,
- key_padding_mask=key_padding_mask,
- need_weights=False,
- )
- x = self.self_attn_dropout(x)
- x = x + residual
- if not self.convolution_first:
- x = self._apply_convolution(x)
- residual = x
- x = self.ffn2(x)
- x = x * 0.5 + residual
- x = self.final_layer_norm(x)
- return x
- class Conformer(torch.nn.Module):
- r"""Implements the Conformer architecture introduced in
- *Conformer: Convolution-augmented Transformer for Speech Recognition*
- [:footcite:`gulati2020conformer`].
- Args:
- input_dim (int): input dimension.
- num_heads (int): number of attention heads in each Conformer layer.
- ffn_dim (int): hidden layer dimension of feedforward networks.
- num_layers (int): number of Conformer layers to instantiate.
- depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
- dropout (float, optional): dropout probability. (Default: 0.0)
- use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
- in the convolution module. (Default: ``False``)
- convolution_first (bool, optional): apply the convolution module ahead of
- the attention module. (Default: ``False``)
- Examples:
- >>> conformer = Conformer(
- >>> input_dim=80,
- >>> num_heads=4,
- >>> ffn_dim=128,
- >>> num_layers=4,
- >>> depthwise_conv_kernel_size=31,
- >>> )
- >>> lengths = torch.randint(1, 400, (10,)) # (batch,)
- >>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim)
- >>> output = conformer(input, lengths)
- """
- def __init__(
- self,
- input_dim: int,
- num_heads: int,
- ffn_dim: int,
- num_layers: int,
- depthwise_conv_kernel_size: int,
- dropout: float = 0.0,
- use_group_norm: bool = False,
- convolution_first: bool = False,
- ):
- super().__init__()
- self.conformer_layers = torch.nn.ModuleList(
- [
- ConformerLayer(
- input_dim,
- ffn_dim,
- num_heads,
- depthwise_conv_kernel_size,
- dropout=dropout,
- use_group_norm=use_group_norm,
- convolution_first=convolution_first,
- )
- for _ in range(num_layers)
- ]
- )
- def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""
- Args:
- input (torch.Tensor): with shape `(B, T, input_dim)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``input``.
- Returns:
- (torch.Tensor, torch.Tensor)
- torch.Tensor
- output frames, with shape `(B, T, input_dim)`
- torch.Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in output frames.
- """
- encoder_padding_mask = _lengths_to_padding_mask(lengths)
- x = input.transpose(0, 1)
- for layer in self.conformer_layers:
- x = layer(x, encoder_padding_mask)
- return x.transpose(0, 1), lengths
|