rnnt.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from torchaudio.models import Emformer
  5. __all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
  6. class _TimeReduction(torch.nn.Module):
  7. r"""Coalesces frames along time dimension into a
  8. fewer number of frames with higher feature dimensionality.
  9. Args:
  10. stride (int): number of frames to merge for each output frame.
  11. """
  12. def __init__(self, stride: int) -> None:
  13. super().__init__()
  14. self.stride = stride
  15. def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  16. r"""Forward pass.
  17. B: batch size;
  18. T: maximum input sequence length in batch;
  19. D: feature dimension of each input sequence frame.
  20. Args:
  21. input (torch.Tensor): input sequences, with shape `(B, T, D)`.
  22. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  23. number of valid frames for i-th batch element in ``input``.
  24. Returns:
  25. (torch.Tensor, torch.Tensor):
  26. torch.Tensor
  27. output sequences, with shape
  28. `(B, T // stride, D * stride)`
  29. torch.Tensor
  30. output lengths, with shape `(B,)` and i-th element representing
  31. number of valid frames for i-th batch element in output sequences.
  32. """
  33. B, T, D = input.shape
  34. num_frames = T - (T % self.stride)
  35. input = input[:, :num_frames, :]
  36. lengths = lengths.div(self.stride, rounding_mode="trunc")
  37. T_max = num_frames // self.stride
  38. output = input.reshape(B, T_max, D * self.stride)
  39. output = output.contiguous()
  40. return output, lengths
  41. class _CustomLSTM(torch.nn.Module):
  42. r"""Custom long-short-term memory (LSTM) block that applies layer normalization
  43. to internal nodes.
  44. Args:
  45. input_dim (int): input dimension.
  46. hidden_dim (int): hidden dimension.
  47. layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
  48. layer_norm_epsilon (float, optional): value of epsilon to use in
  49. layer normalization layers (Default: 1e-5)
  50. """
  51. def __init__(
  52. self,
  53. input_dim: int,
  54. hidden_dim: int,
  55. layer_norm: bool = False,
  56. layer_norm_epsilon: float = 1e-5,
  57. ) -> None:
  58. super().__init__()
  59. self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
  60. self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
  61. if layer_norm:
  62. self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
  63. self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
  64. else:
  65. self.c_norm = torch.nn.Identity()
  66. self.g_norm = torch.nn.Identity()
  67. self.hidden_dim = hidden_dim
  68. def forward(
  69. self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
  70. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  71. r"""Forward pass.
  72. B: batch size;
  73. T: maximum sequence length in batch;
  74. D: feature dimension of each input sequence element.
  75. Args:
  76. input (torch.Tensor): with shape `(T, B, D)`.
  77. state (List[torch.Tensor] or None): list of tensors
  78. representing internal state generated in preceding invocation
  79. of ``forward``.
  80. Returns:
  81. (torch.Tensor, List[torch.Tensor]):
  82. torch.Tensor
  83. output, with shape `(T, B, hidden_dim)`.
  84. List[torch.Tensor]
  85. list of tensors representing internal state generated
  86. in current invocation of ``forward``.
  87. """
  88. if state is None:
  89. B = input.size(1)
  90. h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
  91. c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
  92. else:
  93. h, c = state
  94. gated_input = self.x2g(input)
  95. outputs = []
  96. for gates in gated_input.unbind(0):
  97. gates = gates + self.p2g(h)
  98. gates = self.g_norm(gates)
  99. input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
  100. input_gate = input_gate.sigmoid()
  101. forget_gate = forget_gate.sigmoid()
  102. cell_gate = cell_gate.tanh()
  103. output_gate = output_gate.sigmoid()
  104. c = forget_gate * c + input_gate * cell_gate
  105. c = self.c_norm(c)
  106. h = output_gate * c.tanh()
  107. outputs.append(h)
  108. output = torch.stack(outputs, dim=0)
  109. state = [h, c]
  110. return output, state
  111. class _Transcriber(ABC):
  112. @abstractmethod
  113. def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  114. pass
  115. @abstractmethod
  116. def infer(
  117. self,
  118. input: torch.Tensor,
  119. lengths: torch.Tensor,
  120. states: Optional[List[List[torch.Tensor]]],
  121. ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
  122. pass
  123. class _EmformerEncoder(torch.nn.Module, _Transcriber):
  124. r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
  125. Args:
  126. input_dim (int): feature dimension of each input sequence element.
  127. output_dim (int): feature dimension of each output sequence element.
  128. segment_length (int): length of input segment expressed as number of frames.
  129. right_context_length (int): length of right context expressed as number of frames.
  130. time_reduction_input_dim (int): dimension to scale each element in input sequences to
  131. prior to applying time reduction block.
  132. time_reduction_stride (int): factor by which to reduce length of input sequence.
  133. transformer_num_heads (int): number of attention heads in each Emformer layer.
  134. transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
  135. transformer_num_layers (int): number of Emformer layers to instantiate.
  136. transformer_left_context_length (int): length of left context.
  137. transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
  138. transformer_activation (str, optional): activation function to use in each Emformer layer's
  139. feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
  140. transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
  141. transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
  142. strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
  143. transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
  144. """
  145. def __init__(
  146. self,
  147. *,
  148. input_dim: int,
  149. output_dim: int,
  150. segment_length: int,
  151. right_context_length: int,
  152. time_reduction_input_dim: int,
  153. time_reduction_stride: int,
  154. transformer_num_heads: int,
  155. transformer_ffn_dim: int,
  156. transformer_num_layers: int,
  157. transformer_left_context_length: int,
  158. transformer_dropout: float = 0.0,
  159. transformer_activation: str = "relu",
  160. transformer_max_memory_size: int = 0,
  161. transformer_weight_init_scale_strategy: str = "depthwise",
  162. transformer_tanh_on_mem: bool = False,
  163. ) -> None:
  164. super().__init__()
  165. self.input_linear = torch.nn.Linear(
  166. input_dim,
  167. time_reduction_input_dim,
  168. bias=False,
  169. )
  170. self.time_reduction = _TimeReduction(time_reduction_stride)
  171. transformer_input_dim = time_reduction_input_dim * time_reduction_stride
  172. self.transformer = Emformer(
  173. transformer_input_dim,
  174. transformer_num_heads,
  175. transformer_ffn_dim,
  176. transformer_num_layers,
  177. segment_length // time_reduction_stride,
  178. dropout=transformer_dropout,
  179. activation=transformer_activation,
  180. left_context_length=transformer_left_context_length,
  181. right_context_length=right_context_length // time_reduction_stride,
  182. max_memory_size=transformer_max_memory_size,
  183. weight_init_scale_strategy=transformer_weight_init_scale_strategy,
  184. tanh_on_mem=transformer_tanh_on_mem,
  185. )
  186. self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
  187. self.layer_norm = torch.nn.LayerNorm(output_dim)
  188. def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  189. r"""Forward pass for training.
  190. B: batch size;
  191. T: maximum input sequence length in batch;
  192. D: feature dimension of each input sequence frame (input_dim).
  193. Args:
  194. input (torch.Tensor): input frame sequences right-padded with right context, with
  195. shape `(B, T + right context length, D)`.
  196. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  197. number of valid frames for i-th batch element in ``input``.
  198. Returns:
  199. (torch.Tensor, torch.Tensor):
  200. torch.Tensor
  201. output frame sequences, with
  202. shape `(B, T // time_reduction_stride, output_dim)`.
  203. torch.Tensor
  204. output input lengths, with shape `(B,)` and i-th element representing
  205. number of valid elements for i-th batch element in output frame sequences.
  206. """
  207. input_linear_out = self.input_linear(input)
  208. time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
  209. transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
  210. output_linear_out = self.output_linear(transformer_out)
  211. layer_norm_out = self.layer_norm(output_linear_out)
  212. return layer_norm_out, transformer_lengths
  213. @torch.jit.export
  214. def infer(
  215. self,
  216. input: torch.Tensor,
  217. lengths: torch.Tensor,
  218. states: Optional[List[List[torch.Tensor]]],
  219. ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
  220. r"""Forward pass for inference.
  221. B: batch size;
  222. T: maximum input sequence segment length in batch;
  223. D: feature dimension of each input sequence frame (input_dim).
  224. Args:
  225. input (torch.Tensor): input frame sequence segments right-padded with right context, with
  226. shape `(B, T + right context length, D)`.
  227. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  228. number of valid frames for i-th batch element in ``input``.
  229. state (List[List[torch.Tensor]] or None): list of lists of tensors
  230. representing internal state generated in preceding invocation
  231. of ``infer``.
  232. Returns:
  233. (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
  234. torch.Tensor
  235. output frame sequences, with
  236. shape `(B, T // time_reduction_stride, output_dim)`.
  237. torch.Tensor
  238. output input lengths, with shape `(B,)` and i-th element representing
  239. number of valid elements for i-th batch element in output.
  240. List[List[torch.Tensor]]
  241. output states; list of lists of tensors
  242. representing internal state generated in current invocation
  243. of ``infer``.
  244. """
  245. input_linear_out = self.input_linear(input)
  246. time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
  247. (
  248. transformer_out,
  249. transformer_lengths,
  250. transformer_states,
  251. ) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
  252. output_linear_out = self.output_linear(transformer_out)
  253. layer_norm_out = self.layer_norm(output_linear_out)
  254. return layer_norm_out, transformer_lengths, transformer_states
  255. class _Predictor(torch.nn.Module):
  256. r"""Recurrent neural network transducer (RNN-T) prediction network.
  257. Args:
  258. num_symbols (int): size of target token lexicon.
  259. output_dim (int): feature dimension of each output sequence element.
  260. symbol_embedding_dim (int): dimension of each target token embedding.
  261. num_lstm_layers (int): number of LSTM layers to instantiate.
  262. lstm_hidden_dim (int): output dimension of each LSTM layer.
  263. lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
  264. for LSTM layers. (Default: ``False``)
  265. lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
  266. LSTM layer normalization layers. (Default: 1e-5)
  267. lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
  268. """
  269. def __init__(
  270. self,
  271. num_symbols: int,
  272. output_dim: int,
  273. symbol_embedding_dim: int,
  274. num_lstm_layers: int,
  275. lstm_hidden_dim: int,
  276. lstm_layer_norm: bool = False,
  277. lstm_layer_norm_epsilon: float = 1e-5,
  278. lstm_dropout: float = 0.0,
  279. ) -> None:
  280. super().__init__()
  281. self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
  282. self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
  283. self.lstm_layers = torch.nn.ModuleList(
  284. [
  285. _CustomLSTM(
  286. symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
  287. lstm_hidden_dim,
  288. layer_norm=lstm_layer_norm,
  289. layer_norm_epsilon=lstm_layer_norm_epsilon,
  290. )
  291. for idx in range(num_lstm_layers)
  292. ]
  293. )
  294. self.dropout = torch.nn.Dropout(p=lstm_dropout)
  295. self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
  296. self.output_layer_norm = torch.nn.LayerNorm(output_dim)
  297. self.lstm_dropout = lstm_dropout
  298. def forward(
  299. self,
  300. input: torch.Tensor,
  301. lengths: torch.Tensor,
  302. state: Optional[List[List[torch.Tensor]]] = None,
  303. ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
  304. r"""Forward pass.
  305. B: batch size;
  306. U: maximum sequence length in batch;
  307. D: feature dimension of each input sequence element.
  308. Args:
  309. input (torch.Tensor): target sequences, with shape `(B, U)` and each element
  310. mapping to a target symbol, i.e. in range `[0, num_symbols)`.
  311. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  312. number of valid frames for i-th batch element in ``input``.
  313. state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
  314. representing internal state generated in preceding invocation
  315. of ``forward``. (Default: ``None``)
  316. Returns:
  317. (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
  318. torch.Tensor
  319. output encoding sequences, with shape `(B, U, output_dim)`
  320. torch.Tensor
  321. output lengths, with shape `(B,)` and i-th element representing
  322. number of valid elements for i-th batch element in output encoding sequences.
  323. List[List[torch.Tensor]]
  324. output states; list of lists of tensors
  325. representing internal state generated in current invocation of ``forward``.
  326. """
  327. input_tb = input.permute(1, 0)
  328. embedding_out = self.embedding(input_tb)
  329. input_layer_norm_out = self.input_layer_norm(embedding_out)
  330. lstm_out = input_layer_norm_out
  331. state_out: List[List[torch.Tensor]] = []
  332. for layer_idx, lstm in enumerate(self.lstm_layers):
  333. lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
  334. lstm_out = self.dropout(lstm_out)
  335. state_out.append(lstm_state_out)
  336. linear_out = self.linear(lstm_out)
  337. output_layer_norm_out = self.output_layer_norm(linear_out)
  338. return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
  339. class _Joiner(torch.nn.Module):
  340. r"""Recurrent neural network transducer (RNN-T) joint network.
  341. Args:
  342. input_dim (int): source and target input dimension.
  343. output_dim (int): output dimension.
  344. activation (str, optional): activation function to use in the joiner.
  345. Must be one of ("relu", "tanh"). (Default: "relu")
  346. """
  347. def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
  348. super().__init__()
  349. self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
  350. if activation == "relu":
  351. self.activation = torch.nn.ReLU()
  352. elif activation == "tanh":
  353. self.activation = torch.nn.Tanh()
  354. else:
  355. raise ValueError(f"Unsupported activation {activation}")
  356. def forward(
  357. self,
  358. source_encodings: torch.Tensor,
  359. source_lengths: torch.Tensor,
  360. target_encodings: torch.Tensor,
  361. target_lengths: torch.Tensor,
  362. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  363. r"""Forward pass for training.
  364. B: batch size;
  365. T: maximum source sequence length in batch;
  366. U: maximum target sequence length in batch;
  367. D: dimension of each source and target sequence encoding.
  368. Args:
  369. source_encodings (torch.Tensor): source encoding sequences, with
  370. shape `(B, T, D)`.
  371. source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  372. valid sequence length of i-th batch element in ``source_encodings``.
  373. target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
  374. target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  375. valid sequence length of i-th batch element in ``target_encodings``.
  376. Returns:
  377. (torch.Tensor, torch.Tensor, torch.Tensor):
  378. torch.Tensor
  379. joint network output, with shape `(B, T, U, output_dim)`.
  380. torch.Tensor
  381. output source lengths, with shape `(B,)` and i-th element representing
  382. number of valid elements along dim 1 for i-th batch element in joint network output.
  383. torch.Tensor
  384. output target lengths, with shape `(B,)` and i-th element representing
  385. number of valid elements along dim 2 for i-th batch element in joint network output.
  386. """
  387. joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
  388. activation_out = self.activation(joint_encodings)
  389. output = self.linear(activation_out)
  390. return output, source_lengths, target_lengths
  391. class RNNT(torch.nn.Module):
  392. r"""torchaudio.models.RNNT()
  393. Recurrent neural network transducer (RNN-T) model.
  394. Note:
  395. To build the model, please use one of the factory functions.
  396. Args:
  397. transcriber (torch.nn.Module): transcription network.
  398. predictor (torch.nn.Module): prediction network.
  399. joiner (torch.nn.Module): joint network.
  400. """
  401. def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
  402. super().__init__()
  403. self.transcriber = transcriber
  404. self.predictor = predictor
  405. self.joiner = joiner
  406. def forward(
  407. self,
  408. sources: torch.Tensor,
  409. source_lengths: torch.Tensor,
  410. targets: torch.Tensor,
  411. target_lengths: torch.Tensor,
  412. predictor_state: Optional[List[List[torch.Tensor]]] = None,
  413. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
  414. r"""Forward pass for training.
  415. B: batch size;
  416. T: maximum source sequence length in batch;
  417. U: maximum target sequence length in batch;
  418. D: feature dimension of each source sequence element.
  419. Args:
  420. sources (torch.Tensor): source frame sequences right-padded with right context, with
  421. shape `(B, T, D)`.
  422. source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  423. number of valid frames for i-th batch element in ``sources``.
  424. targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
  425. mapping to a target symbol.
  426. target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  427. number of valid frames for i-th batch element in ``targets``.
  428. predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
  429. representing prediction network internal state generated in preceding invocation
  430. of ``forward``. (Default: ``None``)
  431. Returns:
  432. (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
  433. torch.Tensor
  434. joint network output, with shape
  435. `(B, max output source length, max output target length, output_dim (number of target symbols))`.
  436. torch.Tensor
  437. output source lengths, with shape `(B,)` and i-th element representing
  438. number of valid elements along dim 1 for i-th batch element in joint network output.
  439. torch.Tensor
  440. output target lengths, with shape `(B,)` and i-th element representing
  441. number of valid elements along dim 2 for i-th batch element in joint network output.
  442. List[List[torch.Tensor]]
  443. output states; list of lists of tensors
  444. representing prediction network internal state generated in current invocation
  445. of ``forward``.
  446. """
  447. source_encodings, source_lengths = self.transcriber(
  448. input=sources,
  449. lengths=source_lengths,
  450. )
  451. target_encodings, target_lengths, predictor_state = self.predictor(
  452. input=targets,
  453. lengths=target_lengths,
  454. state=predictor_state,
  455. )
  456. output, source_lengths, target_lengths = self.joiner(
  457. source_encodings=source_encodings,
  458. source_lengths=source_lengths,
  459. target_encodings=target_encodings,
  460. target_lengths=target_lengths,
  461. )
  462. return (
  463. output,
  464. source_lengths,
  465. target_lengths,
  466. predictor_state,
  467. )
  468. @torch.jit.export
  469. def transcribe_streaming(
  470. self,
  471. sources: torch.Tensor,
  472. source_lengths: torch.Tensor,
  473. state: Optional[List[List[torch.Tensor]]],
  474. ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
  475. r"""Applies transcription network to sources in streaming mode.
  476. B: batch size;
  477. T: maximum source sequence segment length in batch;
  478. D: feature dimension of each source sequence frame.
  479. Args:
  480. sources (torch.Tensor): source frame sequence segments right-padded with right context, with
  481. shape `(B, T + right context length, D)`.
  482. source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  483. number of valid frames for i-th batch element in ``sources``.
  484. state (List[List[torch.Tensor]] or None): list of lists of tensors
  485. representing transcription network internal state generated in preceding invocation
  486. of ``transcribe_streaming``.
  487. Returns:
  488. (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
  489. torch.Tensor
  490. output frame sequences, with
  491. shape `(B, T // time_reduction_stride, output_dim)`.
  492. torch.Tensor
  493. output lengths, with shape `(B,)` and i-th element representing
  494. number of valid elements for i-th batch element in output.
  495. List[List[torch.Tensor]]
  496. output states; list of lists of tensors
  497. representing transcription network internal state generated in current invocation
  498. of ``transcribe_streaming``.
  499. """
  500. return self.transcriber.infer(sources, source_lengths, state)
  501. @torch.jit.export
  502. def transcribe(
  503. self,
  504. sources: torch.Tensor,
  505. source_lengths: torch.Tensor,
  506. ) -> Tuple[torch.Tensor, torch.Tensor]:
  507. r"""Applies transcription network to sources in non-streaming mode.
  508. B: batch size;
  509. T: maximum source sequence length in batch;
  510. D: feature dimension of each source sequence frame.
  511. Args:
  512. sources (torch.Tensor): source frame sequences right-padded with right context, with
  513. shape `(B, T + right context length, D)`.
  514. source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  515. number of valid frames for i-th batch element in ``sources``.
  516. Returns:
  517. (torch.Tensor, torch.Tensor):
  518. torch.Tensor
  519. output frame sequences, with
  520. shape `(B, T // time_reduction_stride, output_dim)`.
  521. torch.Tensor
  522. output lengths, with shape `(B,)` and i-th element representing
  523. number of valid elements for i-th batch element in output frame sequences.
  524. """
  525. return self.transcriber(sources, source_lengths)
  526. @torch.jit.export
  527. def predict(
  528. self,
  529. targets: torch.Tensor,
  530. target_lengths: torch.Tensor,
  531. state: Optional[List[List[torch.Tensor]]],
  532. ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
  533. r"""Applies prediction network to targets.
  534. B: batch size;
  535. U: maximum target sequence length in batch;
  536. D: feature dimension of each target sequence frame.
  537. Args:
  538. targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
  539. mapping to a target symbol, i.e. in range `[0, num_symbols)`.
  540. target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  541. number of valid frames for i-th batch element in ``targets``.
  542. state (List[List[torch.Tensor]] or None): list of lists of tensors
  543. representing internal state generated in preceding invocation
  544. of ``predict``.
  545. Returns:
  546. (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
  547. torch.Tensor
  548. output frame sequences, with shape `(B, U, output_dim)`.
  549. torch.Tensor
  550. output lengths, with shape `(B,)` and i-th element representing
  551. number of valid elements for i-th batch element in output.
  552. List[List[torch.Tensor]]
  553. output states; list of lists of tensors
  554. representing internal state generated in current invocation of ``predict``.
  555. """
  556. return self.predictor(input=targets, lengths=target_lengths, state=state)
  557. @torch.jit.export
  558. def join(
  559. self,
  560. source_encodings: torch.Tensor,
  561. source_lengths: torch.Tensor,
  562. target_encodings: torch.Tensor,
  563. target_lengths: torch.Tensor,
  564. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  565. r"""Applies joint network to source and target encodings.
  566. B: batch size;
  567. T: maximum source sequence length in batch;
  568. U: maximum target sequence length in batch;
  569. D: dimension of each source and target sequence encoding.
  570. Args:
  571. source_encodings (torch.Tensor): source encoding sequences, with
  572. shape `(B, T, D)`.
  573. source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  574. valid sequence length of i-th batch element in ``source_encodings``.
  575. target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
  576. target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  577. valid sequence length of i-th batch element in ``target_encodings``.
  578. Returns:
  579. (torch.Tensor, torch.Tensor, torch.Tensor):
  580. torch.Tensor
  581. joint network output, with shape `(B, T, U, output_dim)`.
  582. torch.Tensor
  583. output source lengths, with shape `(B,)` and i-th element representing
  584. number of valid elements along dim 1 for i-th batch element in joint network output.
  585. torch.Tensor
  586. output target lengths, with shape `(B,)` and i-th element representing
  587. number of valid elements along dim 2 for i-th batch element in joint network output.
  588. """
  589. output, source_lengths, target_lengths = self.joiner(
  590. source_encodings=source_encodings,
  591. source_lengths=source_lengths,
  592. target_encodings=target_encodings,
  593. target_lengths=target_lengths,
  594. )
  595. return output, source_lengths, target_lengths
  596. def emformer_rnnt_model(
  597. *,
  598. input_dim: int,
  599. encoding_dim: int,
  600. num_symbols: int,
  601. segment_length: int,
  602. right_context_length: int,
  603. time_reduction_input_dim: int,
  604. time_reduction_stride: int,
  605. transformer_num_heads: int,
  606. transformer_ffn_dim: int,
  607. transformer_num_layers: int,
  608. transformer_dropout: float,
  609. transformer_activation: str,
  610. transformer_left_context_length: int,
  611. transformer_max_memory_size: int,
  612. transformer_weight_init_scale_strategy: str,
  613. transformer_tanh_on_mem: bool,
  614. symbol_embedding_dim: int,
  615. num_lstm_layers: int,
  616. lstm_layer_norm: bool,
  617. lstm_layer_norm_epsilon: float,
  618. lstm_dropout: float,
  619. ) -> RNNT:
  620. r"""Builds Emformer-based recurrent neural network transducer (RNN-T) model.
  621. Note:
  622. For non-streaming inference, the expectation is for `transcribe` to be called on input
  623. sequences right-concatenated with `right_context_length` frames.
  624. For streaming inference, the expectation is for `transcribe_streaming` to be called
  625. on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
  626. frames.
  627. Args:
  628. input_dim (int): dimension of input sequence frames passed to transcription network.
  629. encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
  630. passed to joint network.
  631. num_symbols (int): cardinality of set of target tokens.
  632. segment_length (int): length of input segment expressed as number of frames.
  633. right_context_length (int): length of right context expressed as number of frames.
  634. time_reduction_input_dim (int): dimension to scale each element in input sequences to
  635. prior to applying time reduction block.
  636. time_reduction_stride (int): factor by which to reduce length of input sequence.
  637. transformer_num_heads (int): number of attention heads in each Emformer layer.
  638. transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
  639. transformer_num_layers (int): number of Emformer layers to instantiate.
  640. transformer_left_context_length (int): length of left context considered by Emformer.
  641. transformer_dropout (float): Emformer dropout probability.
  642. transformer_activation (str): activation function to use in each Emformer layer's
  643. feedforward network. Must be one of ("relu", "gelu", "silu").
  644. transformer_max_memory_size (int): maximum number of memory elements to use.
  645. transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
  646. strategy. Must be one of ("depthwise", "constant", ``None``).
  647. transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
  648. symbol_embedding_dim (int): dimension of each target token embedding.
  649. num_lstm_layers (int): number of LSTM layers to instantiate.
  650. lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
  651. lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
  652. lstm_dropout (float): LSTM dropout probability.
  653. Returns:
  654. RNNT:
  655. Emformer RNN-T model.
  656. """
  657. encoder = _EmformerEncoder(
  658. input_dim=input_dim,
  659. output_dim=encoding_dim,
  660. segment_length=segment_length,
  661. right_context_length=right_context_length,
  662. time_reduction_input_dim=time_reduction_input_dim,
  663. time_reduction_stride=time_reduction_stride,
  664. transformer_num_heads=transformer_num_heads,
  665. transformer_ffn_dim=transformer_ffn_dim,
  666. transformer_num_layers=transformer_num_layers,
  667. transformer_dropout=transformer_dropout,
  668. transformer_activation=transformer_activation,
  669. transformer_left_context_length=transformer_left_context_length,
  670. transformer_max_memory_size=transformer_max_memory_size,
  671. transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
  672. transformer_tanh_on_mem=transformer_tanh_on_mem,
  673. )
  674. predictor = _Predictor(
  675. num_symbols,
  676. encoding_dim,
  677. symbol_embedding_dim=symbol_embedding_dim,
  678. num_lstm_layers=num_lstm_layers,
  679. lstm_hidden_dim=symbol_embedding_dim,
  680. lstm_layer_norm=lstm_layer_norm,
  681. lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
  682. lstm_dropout=lstm_dropout,
  683. )
  684. joiner = _Joiner(encoding_dim, num_symbols)
  685. return RNNT(encoder, predictor, joiner)
  686. def emformer_rnnt_base(num_symbols: int) -> RNNT:
  687. r"""Builds basic version of Emformer RNN-T model.
  688. Args:
  689. num_symbols (int): The size of target token lexicon.
  690. Returns:
  691. RNNT:
  692. Emformer RNN-T model.
  693. """
  694. return emformer_rnnt_model(
  695. input_dim=80,
  696. encoding_dim=1024,
  697. num_symbols=num_symbols,
  698. segment_length=16,
  699. right_context_length=4,
  700. time_reduction_input_dim=128,
  701. time_reduction_stride=4,
  702. transformer_num_heads=8,
  703. transformer_ffn_dim=2048,
  704. transformer_num_layers=20,
  705. transformer_dropout=0.1,
  706. transformer_activation="gelu",
  707. transformer_left_context_length=30,
  708. transformer_max_memory_size=0,
  709. transformer_weight_init_scale_strategy="depthwise",
  710. transformer_tanh_on_mem=True,
  711. symbol_embedding_dim=512,
  712. num_lstm_layers=3,
  713. lstm_layer_norm=True,
  714. lstm_layer_norm_epsilon=1e-3,
  715. lstm_dropout=0.3,
  716. )