rnnt_decoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. from typing import Callable, Dict, List, Optional, Tuple
  2. import torch
  3. from torchaudio.models import RNNT
  4. __all__ = ["Hypothesis", "RNNTBeamSearch"]
  5. Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
  6. Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
  7. represented as tuple of (tokens, prediction network output, prediction network state, score).
  8. """
  9. def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
  10. return hypo[0]
  11. def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
  12. return hypo[1]
  13. def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
  14. return hypo[2]
  15. def _get_hypo_score(hypo: Hypothesis) -> float:
  16. return hypo[3]
  17. def _get_hypo_key(hypo: Hypothesis) -> str:
  18. return str(hypo[0])
  19. def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
  20. states: List[List[torch.Tensor]] = []
  21. for i in range(len(_get_hypo_state(hypos[0]))):
  22. batched_state_components: List[torch.Tensor] = []
  23. for j in range(len(_get_hypo_state(hypos[0])[i])):
  24. batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
  25. states.append(batched_state_components)
  26. return states
  27. def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
  28. idx_tensor = torch.tensor([idx], device=device)
  29. return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
  30. def _default_hypo_sort_key(hypo: Hypothesis) -> float:
  31. return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
  32. def _compute_updated_scores(
  33. hypos: List[Hypothesis],
  34. next_token_probs: torch.Tensor,
  35. beam_width: int,
  36. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  37. hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
  38. nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
  39. nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
  40. nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
  41. nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
  42. return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
  43. def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
  44. for i, elem in enumerate(hypo_list):
  45. if _get_hypo_key(hypo) == _get_hypo_key(elem):
  46. del hypo_list[i]
  47. break
  48. class RNNTBeamSearch(torch.nn.Module):
  49. r"""Beam search decoder for RNN-T model.
  50. Args:
  51. model (RNNT): RNN-T model to use.
  52. blank (int): index of blank token in vocabulary.
  53. temperature (float, optional): temperature to apply to joint network output.
  54. Larger values yield more uniform samples. (Default: 1.0)
  55. hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
  56. for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
  57. hypothesis score normalized by token sequence length. (Default: None)
  58. step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
  59. """
  60. def __init__(
  61. self,
  62. model: RNNT,
  63. blank: int,
  64. temperature: float = 1.0,
  65. hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
  66. step_max_tokens: int = 100,
  67. ) -> None:
  68. super().__init__()
  69. self.model = model
  70. self.blank = blank
  71. self.temperature = temperature
  72. if hypo_sort_key is None:
  73. self.hypo_sort_key = _default_hypo_sort_key
  74. else:
  75. self.hypo_sort_key = hypo_sort_key
  76. self.step_max_tokens = step_max_tokens
  77. def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
  78. if hypo is not None:
  79. token = _get_hypo_tokens(hypo)[-1]
  80. state = _get_hypo_state(hypo)
  81. else:
  82. token = self.blank
  83. state = None
  84. one_tensor = torch.tensor([1], device=device)
  85. pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
  86. init_hypo = (
  87. [token],
  88. pred_out[0].detach(),
  89. pred_state,
  90. 0.0,
  91. )
  92. return [init_hypo]
  93. def _gen_next_token_probs(
  94. self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
  95. ) -> torch.Tensor:
  96. one_tensor = torch.tensor([1], device=device)
  97. predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
  98. joined_out, _, _ = self.model.join(
  99. enc_out,
  100. one_tensor,
  101. predictor_out,
  102. torch.tensor([1] * len(hypos), device=device),
  103. ) # [beam_width, 1, 1, num_tokens]
  104. joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
  105. return joined_out[:, 0, 0]
  106. def _gen_b_hypos(
  107. self,
  108. b_hypos: List[Hypothesis],
  109. a_hypos: List[Hypothesis],
  110. next_token_probs: torch.Tensor,
  111. key_to_b_hypo: Dict[str, Hypothesis],
  112. ) -> List[Hypothesis]:
  113. for i in range(len(a_hypos)):
  114. h_a = a_hypos[i]
  115. append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
  116. if _get_hypo_key(h_a) in key_to_b_hypo:
  117. h_b = key_to_b_hypo[_get_hypo_key(h_a)]
  118. _remove_hypo(h_b, b_hypos)
  119. score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
  120. else:
  121. score = float(append_blank_score)
  122. h_b = (
  123. _get_hypo_tokens(h_a),
  124. _get_hypo_predictor_out(h_a),
  125. _get_hypo_state(h_a),
  126. score,
  127. )
  128. b_hypos.append(h_b)
  129. key_to_b_hypo[_get_hypo_key(h_b)] = h_b
  130. _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
  131. return [b_hypos[idx] for idx in sorted_idx]
  132. def _gen_a_hypos(
  133. self,
  134. a_hypos: List[Hypothesis],
  135. b_hypos: List[Hypothesis],
  136. next_token_probs: torch.Tensor,
  137. t: int,
  138. beam_width: int,
  139. device: torch.device,
  140. ) -> List[Hypothesis]:
  141. (
  142. nonblank_nbest_scores,
  143. nonblank_nbest_hypo_idx,
  144. nonblank_nbest_token,
  145. ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
  146. if len(b_hypos) < beam_width:
  147. b_nbest_score = -float("inf")
  148. else:
  149. b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
  150. base_hypos: List[Hypothesis] = []
  151. new_tokens: List[int] = []
  152. new_scores: List[float] = []
  153. for i in range(beam_width):
  154. score = float(nonblank_nbest_scores[i])
  155. if score > b_nbest_score:
  156. a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
  157. base_hypos.append(a_hypos[a_hypo_idx])
  158. new_tokens.append(int(nonblank_nbest_token[i]))
  159. new_scores.append(score)
  160. if base_hypos:
  161. new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
  162. else:
  163. new_hypos: List[Hypothesis] = []
  164. return new_hypos
  165. def _gen_new_hypos(
  166. self,
  167. base_hypos: List[Hypothesis],
  168. tokens: List[int],
  169. scores: List[float],
  170. t: int,
  171. device: torch.device,
  172. ) -> List[Hypothesis]:
  173. tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
  174. states = _batch_state(base_hypos)
  175. pred_out, _, pred_states = self.model.predict(
  176. tgt_tokens,
  177. torch.tensor([1] * len(base_hypos), device=device),
  178. states,
  179. )
  180. new_hypos: List[Hypothesis] = []
  181. for i, h_a in enumerate(base_hypos):
  182. new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
  183. new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
  184. return new_hypos
  185. def _search(
  186. self,
  187. enc_out: torch.Tensor,
  188. hypo: Optional[Hypothesis],
  189. beam_width: int,
  190. ) -> List[Hypothesis]:
  191. n_time_steps = enc_out.shape[1]
  192. device = enc_out.device
  193. a_hypos: List[Hypothesis] = []
  194. b_hypos = self._init_b_hypos(hypo, device)
  195. for t in range(n_time_steps):
  196. a_hypos = b_hypos
  197. b_hypos = torch.jit.annotate(List[Hypothesis], [])
  198. key_to_b_hypo: Dict[str, Hypothesis] = {}
  199. symbols_current_t = 0
  200. while a_hypos:
  201. next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
  202. next_token_probs = next_token_probs.cpu()
  203. b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
  204. if symbols_current_t == self.step_max_tokens:
  205. break
  206. a_hypos = self._gen_a_hypos(
  207. a_hypos,
  208. b_hypos,
  209. next_token_probs,
  210. t,
  211. beam_width,
  212. device,
  213. )
  214. if a_hypos:
  215. symbols_current_t += 1
  216. _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
  217. b_hypos = [b_hypos[idx] for idx in sorted_idx]
  218. return b_hypos
  219. def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
  220. r"""Performs beam search for the given input sequence.
  221. T: number of frames;
  222. D: feature dimension of each frame.
  223. Args:
  224. input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
  225. length (torch.Tensor): number of valid frames in input
  226. sequence, with shape () or (1,).
  227. beam_width (int): beam size to use during search.
  228. Returns:
  229. List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
  230. """
  231. assert input.dim() == 2 or (
  232. input.dim() == 3 and input.shape[0] == 1
  233. ), "input must be of shape (T, D) or (1, T, D)"
  234. if input.dim() == 2:
  235. input = input.unsqueeze(0)
  236. assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
  237. if input.dim() == 0:
  238. input = input.unsqueeze(0)
  239. enc_out, _ = self.model.transcribe(input, length)
  240. return self._search(enc_out, None, beam_width)
  241. @torch.jit.export
  242. def infer(
  243. self,
  244. input: torch.Tensor,
  245. length: torch.Tensor,
  246. beam_width: int,
  247. state: Optional[List[List[torch.Tensor]]] = None,
  248. hypothesis: Optional[Hypothesis] = None,
  249. ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
  250. r"""Performs beam search for the given input sequence in streaming mode.
  251. T: number of frames;
  252. D: feature dimension of each frame.
  253. Args:
  254. input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
  255. length (torch.Tensor): number of valid frames in input
  256. sequence, with shape () or (1,).
  257. beam_width (int): beam size to use during search.
  258. state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
  259. representing transcription network internal state generated in preceding
  260. invocation. (Default: ``None``)
  261. hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
  262. search with. (Default: ``None``)
  263. Returns:
  264. (List[Hypothesis], List[List[torch.Tensor]]):
  265. List[Hypothesis]
  266. top-``beam_width`` hypotheses found by beam search.
  267. List[List[torch.Tensor]]
  268. list of lists of tensors representing transcription network
  269. internal state generated in current invocation.
  270. """
  271. assert input.dim() == 2 or (
  272. input.dim() == 3 and input.shape[0] == 1
  273. ), "input must be of shape (T, D) or (1, T, D)"
  274. if input.dim() == 2:
  275. input = input.unsqueeze(0)
  276. assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
  277. if length.dim() == 0:
  278. length = length.unsqueeze(0)
  279. enc_out, _, state = self.model.transcribe_streaming(input, length, state)
  280. return self._search(enc_out, hypothesis, beam_width), state