utils.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import hashlib
  2. import logging
  3. import os
  4. import tarfile
  5. import urllib
  6. import urllib.request
  7. import warnings
  8. import zipfile
  9. from typing import Any, Iterable, List, Optional
  10. from torch.utils.model_zoo import tqdm
  11. def stream_url(
  12. url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
  13. ) -> Iterable:
  14. """Stream url by chunk
  15. Args:
  16. url (str): Url.
  17. start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
  18. block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
  19. progress_bar (bool, optional): Display a progress bar (Default: ``True``).
  20. """
  21. # If we already have the whole file, there is no need to download it again
  22. req = urllib.request.Request(url, method="HEAD")
  23. with urllib.request.urlopen(req) as response:
  24. url_size = int(response.info().get("Content-Length", -1))
  25. if url_size == start_byte:
  26. return
  27. req = urllib.request.Request(url)
  28. if start_byte:
  29. req.headers["Range"] = "bytes={}-".format(start_byte)
  30. with urllib.request.urlopen(req) as upointer, tqdm(
  31. unit="B",
  32. unit_scale=True,
  33. unit_divisor=1024,
  34. total=url_size,
  35. disable=not progress_bar,
  36. ) as pbar:
  37. num_bytes = 0
  38. while True:
  39. chunk = upointer.read(block_size)
  40. if not chunk:
  41. break
  42. yield chunk
  43. num_bytes += len(chunk)
  44. pbar.update(len(chunk))
  45. def download_url(
  46. url: str,
  47. download_folder: str,
  48. filename: Optional[str] = None,
  49. hash_value: Optional[str] = None,
  50. hash_type: str = "sha256",
  51. progress_bar: bool = True,
  52. resume: bool = False,
  53. ) -> None:
  54. """Download file to disk.
  55. Args:
  56. url (str): Url.
  57. download_folder (str): Folder to download file.
  58. filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
  59. (Default: ``None``).
  60. hash_value (str or None, optional): Hash for url (Default: ``None``).
  61. hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
  62. progress_bar (bool, optional): Display a progress bar (Default: ``True``).
  63. resume (bool, optional): Enable resuming download (Default: ``False``).
  64. """
  65. warnings.warn("download_url is deprecated and will be removed in the v0.12 release.")
  66. req = urllib.request.Request(url, method="HEAD")
  67. req_info = urllib.request.urlopen(req).info()
  68. # Detect filename
  69. filename = filename or req_info.get_filename() or os.path.basename(url)
  70. filepath = os.path.join(download_folder, filename)
  71. if resume and os.path.exists(filepath):
  72. mode = "ab"
  73. local_size: Optional[int] = os.path.getsize(filepath)
  74. elif not resume and os.path.exists(filepath):
  75. raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
  76. else:
  77. mode = "wb"
  78. local_size = None
  79. if hash_value and local_size == int(req_info.get("Content-Length", -1)):
  80. with open(filepath, "rb") as file_obj:
  81. if validate_file(file_obj, hash_value, hash_type):
  82. return
  83. raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
  84. with open(filepath, mode) as fpointer:
  85. for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
  86. fpointer.write(chunk)
  87. with open(filepath, "rb") as file_obj:
  88. if hash_value and not validate_file(file_obj, hash_value, hash_type):
  89. raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
  90. def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
  91. """Validate a given file object with its hash.
  92. Args:
  93. file_obj: File object to read from.
  94. hash_value (str): Hash for url.
  95. hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
  96. Returns:
  97. bool: return True if its a valid file, else False.
  98. """
  99. if hash_type == "sha256":
  100. hash_func = hashlib.sha256()
  101. elif hash_type == "md5":
  102. hash_func = hashlib.md5()
  103. else:
  104. raise ValueError
  105. while True:
  106. # Read by chunk to avoid filling memory
  107. chunk = file_obj.read(1024**2)
  108. if not chunk:
  109. break
  110. hash_func.update(chunk)
  111. return hash_func.hexdigest() == hash_value
  112. def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
  113. """Extract archive.
  114. Args:
  115. from_path (str): the path of the archive.
  116. to_path (str or None, optional): the root path of the extraced files (directory of from_path)
  117. (Default: ``None``)
  118. overwrite (bool, optional): overwrite existing files (Default: ``False``)
  119. Returns:
  120. List[str]: List of paths to extracted files even if not overwritten.
  121. Examples:
  122. >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
  123. >>> from_path = './validation.tar.gz'
  124. >>> to_path = './'
  125. >>> torchaudio.datasets.utils.download_from_url(url, from_path)
  126. >>> torchaudio.datasets.utils.extract_archive(from_path, to_path)
  127. """
  128. if to_path is None:
  129. to_path = os.path.dirname(from_path)
  130. try:
  131. with tarfile.open(from_path, "r") as tar:
  132. logging.info("Opened tar file {}.".format(from_path))
  133. files = []
  134. for file_ in tar: # type: Any
  135. file_path = os.path.join(to_path, file_.name)
  136. if file_.isfile():
  137. files.append(file_path)
  138. if os.path.exists(file_path):
  139. logging.info("{} already extracted.".format(file_path))
  140. if not overwrite:
  141. continue
  142. tar.extract(file_, to_path)
  143. return files
  144. except tarfile.ReadError:
  145. pass
  146. try:
  147. with zipfile.ZipFile(from_path, "r") as zfile:
  148. logging.info("Opened zip file {}.".format(from_path))
  149. files = zfile.namelist()
  150. for file_ in files:
  151. file_path = os.path.join(to_path, file_)
  152. if os.path.exists(file_path):
  153. logging.info("{} already extracted.".format(file_path))
  154. if not overwrite:
  155. continue
  156. zfile.extract(file_, to_path)
  157. return files
  158. except zipfile.BadZipFile:
  159. pass
  160. raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")