utils.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """Defines utilities for switching audio backends"""
  2. import warnings
  3. from typing import List, Optional
  4. import torchaudio
  5. from torchaudio._internal import module_utils as _mod_utils
  6. from . import no_backend, soundfile_backend, sox_io_backend
  7. __all__ = [
  8. "list_audio_backends",
  9. "get_audio_backend",
  10. "set_audio_backend",
  11. ]
  12. def list_audio_backends() -> List[str]:
  13. """List available backends
  14. Returns:
  15. List[str]: The list of available backends.
  16. """
  17. backends = []
  18. if _mod_utils.is_module_available("soundfile"):
  19. backends.append("soundfile")
  20. if _mod_utils.is_sox_available():
  21. backends.append("sox_io")
  22. return backends
  23. def set_audio_backend(backend: Optional[str]):
  24. """Set the backend for I/O operation
  25. Args:
  26. backend (str or None): Name of the backend.
  27. One of ``"sox_io"`` or ``"soundfile"`` based on availability
  28. of the system. If ``None`` is provided the current backend is unassigned.
  29. """
  30. if backend is not None and backend not in list_audio_backends():
  31. raise RuntimeError(f'Backend "{backend}" is not one of ' f"available backends: {list_audio_backends()}.")
  32. if backend is None:
  33. module = no_backend
  34. elif backend == "sox_io":
  35. module = sox_io_backend
  36. elif backend == "soundfile":
  37. module = soundfile_backend
  38. else:
  39. raise NotImplementedError(f'Unexpected backend "{backend}"')
  40. for func in ["save", "load", "info"]:
  41. setattr(torchaudio, func, getattr(module, func))
  42. def _init_audio_backend():
  43. backends = list_audio_backends()
  44. if "sox_io" in backends:
  45. set_audio_backend("sox_io")
  46. elif "soundfile" in backends:
  47. set_audio_backend("soundfile")
  48. else:
  49. warnings.warn("No audio backend is available.")
  50. set_audio_backend(None)
  51. def get_audio_backend() -> Optional[str]:
  52. """Get the name of the current backend
  53. Returns:
  54. Optional[str]: The name of the current backend or ``None`` if no backend is assigned.
  55. """
  56. if torchaudio.load == no_backend.load:
  57. return None
  58. if torchaudio.load == sox_io_backend.load:
  59. return "sox_io"
  60. if torchaudio.load == soundfile_backend.load:
  61. return "soundfile"
  62. raise ValueError("Unknown backend.")