librimix.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from pathlib import Path
  2. from typing import List, Tuple, Union
  3. import torch
  4. import torchaudio
  5. from torch.utils.data import Dataset
  6. SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
  7. class LibriMix(Dataset):
  8. r"""Create the *LibriMix* [:footcite:`cosentino2020librimix`] dataset.
  9. Args:
  10. root (str or Path): The path to the directory where the directory ``Libri2Mix`` or
  11. ``Libri3Mix`` is stored.
  12. subset (str, optional): The subset to use. Options: [``train-360``, ``train-100``,
  13. ``dev``, and ``test``] (Default: ``train-360``).
  14. num_speakers (int, optional): The number of speakers, which determines the directories
  15. to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
  16. N source audios. (Default: 2)
  17. sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines
  18. which subdirectory the audio are fetched. If any of the audio has a different sample
  19. rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
  20. task (str, optional): the task of LibriMix.
  21. Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
  22. (Default: ``sep_clean``)
  23. Note:
  24. The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
  25. """
  26. def __init__(
  27. self,
  28. root: Union[str, Path],
  29. subset: str = "train-360",
  30. num_speakers: int = 2,
  31. sample_rate: int = 8000,
  32. task: str = "sep_clean",
  33. ):
  34. self.root = Path(root) / f"Libri{num_speakers}Mix"
  35. if sample_rate == 8000:
  36. self.root = self.root / "wav8k/min" / subset
  37. elif sample_rate == 16000:
  38. self.root = self.root / "wav16k/min" / subset
  39. else:
  40. raise ValueError(f"Unsupported sample rate. Found {sample_rate}.")
  41. self.sample_rate = sample_rate
  42. self.task = task
  43. self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve()
  44. self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
  45. self.files = [p.name for p in self.mix_dir.glob("*wav")]
  46. self.files.sort()
  47. def _load_audio(self, path) -> torch.Tensor:
  48. waveform, sample_rate = torchaudio.load(path)
  49. if sample_rate != self.sample_rate:
  50. raise ValueError(
  51. f"The dataset contains audio file of sample rate {sample_rate}, "
  52. f"but the requested sample rate is {self.sample_rate}."
  53. )
  54. return waveform
  55. def _load_sample(self, filename) -> SampleType:
  56. mixed = self._load_audio(str(self.mix_dir / filename))
  57. srcs = []
  58. for i, dir_ in enumerate(self.src_dirs):
  59. src = self._load_audio(str(dir_ / filename))
  60. if mixed.shape != src.shape:
  61. raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
  62. srcs.append(src)
  63. return self.sample_rate, mixed, srcs
  64. def __len__(self) -> int:
  65. return len(self.files)
  66. def __getitem__(self, key: int) -> SampleType:
  67. """Load the n-th sample from the dataset.
  68. Args:
  69. key (int): The index of the sample to be loaded
  70. Returns:
  71. (int, Tensor, List[Tensor]): ``(sample_rate, mix_waveform, list_of_source_waveforms)``
  72. """
  73. return self._load_sample(self.files[key])