libritts.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import os
  2. from pathlib import Path
  3. from typing import Tuple, Union
  4. import torchaudio
  5. from torch import Tensor
  6. from torch.hub import download_url_to_file
  7. from torch.utils.data import Dataset
  8. from torchaudio.datasets.utils import extract_archive
  9. URL = "train-clean-100"
  10. FOLDER_IN_ARCHIVE = "LibriTTS"
  11. _CHECKSUMS = {
  12. "http://www.openslr.org/resources/60/dev-clean.tar.gz": "da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a", # noqa: E501
  13. "http://www.openslr.org/resources/60/dev-other.tar.gz": "d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c", # noqa: E501
  14. "http://www.openslr.org/resources/60/test-clean.tar.gz": "234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5", # noqa: E501
  15. "http://www.openslr.org/resources/60/test-other.tar.gz": "33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d", # noqa: E501
  16. "http://www.openslr.org/resources/60/train-clean-100.tar.gz": "c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b", # noqa: E501
  17. "http://www.openslr.org/resources/60/train-clean-360.tar.gz": "ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886", # noqa: E501
  18. "http://www.openslr.org/resources/60/train-other-500.tar.gz": "e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df", # noqa: E501
  19. }
  20. def load_libritts_item(
  21. fileid: str,
  22. path: str,
  23. ext_audio: str,
  24. ext_original_txt: str,
  25. ext_normalized_txt: str,
  26. ) -> Tuple[Tensor, int, str, str, int, int, str]:
  27. speaker_id, chapter_id, segment_id, utterance_id = fileid.split("_")
  28. utterance_id = fileid
  29. normalized_text = utterance_id + ext_normalized_txt
  30. normalized_text = os.path.join(path, speaker_id, chapter_id, normalized_text)
  31. original_text = utterance_id + ext_original_txt
  32. original_text = os.path.join(path, speaker_id, chapter_id, original_text)
  33. file_audio = utterance_id + ext_audio
  34. file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
  35. # Load audio
  36. waveform, sample_rate = torchaudio.load(file_audio)
  37. # Load original text
  38. with open(original_text) as ft:
  39. original_text = ft.readline()
  40. # Load normalized text
  41. with open(normalized_text, "r") as ft:
  42. normalized_text = ft.readline()
  43. return (
  44. waveform,
  45. sample_rate,
  46. original_text,
  47. normalized_text,
  48. int(speaker_id),
  49. int(chapter_id),
  50. utterance_id,
  51. )
  52. class LIBRITTS(Dataset):
  53. """Create a Dataset for *LibriTTS* [:footcite:`Zen2019LibriTTSAC`].
  54. Args:
  55. root (str or Path): Path to the directory where the dataset is found or downloaded.
  56. url (str, optional): The URL to download the dataset from,
  57. or the type of the dataset to dowload.
  58. Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
  59. ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
  60. ``"train-other-500"``. (default: ``"train-clean-100"``)
  61. folder_in_archive (str, optional):
  62. The top-level directory of the dataset. (default: ``"LibriTTS"``)
  63. download (bool, optional):
  64. Whether to download the dataset if it is not found at root path. (default: ``False``).
  65. """
  66. _ext_original_txt = ".original.txt"
  67. _ext_normalized_txt = ".normalized.txt"
  68. _ext_audio = ".wav"
  69. def __init__(
  70. self,
  71. root: Union[str, Path],
  72. url: str = URL,
  73. folder_in_archive: str = FOLDER_IN_ARCHIVE,
  74. download: bool = False,
  75. ) -> None:
  76. if url in [
  77. "dev-clean",
  78. "dev-other",
  79. "test-clean",
  80. "test-other",
  81. "train-clean-100",
  82. "train-clean-360",
  83. "train-other-500",
  84. ]:
  85. ext_archive = ".tar.gz"
  86. base_url = "http://www.openslr.org/resources/60/"
  87. url = os.path.join(base_url, url + ext_archive)
  88. # Get string representation of 'root' in case Path object is passed
  89. root = os.fspath(root)
  90. basename = os.path.basename(url)
  91. archive = os.path.join(root, basename)
  92. basename = basename.split(".")[0]
  93. folder_in_archive = os.path.join(folder_in_archive, basename)
  94. self._path = os.path.join(root, folder_in_archive)
  95. if download:
  96. if not os.path.isdir(self._path):
  97. if not os.path.isfile(archive):
  98. checksum = _CHECKSUMS.get(url, None)
  99. download_url_to_file(url, archive, hash_prefix=checksum)
  100. extract_archive(archive)
  101. else:
  102. if not os.path.exists(self._path):
  103. raise RuntimeError(
  104. f"The path {self._path} doesn't exist. "
  105. "Please check the ``root`` path or set `download=True` to download it"
  106. )
  107. self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
  108. def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
  109. """Load the n-th sample from the dataset.
  110. Args:
  111. n (int): The index of the sample to be loaded
  112. Returns:
  113. (Tensor, int, str, str, str, int, int, str):
  114. ``(waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id)``
  115. """
  116. fileid = self._walker[n]
  117. return load_libritts_item(
  118. fileid,
  119. self._path,
  120. self._ext_audio,
  121. self._ext_original_txt,
  122. self._ext_normalized_txt,
  123. )
  124. def __len__(self) -> int:
  125. return len(self._walker)