_compat.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from typing import Dict, Optional, Tuple
  2. import torch
  3. import torchaudio
  4. from torchaudio.backend.common import AudioMetaData
  5. # Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
  6. def _info_audio(
  7. s: torch.classes.torchaudio.ffmpeg_StreamReader,
  8. ):
  9. i = s.find_best_audio_stream()
  10. sinfo = s.get_src_stream_info(i)
  11. return AudioMetaData(
  12. int(sinfo[7]),
  13. sinfo[5],
  14. sinfo[8],
  15. sinfo[6],
  16. sinfo[1].upper(),
  17. )
  18. def info_audio(
  19. src: str,
  20. format: Optional[str],
  21. ) -> AudioMetaData:
  22. s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None)
  23. return _info_audio(s)
  24. def info_audio_fileobj(
  25. src,
  26. format: Optional[str],
  27. ) -> AudioMetaData:
  28. s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096)
  29. return _info_audio(s)
  30. def _get_load_filter(
  31. frame_offset: int = 0,
  32. num_frames: int = -1,
  33. convert: bool = True,
  34. ) -> Optional[str]:
  35. if frame_offset < 0:
  36. raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset))
  37. if num_frames == 0 or num_frames < -1:
  38. raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames))
  39. # All default values -> no filter
  40. if frame_offset == 0 and num_frames == -1 and not convert:
  41. return None
  42. # Only convert
  43. aformat = "aformat=sample_fmts=fltp"
  44. if frame_offset == 0 and num_frames == -1 and convert:
  45. return aformat
  46. # At least one of frame_offset or num_frames has non-default value
  47. if num_frames > 0:
  48. atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames)
  49. else:
  50. atrim = "atrim=start_sample={}".format(frame_offset)
  51. if not convert:
  52. return atrim
  53. return "{},{}".format(atrim, aformat)
  54. # Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
  55. def _load_audio(
  56. s: torch.classes.torchaudio.ffmpeg_StreamReader,
  57. frame_offset: int = 0,
  58. num_frames: int = -1,
  59. convert: bool = True,
  60. channels_first: bool = True,
  61. ) -> Tuple[torch.Tensor, int]:
  62. i = s.find_best_audio_stream()
  63. sinfo = s.get_src_stream_info(i)
  64. sample_rate = int(sinfo[7])
  65. option: Dict[str, str] = {}
  66. s.add_audio_stream(i, -1, -1, _get_load_filter(frame_offset, num_frames, convert), None, option)
  67. s.process_all_packets()
  68. waveform = s.pop_chunks()[0]
  69. if waveform is None:
  70. raise RuntimeError("Failed to decode audio.")
  71. assert waveform is not None
  72. if channels_first:
  73. waveform = waveform.T
  74. return waveform, sample_rate
  75. def load_audio(
  76. src: str,
  77. frame_offset: int = 0,
  78. num_frames: int = -1,
  79. convert: bool = True,
  80. channels_first: bool = True,
  81. format: Optional[str] = None,
  82. ) -> Tuple[torch.Tensor, int]:
  83. s = torch.classes.torchaudio.ffmpeg_StreamReader(src, format, None)
  84. return _load_audio(s, frame_offset, num_frames, convert, channels_first)
  85. def load_audio_fileobj(
  86. src: str,
  87. frame_offset: int = 0,
  88. num_frames: int = -1,
  89. convert: bool = True,
  90. channels_first: bool = True,
  91. format: Optional[str] = None,
  92. ) -> Tuple[torch.Tensor, int]:
  93. s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096)
  94. return _load_audio(s, frame_offset, num_frames, convert, channels_first)