soundfile_backend.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. """The new soundfile backend which will become default in 0.8.0 onward"""
  2. import warnings
  3. from typing import Optional, Tuple
  4. import torch
  5. from torchaudio._internal import module_utils as _mod_utils
  6. from .common import AudioMetaData
  7. if _mod_utils.is_soundfile_available():
  8. import soundfile
  9. # Mapping from soundfile subtype to number of bits per sample.
  10. # This is mostly heuristical and the value is set to 0 when it is irrelevant
  11. # (lossy formats) or when it can't be inferred.
  12. # For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
  13. # According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
  14. # the default seems to be 8 bits but it can be compressed further to 4 bits.
  15. # The dict is inspired from
  16. # https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
  17. _SUBTYPE_TO_BITS_PER_SAMPLE = {
  18. "PCM_S8": 8, # Signed 8 bit data
  19. "PCM_16": 16, # Signed 16 bit data
  20. "PCM_24": 24, # Signed 24 bit data
  21. "PCM_32": 32, # Signed 32 bit data
  22. "PCM_U8": 8, # Unsigned 8 bit data (WAV and RAW only)
  23. "FLOAT": 32, # 32 bit float data
  24. "DOUBLE": 64, # 64 bit float data
  25. "ULAW": 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
  26. "ALAW": 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
  27. "IMA_ADPCM": 0, # IMA ADPCM.
  28. "MS_ADPCM": 0, # Microsoft ADPCM.
  29. "GSM610": 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
  30. "VOX_ADPCM": 0, # OKI / Dialogix ADPCM
  31. "G721_32": 0, # 32kbs G721 ADPCM encoding.
  32. "G723_24": 0, # 24kbs G723 ADPCM encoding.
  33. "G723_40": 0, # 40kbs G723 ADPCM encoding.
  34. "DWVW_12": 12, # 12 bit Delta Width Variable Word encoding.
  35. "DWVW_16": 16, # 16 bit Delta Width Variable Word encoding.
  36. "DWVW_24": 24, # 24 bit Delta Width Variable Word encoding.
  37. "DWVW_N": 0, # N bit Delta Width Variable Word encoding.
  38. "DPCM_8": 8, # 8 bit differential PCM (XI only)
  39. "DPCM_16": 16, # 16 bit differential PCM (XI only)
  40. "VORBIS": 0, # Xiph Vorbis encoding. (lossy)
  41. "ALAC_16": 16, # Apple Lossless Audio Codec (16 bit).
  42. "ALAC_20": 20, # Apple Lossless Audio Codec (20 bit).
  43. "ALAC_24": 24, # Apple Lossless Audio Codec (24 bit).
  44. "ALAC_32": 32, # Apple Lossless Audio Codec (32 bit).
  45. }
  46. def _get_bit_depth(subtype):
  47. if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
  48. warnings.warn(
  49. f"The {subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
  50. "attribute will be set to 0. If you are seeing this warning, please "
  51. "report by opening an issue on github (after checking for existing/closed ones). "
  52. "You may otherwise ignore this warning."
  53. )
  54. return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0)
  55. _SUBTYPE_TO_ENCODING = {
  56. "PCM_S8": "PCM_S",
  57. "PCM_16": "PCM_S",
  58. "PCM_24": "PCM_S",
  59. "PCM_32": "PCM_S",
  60. "PCM_U8": "PCM_U",
  61. "FLOAT": "PCM_F",
  62. "DOUBLE": "PCM_F",
  63. "ULAW": "ULAW",
  64. "ALAW": "ALAW",
  65. "VORBIS": "VORBIS",
  66. }
  67. def _get_encoding(format: str, subtype: str):
  68. if format == "FLAC":
  69. return "FLAC"
  70. return _SUBTYPE_TO_ENCODING.get(subtype, "UNKNOWN")
  71. @_mod_utils.requires_soundfile()
  72. def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
  73. """Get signal information of an audio file.
  74. Note:
  75. ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
  76. ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
  77. which has a restriction on type annotation due to TorchScript compiler compatiblity.
  78. Args:
  79. filepath (path-like object or file-like object):
  80. Source of audio data.
  81. format (str or None, optional):
  82. Not used. PySoundFile does not accept format hint.
  83. Returns:
  84. AudioMetaData: meta data of the given audio.
  85. """
  86. sinfo = soundfile.info(filepath)
  87. return AudioMetaData(
  88. sinfo.samplerate,
  89. sinfo.frames,
  90. sinfo.channels,
  91. bits_per_sample=_get_bit_depth(sinfo.subtype),
  92. encoding=_get_encoding(sinfo.format, sinfo.subtype),
  93. )
  94. _SUBTYPE2DTYPE = {
  95. "PCM_S8": "int8",
  96. "PCM_U8": "uint8",
  97. "PCM_16": "int16",
  98. "PCM_32": "int32",
  99. "FLOAT": "float32",
  100. "DOUBLE": "float64",
  101. }
  102. @_mod_utils.requires_soundfile()
  103. def load(
  104. filepath: str,
  105. frame_offset: int = 0,
  106. num_frames: int = -1,
  107. normalize: bool = True,
  108. channels_first: bool = True,
  109. format: Optional[str] = None,
  110. ) -> Tuple[torch.Tensor, int]:
  111. """Load audio data from file.
  112. Note:
  113. The formats this function can handle depend on the soundfile installation.
  114. This function is tested on the following formats;
  115. * WAV
  116. * 32-bit floating-point
  117. * 32-bit signed integer
  118. * 16-bit signed integer
  119. * 8-bit unsigned integer
  120. * FLAC
  121. * OGG/VORBIS
  122. * SPHERE
  123. By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
  124. ``float32`` dtype, and the shape of `[channel, time]`.
  125. .. warning::
  126. ``normalize`` argument does not perform volume normalization.
  127. It only converts the sample type to `torch.float32` from the native sample
  128. type.
  129. When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
  130. signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``,
  131. this function can return integer Tensor, where the samples are expressed within the whole range
  132. of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM,
  133. ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not
  134. support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors.
  135. ``normalize`` argument has no effect on 32-bit floating-point WAV and other formats, such as
  136. ``flac`` and ``mp3``.
  137. For these formats, this function always returns ``float32`` Tensor with values.
  138. Note:
  139. ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
  140. ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
  141. which has a restriction on type annotation due to TorchScript compiler compatiblity.
  142. Args:
  143. filepath (path-like object or file-like object):
  144. Source of audio data.
  145. frame_offset (int, optional):
  146. Number of frames to skip before start reading data.
  147. num_frames (int, optional):
  148. Maximum number of frames to read. ``-1`` reads all the remaining samples,
  149. starting from ``frame_offset``.
  150. This function may return the less number of frames if there is not enough
  151. frames in the given file.
  152. normalize (bool, optional):
  153. When ``True``, this function converts the native sample type to ``float32``.
  154. Default: ``True``.
  155. If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
  156. integer type.
  157. This argument has no effect for formats other than integer WAV type.
  158. channels_first (bool, optional):
  159. When True, the returned Tensor has dimension `[channel, time]`.
  160. Otherwise, the returned Tensor's dimension is `[time, channel]`.
  161. format (str or None, optional):
  162. Not used. PySoundFile does not accept format hint.
  163. Returns:
  164. (torch.Tensor, int): Resulting Tensor and sample rate.
  165. If the input file has integer wav format and normalization is off, then it has
  166. integer type, else ``float32`` type. If ``channels_first=True``, it has
  167. `[channel, time]` else `[time, channel]`.
  168. """
  169. with soundfile.SoundFile(filepath, "r") as file_:
  170. if file_.format != "WAV" or normalize:
  171. dtype = "float32"
  172. elif file_.subtype not in _SUBTYPE2DTYPE:
  173. raise ValueError(f"Unsupported subtype: {file_.subtype}")
  174. else:
  175. dtype = _SUBTYPE2DTYPE[file_.subtype]
  176. frames = file_._prepare_read(frame_offset, None, num_frames)
  177. waveform = file_.read(frames, dtype, always_2d=True)
  178. sample_rate = file_.samplerate
  179. waveform = torch.from_numpy(waveform)
  180. if channels_first:
  181. waveform = waveform.t()
  182. return waveform, sample_rate
  183. def _get_subtype_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int):
  184. if not encoding:
  185. if not bits_per_sample:
  186. subtype = {
  187. torch.uint8: "PCM_U8",
  188. torch.int16: "PCM_16",
  189. torch.int32: "PCM_32",
  190. torch.float32: "FLOAT",
  191. torch.float64: "DOUBLE",
  192. }.get(dtype)
  193. if not subtype:
  194. raise ValueError(f"Unsupported dtype for wav: {dtype}")
  195. return subtype
  196. if bits_per_sample == 8:
  197. return "PCM_U8"
  198. return f"PCM_{bits_per_sample}"
  199. if encoding == "PCM_S":
  200. if not bits_per_sample:
  201. return "PCM_32"
  202. if bits_per_sample == 8:
  203. raise ValueError("wav does not support 8-bit signed PCM encoding.")
  204. return f"PCM_{bits_per_sample}"
  205. if encoding == "PCM_U":
  206. if bits_per_sample in (None, 8):
  207. return "PCM_U8"
  208. raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
  209. if encoding == "PCM_F":
  210. if bits_per_sample in (None, 32):
  211. return "FLOAT"
  212. if bits_per_sample == 64:
  213. return "DOUBLE"
  214. raise ValueError("wav only supports 32/64-bit float PCM encoding.")
  215. if encoding == "ULAW":
  216. if bits_per_sample in (None, 8):
  217. return "ULAW"
  218. raise ValueError("wav only supports 8-bit mu-law encoding.")
  219. if encoding == "ALAW":
  220. if bits_per_sample in (None, 8):
  221. return "ALAW"
  222. raise ValueError("wav only supports 8-bit a-law encoding.")
  223. raise ValueError(f"wav does not support {encoding}.")
  224. def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
  225. if encoding in (None, "PCM_S"):
  226. return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
  227. if encoding in ("PCM_U", "PCM_F"):
  228. raise ValueError(f"sph does not support {encoding} encoding.")
  229. if encoding == "ULAW":
  230. if bits_per_sample in (None, 8):
  231. return "ULAW"
  232. raise ValueError("sph only supports 8-bit for mu-law encoding.")
  233. if encoding == "ALAW":
  234. return "ALAW"
  235. raise ValueError(f"sph does not support {encoding}.")
  236. def _get_subtype(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int):
  237. if format == "wav":
  238. return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
  239. if format == "flac":
  240. if encoding:
  241. raise ValueError("flac does not support encoding.")
  242. if not bits_per_sample:
  243. return "PCM_16"
  244. if bits_per_sample > 24:
  245. raise ValueError("flac does not support bits_per_sample > 24.")
  246. return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
  247. if format in ("ogg", "vorbis"):
  248. if encoding or bits_per_sample:
  249. raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
  250. return "VORBIS"
  251. if format == "sph":
  252. return _get_subtype_for_sphere(encoding, bits_per_sample)
  253. if format in ("nis", "nist"):
  254. return "PCM_16"
  255. raise ValueError(f"Unsupported format: {format}")
  256. @_mod_utils.requires_soundfile()
  257. def save(
  258. filepath: str,
  259. src: torch.Tensor,
  260. sample_rate: int,
  261. channels_first: bool = True,
  262. compression: Optional[float] = None,
  263. format: Optional[str] = None,
  264. encoding: Optional[str] = None,
  265. bits_per_sample: Optional[int] = None,
  266. ):
  267. """Save audio data to file.
  268. Note:
  269. The formats this function can handle depend on the soundfile installation.
  270. This function is tested on the following formats;
  271. * WAV
  272. * 32-bit floating-point
  273. * 32-bit signed integer
  274. * 16-bit signed integer
  275. * 8-bit unsigned integer
  276. * FLAC
  277. * OGG/VORBIS
  278. * SPHERE
  279. Note:
  280. ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts
  281. ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend,
  282. which has a restriction on type annotation due to TorchScript compiler compatiblity.
  283. Args:
  284. filepath (str or pathlib.Path): Path to audio file.
  285. src (torch.Tensor): Audio data to save. must be 2D tensor.
  286. sample_rate (int): sampling rate
  287. channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
  288. otherwise `[time, channel]`.
  289. compression (float of None, optional): Not used.
  290. It is here only for interface compatibility reson with "sox_io" backend.
  291. format (str or None, optional): Override the audio format.
  292. When ``filepath`` argument is path-like object, audio format is
  293. inferred from file extension. If the file extension is missing or
  294. different, you can specify the correct format with this argument.
  295. When ``filepath`` argument is file-like object,
  296. this argument is required.
  297. Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
  298. ``"flac"`` and ``"sph"``.
  299. encoding (str or None, optional): Changes the encoding for supported formats.
  300. This argument is effective only for supported formats, sush as
  301. ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
  302. - ``"PCM_S"`` (signed integer Linear PCM)
  303. - ``"PCM_U"`` (unsigned integer Linear PCM)
  304. - ``"PCM_F"`` (floating point PCM)
  305. - ``"ULAW"`` (mu-law)
  306. - ``"ALAW"`` (a-law)
  307. bits_per_sample (int or None, optional): Changes the bit depth for the
  308. supported formats.
  309. When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
  310. you can change the bit depth.
  311. Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
  312. Supported formats/encodings/bit depth/compression are:
  313. ``"wav"``
  314. - 32-bit floating-point PCM
  315. - 32-bit signed integer PCM
  316. - 24-bit signed integer PCM
  317. - 16-bit signed integer PCM
  318. - 8-bit unsigned integer PCM
  319. - 8-bit mu-law
  320. - 8-bit a-law
  321. Note:
  322. Default encoding/bit depth is determined by the dtype of
  323. the input Tensor.
  324. ``"flac"``
  325. - 8-bit
  326. - 16-bit (default)
  327. - 24-bit
  328. ``"ogg"``, ``"vorbis"``
  329. - Doesn't accept changing configuration.
  330. ``"sph"``
  331. - 8-bit signed integer PCM
  332. - 16-bit signed integer PCM
  333. - 24-bit signed integer PCM
  334. - 32-bit signed integer PCM (default)
  335. - 8-bit mu-law
  336. - 8-bit a-law
  337. - 16-bit a-law
  338. - 24-bit a-law
  339. - 32-bit a-law
  340. """
  341. if src.ndim != 2:
  342. raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
  343. if compression is not None:
  344. warnings.warn(
  345. '`save` function of "soundfile" backend does not support "compression" parameter. '
  346. "The argument is silently ignored."
  347. )
  348. if hasattr(filepath, "write"):
  349. if format is None:
  350. raise RuntimeError("`format` is required when saving to file object.")
  351. ext = format.lower()
  352. else:
  353. ext = str(filepath).split(".")[-1].lower()
  354. if bits_per_sample not in (None, 8, 16, 24, 32, 64):
  355. raise ValueError("Invalid bits_per_sample.")
  356. if bits_per_sample == 24:
  357. warnings.warn(
  358. "Saving audio with 24 bits per sample might warp samples near -1. "
  359. "Using 16 bits per sample might be able to avoid this."
  360. )
  361. subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
  362. # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
  363. # so we extend the extensions manually here
  364. if ext in ["nis", "nist", "sph"] and format is None:
  365. format = "NIST"
  366. if channels_first:
  367. src = src.t()
  368. soundfile.write(file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format)