ljspeech.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import csv
  2. import os
  3. from pathlib import Path
  4. from typing import Tuple, Union
  5. import torchaudio
  6. from torch import Tensor
  7. from torch.hub import download_url_to_file
  8. from torch.utils.data import Dataset
  9. from torchaudio.datasets.utils import extract_archive
  10. _RELEASE_CONFIGS = {
  11. "release1": {
  12. "folder_in_archive": "wavs",
  13. "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
  14. "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5",
  15. }
  16. }
  17. class LJSPEECH(Dataset):
  18. """Create a Dataset for *LJSpeech-1.1* [:footcite:`ljspeech17`].
  19. Args:
  20. root (str or Path): Path to the directory where the dataset is found or downloaded.
  21. url (str, optional): The URL to download the dataset from.
  22. (default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
  23. folder_in_archive (str, optional):
  24. The top-level directory of the dataset. (default: ``"wavs"``)
  25. download (bool, optional):
  26. Whether to download the dataset if it is not found at root path. (default: ``False``).
  27. """
  28. def __init__(
  29. self,
  30. root: Union[str, Path],
  31. url: str = _RELEASE_CONFIGS["release1"]["url"],
  32. folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
  33. download: bool = False,
  34. ) -> None:
  35. self._parse_filesystem(root, url, folder_in_archive, download)
  36. def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
  37. root = Path(root)
  38. basename = os.path.basename(url)
  39. archive = root / basename
  40. basename = Path(basename.split(".tar.bz2")[0])
  41. folder_in_archive = basename / folder_in_archive
  42. self._path = root / folder_in_archive
  43. self._metadata_path = root / basename / "metadata.csv"
  44. if download:
  45. if not os.path.isdir(self._path):
  46. if not os.path.isfile(archive):
  47. checksum = _RELEASE_CONFIGS["release1"]["checksum"]
  48. download_url_to_file(url, archive, hash_prefix=checksum)
  49. extract_archive(archive)
  50. else:
  51. if not os.path.exists(self._path):
  52. raise RuntimeError(
  53. f"The path {self._path} doesn't exist. "
  54. "Please check the ``root`` path or set `download=True` to download it"
  55. )
  56. with open(self._metadata_path, "r", newline="") as metadata:
  57. flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
  58. self._flist = list(flist)
  59. def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
  60. """Load the n-th sample from the dataset.
  61. Args:
  62. n (int): The index of the sample to be loaded
  63. Returns:
  64. (Tensor, int, str, str):
  65. ``(waveform, sample_rate, transcript, normalized_transcript)``
  66. """
  67. line = self._flist[n]
  68. fileid, transcript, normalized_transcript = line
  69. fileid_audio = self._path / (fileid + ".wav")
  70. # Load audio
  71. waveform, sample_rate = torchaudio.load(fileid_audio)
  72. return (
  73. waveform,
  74. sample_rate,
  75. transcript,
  76. normalized_transcript,
  77. )
  78. def __len__(self) -> int:
  79. return len(self._flist)