conformer.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from typing import Optional, Tuple
  2. import torch
  3. __all__ = ["Conformer"]
  4. def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
  5. batch_size = lengths.shape[0]
  6. max_length = int(torch.max(lengths).item())
  7. padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
  8. batch_size, max_length
  9. ) >= lengths.unsqueeze(1)
  10. return padding_mask
  11. class _ConvolutionModule(torch.nn.Module):
  12. r"""Conformer convolution module.
  13. Args:
  14. input_dim (int): input dimension.
  15. num_channels (int): number of depthwise convolution layer input channels.
  16. depthwise_kernel_size (int): kernel size of depthwise convolution layer.
  17. dropout (float, optional): dropout probability. (Default: 0.0)
  18. bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
  19. use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
  20. """
  21. def __init__(
  22. self,
  23. input_dim: int,
  24. num_channels: int,
  25. depthwise_kernel_size: int,
  26. dropout: float = 0.0,
  27. bias: bool = False,
  28. use_group_norm: bool = False,
  29. ) -> None:
  30. super().__init__()
  31. assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
  32. self.layer_norm = torch.nn.LayerNorm(input_dim)
  33. self.sequential = torch.nn.Sequential(
  34. torch.nn.Conv1d(
  35. input_dim,
  36. 2 * num_channels,
  37. 1,
  38. stride=1,
  39. padding=0,
  40. bias=bias,
  41. ),
  42. torch.nn.GLU(dim=1),
  43. torch.nn.Conv1d(
  44. num_channels,
  45. num_channels,
  46. depthwise_kernel_size,
  47. stride=1,
  48. padding=(depthwise_kernel_size - 1) // 2,
  49. groups=num_channels,
  50. bias=bias,
  51. ),
  52. torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
  53. if use_group_norm
  54. else torch.nn.BatchNorm1d(num_channels),
  55. torch.nn.SiLU(),
  56. torch.nn.Conv1d(
  57. num_channels,
  58. input_dim,
  59. kernel_size=1,
  60. stride=1,
  61. padding=0,
  62. bias=bias,
  63. ),
  64. torch.nn.Dropout(dropout),
  65. )
  66. def forward(self, input: torch.Tensor) -> torch.Tensor:
  67. r"""
  68. Args:
  69. input (torch.Tensor): with shape `(B, T, D)`.
  70. Returns:
  71. torch.Tensor: output, with shape `(B, T, D)`.
  72. """
  73. x = self.layer_norm(input)
  74. x = x.transpose(1, 2)
  75. x = self.sequential(x)
  76. return x.transpose(1, 2)
  77. class _FeedForwardModule(torch.nn.Module):
  78. r"""Positionwise feed forward layer.
  79. Args:
  80. input_dim (int): input dimension.
  81. hidden_dim (int): hidden dimension.
  82. dropout (float, optional): dropout probability. (Default: 0.0)
  83. """
  84. def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
  85. super().__init__()
  86. self.sequential = torch.nn.Sequential(
  87. torch.nn.LayerNorm(input_dim),
  88. torch.nn.Linear(input_dim, hidden_dim, bias=True),
  89. torch.nn.SiLU(),
  90. torch.nn.Dropout(dropout),
  91. torch.nn.Linear(hidden_dim, input_dim, bias=True),
  92. torch.nn.Dropout(dropout),
  93. )
  94. def forward(self, input: torch.Tensor) -> torch.Tensor:
  95. r"""
  96. Args:
  97. input (torch.Tensor): with shape `(*, D)`.
  98. Returns:
  99. torch.Tensor: output, with shape `(*, D)`.
  100. """
  101. return self.sequential(input)
  102. class ConformerLayer(torch.nn.Module):
  103. r"""Conformer layer that constitutes Conformer.
  104. Args:
  105. input_dim (int): input dimension.
  106. ffn_dim (int): hidden layer dimension of feedforward network.
  107. num_attention_heads (int): number of attention heads.
  108. depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
  109. dropout (float, optional): dropout probability. (Default: 0.0)
  110. use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
  111. in the convolution module. (Default: ``False``)
  112. convolution_first (bool, optional): apply the convolution module ahead of
  113. the attention module. (Default: ``False``)
  114. """
  115. def __init__(
  116. self,
  117. input_dim: int,
  118. ffn_dim: int,
  119. num_attention_heads: int,
  120. depthwise_conv_kernel_size: int,
  121. dropout: float = 0.0,
  122. use_group_norm: bool = False,
  123. convolution_first: bool = False,
  124. ) -> None:
  125. super().__init__()
  126. self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
  127. self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
  128. self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
  129. self.self_attn_dropout = torch.nn.Dropout(dropout)
  130. self.conv_module = _ConvolutionModule(
  131. input_dim=input_dim,
  132. num_channels=input_dim,
  133. depthwise_kernel_size=depthwise_conv_kernel_size,
  134. dropout=dropout,
  135. bias=True,
  136. use_group_norm=use_group_norm,
  137. )
  138. self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
  139. self.final_layer_norm = torch.nn.LayerNorm(input_dim)
  140. self.convolution_first = convolution_first
  141. def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
  142. residual = input
  143. input = input.transpose(0, 1)
  144. input = self.conv_module(input)
  145. input = input.transpose(0, 1)
  146. input = residual + input
  147. return input
  148. def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
  149. r"""
  150. Args:
  151. input (torch.Tensor): input, with shape `(T, B, D)`.
  152. key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer.
  153. Returns:
  154. torch.Tensor: output, with shape `(T, B, D)`.
  155. """
  156. residual = input
  157. x = self.ffn1(input)
  158. x = x * 0.5 + residual
  159. if self.convolution_first:
  160. x = self._apply_convolution(x)
  161. residual = x
  162. x = self.self_attn_layer_norm(x)
  163. x, _ = self.self_attn(
  164. query=x,
  165. key=x,
  166. value=x,
  167. key_padding_mask=key_padding_mask,
  168. need_weights=False,
  169. )
  170. x = self.self_attn_dropout(x)
  171. x = x + residual
  172. if not self.convolution_first:
  173. x = self._apply_convolution(x)
  174. residual = x
  175. x = self.ffn2(x)
  176. x = x * 0.5 + residual
  177. x = self.final_layer_norm(x)
  178. return x
  179. class Conformer(torch.nn.Module):
  180. r"""Implements the Conformer architecture introduced in
  181. *Conformer: Convolution-augmented Transformer for Speech Recognition*
  182. [:footcite:`gulati2020conformer`].
  183. Args:
  184. input_dim (int): input dimension.
  185. num_heads (int): number of attention heads in each Conformer layer.
  186. ffn_dim (int): hidden layer dimension of feedforward networks.
  187. num_layers (int): number of Conformer layers to instantiate.
  188. depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
  189. dropout (float, optional): dropout probability. (Default: 0.0)
  190. use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
  191. in the convolution module. (Default: ``False``)
  192. convolution_first (bool, optional): apply the convolution module ahead of
  193. the attention module. (Default: ``False``)
  194. Examples:
  195. >>> conformer = Conformer(
  196. >>> input_dim=80,
  197. >>> num_heads=4,
  198. >>> ffn_dim=128,
  199. >>> num_layers=4,
  200. >>> depthwise_conv_kernel_size=31,
  201. >>> )
  202. >>> lengths = torch.randint(1, 400, (10,)) # (batch,)
  203. >>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim)
  204. >>> output = conformer(input, lengths)
  205. """
  206. def __init__(
  207. self,
  208. input_dim: int,
  209. num_heads: int,
  210. ffn_dim: int,
  211. num_layers: int,
  212. depthwise_conv_kernel_size: int,
  213. dropout: float = 0.0,
  214. use_group_norm: bool = False,
  215. convolution_first: bool = False,
  216. ):
  217. super().__init__()
  218. self.conformer_layers = torch.nn.ModuleList(
  219. [
  220. ConformerLayer(
  221. input_dim,
  222. ffn_dim,
  223. num_heads,
  224. depthwise_conv_kernel_size,
  225. dropout=dropout,
  226. use_group_norm=use_group_norm,
  227. convolution_first=convolution_first,
  228. )
  229. for _ in range(num_layers)
  230. ]
  231. )
  232. def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  233. r"""
  234. Args:
  235. input (torch.Tensor): with shape `(B, T, input_dim)`.
  236. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  237. number of valid frames for i-th batch element in ``input``.
  238. Returns:
  239. (torch.Tensor, torch.Tensor)
  240. torch.Tensor
  241. output frames, with shape `(B, T, input_dim)`
  242. torch.Tensor
  243. output lengths, with shape `(B,)` and i-th element representing
  244. number of valid frames for i-th batch element in output frames.
  245. """
  246. encoder_padding_mask = _lengths_to_padding_mask(lengths)
  247. x = input.transpose(0, 1)
  248. for layer in self.conformer_layers:
  249. x = layer(x, encoder_padding_mask)
  250. return x.transpose(0, 1), lengths