dr_vctk.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from pathlib import Path
  2. from typing import Dict, Tuple, Union
  3. import torchaudio
  4. from torch import Tensor
  5. from torch.hub import download_url_to_file
  6. from torch.utils.data import Dataset
  7. from torchaudio.datasets.utils import extract_archive
  8. _URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
  9. _CHECKSUM = "781f12f4406ed36ed27ae3bce55da47ba176e2d8bae67319e389e07b2c9bd769"
  10. _SUPPORTED_SUBSETS = {"train", "test"}
  11. class DR_VCTK(Dataset):
  12. """Create a dataset for *Device Recorded VCTK (Small subset version)* [:footcite:`Sarfjoo2018DeviceRV`].
  13. Args:
  14. root (str or Path): Root directory where the dataset's top level directory is found.
  15. subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``).
  16. download (bool):
  17. Whether to download the dataset if it is not found at root path. (default: ``False``).
  18. url (str): The URL to download the dataset from.
  19. (default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``)
  20. """
  21. def __init__(
  22. self,
  23. root: Union[str, Path],
  24. subset: str = "train",
  25. *,
  26. download: bool = False,
  27. url: str = _URL,
  28. ) -> None:
  29. if subset not in _SUPPORTED_SUBSETS:
  30. raise RuntimeError(
  31. f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}"
  32. )
  33. root = Path(root).expanduser()
  34. archive = root / "DR-VCTK.zip"
  35. self._subset = subset
  36. self._path = root / "DR-VCTK" / "DR-VCTK"
  37. self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k"
  38. self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k"
  39. self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt"
  40. if not self._path.is_dir():
  41. if not archive.is_file():
  42. if not download:
  43. raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
  44. download_url_to_file(url, archive, hash_prefix=_CHECKSUM)
  45. extract_archive(archive, root)
  46. self._config = self._load_config(self._config_filepath)
  47. self._filename_list = sorted(self._config)
  48. def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]:
  49. # Skip header
  50. skip_rows = 2 if self._subset == "train" else 1
  51. config = {}
  52. with open(filepath) as f:
  53. for i, line in enumerate(f):
  54. if i < skip_rows or not line:
  55. continue
  56. filename, source, channel_id = line.strip().split("\t")
  57. config[filename] = (source, int(channel_id))
  58. return config
  59. def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
  60. speaker_id, utterance_id = filename.split(".")[0].split("_")
  61. source, channel_id = self._config[filename]
  62. file_clean_audio = self._clean_audio_dir / filename
  63. file_noisy_audio = self._noisy_audio_dir / filename
  64. waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio)
  65. waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio)
  66. return (
  67. waveform_clean,
  68. sample_rate_clean,
  69. waveform_noisy,
  70. sample_rate_noisy,
  71. speaker_id,
  72. utterance_id,
  73. source,
  74. channel_id,
  75. )
  76. def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
  77. """Load the n-th sample from the dataset.
  78. Args:
  79. n (int): The index of the sample to be loaded
  80. Returns:
  81. (Tensor, int, Tensor, int, str, str, str, int):
  82. ``(waveform_clean, sample_rate_clean, waveform_noisy, sample_rate_noisy, speaker_id,\
  83. utterance_id, source, channel_id)``
  84. """
  85. filename = self._filename_list[n]
  86. return self._load_dr_vctk_item(filename)
  87. def __len__(self) -> int:
  88. return len(self._filename_list)