| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- import json
- import math
- from abc import ABC, abstractmethod
- from dataclasses import dataclass
- from functools import partial
- from typing import Callable, List, Tuple
- import torch
- import torchaudio
- from torchaudio._internal import module_utils
- from torchaudio.models import emformer_rnnt_base, RNNT, RNNTBeamSearch
- __all__ = []
- _decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
- _gain = pow(10, 0.05 * _decibel)
- def _piecewise_linear_log(x):
- x[x > math.e] = torch.log(x[x > math.e])
- x[x <= math.e] = x[x <= math.e] / math.e
- return x
- class _FunctionalModule(torch.nn.Module):
- def __init__(self, functional):
- super().__init__()
- self.functional = functional
- def forward(self, input):
- return self.functional(input)
- class _GlobalStatsNormalization(torch.nn.Module):
- def __init__(self, global_stats_path):
- super().__init__()
- with open(global_stats_path) as f:
- blob = json.loads(f.read())
- self.register_buffer("mean", torch.tensor(blob["mean"]))
- self.register_buffer("invstddev", torch.tensor(blob["invstddev"]))
- def forward(self, input):
- return (input - self.mean) * self.invstddev
- class _FeatureExtractor(ABC):
- @abstractmethod
- def __call__(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- """Generates features and length output from the given input tensor.
- Args:
- input (torch.Tensor): input tensor.
- Returns:
- (torch.Tensor, torch.Tensor):
- torch.Tensor:
- Features, with shape `(length, *)`.
- torch.Tensor:
- Length, with shape `(1,)`.
- """
- class _TokenProcessor(ABC):
- @abstractmethod
- def __call__(self, tokens: List[int], **kwargs) -> str:
- """Decodes given list of tokens to text sequence.
- Args:
- tokens (List[int]): list of tokens to decode.
- Returns:
- str:
- Decoded text sequence.
- """
- class _ModuleFeatureExtractor(torch.nn.Module, _FeatureExtractor):
- """``torch.nn.Module``-based feature extraction pipeline.
- Args:
- pipeline (torch.nn.Module): module that implements feature extraction logic.
- """
- def __init__(self, pipeline: torch.nn.Module) -> None:
- super().__init__()
- self.pipeline = pipeline
- def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- """Generates features and length output from the given input tensor.
- Args:
- input (torch.Tensor): input tensor.
- Returns:
- (torch.Tensor, torch.Tensor):
- torch.Tensor:
- Features, with shape `(length, *)`.
- torch.Tensor:
- Length, with shape `(1,)`.
- """
- features = self.pipeline(input)
- length = torch.tensor([features.shape[0]])
- return features, length
- class _SentencePieceTokenProcessor(_TokenProcessor):
- """SentencePiece-model-based token processor.
- Args:
- sp_model_path (str): path to SentencePiece model.
- """
- def __init__(self, sp_model_path: str) -> None:
- if not module_utils.is_module_available("sentencepiece"):
- raise RuntimeError("SentencePiece is not available. Please install it.")
- import sentencepiece as spm
- self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
- self.post_process_remove_list = {
- self.sp_model.unk_id(),
- self.sp_model.eos_id(),
- self.sp_model.pad_id(),
- }
- def __call__(self, tokens: List[int], lstrip: bool = True) -> str:
- """Decodes given list of tokens to text sequence.
- Args:
- tokens (List[int]): list of tokens to decode.
- lstrip (bool, optional): if ``True``, returns text sequence with leading whitespace
- removed. (Default: ``True``).
- Returns:
- str:
- Decoded text sequence.
- """
- filtered_hypo_tokens = [
- token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list
- ]
- output_string = "".join(self.sp_model.id_to_piece(filtered_hypo_tokens)).replace("\u2581", " ")
- if lstrip:
- return output_string.lstrip()
- else:
- return output_string
- @dataclass
- class RNNTBundle:
- """torchaudio.pipelines.RNNTBundle()
- Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
- inference with an RNN-T model.
- More specifically, the class provides methods that produce the featurization pipeline,
- decoder wrapping the specified RNN-T model, and output token post-processor that together
- constitute a complete end-to-end ASR inference pipeline that produces a text sequence
- given a raw waveform.
- It can support non-streaming (full-context) inference as well as streaming inference.
- Users should not directly instantiate objects of this class; rather, users should use the
- instances (representing pre-trained models) that exist within the module,
- e.g. :py:obj:`EMFORMER_RNNT_BASE_LIBRISPEECH`.
- Example
- >>> import torchaudio
- >>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
- >>> import torch
- >>>
- >>> # Non-streaming inference.
- >>> # Build feature extractor, decoder with RNN-T model, and token processor.
- >>> feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_feature_extractor()
- 100%|███████████████████████████████| 3.81k/3.81k [00:00<00:00, 4.22MB/s]
- >>> decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()
- Downloading: "https://download.pytorch.org/torchaudio/models/emformer_rnnt_base_librispeech.pt"
- 100%|███████████████████████████████| 293M/293M [00:07<00:00, 42.1MB/s]
- >>> token_processor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_token_processor()
- 100%|███████████████████████████████| 295k/295k [00:00<00:00, 25.4MB/s]
- >>>
- >>> # Instantiate LibriSpeech dataset; retrieve waveform for first sample.
- >>> dataset = torchaudio.datasets.LIBRISPEECH("/home/librispeech", url="test-clean")
- >>> waveform = next(iter(dataset))[0].squeeze()
- >>>
- >>> with torch.no_grad():
- >>> # Produce mel-scale spectrogram features.
- >>> features, length = feature_extractor(waveform)
- >>>
- >>> # Generate top-10 hypotheses.
- >>> hypotheses = decoder(features, length, 10)
- >>>
- >>> # For top hypothesis, convert predicted tokens to text.
- >>> text = token_processor(hypotheses[0][0])
- >>> print(text)
- he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
- >>>
- >>>
- >>> # Streaming inference.
- >>> hop_length = EMFORMER_RNNT_BASE_LIBRISPEECH.hop_length
- >>> num_samples_segment = EMFORMER_RNNT_BASE_LIBRISPEECH.segment_length * hop_length
- >>> num_samples_segment_right_context = (
- >>> num_samples_segment + EMFORMER_RNNT_BASE_LIBRISPEECH.right_context_length * hop_length
- >>> )
- >>>
- >>> # Build streaming inference feature extractor.
- >>> streaming_feature_extractor = EMFORMER_RNNT_BASE_LIBRISPEECH.get_streaming_feature_extractor()
- >>>
- >>> # Process same waveform as before, this time sequentially across overlapping segments
- >>> # to simulate streaming inference. Note the usage of ``streaming_feature_extractor`` and ``decoder.infer``.
- >>> state, hypothesis = None, None
- >>> for idx in range(0, len(waveform), num_samples_segment):
- >>> segment = waveform[idx: idx + num_samples_segment_right_context]
- >>> segment = torch.nn.functional.pad(segment, (0, num_samples_segment_right_context - len(segment)))
- >>> with torch.no_grad():
- >>> features, length = streaming_feature_extractor(segment)
- >>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
- >>> hypothesis = hypotheses[0]
- >>> transcript = token_processor(hypothesis[0])
- >>> if transcript:
- >>> print(transcript, end=" ", flush=True)
- he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
- """
- class FeatureExtractor(_FeatureExtractor):
- pass
- class TokenProcessor(_TokenProcessor):
- pass
- _rnnt_path: str
- _rnnt_factory_func: Callable[[], RNNT]
- _global_stats_path: str
- _sp_model_path: str
- _right_padding: int
- _blank: int
- _sample_rate: int
- _n_fft: int
- _n_mels: int
- _hop_length: int
- _segment_length: int
- _right_context_length: int
- def _get_model(self) -> RNNT:
- model = self._rnnt_factory_func()
- path = torchaudio.utils.download_asset(self._rnnt_path)
- state_dict = torch.load(path)
- model.load_state_dict(state_dict)
- model.eval()
- return model
- @property
- def sample_rate(self) -> int:
- """Sample rate (in cycles per second) of input waveforms.
- :type: int
- """
- return self._sample_rate
- @property
- def n_fft(self) -> int:
- """Size of FFT window to use.
- :type: int
- """
- return self._n_fft
- @property
- def n_mels(self) -> int:
- """Number of mel spectrogram features to extract from input waveforms.
- :type: int
- """
- return self._n_mels
- @property
- def hop_length(self) -> int:
- """Number of samples between successive frames in input expected by model.
- :type: int
- """
- return self._hop_length
- @property
- def segment_length(self) -> int:
- """Number of frames in segment in input expected by model.
- :type: int
- """
- return self._segment_length
- @property
- def right_context_length(self) -> int:
- """Number of frames in right contextual block in input expected by model.
- :type: int
- """
- return self._right_context_length
- def get_decoder(self) -> RNNTBeamSearch:
- """Constructs RNN-T decoder.
- Returns:
- RNNTBeamSearch
- """
- model = self._get_model()
- return RNNTBeamSearch(model, self._blank)
- def get_feature_extractor(self) -> FeatureExtractor:
- """Constructs feature extractor for non-streaming (full-context) ASR.
- Returns:
- FeatureExtractor
- """
- local_path = torchaudio.utils.download_asset(self._global_stats_path)
- return _ModuleFeatureExtractor(
- torch.nn.Sequential(
- torchaudio.transforms.MelSpectrogram(
- sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
- ),
- _FunctionalModule(lambda x: x.transpose(1, 0)),
- _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
- _GlobalStatsNormalization(local_path),
- _FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 0, 0, self._right_padding))),
- )
- )
- def get_streaming_feature_extractor(self) -> FeatureExtractor:
- """Constructs feature extractor for streaming (simultaneous) ASR.
- Returns:
- FeatureExtractor
- """
- local_path = torchaudio.utils.download_asset(self._global_stats_path)
- return _ModuleFeatureExtractor(
- torch.nn.Sequential(
- torchaudio.transforms.MelSpectrogram(
- sample_rate=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, hop_length=self.hop_length
- ),
- _FunctionalModule(lambda x: x.transpose(1, 0)),
- _FunctionalModule(lambda x: _piecewise_linear_log(x * _gain)),
- _GlobalStatsNormalization(local_path),
- )
- )
- def get_token_processor(self) -> TokenProcessor:
- """Constructs token processor.
- Returns:
- TokenProcessor
- """
- local_path = torchaudio.utils.download_asset(self._sp_model_path)
- return _SentencePieceTokenProcessor(local_path)
- EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
- _rnnt_path="models/emformer_rnnt_base_librispeech.pt",
- _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
- _global_stats_path="pipeline-assets/global_stats_rnnt_librispeech.json",
- _sp_model_path="pipeline-assets/spm_bpe_4096_librispeech.model",
- _right_padding=4,
- _blank=4096,
- _sample_rate=16000,
- _n_fft=400,
- _n_mels=80,
- _hop_length=160,
- _segment_length=16,
- _right_context_length=4,
- )
- EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
- The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
- and utilizes weights trained on LibriSpeech using training script ``train.py``
- `here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with default arguments.
- Please refer to :py:class:`RNNTBundle` for usage instructions.
- """
|