quesst14.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. import re
  3. from pathlib import Path
  4. from typing import Optional, Tuple, Union
  5. import torch
  6. import torchaudio
  7. from torch.hub import download_url_to_file
  8. from torch.utils.data import Dataset
  9. from torchaudio.datasets.utils import extract_archive
  10. URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz"
  11. _CHECKSUM = "4f869e06bc066bbe9c5dde31dbd3909a0870d70291110ebbb38878dcbc2fc5e4"
  12. _LANGUAGES = [
  13. "albanian",
  14. "basque",
  15. "czech",
  16. "nnenglish",
  17. "romanian",
  18. "slovak",
  19. ]
  20. class QUESST14(Dataset):
  21. """Create *QUESST14* [:footcite:`Mir2015QUESST2014EQ`] Dataset
  22. Args:
  23. root (str or Path): Root directory where the dataset's top level directory is found
  24. subset (str): Subset of the dataset to use. Options: [``"docs"``, ``"dev"``, ``"eval"``].
  25. language (str or None, optional): Language to get dataset for.
  26. Options: [``None``, ``albanian``, ``basque``, ``czech``, ``nnenglish``, ``romanian``, ``slovak``].
  27. If ``None``, dataset consists of all languages. (default: ``"nnenglish"``)
  28. download (bool, optional): Whether to download the dataset if it is not found at root path.
  29. (default: ``False``)
  30. """
  31. def __init__(
  32. self,
  33. root: Union[str, Path],
  34. subset: str,
  35. language: Optional[str] = "nnenglish",
  36. download: bool = False,
  37. ) -> None:
  38. assert subset in ["docs", "dev", "eval"], "`subset` must be one of ['docs', 'dev', 'eval']"
  39. assert language is None or language in _LANGUAGES, f"`language` must be None or one of {str(_LANGUAGES)}"
  40. # Get string representation of 'root'
  41. root = os.fspath(root)
  42. basename = os.path.basename(URL)
  43. archive = os.path.join(root, basename)
  44. basename = basename.rsplit(".", 2)[0]
  45. self._path = os.path.join(root, basename)
  46. if not os.path.isdir(self._path):
  47. if not os.path.isfile(archive):
  48. if not download:
  49. raise RuntimeError("Dataset not found. Please use `download=True` to download")
  50. download_url_to_file(URL, archive, hash_prefix=_CHECKSUM)
  51. extract_archive(archive, root)
  52. if subset == "docs":
  53. self.data = filter_audio_paths(self._path, language, "language_key_utterances.lst")
  54. elif subset == "dev":
  55. self.data = filter_audio_paths(self._path, language, "language_key_dev.lst")
  56. elif subset == "eval":
  57. self.data = filter_audio_paths(self._path, language, "language_key_eval.lst")
  58. def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, str]:
  59. audio_path = self.data[n]
  60. wav, sample_rate = torchaudio.load(audio_path)
  61. return wav, sample_rate, audio_path.with_suffix("").name
  62. def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
  63. """Load the n-th sample from the dataset.
  64. Args:
  65. n (int): The index of the sample to be loaded
  66. Returns:
  67. (Tensor, int, str): ``(waveform, sample_rate, file_name)``
  68. """
  69. return self._load_sample(n)
  70. def __len__(self) -> int:
  71. return len(self.data)
  72. def filter_audio_paths(
  73. path: str,
  74. language: str,
  75. lst_name: str,
  76. ):
  77. """Extract audio paths for the given language."""
  78. audio_paths = []
  79. path = Path(path)
  80. with open(path / "scoring" / lst_name) as f:
  81. for line in f:
  82. audio_path, lang = line.strip().split()
  83. if language is not None and lang != language:
  84. continue
  85. audio_path = re.sub(r"^.*?\/", "", audio_path)
  86. audio_paths.append(path / audio_path)
  87. return audio_paths