deepspeech.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. __all__ = ["DeepSpeech"]
  3. class FullyConnected(torch.nn.Module):
  4. """
  5. Args:
  6. n_feature: Number of input features
  7. n_hidden: Internal hidden unit size.
  8. """
  9. def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None:
  10. super(FullyConnected, self).__init__()
  11. self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
  12. self.relu_max_clip = relu_max_clip
  13. self.dropout = dropout
  14. def forward(self, x: torch.Tensor) -> torch.Tensor:
  15. x = self.fc(x)
  16. x = torch.nn.functional.relu(x)
  17. x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
  18. if self.dropout:
  19. x = torch.nn.functional.dropout(x, self.dropout, self.training)
  20. return x
  21. class DeepSpeech(torch.nn.Module):
  22. """
  23. DeepSpeech model architecture from *Deep Speech: Scaling up end-to-end speech recognition*
  24. [:footcite:`hannun2014deep`].
  25. Args:
  26. n_feature: Number of input features
  27. n_hidden: Internal hidden unit size.
  28. n_class: Number of output classes
  29. """
  30. def __init__(
  31. self,
  32. n_feature: int,
  33. n_hidden: int = 2048,
  34. n_class: int = 40,
  35. dropout: float = 0.0,
  36. ) -> None:
  37. super(DeepSpeech, self).__init__()
  38. self.n_hidden = n_hidden
  39. self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
  40. self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
  41. self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
  42. self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True)
  43. self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
  44. self.out = torch.nn.Linear(n_hidden, n_class)
  45. def forward(self, x: torch.Tensor) -> torch.Tensor:
  46. """
  47. Args:
  48. x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
  49. Returns:
  50. Tensor: Predictor tensor of dimension (batch, time, class).
  51. """
  52. # N x C x T x F
  53. x = self.fc1(x)
  54. # N x C x T x H
  55. x = self.fc2(x)
  56. # N x C x T x H
  57. x = self.fc3(x)
  58. # N x C x T x H
  59. x = x.squeeze(1)
  60. # N x T x H
  61. x = x.transpose(0, 1)
  62. # T x N x H
  63. x, _ = self.bi_rnn(x)
  64. # The fifth (non-recurrent) layer takes both the forward and backward units as inputs
  65. x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :]
  66. # T x N x H
  67. x = self.fc4(x)
  68. # T x N x H
  69. x = self.out(x)
  70. # T x N x n_class
  71. x = x.permute(1, 0, 2)
  72. # N x T x n_class
  73. x = torch.nn.functional.log_softmax(x, dim=2)
  74. # N x T x n_class
  75. return x