tacotron2.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import warnings
  28. from typing import List, Optional, Tuple, Union
  29. import torch
  30. from torch import nn, Tensor
  31. from torch.nn import functional as F
  32. __all__ = [
  33. "Tacotron2",
  34. ]
  35. def _get_linear_layer(in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear") -> torch.nn.Linear:
  36. r"""Linear layer with xavier uniform initialization.
  37. Args:
  38. in_dim (int): Size of each input sample.
  39. out_dim (int): Size of each output sample.
  40. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``)
  41. w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
  42. for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
  43. Returns:
  44. (torch.nn.Linear): The corresponding linear layer.
  45. """
  46. linear = torch.nn.Linear(in_dim, out_dim, bias=bias)
  47. torch.nn.init.xavier_uniform_(linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
  48. return linear
  49. def _get_conv1d_layer(
  50. in_channels: int,
  51. out_channels: int,
  52. kernel_size: int = 1,
  53. stride: int = 1,
  54. padding: Optional[Union[str, int, Tuple[int]]] = None,
  55. dilation: int = 1,
  56. bias: bool = True,
  57. w_init_gain: str = "linear",
  58. ) -> torch.nn.Conv1d:
  59. r"""1D convolution with xavier uniform initialization.
  60. Args:
  61. in_channels (int): Number of channels in the input image.
  62. out_channels (int): Number of channels produced by the convolution.
  63. kernel_size (int, optional): Number of channels in the input image. (Default: ``1``)
  64. stride (int, optional): Number of channels in the input image. (Default: ``1``)
  65. padding (str, int or tuple, optional): Padding added to both sides of the input.
  66. (Default: dilation * (kernel_size - 1) / 2)
  67. dilation (int, optional): Number of channels in the input image. (Default: ``1``)
  68. w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
  69. for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)
  70. Returns:
  71. (torch.nn.Conv1d): The corresponding Conv1D layer.
  72. """
  73. if padding is None:
  74. assert kernel_size % 2 == 1
  75. padding = int(dilation * (kernel_size - 1) / 2)
  76. conv1d = torch.nn.Conv1d(
  77. in_channels,
  78. out_channels,
  79. kernel_size=kernel_size,
  80. stride=stride,
  81. padding=padding,
  82. dilation=dilation,
  83. bias=bias,
  84. )
  85. torch.nn.init.xavier_uniform_(conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
  86. return conv1d
  87. def _get_mask_from_lengths(lengths: Tensor) -> Tensor:
  88. r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask
  89. is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths.
  90. Args:
  91. lengths (Tensor): The length of each element in the batch, with shape (n_batch, ).
  92. Returns:
  93. mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``).
  94. """
  95. max_len = torch.max(lengths).item()
  96. ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype)
  97. mask = (ids < lengths.unsqueeze(1)).byte()
  98. mask = torch.le(mask, 0)
  99. return mask
  100. class _LocationLayer(nn.Module):
  101. r"""Location layer used in the Attention model.
  102. Args:
  103. attention_n_filter (int): Number of filters for attention model.
  104. attention_kernel_size (int): Kernel size for attention model.
  105. attention_hidden_dim (int): Dimension of attention hidden representation.
  106. """
  107. def __init__(
  108. self,
  109. attention_n_filter: int,
  110. attention_kernel_size: int,
  111. attention_hidden_dim: int,
  112. ):
  113. super().__init__()
  114. padding = int((attention_kernel_size - 1) / 2)
  115. self.location_conv = _get_conv1d_layer(
  116. 2,
  117. attention_n_filter,
  118. kernel_size=attention_kernel_size,
  119. padding=padding,
  120. bias=False,
  121. stride=1,
  122. dilation=1,
  123. )
  124. self.location_dense = _get_linear_layer(
  125. attention_n_filter, attention_hidden_dim, bias=False, w_init_gain="tanh"
  126. )
  127. def forward(self, attention_weights_cat: Tensor) -> Tensor:
  128. r"""Location layer used in the Attention model.
  129. Args:
  130. attention_weights_cat (Tensor): Cumulative and previous attention weights
  131. with shape (n_batch, 2, max of ``text_lengths``).
  132. Returns:
  133. processed_attention (Tensor): Cumulative and previous attention weights
  134. with shape (n_batch, ``attention_hidden_dim``).
  135. """
  136. # (n_batch, attention_n_filter, text_lengths.max())
  137. processed_attention = self.location_conv(attention_weights_cat)
  138. processed_attention = processed_attention.transpose(1, 2)
  139. # (n_batch, text_lengths.max(), attention_hidden_dim)
  140. processed_attention = self.location_dense(processed_attention)
  141. return processed_attention
  142. class _Attention(nn.Module):
  143. r"""Locally sensitive attention model.
  144. Args:
  145. attention_rnn_dim (int): Number of hidden units for RNN.
  146. encoder_embedding_dim (int): Number of embedding dimensions in the Encoder.
  147. attention_hidden_dim (int): Dimension of attention hidden representation.
  148. attention_location_n_filter (int): Number of filters for Attention model.
  149. attention_location_kernel_size (int): Kernel size for Attention model.
  150. """
  151. def __init__(
  152. self,
  153. attention_rnn_dim: int,
  154. encoder_embedding_dim: int,
  155. attention_hidden_dim: int,
  156. attention_location_n_filter: int,
  157. attention_location_kernel_size: int,
  158. ) -> None:
  159. super().__init__()
  160. self.query_layer = _get_linear_layer(attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh")
  161. self.memory_layer = _get_linear_layer(
  162. encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh"
  163. )
  164. self.v = _get_linear_layer(attention_hidden_dim, 1, bias=False)
  165. self.location_layer = _LocationLayer(
  166. attention_location_n_filter,
  167. attention_location_kernel_size,
  168. attention_hidden_dim,
  169. )
  170. self.score_mask_value = -float("inf")
  171. def _get_alignment_energies(self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor) -> Tensor:
  172. r"""Get the alignment vector.
  173. Args:
  174. query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step).
  175. processed_memory (Tensor): Processed Encoder outputs
  176. with shape (n_batch, max of ``text_lengths``, attention_hidden_dim).
  177. attention_weights_cat (Tensor): Cumulative and previous attention weights
  178. with shape (n_batch, 2, max of ``text_lengths``).
  179. Returns:
  180. alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``).
  181. """
  182. processed_query = self.query_layer(query.unsqueeze(1))
  183. processed_attention_weights = self.location_layer(attention_weights_cat)
  184. energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory))
  185. alignment = energies.squeeze(2)
  186. return alignment
  187. def forward(
  188. self,
  189. attention_hidden_state: Tensor,
  190. memory: Tensor,
  191. processed_memory: Tensor,
  192. attention_weights_cat: Tensor,
  193. mask: Tensor,
  194. ) -> Tuple[Tensor, Tensor]:
  195. r"""Pass the input through the Attention model.
  196. Args:
  197. attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``).
  198. memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
  199. processed_memory (Tensor): Processed Encoder outputs
  200. with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
  201. attention_weights_cat (Tensor): Previous and cumulative attention weights
  202. with shape (n_batch, current_num_frames * 2, max of ``text_lengths``).
  203. mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
  204. Returns:
  205. attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
  206. attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
  207. """
  208. alignment = self._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat)
  209. alignment = alignment.masked_fill(mask, self.score_mask_value)
  210. attention_weights = F.softmax(alignment, dim=1)
  211. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  212. attention_context = attention_context.squeeze(1)
  213. return attention_context, attention_weights
  214. class _Prenet(nn.Module):
  215. r"""Prenet Module. It is consists of ``len(output_size)`` linear layers.
  216. Args:
  217. in_dim (int): The size of each input sample.
  218. output_sizes (list): The output dimension of each linear layers.
  219. """
  220. def __init__(self, in_dim: int, out_sizes: List[int]) -> None:
  221. super().__init__()
  222. in_sizes = [in_dim] + out_sizes[:-1]
  223. self.layers = nn.ModuleList(
  224. [_get_linear_layer(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, out_sizes)]
  225. )
  226. def forward(self, x: Tensor) -> Tensor:
  227. r"""Pass the input through Prenet.
  228. Args:
  229. x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim).
  230. Return:
  231. x (Tensor): Tensor with shape (n_batch, sizes[-1])
  232. """
  233. for linear in self.layers:
  234. x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
  235. return x
  236. class _Postnet(nn.Module):
  237. r"""Postnet Module.
  238. Args:
  239. n_mels (int): Number of mel bins.
  240. postnet_embedding_dim (int): Postnet embedding dimension.
  241. postnet_kernel_size (int): Postnet kernel size.
  242. postnet_n_convolution (int): Number of postnet convolutions.
  243. """
  244. def __init__(
  245. self,
  246. n_mels: int,
  247. postnet_embedding_dim: int,
  248. postnet_kernel_size: int,
  249. postnet_n_convolution: int,
  250. ):
  251. super().__init__()
  252. self.convolutions = nn.ModuleList()
  253. for i in range(postnet_n_convolution):
  254. in_channels = n_mels if i == 0 else postnet_embedding_dim
  255. out_channels = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
  256. init_gain = "linear" if i == (postnet_n_convolution - 1) else "tanh"
  257. num_features = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim
  258. self.convolutions.append(
  259. nn.Sequential(
  260. _get_conv1d_layer(
  261. in_channels,
  262. out_channels,
  263. kernel_size=postnet_kernel_size,
  264. stride=1,
  265. padding=int((postnet_kernel_size - 1) / 2),
  266. dilation=1,
  267. w_init_gain=init_gain,
  268. ),
  269. nn.BatchNorm1d(num_features),
  270. )
  271. )
  272. self.n_convs = len(self.convolutions)
  273. def forward(self, x: Tensor) -> Tensor:
  274. r"""Pass the input through Postnet.
  275. Args:
  276. x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
  277. Return:
  278. x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
  279. """
  280. for i, conv in enumerate(self.convolutions):
  281. if i < self.n_convs - 1:
  282. x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
  283. else:
  284. x = F.dropout(conv(x), 0.5, training=self.training)
  285. return x
  286. class _Encoder(nn.Module):
  287. r"""Encoder Module.
  288. Args:
  289. encoder_embedding_dim (int): Number of embedding dimensions in the encoder.
  290. encoder_n_convolution (int): Number of convolution layers in the encoder.
  291. encoder_kernel_size (int): The kernel size in the encoder.
  292. Examples
  293. >>> encoder = _Encoder(3, 512, 5)
  294. >>> input = torch.rand(10, 20, 30)
  295. >>> output = encoder(input) # shape: (10, 30, 512)
  296. """
  297. def __init__(
  298. self,
  299. encoder_embedding_dim: int,
  300. encoder_n_convolution: int,
  301. encoder_kernel_size: int,
  302. ) -> None:
  303. super().__init__()
  304. self.convolutions = nn.ModuleList()
  305. for _ in range(encoder_n_convolution):
  306. conv_layer = nn.Sequential(
  307. _get_conv1d_layer(
  308. encoder_embedding_dim,
  309. encoder_embedding_dim,
  310. kernel_size=encoder_kernel_size,
  311. stride=1,
  312. padding=int((encoder_kernel_size - 1) / 2),
  313. dilation=1,
  314. w_init_gain="relu",
  315. ),
  316. nn.BatchNorm1d(encoder_embedding_dim),
  317. )
  318. self.convolutions.append(conv_layer)
  319. self.lstm = nn.LSTM(
  320. encoder_embedding_dim,
  321. int(encoder_embedding_dim / 2),
  322. 1,
  323. batch_first=True,
  324. bidirectional=True,
  325. )
  326. self.lstm.flatten_parameters()
  327. def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor:
  328. r"""Pass the input through the Encoder.
  329. Args:
  330. x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq).
  331. input_lengths (Tensor): The length of each input sequence with shape (n_batch, ).
  332. Return:
  333. x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim).
  334. """
  335. for conv in self.convolutions:
  336. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  337. x = x.transpose(1, 2)
  338. input_lengths = input_lengths.cpu()
  339. x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True)
  340. outputs, _ = self.lstm(x)
  341. outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
  342. return outputs
  343. class _Decoder(nn.Module):
  344. r"""Decoder with Attention model.
  345. Args:
  346. n_mels (int): number of mel bins
  347. n_frames_per_step (int): number of frames processed per step, only 1 is supported
  348. encoder_embedding_dim (int): the number of embedding dimensions in the encoder.
  349. decoder_rnn_dim (int): number of units in decoder LSTM
  350. decoder_max_step (int): maximum number of output mel spectrograms
  351. decoder_dropout (float): dropout probability for decoder LSTM
  352. decoder_early_stopping (bool): stop decoding when all samples are finished
  353. attention_rnn_dim (int): number of units in attention LSTM
  354. attention_hidden_dim (int): dimension of attention hidden representation
  355. attention_location_n_filter (int): number of filters for attention model
  356. attention_location_kernel_size (int): kernel size for attention model
  357. attention_dropout (float): dropout probability for attention LSTM
  358. prenet_dim (int): number of ReLU units in prenet layers
  359. gate_threshold (float): probability threshold for stop token
  360. """
  361. def __init__(
  362. self,
  363. n_mels: int,
  364. n_frames_per_step: int,
  365. encoder_embedding_dim: int,
  366. decoder_rnn_dim: int,
  367. decoder_max_step: int,
  368. decoder_dropout: float,
  369. decoder_early_stopping: bool,
  370. attention_rnn_dim: int,
  371. attention_hidden_dim: int,
  372. attention_location_n_filter: int,
  373. attention_location_kernel_size: int,
  374. attention_dropout: float,
  375. prenet_dim: int,
  376. gate_threshold: float,
  377. ) -> None:
  378. super().__init__()
  379. self.n_mels = n_mels
  380. self.n_frames_per_step = n_frames_per_step
  381. self.encoder_embedding_dim = encoder_embedding_dim
  382. self.attention_rnn_dim = attention_rnn_dim
  383. self.decoder_rnn_dim = decoder_rnn_dim
  384. self.prenet_dim = prenet_dim
  385. self.decoder_max_step = decoder_max_step
  386. self.gate_threshold = gate_threshold
  387. self.attention_dropout = attention_dropout
  388. self.decoder_dropout = decoder_dropout
  389. self.decoder_early_stopping = decoder_early_stopping
  390. self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim])
  391. self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim)
  392. self.attention_layer = _Attention(
  393. attention_rnn_dim,
  394. encoder_embedding_dim,
  395. attention_hidden_dim,
  396. attention_location_n_filter,
  397. attention_location_kernel_size,
  398. )
  399. self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True)
  400. self.linear_projection = _get_linear_layer(decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step)
  401. self.gate_layer = _get_linear_layer(
  402. decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid"
  403. )
  404. def _get_initial_frame(self, memory: Tensor) -> Tensor:
  405. r"""Gets all zeros frames to use as the first decoder input.
  406. Args:
  407. memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
  408. Returns:
  409. decoder_input (Tensor): all zeros frames with shape
  410. (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``).
  411. """
  412. n_batch = memory.size(0)
  413. dtype = memory.dtype
  414. device = memory.device
  415. decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
  416. return decoder_input
  417. def _initialize_decoder_states(
  418. self, memory: Tensor
  419. ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  420. r"""Initializes attention rnn states, decoder rnn states, attention
  421. weights, attention cumulative weights, attention context, stores memory
  422. and stores processed memory.
  423. Args:
  424. memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
  425. Returns:
  426. attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
  427. attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
  428. decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
  429. decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
  430. attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
  431. attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
  432. attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
  433. processed_memory (Tensor): Processed encoder outputs
  434. with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
  435. """
  436. n_batch = memory.size(0)
  437. max_time = memory.size(1)
  438. dtype = memory.dtype
  439. device = memory.device
  440. attention_hidden = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
  441. attention_cell = torch.zeros(n_batch, self.attention_rnn_dim, dtype=dtype, device=device)
  442. decoder_hidden = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
  443. decoder_cell = torch.zeros(n_batch, self.decoder_rnn_dim, dtype=dtype, device=device)
  444. attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
  445. attention_weights_cum = torch.zeros(n_batch, max_time, dtype=dtype, device=device)
  446. attention_context = torch.zeros(n_batch, self.encoder_embedding_dim, dtype=dtype, device=device)
  447. processed_memory = self.attention_layer.memory_layer(memory)
  448. return (
  449. attention_hidden,
  450. attention_cell,
  451. decoder_hidden,
  452. decoder_cell,
  453. attention_weights,
  454. attention_weights_cum,
  455. attention_context,
  456. processed_memory,
  457. )
  458. def _parse_decoder_inputs(self, decoder_inputs: Tensor) -> Tensor:
  459. r"""Prepares decoder inputs.
  460. Args:
  461. decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs,
  462. with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
  463. Returns:
  464. inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``).
  465. """
  466. # (n_batch, n_mels, mel_specgram_lengths.max()) -> (n_batch, mel_specgram_lengths.max(), n_mels)
  467. decoder_inputs = decoder_inputs.transpose(1, 2)
  468. decoder_inputs = decoder_inputs.view(
  469. decoder_inputs.size(0),
  470. int(decoder_inputs.size(1) / self.n_frames_per_step),
  471. -1,
  472. )
  473. # (n_batch, mel_specgram_lengths.max(), n_mels) -> (mel_specgram_lengths.max(), n_batch, n_mels)
  474. decoder_inputs = decoder_inputs.transpose(0, 1)
  475. return decoder_inputs
  476. def _parse_decoder_outputs(
  477. self, mel_specgram: Tensor, gate_outputs: Tensor, alignments: Tensor
  478. ) -> Tuple[Tensor, Tensor, Tensor]:
  479. r"""Prepares decoder outputs for output
  480. Args:
  481. mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
  482. gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
  483. alignments (Tensor): sequence of attention weights from the decoder
  484. with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)
  485. Returns:
  486. mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
  487. gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``)
  488. alignments (Tensor): sequence of attention weights from the decoder
  489. with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``)
  490. """
  491. # (mel_specgram_lengths.max(), n_batch, text_lengths.max())
  492. # -> (n_batch, mel_specgram_lengths.max(), text_lengths.max())
  493. alignments = alignments.transpose(0, 1).contiguous()
  494. # (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max())
  495. gate_outputs = gate_outputs.transpose(0, 1).contiguous()
  496. # (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels)
  497. mel_specgram = mel_specgram.transpose(0, 1).contiguous()
  498. # decouple frames per step
  499. shape = (mel_specgram.shape[0], -1, self.n_mels)
  500. mel_specgram = mel_specgram.view(*shape)
  501. # (n_batch, mel_specgram_lengths.max(), n_mels) -> (n_batch, n_mels, T_out)
  502. mel_specgram = mel_specgram.transpose(1, 2)
  503. return mel_specgram, gate_outputs, alignments
  504. def decode(
  505. self,
  506. decoder_input: Tensor,
  507. attention_hidden: Tensor,
  508. attention_cell: Tensor,
  509. decoder_hidden: Tensor,
  510. decoder_cell: Tensor,
  511. attention_weights: Tensor,
  512. attention_weights_cum: Tensor,
  513. attention_context: Tensor,
  514. memory: Tensor,
  515. processed_memory: Tensor,
  516. mask: Tensor,
  517. ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  518. r"""Decoder step using stored states, attention and memory
  519. Args:
  520. decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``).
  521. attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
  522. attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
  523. decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
  524. decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
  525. attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
  526. attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
  527. attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
  528. memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
  529. processed_memory (Tensor): Processed Encoder outputs
  530. with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
  531. mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).
  532. Returns:
  533. decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``).
  534. gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``).
  535. attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
  536. attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
  537. decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
  538. decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
  539. attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
  540. attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
  541. attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
  542. """
  543. cell_input = torch.cat((decoder_input, attention_context), -1)
  544. attention_hidden, attention_cell = self.attention_rnn(cell_input, (attention_hidden, attention_cell))
  545. attention_hidden = F.dropout(attention_hidden, self.attention_dropout, self.training)
  546. attention_weights_cat = torch.cat((attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1)
  547. attention_context, attention_weights = self.attention_layer(
  548. attention_hidden, memory, processed_memory, attention_weights_cat, mask
  549. )
  550. attention_weights_cum += attention_weights
  551. decoder_input = torch.cat((attention_hidden, attention_context), -1)
  552. decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell))
  553. decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training)
  554. decoder_hidden_attention_context = torch.cat((decoder_hidden, attention_context), dim=1)
  555. decoder_output = self.linear_projection(decoder_hidden_attention_context)
  556. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  557. return (
  558. decoder_output,
  559. gate_prediction,
  560. attention_hidden,
  561. attention_cell,
  562. decoder_hidden,
  563. decoder_cell,
  564. attention_weights,
  565. attention_weights_cum,
  566. attention_context,
  567. )
  568. def forward(
  569. self, memory: Tensor, mel_specgram_truth: Tensor, memory_lengths: Tensor
  570. ) -> Tuple[Tensor, Tensor, Tensor]:
  571. r"""Decoder forward pass for training.
  572. Args:
  573. memory (Tensor): Encoder outputs
  574. with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
  575. mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing
  576. with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
  577. memory_lengths (Tensor): Encoder output lengths for attention masking
  578. (the same as ``text_lengths``) with shape (n_batch, ).
  579. Returns:
  580. mel_specgram (Tensor): Predicted mel spectrogram
  581. with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
  582. gate_outputs (Tensor): Predicted stop token for each timestep
  583. with shape (n_batch, max of ``mel_specgram_lengths``).
  584. alignments (Tensor): Sequence of attention weights from the decoder
  585. with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
  586. """
  587. decoder_input = self._get_initial_frame(memory).unsqueeze(0)
  588. decoder_inputs = self._parse_decoder_inputs(mel_specgram_truth)
  589. decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
  590. decoder_inputs = self.prenet(decoder_inputs)
  591. mask = _get_mask_from_lengths(memory_lengths)
  592. (
  593. attention_hidden,
  594. attention_cell,
  595. decoder_hidden,
  596. decoder_cell,
  597. attention_weights,
  598. attention_weights_cum,
  599. attention_context,
  600. processed_memory,
  601. ) = self._initialize_decoder_states(memory)
  602. mel_outputs, gate_outputs, alignments = [], [], []
  603. while len(mel_outputs) < decoder_inputs.size(0) - 1:
  604. decoder_input = decoder_inputs[len(mel_outputs)]
  605. (
  606. mel_output,
  607. gate_output,
  608. attention_hidden,
  609. attention_cell,
  610. decoder_hidden,
  611. decoder_cell,
  612. attention_weights,
  613. attention_weights_cum,
  614. attention_context,
  615. ) = self.decode(
  616. decoder_input,
  617. attention_hidden,
  618. attention_cell,
  619. decoder_hidden,
  620. decoder_cell,
  621. attention_weights,
  622. attention_weights_cum,
  623. attention_context,
  624. memory,
  625. processed_memory,
  626. mask,
  627. )
  628. mel_outputs += [mel_output.squeeze(1)]
  629. gate_outputs += [gate_output.squeeze(1)]
  630. alignments += [attention_weights]
  631. mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs(
  632. torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments)
  633. )
  634. return mel_specgram, gate_outputs, alignments
  635. def _get_go_frame(self, memory: Tensor) -> Tensor:
  636. """Gets all zeros frames to use as the first decoder input
  637. args:
  638. memory (Tensor): Encoder outputs
  639. with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
  640. returns:
  641. decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
  642. """
  643. n_batch = memory.size(0)
  644. dtype = memory.dtype
  645. device = memory.device
  646. decoder_input = torch.zeros(n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device)
  647. return decoder_input
  648. @torch.jit.export
  649. def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  650. """Decoder inference
  651. Args:
  652. memory (Tensor): Encoder outputs
  653. with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
  654. memory_lengths (Tensor): Encoder output lengths for attention masking
  655. (the same as ``text_lengths``) with shape (n_batch, ).
  656. Returns:
  657. mel_specgram (Tensor): Predicted mel spectrogram
  658. with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
  659. mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
  660. gate_outputs (Tensor): Predicted stop token for each timestep
  661. with shape (n_batch, max of ``mel_specgram_lengths``).
  662. alignments (Tensor): Sequence of attention weights from the decoder
  663. with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
  664. """
  665. batch_size, device = memory.size(0), memory.device
  666. decoder_input = self._get_go_frame(memory)
  667. mask = _get_mask_from_lengths(memory_lengths)
  668. (
  669. attention_hidden,
  670. attention_cell,
  671. decoder_hidden,
  672. decoder_cell,
  673. attention_weights,
  674. attention_weights_cum,
  675. attention_context,
  676. processed_memory,
  677. ) = self._initialize_decoder_states(memory)
  678. mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
  679. finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
  680. mel_specgrams: List[Tensor] = []
  681. gate_outputs: List[Tensor] = []
  682. alignments: List[Tensor] = []
  683. for _ in range(self.decoder_max_step):
  684. decoder_input = self.prenet(decoder_input)
  685. (
  686. mel_specgram,
  687. gate_output,
  688. attention_hidden,
  689. attention_cell,
  690. decoder_hidden,
  691. decoder_cell,
  692. attention_weights,
  693. attention_weights_cum,
  694. attention_context,
  695. ) = self.decode(
  696. decoder_input,
  697. attention_hidden,
  698. attention_cell,
  699. decoder_hidden,
  700. decoder_cell,
  701. attention_weights,
  702. attention_weights_cum,
  703. attention_context,
  704. memory,
  705. processed_memory,
  706. mask,
  707. )
  708. mel_specgrams.append(mel_specgram.unsqueeze(0))
  709. gate_outputs.append(gate_output.transpose(0, 1))
  710. alignments.append(attention_weights)
  711. mel_specgram_lengths[~finished] += 1
  712. finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
  713. if self.decoder_early_stopping and torch.all(finished):
  714. break
  715. decoder_input = mel_specgram
  716. if len(mel_specgrams) == self.decoder_max_step:
  717. warnings.warn(
  718. "Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
  719. )
  720. mel_specgrams = torch.cat(mel_specgrams, dim=0)
  721. gate_outputs = torch.cat(gate_outputs, dim=0)
  722. alignments = torch.cat(alignments, dim=0)
  723. mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)
  724. return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments
  725. class Tacotron2(nn.Module):
  726. r"""Tacotron2 model based on the implementation from
  727. `Nvidia <https://github.com/NVIDIA/DeepLearningExamples/>`_.
  728. The original implementation was introduced in
  729. *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
  730. [:footcite:`shen2018natural`].
  731. Args:
  732. mask_padding (bool, optional): Use mask padding (Default: ``False``).
  733. n_mels (int, optional): Number of mel bins (Default: ``80``).
  734. n_symbol (int, optional): Number of symbols for the input text (Default: ``148``).
  735. n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``).
  736. symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``).
  737. encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``).
  738. encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``).
  739. encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``).
  740. decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``).
  741. decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``).
  742. decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``).
  743. decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``).
  744. attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``).
  745. attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``).
  746. attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``).
  747. attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``).
  748. attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``).
  749. prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``).
  750. postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``).
  751. postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``).
  752. postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``).
  753. gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``).
  754. """
  755. def __init__(
  756. self,
  757. mask_padding: bool = False,
  758. n_mels: int = 80,
  759. n_symbol: int = 148,
  760. n_frames_per_step: int = 1,
  761. symbol_embedding_dim: int = 512,
  762. encoder_embedding_dim: int = 512,
  763. encoder_n_convolution: int = 3,
  764. encoder_kernel_size: int = 5,
  765. decoder_rnn_dim: int = 1024,
  766. decoder_max_step: int = 2000,
  767. decoder_dropout: float = 0.1,
  768. decoder_early_stopping: bool = True,
  769. attention_rnn_dim: int = 1024,
  770. attention_hidden_dim: int = 128,
  771. attention_location_n_filter: int = 32,
  772. attention_location_kernel_size: int = 31,
  773. attention_dropout: float = 0.1,
  774. prenet_dim: int = 256,
  775. postnet_n_convolution: int = 5,
  776. postnet_kernel_size: int = 5,
  777. postnet_embedding_dim: int = 512,
  778. gate_threshold: float = 0.5,
  779. ) -> None:
  780. super().__init__()
  781. self.mask_padding = mask_padding
  782. self.n_mels = n_mels
  783. self.n_frames_per_step = n_frames_per_step
  784. self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim)
  785. torch.nn.init.xavier_uniform_(self.embedding.weight)
  786. self.encoder = _Encoder(encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size)
  787. self.decoder = _Decoder(
  788. n_mels,
  789. n_frames_per_step,
  790. encoder_embedding_dim,
  791. decoder_rnn_dim,
  792. decoder_max_step,
  793. decoder_dropout,
  794. decoder_early_stopping,
  795. attention_rnn_dim,
  796. attention_hidden_dim,
  797. attention_location_n_filter,
  798. attention_location_kernel_size,
  799. attention_dropout,
  800. prenet_dim,
  801. gate_threshold,
  802. )
  803. self.postnet = _Postnet(n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution)
  804. def forward(
  805. self,
  806. tokens: Tensor,
  807. token_lengths: Tensor,
  808. mel_specgram: Tensor,
  809. mel_specgram_lengths: Tensor,
  810. ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  811. r"""Pass the input through the Tacotron2 model. This is in teacher
  812. forcing mode, which is generally used for training.
  813. The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
  814. The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
  815. Args:
  816. tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
  817. token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
  818. mel_specgram (Tensor): The target mel spectrogram
  819. with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
  820. mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
  821. Returns:
  822. [Tensor, Tensor, Tensor, Tensor]:
  823. Tensor
  824. Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
  825. Tensor
  826. Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
  827. Tensor
  828. The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
  829. Tensor
  830. Sequence of attention weights from the decoder with
  831. shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
  832. """
  833. embedded_inputs = self.embedding(tokens).transpose(1, 2)
  834. encoder_outputs = self.encoder(embedded_inputs, token_lengths)
  835. mel_specgram, gate_outputs, alignments = self.decoder(
  836. encoder_outputs, mel_specgram, memory_lengths=token_lengths
  837. )
  838. mel_specgram_postnet = self.postnet(mel_specgram)
  839. mel_specgram_postnet = mel_specgram + mel_specgram_postnet
  840. if self.mask_padding:
  841. mask = _get_mask_from_lengths(mel_specgram_lengths)
  842. mask = mask.expand(self.n_mels, mask.size(0), mask.size(1))
  843. mask = mask.permute(1, 0, 2)
  844. mel_specgram.masked_fill_(mask, 0.0)
  845. mel_specgram_postnet.masked_fill_(mask, 0.0)
  846. gate_outputs.masked_fill_(mask[:, 0, :], 1e3)
  847. return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
  848. @torch.jit.export
  849. def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
  850. r"""Using Tacotron2 for inference. The input is a batch of encoded
  851. sentences (``tokens``) and its corresponding lengths (``lengths``). The
  852. output is the generated mel spectrograms, its corresponding lengths, and
  853. the attention weights from the decoder.
  854. The input `tokens` should be padded with zeros to length max of ``lengths``.
  855. Args:
  856. tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
  857. lengths (Tensor or None, optional):
  858. The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
  859. If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
  860. Returns:
  861. (Tensor, Tensor, Tensor):
  862. Tensor
  863. The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
  864. Tensor
  865. The length of the predicted mel spectrogram with shape `(n_batch, )`.
  866. Tensor
  867. Sequence of attention weights from the decoder with shape
  868. `(n_batch, max of mel_specgram_lengths, max of lengths)`.
  869. """
  870. n_batch, max_length = tokens.shape
  871. if lengths is None:
  872. lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
  873. assert lengths is not None # For TorchScript compiler
  874. embedded_inputs = self.embedding(tokens).transpose(1, 2)
  875. encoder_outputs = self.encoder(embedded_inputs, lengths)
  876. mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(encoder_outputs, lengths)
  877. mel_outputs_postnet = self.postnet(mel_specgram)
  878. mel_outputs_postnet = mel_specgram + mel_outputs_postnet
  879. alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
  880. return mel_outputs_postnet, mel_specgram_lengths, alignments