commonvoice.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import csv
  2. import os
  3. from pathlib import Path
  4. from typing import Dict, List, Tuple, Union
  5. import torchaudio
  6. from torch import Tensor
  7. from torch.utils.data import Dataset
  8. def load_commonvoice_item(
  9. line: List[str], header: List[str], path: str, folder_audio: str, ext_audio: str
  10. ) -> Tuple[Tensor, int, Dict[str, str]]:
  11. # Each line as the following data:
  12. # client_id, path, sentence, up_votes, down_votes, age, gender, accent
  13. assert header[1] == "path"
  14. fileid = line[1]
  15. filename = os.path.join(path, folder_audio, fileid)
  16. if not filename.endswith(ext_audio):
  17. filename += ext_audio
  18. waveform, sample_rate = torchaudio.load(filename)
  19. dic = dict(zip(header, line))
  20. return waveform, sample_rate, dic
  21. class COMMONVOICE(Dataset):
  22. """Create a Dataset for *CommonVoice* [:footcite:`ardila2020common`].
  23. Args:
  24. root (str or Path): Path to the directory where the dataset is located.
  25. (Where the ``tsv`` file is present.)
  26. tsv (str, optional):
  27. The name of the tsv file used to construct the metadata, such as
  28. ``"train.tsv"``, ``"test.tsv"``, ``"dev.tsv"``, ``"invalidated.tsv"``,
  29. ``"validated.tsv"`` and ``"other.tsv"``. (default: ``"train.tsv"``)
  30. """
  31. _ext_txt = ".txt"
  32. _ext_audio = ".mp3"
  33. _folder_audio = "clips"
  34. def __init__(self, root: Union[str, Path], tsv: str = "train.tsv") -> None:
  35. # Get string representation of 'root' in case Path object is passed
  36. self._path = os.fspath(root)
  37. self._tsv = os.path.join(self._path, tsv)
  38. with open(self._tsv, "r") as tsv_:
  39. walker = csv.reader(tsv_, delimiter="\t")
  40. self._header = next(walker)
  41. self._walker = list(walker)
  42. def __getitem__(self, n: int) -> Tuple[Tensor, int, Dict[str, str]]:
  43. """Load the n-th sample from the dataset.
  44. Args:
  45. n (int): The index of the sample to be loaded
  46. Returns:
  47. (Tensor, int, Dict[str, str]): ``(waveform, sample_rate, dictionary)``, where dictionary
  48. is built from the TSV file with the following keys: ``client_id``, ``path``, ``sentence``,
  49. ``up_votes``, ``down_votes``, ``age``, ``gender`` and ``accent``.
  50. """
  51. line = self._walker[n]
  52. return load_commonvoice_item(line, self._header, self._path, self._folder_audio, self._ext_audio)
  53. def __len__(self) -> int:
  54. return len(self._walker)