tedlium.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import os
  2. from pathlib import Path
  3. from typing import Tuple, Union
  4. import torchaudio
  5. from torch import Tensor
  6. from torch.hub import download_url_to_file
  7. from torch.utils.data import Dataset
  8. from torchaudio.datasets.utils import extract_archive
  9. _RELEASE_CONFIGS = {
  10. "release1": {
  11. "folder_in_archive": "TEDLIUM_release1",
  12. "url": "http://www.openslr.org/resources/7/TEDLIUM_release1.tar.gz",
  13. "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
  14. "data_path": "",
  15. "subset": "train",
  16. "supported_subsets": ["train", "test", "dev"],
  17. "dict": "TEDLIUM.150K.dic",
  18. },
  19. "release2": {
  20. "folder_in_archive": "TEDLIUM_release2",
  21. "url": "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz",
  22. "checksum": "93281b5fcaaae5c88671c9d000b443cb3c7ea3499ad12010b3934ca41a7b9c58",
  23. "data_path": "",
  24. "subset": "train",
  25. "supported_subsets": ["train", "test", "dev"],
  26. "dict": "TEDLIUM.152k.dic",
  27. },
  28. "release3": {
  29. "folder_in_archive": "TEDLIUM_release-3",
  30. "url": "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz",
  31. "checksum": "ad1e454d14d1ad550bc2564c462d87c7a7ec83d4dc2b9210f22ab4973b9eccdb",
  32. "data_path": "data/",
  33. "subset": "train",
  34. "supported_subsets": ["train", "test", "dev"],
  35. "dict": "TEDLIUM.152k.dic",
  36. },
  37. }
  38. class TEDLIUM(Dataset):
  39. """
  40. Create a Dataset for *Tedlium* [:footcite:`rousseau2012tedlium`]. It supports releases 1,2 and 3.
  41. Args:
  42. root (str or Path): Path to the directory where the dataset is found or downloaded.
  43. release (str, optional): Release version.
  44. Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
  45. (default: ``"release1"``).
  46. subset (str, optional): The subset of dataset to use. Valid options are ``"train"``, ``"dev"``,
  47. and ``"test"``. Defaults to ``"train"``.
  48. download (bool, optional):
  49. Whether to download the dataset if it is not found at root path. (default: ``False``).
  50. audio_ext (str, optional): extension for audio file (default: ``".sph"``)
  51. """
  52. def __init__(
  53. self,
  54. root: Union[str, Path],
  55. release: str = "release1",
  56. subset: str = "train",
  57. download: bool = False,
  58. audio_ext: str = ".sph",
  59. ) -> None:
  60. self._ext_audio = audio_ext
  61. if release in _RELEASE_CONFIGS.keys():
  62. folder_in_archive = _RELEASE_CONFIGS[release]["folder_in_archive"]
  63. url = _RELEASE_CONFIGS[release]["url"]
  64. subset = subset if subset else _RELEASE_CONFIGS[release]["subset"]
  65. else:
  66. # Raise warning
  67. raise RuntimeError(
  68. "The release {} does not match any of the supported tedlium releases{} ".format(
  69. release,
  70. _RELEASE_CONFIGS.keys(),
  71. )
  72. )
  73. if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]:
  74. # Raise warning
  75. raise RuntimeError(
  76. "The subset {} does not match any of the supported tedlium subsets{} ".format(
  77. subset,
  78. _RELEASE_CONFIGS[release]["supported_subsets"],
  79. )
  80. )
  81. # Get string representation of 'root' in case Path object is passed
  82. root = os.fspath(root)
  83. basename = os.path.basename(url)
  84. archive = os.path.join(root, basename)
  85. basename = basename.split(".")[0]
  86. if release == "release3":
  87. if subset == "train":
  88. self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"])
  89. else:
  90. self._path = os.path.join(root, folder_in_archive, "legacy", subset)
  91. else:
  92. self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"], subset)
  93. if download:
  94. if not os.path.isdir(self._path):
  95. if not os.path.isfile(archive):
  96. checksum = _RELEASE_CONFIGS[release]["checksum"]
  97. download_url_to_file(url, archive, hash_prefix=checksum)
  98. extract_archive(archive)
  99. else:
  100. if not os.path.exists(self._path):
  101. raise RuntimeError(
  102. f"The path {self._path} doesn't exist. "
  103. "Please check the ``root`` path or set `download=True` to download it"
  104. )
  105. # Create list for all samples
  106. self._filelist = []
  107. stm_path = os.path.join(self._path, "stm")
  108. for file in sorted(os.listdir(stm_path)):
  109. if file.endswith(".stm"):
  110. stm_path = os.path.join(self._path, "stm", file)
  111. with open(stm_path) as f:
  112. l = len(f.readlines())
  113. file = file.replace(".stm", "")
  114. self._filelist.extend((file, line) for line in range(l))
  115. # Create dict path for later read
  116. self._dict_path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["dict"])
  117. self._phoneme_dict = None
  118. def _load_tedlium_item(self, fileid: str, line: int, path: str) -> Tuple[Tensor, int, str, int, int, int]:
  119. """Loads a TEDLIUM dataset sample given a file name and corresponding sentence name.
  120. Args:
  121. fileid (str): File id to identify both text and audio files corresponding to the sample
  122. line (int): Line identifier for the sample inside the text file
  123. path (str): Dataset root path
  124. Returns:
  125. (Tensor, int, str, int, int, int):
  126. ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
  127. """
  128. transcript_path = os.path.join(path, "stm", fileid)
  129. with open(transcript_path + ".stm") as f:
  130. transcript = f.readlines()[line]
  131. talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6)
  132. wave_path = os.path.join(path, "sph", fileid)
  133. waveform, sample_rate = self._load_audio(wave_path + self._ext_audio, start_time=start_time, end_time=end_time)
  134. return (waveform, sample_rate, transcript, talk_id, speaker_id, identifier)
  135. def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate: int = 16000) -> [Tensor, int]:
  136. """Default load function used in TEDLIUM dataset, you can overwrite this function to customize functionality
  137. and load individual sentences from a full ted audio talk file.
  138. Args:
  139. path (str): Path to audio file
  140. start_time (int): Time in seconds where the sample sentence stars
  141. end_time (int): Time in seconds where the sample sentence finishes
  142. sample_rate (float, optional): Sampling rate
  143. Returns:
  144. [Tensor, int]: Audio tensor representation and sample rate
  145. """
  146. start_time = int(float(start_time) * sample_rate)
  147. end_time = int(float(end_time) * sample_rate)
  148. kwargs = {"frame_offset": start_time, "num_frames": end_time - start_time}
  149. return torchaudio.load(path, **kwargs)
  150. def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
  151. """Load the n-th sample from the dataset.
  152. Args:
  153. n (int): The index of the sample to be loaded
  154. Returns:
  155. tuple: ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
  156. """
  157. fileid, line = self._filelist[n]
  158. return self._load_tedlium_item(fileid, line, self._path)
  159. def __len__(self) -> int:
  160. """TEDLIUM dataset custom function overwritting len default behaviour.
  161. Returns:
  162. int: TEDLIUM dataset length
  163. """
  164. return len(self._filelist)
  165. @property
  166. def phoneme_dict(self):
  167. """dict[str, tuple[str]]: Phonemes. Mapping from word to tuple of phonemes.
  168. Note that some words have empty phonemes.
  169. """
  170. # Read phoneme dictionary
  171. if not self._phoneme_dict:
  172. self._phoneme_dict = {}
  173. with open(self._dict_path, "r", encoding="utf-8") as f:
  174. for line in f.readlines():
  175. content = line.strip().split()
  176. self._phoneme_dict[content[0]] = tuple(content[1:]) # content[1:] can be empty list
  177. return self._phoneme_dict.copy()