| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340 |
- from typing import Callable, Dict, List, Optional, Tuple
- import torch
- from torchaudio.models import RNNT
- __all__ = ["Hypothesis", "RNNTBeamSearch"]
- Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
- Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
- represented as tuple of (tokens, prediction network output, prediction network state, score).
- """
- def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
- return hypo[0]
- def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
- return hypo[1]
- def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
- return hypo[2]
- def _get_hypo_score(hypo: Hypothesis) -> float:
- return hypo[3]
- def _get_hypo_key(hypo: Hypothesis) -> str:
- return str(hypo[0])
- def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
- states: List[List[torch.Tensor]] = []
- for i in range(len(_get_hypo_state(hypos[0]))):
- batched_state_components: List[torch.Tensor] = []
- for j in range(len(_get_hypo_state(hypos[0])[i])):
- batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
- states.append(batched_state_components)
- return states
- def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
- idx_tensor = torch.tensor([idx], device=device)
- return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
- def _default_hypo_sort_key(hypo: Hypothesis) -> float:
- return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
- def _compute_updated_scores(
- hypos: List[Hypothesis],
- next_token_probs: torch.Tensor,
- beam_width: int,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
- nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
- nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
- nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
- nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
- return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
- def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
- for i, elem in enumerate(hypo_list):
- if _get_hypo_key(hypo) == _get_hypo_key(elem):
- del hypo_list[i]
- break
- class RNNTBeamSearch(torch.nn.Module):
- r"""Beam search decoder for RNN-T model.
- Args:
- model (RNNT): RNN-T model to use.
- blank (int): index of blank token in vocabulary.
- temperature (float, optional): temperature to apply to joint network output.
- Larger values yield more uniform samples. (Default: 1.0)
- hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
- for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
- hypothesis score normalized by token sequence length. (Default: None)
- step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
- """
- def __init__(
- self,
- model: RNNT,
- blank: int,
- temperature: float = 1.0,
- hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
- step_max_tokens: int = 100,
- ) -> None:
- super().__init__()
- self.model = model
- self.blank = blank
- self.temperature = temperature
- if hypo_sort_key is None:
- self.hypo_sort_key = _default_hypo_sort_key
- else:
- self.hypo_sort_key = hypo_sort_key
- self.step_max_tokens = step_max_tokens
- def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
- if hypo is not None:
- token = _get_hypo_tokens(hypo)[-1]
- state = _get_hypo_state(hypo)
- else:
- token = self.blank
- state = None
- one_tensor = torch.tensor([1], device=device)
- pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
- init_hypo = (
- [token],
- pred_out[0].detach(),
- pred_state,
- 0.0,
- )
- return [init_hypo]
- def _gen_next_token_probs(
- self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
- ) -> torch.Tensor:
- one_tensor = torch.tensor([1], device=device)
- predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
- joined_out, _, _ = self.model.join(
- enc_out,
- one_tensor,
- predictor_out,
- torch.tensor([1] * len(hypos), device=device),
- ) # [beam_width, 1, 1, num_tokens]
- joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
- return joined_out[:, 0, 0]
- def _gen_b_hypos(
- self,
- b_hypos: List[Hypothesis],
- a_hypos: List[Hypothesis],
- next_token_probs: torch.Tensor,
- key_to_b_hypo: Dict[str, Hypothesis],
- ) -> List[Hypothesis]:
- for i in range(len(a_hypos)):
- h_a = a_hypos[i]
- append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
- if _get_hypo_key(h_a) in key_to_b_hypo:
- h_b = key_to_b_hypo[_get_hypo_key(h_a)]
- _remove_hypo(h_b, b_hypos)
- score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
- else:
- score = float(append_blank_score)
- h_b = (
- _get_hypo_tokens(h_a),
- _get_hypo_predictor_out(h_a),
- _get_hypo_state(h_a),
- score,
- )
- b_hypos.append(h_b)
- key_to_b_hypo[_get_hypo_key(h_b)] = h_b
- _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
- return [b_hypos[idx] for idx in sorted_idx]
- def _gen_a_hypos(
- self,
- a_hypos: List[Hypothesis],
- b_hypos: List[Hypothesis],
- next_token_probs: torch.Tensor,
- t: int,
- beam_width: int,
- device: torch.device,
- ) -> List[Hypothesis]:
- (
- nonblank_nbest_scores,
- nonblank_nbest_hypo_idx,
- nonblank_nbest_token,
- ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
- if len(b_hypos) < beam_width:
- b_nbest_score = -float("inf")
- else:
- b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
- base_hypos: List[Hypothesis] = []
- new_tokens: List[int] = []
- new_scores: List[float] = []
- for i in range(beam_width):
- score = float(nonblank_nbest_scores[i])
- if score > b_nbest_score:
- a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
- base_hypos.append(a_hypos[a_hypo_idx])
- new_tokens.append(int(nonblank_nbest_token[i]))
- new_scores.append(score)
- if base_hypos:
- new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
- else:
- new_hypos: List[Hypothesis] = []
- return new_hypos
- def _gen_new_hypos(
- self,
- base_hypos: List[Hypothesis],
- tokens: List[int],
- scores: List[float],
- t: int,
- device: torch.device,
- ) -> List[Hypothesis]:
- tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
- states = _batch_state(base_hypos)
- pred_out, _, pred_states = self.model.predict(
- tgt_tokens,
- torch.tensor([1] * len(base_hypos), device=device),
- states,
- )
- new_hypos: List[Hypothesis] = []
- for i, h_a in enumerate(base_hypos):
- new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
- new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
- return new_hypos
- def _search(
- self,
- enc_out: torch.Tensor,
- hypo: Optional[Hypothesis],
- beam_width: int,
- ) -> List[Hypothesis]:
- n_time_steps = enc_out.shape[1]
- device = enc_out.device
- a_hypos: List[Hypothesis] = []
- b_hypos = self._init_b_hypos(hypo, device)
- for t in range(n_time_steps):
- a_hypos = b_hypos
- b_hypos = torch.jit.annotate(List[Hypothesis], [])
- key_to_b_hypo: Dict[str, Hypothesis] = {}
- symbols_current_t = 0
- while a_hypos:
- next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
- next_token_probs = next_token_probs.cpu()
- b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
- if symbols_current_t == self.step_max_tokens:
- break
- a_hypos = self._gen_a_hypos(
- a_hypos,
- b_hypos,
- next_token_probs,
- t,
- beam_width,
- device,
- )
- if a_hypos:
- symbols_current_t += 1
- _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
- b_hypos = [b_hypos[idx] for idx in sorted_idx]
- return b_hypos
- def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
- r"""Performs beam search for the given input sequence.
- T: number of frames;
- D: feature dimension of each frame.
- Args:
- input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
- length (torch.Tensor): number of valid frames in input
- sequence, with shape () or (1,).
- beam_width (int): beam size to use during search.
- Returns:
- List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
- """
- assert input.dim() == 2 or (
- input.dim() == 3 and input.shape[0] == 1
- ), "input must be of shape (T, D) or (1, T, D)"
- if input.dim() == 2:
- input = input.unsqueeze(0)
- assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
- if input.dim() == 0:
- input = input.unsqueeze(0)
- enc_out, _ = self.model.transcribe(input, length)
- return self._search(enc_out, None, beam_width)
- @torch.jit.export
- def infer(
- self,
- input: torch.Tensor,
- length: torch.Tensor,
- beam_width: int,
- state: Optional[List[List[torch.Tensor]]] = None,
- hypothesis: Optional[Hypothesis] = None,
- ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
- r"""Performs beam search for the given input sequence in streaming mode.
- T: number of frames;
- D: feature dimension of each frame.
- Args:
- input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
- length (torch.Tensor): number of valid frames in input
- sequence, with shape () or (1,).
- beam_width (int): beam size to use during search.
- state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
- representing transcription network internal state generated in preceding
- invocation. (Default: ``None``)
- hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
- search with. (Default: ``None``)
- Returns:
- (List[Hypothesis], List[List[torch.Tensor]]):
- List[Hypothesis]
- top-``beam_width`` hypotheses found by beam search.
- List[List[torch.Tensor]]
- list of lists of tensors representing transcription network
- internal state generated in current invocation.
- """
- assert input.dim() == 2 or (
- input.dim() == 3 and input.shape[0] == 1
- ), "input must be of shape (T, D) or (1, T, D)"
- if input.dim() == 2:
- input = input.unsqueeze(0)
- assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
- if length.dim() == 0:
- length = length.unsqueeze(0)
- enc_out, _, state = self.model.transcribe_streaming(input, length, state)
- return self._search(enc_out, hypothesis, beam_width), state
|