cmudict.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import os
  2. import re
  3. from pathlib import Path
  4. from typing import Iterable, List, Tuple, Union
  5. from torch.hub import download_url_to_file
  6. from torch.utils.data import Dataset
  7. _CHECKSUMS = {
  8. "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", # noqa: E501
  9. "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027", # noqa: E501
  10. }
  11. _PUNCTUATIONS = set(
  12. [
  13. "!EXCLAMATION-POINT",
  14. '"CLOSE-QUOTE',
  15. '"DOUBLE-QUOTE',
  16. '"END-OF-QUOTE',
  17. '"END-QUOTE',
  18. '"IN-QUOTES',
  19. '"QUOTE',
  20. '"UNQUOTE',
  21. "#HASH-MARK",
  22. "#POUND-SIGN",
  23. "#SHARP-SIGN",
  24. "%PERCENT",
  25. "&AMPERSAND",
  26. "'END-INNER-QUOTE",
  27. "'END-QUOTE",
  28. "'INNER-QUOTE",
  29. "'QUOTE",
  30. "'SINGLE-QUOTE",
  31. "(BEGIN-PARENS",
  32. "(IN-PARENTHESES",
  33. "(LEFT-PAREN",
  34. "(OPEN-PARENTHESES",
  35. "(PAREN",
  36. "(PARENS",
  37. "(PARENTHESES",
  38. ")CLOSE-PAREN",
  39. ")CLOSE-PARENTHESES",
  40. ")END-PAREN",
  41. ")END-PARENS",
  42. ")END-PARENTHESES",
  43. ")END-THE-PAREN",
  44. ")PAREN",
  45. ")PARENS",
  46. ")RIGHT-PAREN",
  47. ")UN-PARENTHESES",
  48. "+PLUS",
  49. ",COMMA",
  50. "--DASH",
  51. "-DASH",
  52. "-HYPHEN",
  53. "...ELLIPSIS",
  54. ".DECIMAL",
  55. ".DOT",
  56. ".FULL-STOP",
  57. ".PERIOD",
  58. ".POINT",
  59. "/SLASH",
  60. ":COLON",
  61. ";SEMI-COLON",
  62. ";SEMI-COLON(1)",
  63. "?QUESTION-MARK",
  64. "{BRACE",
  65. "{LEFT-BRACE",
  66. "{OPEN-BRACE",
  67. "}CLOSE-BRACE",
  68. "}RIGHT-BRACE",
  69. ]
  70. )
  71. def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
  72. _alt_re = re.compile(r"\([0-9]+\)")
  73. cmudict: List[Tuple[str, List[str]]] = list()
  74. for line in lines:
  75. if not line or line.startswith(";;;"): # ignore comments
  76. continue
  77. word, phones = line.strip().split(" ")
  78. if word in _PUNCTUATIONS:
  79. if exclude_punctuations:
  80. continue
  81. # !EXCLAMATION-POINT -> !
  82. # --DASH -> --
  83. # ...ELLIPSIS -> ...
  84. if word.startswith("..."):
  85. word = "..."
  86. elif word.startswith("--"):
  87. word = "--"
  88. else:
  89. word = word[0]
  90. # if a word have multiple pronunciations, there will be (number) appended to it
  91. # for example, DATAPOINTS and DATAPOINTS(1),
  92. # the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
  93. word = re.sub(_alt_re, "", word)
  94. phones = phones.split(" ")
  95. cmudict.append((word, phones))
  96. return cmudict
  97. class CMUDict(Dataset):
  98. """Create a Dataset for *CMU Pronouncing Dictionary* [:footcite:`cmudict`] (CMUDict).
  99. Args:
  100. root (str or Path): Path to the directory where the dataset is found or downloaded.
  101. exclude_punctuations (bool, optional):
  102. When enabled, exclude the pronounciation of punctuations, such as
  103. `!EXCLAMATION-POINT` and `#HASH-MARK`.
  104. download (bool, optional):
  105. Whether to download the dataset if it is not found at root path. (default: ``False``).
  106. url (str, optional):
  107. The URL to download the dictionary from.
  108. (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
  109. url_symbols (str, optional):
  110. The URL to download the list of symbols from.
  111. (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
  112. """
  113. def __init__(
  114. self,
  115. root: Union[str, Path],
  116. exclude_punctuations: bool = True,
  117. *,
  118. download: bool = False,
  119. url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
  120. url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
  121. ) -> None:
  122. self.exclude_punctuations = exclude_punctuations
  123. self._root_path = Path(root)
  124. if not os.path.isdir(self._root_path):
  125. raise RuntimeError(f"The root directory does not exist; {root}")
  126. dict_file = self._root_path / os.path.basename(url)
  127. symbol_file = self._root_path / os.path.basename(url_symbols)
  128. if not os.path.exists(dict_file):
  129. if not download:
  130. raise RuntimeError(
  131. "The dictionary file is not found in the following location. "
  132. f"Set `download=True` to download it. {dict_file}"
  133. )
  134. checksum = _CHECKSUMS.get(url, None)
  135. download_url_to_file(url, dict_file, checksum)
  136. if not os.path.exists(symbol_file):
  137. if not download:
  138. raise RuntimeError(
  139. "The symbol file is not found in the following location. "
  140. f"Set `download=True` to download it. {symbol_file}"
  141. )
  142. checksum = _CHECKSUMS.get(url_symbols, None)
  143. download_url_to_file(url_symbols, symbol_file, checksum)
  144. with open(symbol_file, "r") as text:
  145. self._symbols = [line.strip() for line in text.readlines()]
  146. with open(dict_file, "r", encoding="latin-1") as text:
  147. self._dictionary = _parse_dictionary(text.readlines(), exclude_punctuations=self.exclude_punctuations)
  148. def __getitem__(self, n: int) -> Tuple[str, List[str]]:
  149. """Load the n-th sample from the dataset.
  150. Args:
  151. n (int): The index of the sample to be loaded.
  152. Returns:
  153. (str, List[str]): The corresponding word and phonemes ``(word, [phonemes])``.
  154. """
  155. return self._dictionary[n]
  156. def __len__(self) -> int:
  157. return len(self._dictionary)
  158. @property
  159. def symbols(self) -> List[str]:
  160. """list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`."""
  161. return self._symbols.copy()