wavernn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import math
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn, Tensor
  6. __all__ = [
  7. "ResBlock",
  8. "MelResNet",
  9. "Stretch2d",
  10. "UpsampleNetwork",
  11. "WaveRNN",
  12. ]
  13. class ResBlock(nn.Module):
  14. r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
  15. Args:
  16. n_freq: the number of bins in a spectrogram. (Default: ``128``)
  17. Examples
  18. >>> resblock = ResBlock()
  19. >>> input = torch.rand(10, 128, 512) # a random spectrogram
  20. >>> output = resblock(input) # shape: (10, 128, 512)
  21. """
  22. def __init__(self, n_freq: int = 128) -> None:
  23. super().__init__()
  24. self.resblock_model = nn.Sequential(
  25. nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
  26. nn.BatchNorm1d(n_freq),
  27. nn.ReLU(inplace=True),
  28. nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
  29. nn.BatchNorm1d(n_freq),
  30. )
  31. def forward(self, specgram: Tensor) -> Tensor:
  32. r"""Pass the input through the ResBlock layer.
  33. Args:
  34. specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
  35. Return:
  36. Tensor shape: (n_batch, n_freq, n_time)
  37. """
  38. return self.resblock_model(specgram) + specgram
  39. class MelResNet(nn.Module):
  40. r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
  41. Args:
  42. n_res_block: the number of ResBlock in stack. (Default: ``10``)
  43. n_freq: the number of bins in a spectrogram. (Default: ``128``)
  44. n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
  45. n_output: the number of output dimensions of melresnet. (Default: ``128``)
  46. kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
  47. Examples
  48. >>> melresnet = MelResNet()
  49. >>> input = torch.rand(10, 128, 512) # a random spectrogram
  50. >>> output = melresnet(input) # shape: (10, 128, 508)
  51. """
  52. def __init__(
  53. self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
  54. ) -> None:
  55. super().__init__()
  56. ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
  57. self.melresnet_model = nn.Sequential(
  58. nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
  59. nn.BatchNorm1d(n_hidden),
  60. nn.ReLU(inplace=True),
  61. *ResBlocks,
  62. nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
  63. )
  64. def forward(self, specgram: Tensor) -> Tensor:
  65. r"""Pass the input through the MelResNet layer.
  66. Args:
  67. specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
  68. Return:
  69. Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
  70. """
  71. return self.melresnet_model(specgram)
  72. class Stretch2d(nn.Module):
  73. r"""Upscale the frequency and time dimensions of a spectrogram.
  74. Args:
  75. time_scale: the scale factor in time dimension
  76. freq_scale: the scale factor in frequency dimension
  77. Examples
  78. >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
  79. >>> input = torch.rand(10, 100, 512) # a random spectrogram
  80. >>> output = stretch2d(input) # shape: (10, 500, 5120)
  81. """
  82. def __init__(self, time_scale: int, freq_scale: int) -> None:
  83. super().__init__()
  84. self.freq_scale = freq_scale
  85. self.time_scale = time_scale
  86. def forward(self, specgram: Tensor) -> Tensor:
  87. r"""Pass the input through the Stretch2d layer.
  88. Args:
  89. specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
  90. Return:
  91. Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
  92. """
  93. return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)
  94. class UpsampleNetwork(nn.Module):
  95. r"""Upscale the dimensions of a spectrogram.
  96. Args:
  97. upsample_scales: the list of upsample scales.
  98. n_res_block: the number of ResBlock in stack. (Default: ``10``)
  99. n_freq: the number of bins in a spectrogram. (Default: ``128``)
  100. n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
  101. n_output: the number of output dimensions of melresnet. (Default: ``128``)
  102. kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
  103. Examples
  104. >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
  105. >>> input = torch.rand(10, 128, 10) # a random spectrogram
  106. >>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536)
  107. """
  108. def __init__(
  109. self,
  110. upsample_scales: List[int],
  111. n_res_block: int = 10,
  112. n_freq: int = 128,
  113. n_hidden: int = 128,
  114. n_output: int = 128,
  115. kernel_size: int = 5,
  116. ) -> None:
  117. super().__init__()
  118. total_scale = 1
  119. for upsample_scale in upsample_scales:
  120. total_scale *= upsample_scale
  121. self.total_scale: int = total_scale
  122. self.indent = (kernel_size - 1) // 2 * total_scale
  123. self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
  124. self.resnet_stretch = Stretch2d(total_scale, 1)
  125. up_layers = []
  126. for scale in upsample_scales:
  127. stretch = Stretch2d(scale, 1)
  128. conv = nn.Conv2d(
  129. in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
  130. )
  131. torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
  132. up_layers.append(stretch)
  133. up_layers.append(conv)
  134. self.upsample_layers = nn.Sequential(*up_layers)
  135. def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
  136. r"""Pass the input through the UpsampleNetwork layer.
  137. Args:
  138. specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
  139. Return:
  140. Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
  141. (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
  142. where total_scale is the product of all elements in upsample_scales.
  143. """
  144. resnet_output = self.resnet(specgram).unsqueeze(1)
  145. resnet_output = self.resnet_stretch(resnet_output)
  146. resnet_output = resnet_output.squeeze(1)
  147. specgram = specgram.unsqueeze(1)
  148. upsampling_output = self.upsample_layers(specgram)
  149. upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
  150. return upsampling_output, resnet_output
  151. class WaveRNN(nn.Module):
  152. r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.
  153. The original implementation was introduced in *Efficient Neural Audio Synthesis*
  154. [:footcite:`kalchbrenner2018efficient`]. The input channels of waveform and spectrogram have to be 1.
  155. The product of `upsample_scales` must equal `hop_length`.
  156. Args:
  157. upsample_scales: the list of upsample scales.
  158. n_classes: the number of output classes.
  159. hop_length: the number of samples between the starts of consecutive frames.
  160. n_res_block: the number of ResBlock in stack. (Default: ``10``)
  161. n_rnn: the dimension of RNN layer. (Default: ``512``)
  162. n_fc: the dimension of fully connected layer. (Default: ``512``)
  163. kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
  164. n_freq: the number of bins in a spectrogram. (Default: ``128``)
  165. n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
  166. n_output: the number of output dimensions of melresnet. (Default: ``128``)
  167. Example
  168. >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
  169. >>> waveform, sample_rate = torchaudio.load(file)
  170. >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
  171. >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
  172. >>> output = wavernn(waveform, specgram)
  173. >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
  174. """
  175. def __init__(
  176. self,
  177. upsample_scales: List[int],
  178. n_classes: int,
  179. hop_length: int,
  180. n_res_block: int = 10,
  181. n_rnn: int = 512,
  182. n_fc: int = 512,
  183. kernel_size: int = 5,
  184. n_freq: int = 128,
  185. n_hidden: int = 128,
  186. n_output: int = 128,
  187. ) -> None:
  188. super().__init__()
  189. self.kernel_size = kernel_size
  190. self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
  191. self.n_rnn = n_rnn
  192. self.n_aux = n_output // 4
  193. self.hop_length = hop_length
  194. self.n_classes = n_classes
  195. self.n_bits: int = int(math.log2(self.n_classes))
  196. total_scale = 1
  197. for upsample_scale in upsample_scales:
  198. total_scale *= upsample_scale
  199. if total_scale != self.hop_length:
  200. raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
  201. self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
  202. self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
  203. self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
  204. self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
  205. self.relu1 = nn.ReLU(inplace=True)
  206. self.relu2 = nn.ReLU(inplace=True)
  207. self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
  208. self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
  209. self.fc3 = nn.Linear(n_fc, self.n_classes)
  210. def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
  211. r"""Pass the input through the WaveRNN model.
  212. Args:
  213. waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
  214. specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
  215. Return:
  216. Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
  217. """
  218. assert waveform.size(1) == 1, "Require the input channel of waveform is 1"
  219. assert specgram.size(1) == 1, "Require the input channel of specgram is 1"
  220. # remove channel dimension until the end
  221. waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
  222. batch_size = waveform.size(0)
  223. h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
  224. h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
  225. # output of upsample:
  226. # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
  227. # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
  228. specgram, aux = self.upsample(specgram)
  229. specgram = specgram.transpose(1, 2)
  230. aux = aux.transpose(1, 2)
  231. aux_idx = [self.n_aux * i for i in range(5)]
  232. a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
  233. a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
  234. a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
  235. a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
  236. x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
  237. x = self.fc(x)
  238. res = x
  239. x, _ = self.rnn1(x, h1)
  240. x = x + res
  241. res = x
  242. x = torch.cat([x, a2], dim=-1)
  243. x, _ = self.rnn2(x, h2)
  244. x = x + res
  245. x = torch.cat([x, a3], dim=-1)
  246. x = self.fc1(x)
  247. x = self.relu1(x)
  248. x = torch.cat([x, a4], dim=-1)
  249. x = self.fc2(x)
  250. x = self.relu2(x)
  251. x = self.fc3(x)
  252. # bring back channel dimension
  253. return x.unsqueeze(1)
  254. @torch.jit.export
  255. def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
  256. r"""Inference method of WaveRNN.
  257. This function currently only supports multinomial sampling, which assumes the
  258. network is trained on cross entropy loss.
  259. Args:
  260. specgram (Tensor):
  261. Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
  262. lengths (Tensor or None, optional):
  263. Indicates the valid length of each audio in the batch.
  264. Shape: `(batch, )`.
  265. When the ``specgram`` contains spectrograms with different durations,
  266. by providing ``lengths`` argument, the model will compute
  267. the corresponding valid output lengths.
  268. If ``None``, it is assumed that all the audio in ``waveforms``
  269. have valid length. Default: ``None``.
  270. Returns:
  271. (Tensor, Optional[Tensor]):
  272. Tensor
  273. The inferred waveform of size `(n_batch, 1, n_time)`.
  274. 1 stands for a single channel.
  275. Tensor or None
  276. If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
  277. is returned.
  278. It indicates the valid length in time axis of the output Tensor.
  279. """
  280. device = specgram.device
  281. dtype = specgram.dtype
  282. specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
  283. specgram, aux = self.upsample(specgram)
  284. if lengths is not None:
  285. lengths = lengths * self.upsample.total_scale
  286. output: List[Tensor] = []
  287. b_size, _, seq_len = specgram.size()
  288. h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
  289. h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
  290. x = torch.zeros((b_size, 1), device=device, dtype=dtype)
  291. aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
  292. for i in range(seq_len):
  293. m_t = specgram[:, :, i]
  294. a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
  295. x = torch.cat([x, m_t, a1_t], dim=1)
  296. x = self.fc(x)
  297. _, h1 = self.rnn1(x.unsqueeze(1), h1)
  298. x = x + h1[0]
  299. inp = torch.cat([x, a2_t], dim=1)
  300. _, h2 = self.rnn2(inp.unsqueeze(1), h2)
  301. x = x + h2[0]
  302. x = torch.cat([x, a3_t], dim=1)
  303. x = F.relu(self.fc1(x))
  304. x = torch.cat([x, a4_t], dim=1)
  305. x = F.relu(self.fc2(x))
  306. logits = self.fc3(x)
  307. posterior = F.softmax(logits, dim=1)
  308. x = torch.multinomial(posterior, 1).float()
  309. # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
  310. x = 2 * x / (2**self.n_bits - 1.0) - 1.0
  311. output.append(x)
  312. return torch.stack(output).permute(1, 2, 0), lengths