rnnt_pipeline.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import json
  2. import math
  3. from abc import ABC, abstractmethod
  4. from dataclasses import dataclass
  5. from functools import partial
  6. from typing import Callable, List, Tuple
  7. import torch
  8. import torchaudio
  9. from torchaudio._internal import module_utils
  10. from torchaudio.models import emformer_rnnt_base, RNNT, RNNTBeamSearch
  11. __all__ = []
  12. _decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
  13. _gain = pow(10, 0.05 * _decibel)
  14. def _piecewise_linear_log(x):
  15. x[x > math.e] = torch.log(x[x > math.e])
  16. x[x <= math.e] = x[x <= math.e] / math.e
  17. return x
  18. class _FunctionalModule(torch.nn.Module):
  19. def __init__(self, functional):
  20. super().__init__()
  21. self.functional = functional
  22. def forward(self, input):
  23. return self.functional(input)
  24. class _GlobalStatsNormalization(torch.nn.Module):
  25. def __init__(self, global_stats_path):
  26. super().__init__()
  27. with open(global_stats_path) as f:
  28. blob = json.loads(f.read())
  29. self.register_buffer("mean", torch.tensor(blob["mean"]))
  30. self.register_buffer("invstddev", torch.tensor(blob["invstddev"]))
  31. def forward(self, input):
  32. return (input - self.mean) * self.invstddev
  33. class _FeatureExtractor(ABC):
  34. @abstractmethod
  35. def __call__(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  36. """Generates features and length output from the given input tensor.
  37. Args:
  38. input (torch.Tensor): input tensor.
  39. Returns:
  40. (torch.Tensor, torch.Tensor):
  41. torch.Tensor:
  42. Features, with shape `(length, *)`.
  43. torch.Tensor:
  44. Length, with shape `(1,)`.
  45. """
  46. class _TokenProcessor(ABC):
  47. @abstractmethod
  48. def __call__(self, tokens: List[int], **kwargs) -> str:
  49. """Decodes given list of tokens to text sequence.
  50. Args:
  51. tokens (List[int]): list of tokens to decode.
  52. Returns:
  53. str:
  54. Decoded text sequence.
  55. """
  56. class _ModuleFeatureExtractor(torch.nn.Module, _FeatureExtractor):
  57. """``torch.nn.Module``-based feature extraction pipeline.
  58. Args:
  59. pipeline (torch.nn.Module): module that implements feature extraction logic.
  60. """
  61. def __init__(self, pipeline: torch.nn.Module) -> None:
  62. super().__init__()
  63. self.pipeline = pipeline
  64. def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  65. """Generates features and length output from the given input tensor.
  66. Args:
  67. input (torch.Tensor): input tensor.
  68. Returns:
  69. (torch.Tensor, torch.Tensor):
  70. torch.Tensor:
  71. Features, with shape `(length, *)`.
  72. torch.Tensor:
  73. Length, with shape `(1,)`.
  74. """
  75. features = self.pipeline(input)
  76. length = torch.tensor([features.shape[0]])
  77. return features, length
  78. class _SentencePieceTokenProcessor(_TokenProcessor):
  79. """SentencePiece-model-based token processor.
  80. Args:
  81. sp_model_path (str): path to SentencePiece model.
  82. """
  83. def __init__(self, sp_model_path: str) -> None:
  84. if not module_utils.is_module_available("sentencepiece"):
  85. raise RuntimeError("SentencePiece is not available. Please install it.")
  86. import sentencepiece as spm
  87. self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
  88. self.post_process_remove_list = {
  89. self.sp_model.unk_id(),
  90. self.sp_model.eos_id(),
  91. self.sp_model.pad_id(),
  92. }
  93. def __call__(self, tokens: List[int], lstrip: bool = True) -> str:
  94. """Decodes given list of tokens to text sequence.
  95. Args:
  96. tokens (List[int]): list of tokens to decode.
  97. lstrip (bool, optional): if ``True``, returns text sequence with leading whitespace
  98. removed. (Default: ``True``).
  99. Returns:
  100. str:
  101. Decoded text sequence.
  102. """
  103. filtered_hypo_tokens = [
  104. token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list
  105. ]
  106. output_string = "".join(self.sp_model.id_to_piece(filtered_hypo_tokens)).replace("\u2581", " ")
  107. if lstrip:
  108. return output_string.lstrip()
  109. else:
  110. return output_string
  111. @dataclass
  112. class RNNTBundle:
  113. """torchaudio.pipelines.RNNTBundle()
  114. Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
  115. inference with an RNN-T model.
  116. More specifically, the class provides methods that produce the featurization pipeline,
  117. decoder wrapping the specified RNN-T model, and output token post-processor that together
  118. constitute a complete end-to-end ASR inference pipeline that produces a text sequence
  119. given a raw waveform.
  120. It can support non-streaming (full-context) inference as well as streaming inference.
  121. Users should not directly instantiate objects of this class; rather, users should use the
  122. instances (representing pre-trained models) that exist within the module,
  123. e.g. :py:obj:`EMFORMER_RNNT_BASE_LIBRISPEECH`.
  124. Example
  125. >>> import torchaudio
  126. >>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
  127. >>> import torch
  128. >>>
  129. >>> # Non-streaming inference.
  130. >>> # Build feature extractor, decoder with RNN-T model, and token processor.
  131. >>> feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor()
  132. 100%|███████████████████████████████| 3.81k/3.81k [00:00<00:00, 4.22MB/s]
  133. >>> decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()
  134. Downloading: "https://download.pytorch.org/torchaudio/models/emformer_rnnt_base_librispeech.pt"
  135. 100%|███████████████████████████████| 293M/293M [00:07<00:00, 42.1MB/s]
  136. >>> token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor()
  137. 100%|███████████████████████████████| 295k/295k [00:00<00:00, 25.4MB/s]
  138. >>>
  139. >>> # Instantiate LibriSpeech dataset; retrieve waveform for first sample.
  140. >>> dataset = torchaudio.datasets.LIBRISPEECH("/home/librispeech", url="test-clean")
  141. >>> waveform = next(iter(dataset))[0].squeeze()
  142. >>>
  143. >>> with torch.no_grad():
  144. >>> # Produce mel-scale spectrogram features.
  145. >>> features, length = feature_extractor(waveform)
  146. >>>
  147. >>> # Generate top-10 hypotheses.
  148. >>> hypotheses = decoder(features, length, 10)
  149. >>>
  150. >>> # For top hypothesis, convert predicted tokens to text.
  151. >>> text = token_processor(hypotheses[0][0])
  152. >>> print(text)
  153. he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
  154. >>>
  155. >>>
  156. >>> # Streaming inference.
  157. >>> hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length
  158. >>> num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length
  159. >>> num_samples_segment_right_context = (
  160. >>> num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length
  161. >>> )
  162. >>>
  163. >>> # Build streaming inference feature extractor.
  164. >>> streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor()
  165. >>>
  166. >>> # Process same waveform as before, this time sequentially across overlapping segments
  167. >>> # to simulate streaming inference. Note the usage of ``streaming_feature_extractor`` and ``decoder.infer``.
  168. >>> state, hypothesis = None, None
  169. >>> for idx in range(0, len(waveform), num_samples_segment):
  170. >>> segment = waveform[idx: idx + num_samples_segment_right_context]
  171. >>> segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
  172. >>> with torch.no_grad():
  173. >>> features, length = streaming_feature_extractor(segment)
  174. >>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
  175. >>> hypothesis = hypotheses[0]
  176. >>> transcript = token_processor(hypothesis[0])
  177. >>> if transcript:
  178. >>> print(transcript, end=" ", flush=True)
  179. he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
  180. """
  181. class FeatureExtractor(_FeatureExtractor):
  182. pass
  183. class TokenProcessor(_TokenProcessor):
  184. pass
  185. _rnnt_path: str
  186. _rnnt_factory_func: Callable[[], RNNT]
  187. _global_stats_path: str
  188. _sp_model_path: str
  189. _right_padding: int
  190. _blank: int
  191. _sample_rate: int
  192. _n_fft: int
  193. _n_mels: int
  194. _hop_length: int
  195. _segment_length: int
  196. _right_context_length: int
  197. def _get_model(self) -> RNNT:
  198. model = self._rnnt_factory_func()
  199. path = torchaudio.utils.download_asset(self._rnnt_path)
  200. state_dict = torch.load(path)
  201. model.load_state_dict(state_dict)
  202. model.eval()
  203. return model
  204. @property
  205. def sample_rate(self) -> int:
  206. """Sample rate (in cycles per second) of input waveforms.
  207. :type: int
  208. """
  209. return self._sample_rate
  210. @property
  211. def n_fft(self) -> int:
  212. """Size of FFT window to use.
  213. :type: int
  214. """
  215. return self._n_fft
  216. @property
  217. def n_mels(self) -> int:
  218. """Number of mel spectrogram features to extract from input waveforms.
  219. :type: int
  220. """
  221. return self._n_mels
  222. @property
  223. def hop_length(self) -> int:
  224. """Number of samples between successive frames in input expected by model.
  225. :type: int
  226. """
  227. return self._hop_length
  228. @property
  229. def segment_length(self) -> int:
  230. """Number of frames in segment in input expected by model.
  231. :type: int
  232. """
  233. return self._segment_length
  234. @property
  235. def right_context_length(self) -> int:
  236. """Number of frames in right contextual block in input expected by model.
  237. :type: int
  238. """
  239. return self._right_context_length
  240. def get_decoder(self) -> RNNTBeamSearch:
  241. """Constructs RNN-T decoder.
  242. Returns:
  243. RNNTBeamSearch
  244. """
  245. model = self._get_model()
  246. return RNNTBeamSearch(model, self._blank)
  247. def get_feature_extractor(self) -> FeatureExtractor:
  248. """Constructs feature extractor for non-streaming (full-context) ASR.
  249. Returns:
  250. FeatureExtractor
  251. """
  252. local_path = torchaudio.utils.download_asset(self._global_stats_path)
  253. return _ModuleFeatureExtractor(
  254. torch.nn.Sequential(
  255. torchaudio.transforms.MelSpectrogram(
  256. sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
  257. ),
  258. _FunctionalModule(lambda x: x.transpose(1, 0)),
  259. _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
  260. _GlobalStatsNormalization(local_path),
  261. _FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 0, 0, self._right_padding))),
  262. )
  263. )
  264. def get_streaming_feature_extractor(self) -> FeatureExtractor:
  265. """Constructs feature extractor for streaming (simultaneous) ASR.
  266. Returns:
  267. FeatureExtractor
  268. """
  269. local_path = torchaudio.utils.download_asset(self._global_stats_path)
  270. return _ModuleFeatureExtractor(
  271. torch.nn.Sequential(
  272. torchaudio.transforms.MelSpectrogram(
  273. sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
  274. ),
  275. _FunctionalModule(lambda x: x.transpose(1, 0)),
  276. _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
  277. _GlobalStatsNormalization(local_path),
  278. )
  279. )
  280. def get_token_processor(self) -> TokenProcessor:
  281. """Constructs token processor.
  282. Returns:
  283. TokenProcessor
  284. """
  285. local_path = torchaudio.utils.download_asset(self._sp_model_path)
  286. return _SentencePieceTokenProcessor(local_path)
  287. EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
  288. _rnnt_path="models/emformer_rnnt_base_librispeech.pt",
  289. _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
  290. _global_stats_path="pipeline-assets/global_stats_rnnt_librispeech.json",
  291. _sp_model_path="pipeline-assets/spm_bpe_4096_librispeech.model",
  292. _right_padding=4,
  293. _blank=4096,
  294. _sample_rate=16000,
  295. _n_fft=400,
  296. _n_mels=80,
  297. _hop_length=160,
  298. _segment_length=16,
  299. _right_context_length=4,
  300. )
  301. EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
  302. The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
  303. and utilizes weights trained on LibriSpeech using training script ``train.py``
  304. `here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with default arguments.
  305. Please refer to :py:class:`RNNTBundle` for usage instructions.
  306. """