librilight_limited.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. from pathlib import Path
  3. from typing import List, Tuple, Union
  4. from torch import Tensor
  5. from torch.hub import download_url_to_file
  6. from torch.utils.data import Dataset
  7. from torchaudio.datasets.librispeech import load_librispeech_item
  8. from torchaudio.datasets.utils import extract_archive
  9. _ARCHIVE_NAME = "librispeech_finetuning"
  10. _URL = "https://dl.fbaipublicfiles.com/librilight/data/librispeech_finetuning.tgz"
  11. _CHECKSUM = "5d1efdc777b548194d7e09ba89126e2188026df9fd57aa57eb14408d2b2342af"
  12. def _get_fileids_paths(path, subset, _ext_audio) -> List[Tuple[str, str]]:
  13. """Get the file names and the corresponding file paths without `speaker_id`
  14. and `chapter_id` directories.
  15. The format of path is like:
  16. {root}/{_ARCHIVE_NAME}/1h/[0-5]/[clean, other] or
  17. {root}/{_ARCHIVE_NAME}/9h/[clean, other]
  18. """
  19. if subset == "10min":
  20. files_paths = [
  21. (os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
  22. for p in Path(path).glob("1h/0/*/*/*/*" + _ext_audio)
  23. ]
  24. elif subset in ["1h", "10h"]:
  25. files_paths = [
  26. (os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
  27. for p in Path(path).glob("1h/*/*/*/*/*" + _ext_audio)
  28. ]
  29. if subset == "10h":
  30. files_paths += [
  31. (os.path.join(os.path.dirname(p), "..", ".."), str(p.stem))
  32. for p in Path(path).glob("9h/*/*/*/*" + _ext_audio)
  33. ]
  34. else:
  35. raise ValueError(f"Unsupported subset value. Found {subset}.")
  36. files_paths = sorted(files_paths, key=lambda x: x[0] + x[1])
  37. return files_paths
  38. class LibriLightLimited(Dataset):
  39. """Create a Dataset for LibriLightLimited, which is the supervised subset of
  40. LibriLight dataset.
  41. Args:
  42. root (str or Path): Path to the directory where the dataset is found or downloaded.
  43. subset (str, optional): The subset to use. Options: [``10min``, ``1h``, ``10h``]
  44. (Default: ``10min``).
  45. download (bool, optional):
  46. Whether to download the dataset if it is not found at root path. (default: ``False``).
  47. """
  48. _ext_txt = ".trans.txt"
  49. _ext_audio = ".flac"
  50. def __init__(
  51. self,
  52. root: Union[str, Path],
  53. subset: str = "10min",
  54. download: bool = False,
  55. ) -> None:
  56. assert subset in ["10min", "1h", "10h"], "`subset` must be one of ['10min', '1h', '10h']"
  57. root = os.fspath(root)
  58. self._path = os.path.join(root, _ARCHIVE_NAME)
  59. archive = os.path.join(root, f"{_ARCHIVE_NAME}.tgz")
  60. if not os.path.isdir(self._path):
  61. if not download:
  62. raise RuntimeError("Dataset not found. Please use `download=True` to download")
  63. if not os.path.isfile(archive):
  64. download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
  65. extract_archive(archive)
  66. self._fileids_paths = _get_fileids_paths(self._path, subset, self._ext_audio)
  67. def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
  68. """Load the n-th sample from the dataset.
  69. Args:
  70. n (int): The index of the sample to be loaded
  71. Returns:
  72. (Tensor, int, str, int, int, int):
  73. ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
  74. """
  75. file_path, fileid = self._fileids_paths[n]
  76. return load_librispeech_item(fileid, file_path, self._ext_audio, self._ext_txt)
  77. def __len__(self) -> int:
  78. return len(self._fileids_paths)