wav2letter.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from torch import nn, Tensor
  2. __all__ = [
  3. "Wav2Letter",
  4. ]
  5. class Wav2Letter(nn.Module):
  6. r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech
  7. Recognition System* [:footcite:`collobert2016wav2letter`].
  8. :math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}`
  9. Args:
  10. num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
  11. input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum``
  12. or ``mfcc`` (Default: ``waveform``).
  13. num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
  14. """
  15. def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
  16. super(Wav2Letter, self).__init__()
  17. acoustic_num_features = 250 if input_type == "waveform" else num_features
  18. acoustic_model = nn.Sequential(
  19. nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23),
  20. nn.ReLU(inplace=True),
  21. nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
  22. nn.ReLU(inplace=True),
  23. nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
  24. nn.ReLU(inplace=True),
  25. nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
  26. nn.ReLU(inplace=True),
  27. nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
  28. nn.ReLU(inplace=True),
  29. nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
  30. nn.ReLU(inplace=True),
  31. nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
  32. nn.ReLU(inplace=True),
  33. nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
  34. nn.ReLU(inplace=True),
  35. nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16),
  36. nn.ReLU(inplace=True),
  37. nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
  38. nn.ReLU(inplace=True),
  39. nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
  40. nn.ReLU(inplace=True),
  41. )
  42. if input_type == "waveform":
  43. waveform_model = nn.Sequential(
  44. nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45),
  45. nn.ReLU(inplace=True),
  46. )
  47. self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
  48. if input_type in ["power_spectrum", "mfcc"]:
  49. self.acoustic_model = acoustic_model
  50. def forward(self, x: Tensor) -> Tensor:
  51. r"""
  52. Args:
  53. x (torch.Tensor): Tensor of dimension (batch_size, num_features, input_length).
  54. Returns:
  55. Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length).
  56. """
  57. x = self.acoustic_model(x)
  58. x = nn.functional.log_softmax(x, dim=1)
  59. return x