download.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import hashlib
  2. import logging
  3. from os import PathLike
  4. from pathlib import Path
  5. from typing import Union
  6. import torch
  7. _LG = logging.getLogger(__name__)
  8. def _get_local_path(key):
  9. path = Path(torch.hub.get_dir()) / "torchaudio" / Path(key)
  10. path.parent.mkdir(parents=True, exist_ok=True)
  11. return path
  12. def _download(key, path, progress):
  13. url = f"https://download.pytorch.org/torchaudio/{key}"
  14. torch.hub.download_url_to_file(url, path, progress=progress)
  15. def _get_hash(path, hash, chunk_size=1028):
  16. m = hashlib.sha256()
  17. with open(path, "rb") as file:
  18. data = file.read(chunk_size)
  19. while data:
  20. m.update(data)
  21. data = file.read(chunk_size)
  22. return m.hexdigest()
  23. def download_asset(
  24. key: str,
  25. hash: str = "",
  26. path: Union[str, PathLike] = "",
  27. *,
  28. progress: bool = True,
  29. ) -> str:
  30. """Download and store torchaudio assets to local file system.
  31. If a file exists at the download path, then that path is returned with or without
  32. hash validation.
  33. Args:
  34. key (str): The asset identifier.
  35. hash (str, optional):
  36. The value of SHA256 hash of the asset. If provided, it is used to verify
  37. the downloaded / cached object. If not provided, then no hash validation
  38. is performed. This means if a file exists at the download path, then the path
  39. is returned as-is without verifying the identity of the file.
  40. path (path-like object, optional):
  41. By default, the downloaded asset is saved in a directory under
  42. :py:func:`torch.hub.get_dir` and intermediate directories based on the given `key`
  43. are created.
  44. This argument can be used to overwrite the target location.
  45. When this argument is provided, all the intermediate directories have to be
  46. created beforehand.
  47. progress (bool): Whether to show progress bar for downloading. Default: ``True``.
  48. Note:
  49. Currently the valid key values are the route on ``download.pytorch.org/torchaudio``,
  50. but this is an implementation detail.
  51. Returns:
  52. str: The path to the asset on the local file system.
  53. """
  54. path = path or _get_local_path(key)
  55. if path.exists():
  56. _LG.info("The local file (%s) exists. Skipping the download.", path)
  57. else:
  58. _LG.info("Downloading %s to %s", key, path)
  59. _download(key, path, progress=progress)
  60. if hash:
  61. _LG.info("Verifying the hash value.")
  62. digest = _get_hash(path, hash)
  63. if digest != hash:
  64. raise ValueError(
  65. f"The hash value of the downloaded file ({path}), '{digest}' does not match "
  66. f"the provided hash value, '{hash}'."
  67. )
  68. _LG.info("Hash validated.")
  69. return str(path)