| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- import hashlib
- import logging
- import os
- import tarfile
- import urllib
- import urllib.request
- import warnings
- import zipfile
- from typing import Any, Iterable, List, Optional
- from torch.utils.model_zoo import tqdm
- def stream_url(
- url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
- ) -> Iterable:
- """Stream url by chunk
- Args:
- url (str): Url.
- start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
- block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
- progress_bar (bool, optional): Display a progress bar (Default: ``True``).
- """
- # If we already have the whole file, there is no need to download it again
- req = urllib.request.Request(url, method="HEAD")
- with urllib.request.urlopen(req) as response:
- url_size = int(response.info().get("Content-Length", -1))
- if url_size == start_byte:
- return
- req = urllib.request.Request(url)
- if start_byte:
- req.headers["Range"] = "bytes={}-".format(start_byte)
- with urllib.request.urlopen(req) as upointer, tqdm(
- unit="B",
- unit_scale=True,
- unit_divisor=1024,
- total=url_size,
- disable=not progress_bar,
- ) as pbar:
- num_bytes = 0
- while True:
- chunk = upointer.read(block_size)
- if not chunk:
- break
- yield chunk
- num_bytes += len(chunk)
- pbar.update(len(chunk))
- def download_url(
- url: str,
- download_folder: str,
- filename: Optional[str] = None,
- hash_value: Optional[str] = None,
- hash_type: str = "sha256",
- progress_bar: bool = True,
- resume: bool = False,
- ) -> None:
- """Download file to disk.
- Args:
- url (str): Url.
- download_folder (str): Folder to download file.
- filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
- (Default: ``None``).
- hash_value (str or None, optional): Hash for url (Default: ``None``).
- hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
- progress_bar (bool, optional): Display a progress bar (Default: ``True``).
- resume (bool, optional): Enable resuming download (Default: ``False``).
- """
- warnings.warn("download_url is deprecated and will be removed in the v0.12 release.")
- req = urllib.request.Request(url, method="HEAD")
- req_info = urllib.request.urlopen(req).info()
- # Detect filename
- filename = filename or req_info.get_filename() or os.path.basename(url)
- filepath = os.path.join(download_folder, filename)
- if resume and os.path.exists(filepath):
- mode = "ab"
- local_size: Optional[int] = os.path.getsize(filepath)
- elif not resume and os.path.exists(filepath):
- raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
- else:
- mode = "wb"
- local_size = None
- if hash_value and local_size == int(req_info.get("Content-Length", -1)):
- with open(filepath, "rb") as file_obj:
- if validate_file(file_obj, hash_value, hash_type):
- return
- raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
- with open(filepath, mode) as fpointer:
- for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
- fpointer.write(chunk)
- with open(filepath, "rb") as file_obj:
- if hash_value and not validate_file(file_obj, hash_value, hash_type):
- raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
- def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
- """Validate a given file object with its hash.
- Args:
- file_obj: File object to read from.
- hash_value (str): Hash for url.
- hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
- Returns:
- bool: return True if its a valid file, else False.
- """
- if hash_type == "sha256":
- hash_func = hashlib.sha256()
- elif hash_type == "md5":
- hash_func = hashlib.md5()
- else:
- raise ValueError
- while True:
- # Read by chunk to avoid filling memory
- chunk = file_obj.read(1024**2)
- if not chunk:
- break
- hash_func.update(chunk)
- return hash_func.hexdigest() == hash_value
- def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
- """Extract archive.
- Args:
- from_path (str): the path of the archive.
- to_path (str or None, optional): the root path of the extraced files (directory of from_path)
- (Default: ``None``)
- overwrite (bool, optional): overwrite existing files (Default: ``False``)
- Returns:
- List[str]: List of paths to extracted files even if not overwritten.
- Examples:
- >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
- >>> from_path = './validation.tar.gz'
- >>> to_path = './'
- >>> torchaudio.datasets.utils.download_from_url(url, from_path)
- >>> torchaudio.datasets.utils.extract_archive(from_path, to_path)
- """
- if to_path is None:
- to_path = os.path.dirname(from_path)
- try:
- with tarfile.open(from_path, "r") as tar:
- logging.info("Opened tar file {}.".format(from_path))
- files = []
- for file_ in tar: # type: Any
- file_path = os.path.join(to_path, file_.name)
- if file_.isfile():
- files.append(file_path)
- if os.path.exists(file_path):
- logging.info("{} already extracted.".format(file_path))
- if not overwrite:
- continue
- tar.extract(file_, to_path)
- return files
- except tarfile.ReadError:
- pass
- try:
- with zipfile.ZipFile(from_path, "r") as zfile:
- logging.info("Opened zip file {}.".format(from_path))
- files = zfile.namelist()
- for file_ in files:
- file_path = os.path.join(to_path, file_)
- if os.path.exists(file_path):
- logging.info("{} already extracted.".format(file_path))
- if not overwrite:
- continue
- zfile.extract(file_, to_path)
- return files
- except zipfile.BadZipFile:
- pass
- raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
|