| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import csv
- import os
- from pathlib import Path
- from typing import Tuple, Union
- import torchaudio
- from torch import Tensor
- from torch.hub import download_url_to_file
- from torch.utils.data import Dataset
- from torchaudio.datasets.utils import extract_archive
- _RELEASE_CONFIGS = {
- "release1": {
- "folder_in_archive": "wavs",
- "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
- "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5",
- }
- }
- class LJSPEECH(Dataset):
- """Create a Dataset for *LJSpeech-1.1* [:footcite:`ljspeech17`].
- Args:
- root (str or Path): Path to the directory where the dataset is found or downloaded.
- url (str, optional): The URL to download the dataset from.
- (default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
- folder_in_archive (str, optional):
- The top-level directory of the dataset. (default: ``"wavs"``)
- download (bool, optional):
- Whether to download the dataset if it is not found at root path. (default: ``False``).
- """
- def __init__(
- self,
- root: Union[str, Path],
- url: str = _RELEASE_CONFIGS["release1"]["url"],
- folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
- download: bool = False,
- ) -> None:
- self._parse_filesystem(root, url, folder_in_archive, download)
- def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
- root = Path(root)
- basename = os.path.basename(url)
- archive = root / basename
- basename = Path(basename.split(".tar.bz2")[0])
- folder_in_archive = basename / folder_in_archive
- self._path = root / folder_in_archive
- self._metadata_path = root / basename / "metadata.csv"
- if download:
- if not os.path.isdir(self._path):
- if not os.path.isfile(archive):
- checksum = _RELEASE_CONFIGS["release1"]["checksum"]
- download_url_to_file(url, archive, hash_prefix=checksum)
- extract_archive(archive)
- else:
- if not os.path.exists(self._path):
- raise RuntimeError(
- f"The path {self._path} doesn't exist. "
- "Please check the ``root`` path or set `download=True` to download it"
- )
- with open(self._metadata_path, "r", newline="") as metadata:
- flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
- self._flist = list(flist)
- def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
- """Load the n-th sample from the dataset.
- Args:
- n (int): The index of the sample to be loaded
- Returns:
- (Tensor, int, str, str):
- ``(waveform, sample_rate, transcript, normalized_transcript)``
- """
- line = self._flist[n]
- fileid, transcript, normalized_transcript = line
- fileid_audio = self._path / (fileid + ".wav")
- # Load audio
- waveform, sample_rate = torchaudio.load(fileid_audio)
- return (
- waveform,
- sample_rate,
- transcript,
- normalized_transcript,
- )
- def __len__(self) -> int:
- return len(self._flist)
|