functional.py 79 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162
  1. # -*- coding: utf-8 -*-
  2. import io
  3. import math
  4. import warnings
  5. from collections.abc import Sequence
  6. from typing import Optional, Tuple, Union, List
  7. import torch
  8. import torchaudio
  9. from torch import Tensor
  10. from torchaudio._internal import module_utils as _mod_utils
  11. __all__ = [
  12. "spectrogram",
  13. "inverse_spectrogram",
  14. "griffinlim",
  15. "amplitude_to_DB",
  16. "DB_to_amplitude",
  17. "compute_deltas",
  18. "compute_kaldi_pitch",
  19. "melscale_fbanks",
  20. "linear_fbanks",
  21. "create_dct",
  22. "compute_deltas",
  23. "detect_pitch_frequency",
  24. "DB_to_amplitude",
  25. "mu_law_encoding",
  26. "mu_law_decoding",
  27. "phase_vocoder",
  28. "mask_along_axis",
  29. "mask_along_axis_iid",
  30. "sliding_window_cmn",
  31. "spectral_centroid",
  32. "apply_codec",
  33. "resample",
  34. "edit_distance",
  35. "pitch_shift",
  36. "rnnt_loss",
  37. "psd",
  38. "mvdr_weights_souden",
  39. "mvdr_weights_rtf",
  40. "rtf_evd",
  41. "rtf_power",
  42. "apply_beamforming",
  43. ]
  44. def spectrogram(
  45. waveform: Tensor,
  46. pad: int,
  47. window: Tensor,
  48. n_fft: int,
  49. hop_length: int,
  50. win_length: int,
  51. power: Optional[float],
  52. normalized: bool,
  53. center: bool = True,
  54. pad_mode: str = "reflect",
  55. onesided: bool = True,
  56. return_complex: Optional[bool] = None,
  57. ) -> Tensor:
  58. r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
  59. The spectrogram can be either magnitude-only or complex.
  60. .. devices:: CPU CUDA
  61. .. properties:: Autograd TorchScript
  62. Args:
  63. waveform (Tensor): Tensor of audio of dimension `(..., time)`
  64. pad (int): Two sided padding of signal
  65. window (Tensor): Window tensor that is applied/multiplied to each frame/window
  66. n_fft (int): Size of FFT
  67. hop_length (int): Length of hop between STFT windows
  68. win_length (int): Window size
  69. power (float or None): Exponent for the magnitude spectrogram,
  70. (must be > 0) e.g., 1 for energy, 2 for power, etc.
  71. If None, then the complex spectrum is returned instead.
  72. normalized (bool): Whether to normalize by magnitude after stft
  73. center (bool, optional): whether to pad :attr:`waveform` on both sides so
  74. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
  75. Default: ``True``
  76. pad_mode (string, optional): controls the padding method used when
  77. :attr:`center` is ``True``. Default: ``"reflect"``
  78. onesided (bool, optional): controls whether to return half of results to
  79. avoid redundancy. Default: ``True``
  80. return_complex (bool, optional):
  81. Deprecated and not used.
  82. Returns:
  83. Tensor: Dimension `(..., freq, time)`, freq is
  84. ``n_fft // 2 + 1`` and ``n_fft`` is the number of
  85. Fourier bins, and time is the number of window hops (n_frame).
  86. """
  87. if return_complex is not None:
  88. warnings.warn(
  89. "`return_complex` argument is now deprecated and is not effective."
  90. "`torchaudio.functional.spectrogram(power=None)` always returns a tensor with "
  91. "complex dtype. Please remove the argument in the function call."
  92. )
  93. if pad > 0:
  94. # TODO add "with torch.no_grad():" back when JIT supports it
  95. waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
  96. # pack batch
  97. shape = waveform.size()
  98. waveform = waveform.reshape(-1, shape[-1])
  99. # default values are consistent with librosa.core.spectrum._spectrogram
  100. spec_f = torch.stft(
  101. input=waveform,
  102. n_fft=n_fft,
  103. hop_length=hop_length,
  104. win_length=win_length,
  105. window=window,
  106. center=center,
  107. pad_mode=pad_mode,
  108. normalized=False,
  109. onesided=onesided,
  110. return_complex=True,
  111. )
  112. # unpack batch
  113. spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
  114. if normalized:
  115. spec_f /= window.pow(2.0).sum().sqrt()
  116. if power is not None:
  117. if power == 1.0:
  118. return spec_f.abs()
  119. return spec_f.abs().pow(power)
  120. return spec_f
  121. def inverse_spectrogram(
  122. spectrogram: Tensor,
  123. length: Optional[int],
  124. pad: int,
  125. window: Tensor,
  126. n_fft: int,
  127. hop_length: int,
  128. win_length: int,
  129. normalized: bool,
  130. center: bool = True,
  131. pad_mode: str = "reflect",
  132. onesided: bool = True,
  133. ) -> Tensor:
  134. r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided
  135. complex-valued spectrogram.
  136. .. devices:: CPU CUDA
  137. .. properties:: Autograd TorchScript
  138. Args:
  139. spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time).
  140. length (int or None): The output length of the waveform.
  141. pad (int): Two sided padding of signal. It is only effective when ``length`` is provided.
  142. window (Tensor): Window tensor that is applied/multiplied to each frame/window
  143. n_fft (int): Size of FFT
  144. hop_length (int): Length of hop between STFT windows
  145. win_length (int): Window size
  146. normalized (bool): Whether the stft output was normalized by magnitude
  147. center (bool, optional): whether the waveform was padded on both sides so
  148. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
  149. Default: ``True``
  150. pad_mode (string, optional): controls the padding method used when
  151. :attr:`center` is ``True``. This parameter is provided for compatibility with the
  152. spectrogram function and is not used. Default: ``"reflect"``
  153. onesided (bool, optional): controls whether spectrogram was done in onesided mode.
  154. Default: ``True``
  155. Returns:
  156. Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
  157. """
  158. if not spectrogram.is_complex():
  159. raise ValueError("Expected `spectrogram` to be complex dtype.")
  160. if normalized:
  161. spectrogram = spectrogram * window.pow(2.0).sum().sqrt()
  162. # pack batch
  163. shape = spectrogram.size()
  164. spectrogram = spectrogram.reshape(-1, shape[-2], shape[-1])
  165. # default values are consistent with librosa.core.spectrum._spectrogram
  166. waveform = torch.istft(
  167. input=spectrogram,
  168. n_fft=n_fft,
  169. hop_length=hop_length,
  170. win_length=win_length,
  171. window=window,
  172. center=center,
  173. normalized=False,
  174. onesided=onesided,
  175. length=length + 2 * pad if length is not None else None,
  176. return_complex=False,
  177. )
  178. if length is not None and pad > 0:
  179. # remove padding from front and back
  180. waveform = waveform[:, pad:-pad]
  181. # unpack batch
  182. waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
  183. return waveform
  184. def _get_complex_dtype(real_dtype: torch.dtype):
  185. if real_dtype == torch.double:
  186. return torch.cdouble
  187. if real_dtype == torch.float:
  188. return torch.cfloat
  189. if real_dtype == torch.half:
  190. return torch.complex32
  191. raise ValueError(f"Unexpected dtype {real_dtype}")
  192. def griffinlim(
  193. specgram: Tensor,
  194. window: Tensor,
  195. n_fft: int,
  196. hop_length: int,
  197. win_length: int,
  198. power: float,
  199. n_iter: int,
  200. momentum: float,
  201. length: Optional[int],
  202. rand_init: bool,
  203. ) -> Tensor:
  204. r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
  205. .. devices:: CPU CUDA
  206. .. properties:: Autograd TorchScript
  207. Implementation ported from
  208. *librosa* [:footcite:`brian_mcfee-proc-scipy-2015`], *A fast Griffin-Lim algorithm* [:footcite:`6701851`]
  209. and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
  210. Args:
  211. specgram (Tensor): A magnitude-only STFT spectrogram of dimension `(..., freq, frames)`
  212. where freq is ``n_fft // 2 + 1``.
  213. window (Tensor): Window tensor that is applied/multiplied to each frame/window
  214. n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
  215. hop_length (int): Length of hop between STFT windows. (
  216. Default: ``win_length // 2``)
  217. win_length (int): Window size. (Default: ``n_fft``)
  218. power (float): Exponent for the magnitude spectrogram,
  219. (must be > 0) e.g., 1 for energy, 2 for power, etc.
  220. n_iter (int): Number of iteration for phase recovery process.
  221. momentum (float): The momentum parameter for fast Griffin-Lim.
  222. Setting this to 0 recovers the original Griffin-Lim method.
  223. Values near 1 can lead to faster convergence, but above 1 may not converge.
  224. length (int or None): Array length of the expected output.
  225. rand_init (bool): Initializes phase randomly if True, to zero otherwise.
  226. Returns:
  227. Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
  228. """
  229. assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
  230. assert momentum >= 0, "momentum={} < 0".format(momentum)
  231. # pack batch
  232. shape = specgram.size()
  233. specgram = specgram.reshape([-1] + list(shape[-2:]))
  234. specgram = specgram.pow(1 / power)
  235. # initialize the phase
  236. if rand_init:
  237. angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
  238. else:
  239. angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
  240. # And initialize the previous iterate to 0
  241. tprev = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device)
  242. for _ in range(n_iter):
  243. # Invert with our current estimate of the phases
  244. inverse = torch.istft(
  245. specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
  246. )
  247. # Rebuild the spectrogram
  248. rebuilt = torch.stft(
  249. input=inverse,
  250. n_fft=n_fft,
  251. hop_length=hop_length,
  252. win_length=win_length,
  253. window=window,
  254. center=True,
  255. pad_mode="reflect",
  256. normalized=False,
  257. onesided=True,
  258. return_complex=True,
  259. )
  260. # Update our phase estimates
  261. angles = rebuilt
  262. if momentum:
  263. angles = angles - tprev.mul_(momentum / (1 + momentum))
  264. angles = angles.div(angles.abs().add(1e-16))
  265. # Store the previous iterate
  266. tprev = rebuilt
  267. # Return the final phase estimates
  268. waveform = torch.istft(
  269. specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
  270. )
  271. # unpack batch
  272. waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
  273. return waveform
  274. def amplitude_to_DB(
  275. x: Tensor, multiplier: float, amin: float, db_multiplier: float, top_db: Optional[float] = None
  276. ) -> Tensor:
  277. r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
  278. .. devices:: CPU CUDA
  279. .. properties:: Autograd TorchScript
  280. The output of each tensor in a batch depends on the maximum value of that tensor,
  281. and so may return different values for an audio clip split into snippets vs. a full clip.
  282. Args:
  283. x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
  284. the form `(..., freq, time)`. Batched inputs should include a channel dimension and
  285. have the form `(batch, channel, freq, time)`.
  286. multiplier (float): Use 10. for power and 20. for amplitude
  287. amin (float): Number to clamp ``x``
  288. db_multiplier (float): Log10(max(reference value and amin))
  289. top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
  290. is 80. (Default: ``None``)
  291. Returns:
  292. Tensor: Output tensor in decibel scale
  293. """
  294. x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
  295. x_db -= multiplier * db_multiplier
  296. if top_db is not None:
  297. # Expand batch
  298. shape = x_db.size()
  299. packed_channels = shape[-3] if x_db.dim() > 2 else 1
  300. x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])
  301. x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))
  302. # Repack batch
  303. x_db = x_db.reshape(shape)
  304. return x_db
  305. def DB_to_amplitude(x: Tensor, ref: float, power: float) -> Tensor:
  306. r"""Turn a tensor from the decibel scale to the power/amplitude scale.
  307. .. devices:: CPU CUDA
  308. .. properties:: TorchScript
  309. Args:
  310. x (Tensor): Input tensor before being converted to power/amplitude scale.
  311. ref (float): Reference which the output will be scaled by.
  312. power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.
  313. Returns:
  314. Tensor: Output tensor in power/amplitude scale.
  315. """
  316. return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)
  317. def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
  318. r"""Convert Hz to Mels.
  319. Args:
  320. freqs (float): Frequencies in Hz
  321. mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
  322. Returns:
  323. mels (float): Frequency in Mels
  324. """
  325. if mel_scale not in ["slaney", "htk"]:
  326. raise ValueError('mel_scale should be one of "htk" or "slaney".')
  327. if mel_scale == "htk":
  328. return 2595.0 * math.log10(1.0 + (freq / 700.0))
  329. # Fill in the linear part
  330. f_min = 0.0
  331. f_sp = 200.0 / 3
  332. mels = (freq - f_min) / f_sp
  333. # Fill in the log-scale part
  334. min_log_hz = 1000.0
  335. min_log_mel = (min_log_hz - f_min) / f_sp
  336. logstep = math.log(6.4) / 27.0
  337. if freq >= min_log_hz:
  338. mels = min_log_mel + math.log(freq / min_log_hz) / logstep
  339. return mels
  340. def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
  341. """Convert mel bin numbers to frequencies.
  342. Args:
  343. mels (Tensor): Mel frequencies
  344. mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
  345. Returns:
  346. freqs (Tensor): Mels converted in Hz
  347. """
  348. if mel_scale not in ["slaney", "htk"]:
  349. raise ValueError('mel_scale should be one of "htk" or "slaney".')
  350. if mel_scale == "htk":
  351. return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
  352. # Fill in the linear scale
  353. f_min = 0.0
  354. f_sp = 200.0 / 3
  355. freqs = f_min + f_sp * mels
  356. # And now the nonlinear scale
  357. min_log_hz = 1000.0
  358. min_log_mel = (min_log_hz - f_min) / f_sp
  359. logstep = math.log(6.4) / 27.0
  360. log_t = mels >= min_log_mel
  361. freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
  362. return freqs
  363. def _create_triangular_filterbank(
  364. all_freqs: Tensor,
  365. f_pts: Tensor,
  366. ) -> Tensor:
  367. """Create a triangular filter bank.
  368. Args:
  369. all_freqs (Tensor): STFT freq points of size (`n_freqs`).
  370. f_pts (Tensor): Filter mid points of size (`n_filter`).
  371. Returns:
  372. fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
  373. """
  374. # Adopted from Librosa
  375. # calculate the difference between each filter mid point and each stft freq point in hertz
  376. f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
  377. slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
  378. # create overlapping triangles
  379. zero = torch.zeros(1)
  380. down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
  381. up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
  382. fb = torch.max(zero, torch.min(down_slopes, up_slopes))
  383. return fb
  384. def melscale_fbanks(
  385. n_freqs: int,
  386. f_min: float,
  387. f_max: float,
  388. n_mels: int,
  389. sample_rate: int,
  390. norm: Optional[str] = None,
  391. mel_scale: str = "htk",
  392. ) -> Tensor:
  393. r"""Create a frequency bin conversion matrix.
  394. .. devices:: CPU
  395. .. properties:: TorchScript
  396. Note:
  397. For the sake of the numerical compatibility with librosa, not all the coefficients
  398. in the resulting filter bank has magnitude of 1.
  399. .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
  400. :alt: Visualization of generated filter bank
  401. Args:
  402. n_freqs (int): Number of frequencies to highlight/apply
  403. f_min (float): Minimum frequency (Hz)
  404. f_max (float): Maximum frequency (Hz)
  405. n_mels (int): Number of mel filterbanks
  406. sample_rate (int): Sample rate of the audio waveform
  407. norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
  408. (area normalization). (Default: ``None``)
  409. mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
  410. Returns:
  411. Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
  412. meaning number of frequencies to highlight/apply to x the number of filterbanks.
  413. Each column is a filterbank so that assuming there is a matrix A of
  414. size (..., ``n_freqs``), the applied result would be
  415. ``A * melscale_fbanks(A.size(-1), ...)``.
  416. """
  417. if norm is not None and norm != "slaney":
  418. raise ValueError("norm must be one of None or 'slaney'")
  419. # freq bins
  420. all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
  421. # calculate mel freq bins
  422. m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
  423. m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
  424. m_pts = torch.linspace(m_min, m_max, n_mels + 2)
  425. f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
  426. # create filterbank
  427. fb = _create_triangular_filterbank(all_freqs, f_pts)
  428. if norm is not None and norm == "slaney":
  429. # Slaney-style mel is scaled to be approx constant energy per channel
  430. enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
  431. fb *= enorm.unsqueeze(0)
  432. if (fb.max(dim=0).values == 0.0).any():
  433. warnings.warn(
  434. "At least one mel filterbank has all zero values. "
  435. f"The value for `n_mels` ({n_mels}) may be set too high. "
  436. f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
  437. )
  438. return fb
  439. def linear_fbanks(
  440. n_freqs: int,
  441. f_min: float,
  442. f_max: float,
  443. n_filter: int,
  444. sample_rate: int,
  445. ) -> Tensor:
  446. r"""Creates a linear triangular filterbank.
  447. .. devices:: CPU
  448. .. properties:: TorchScript
  449. Note:
  450. For the sake of the numerical compatibility with librosa, not all the coefficients
  451. in the resulting filter bank has magnitude of 1.
  452. .. image:: https://download.pytorch.org/torchaudio/doc-assets/lin_fbanks.png
  453. :alt: Visualization of generated filter bank
  454. Args:
  455. n_freqs (int): Number of frequencies to highlight/apply
  456. f_min (float): Minimum frequency (Hz)
  457. f_max (float): Maximum frequency (Hz)
  458. n_filter (int): Number of (linear) triangular filter
  459. sample_rate (int): Sample rate of the audio waveform
  460. Returns:
  461. Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_filter``)
  462. meaning number of frequencies to highlight/apply to x the number of filterbanks.
  463. Each column is a filterbank so that assuming there is a matrix A of
  464. size (..., ``n_freqs``), the applied result would be
  465. ``A * linear_fbanks(A.size(-1), ...)``.
  466. """
  467. # freq bins
  468. all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
  469. # filter mid-points
  470. f_pts = torch.linspace(f_min, f_max, n_filter + 2)
  471. # create filterbank
  472. fb = _create_triangular_filterbank(all_freqs, f_pts)
  473. return fb
  474. def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
  475. r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
  476. normalized depending on norm.
  477. .. devices:: CPU
  478. .. properties:: TorchScript
  479. Args:
  480. n_mfcc (int): Number of mfc coefficients to retain
  481. n_mels (int): Number of mel filterbanks
  482. norm (str or None): Norm to use (either 'ortho' or None)
  483. Returns:
  484. Tensor: The transformation matrix, to be right-multiplied to
  485. row-wise data of size (``n_mels``, ``n_mfcc``).
  486. """
  487. # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
  488. n = torch.arange(float(n_mels))
  489. k = torch.arange(float(n_mfcc)).unsqueeze(1)
  490. dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
  491. if norm is None:
  492. dct *= 2.0
  493. else:
  494. assert norm == "ortho"
  495. dct[0] *= 1.0 / math.sqrt(2.0)
  496. dct *= math.sqrt(2.0 / float(n_mels))
  497. return dct.t()
  498. def mu_law_encoding(x: Tensor, quantization_channels: int) -> Tensor:
  499. r"""Encode signal based on mu-law companding.
  500. .. devices:: CPU CUDA
  501. .. properties:: TorchScript
  502. For more info see the
  503. `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
  504. This algorithm expects the signal has been scaled to between -1 and 1 and
  505. returns a signal encoded with values from 0 to quantization_channels - 1.
  506. Args:
  507. x (Tensor): Input tensor
  508. quantization_channels (int): Number of channels
  509. Returns:
  510. Tensor: Input after mu-law encoding
  511. """
  512. mu = quantization_channels - 1.0
  513. if not x.is_floating_point():
  514. warnings.warn(
  515. "The input Tensor must be of floating type. \
  516. This will be an error in the v0.12 release."
  517. )
  518. x = x.to(torch.float)
  519. mu = torch.tensor(mu, dtype=x.dtype)
  520. x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
  521. x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
  522. return x_mu
  523. def mu_law_decoding(x_mu: Tensor, quantization_channels: int) -> Tensor:
  524. r"""Decode mu-law encoded signal.
  525. .. devices:: CPU CUDA
  526. .. properties:: TorchScript
  527. For more info see the
  528. `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
  529. This expects an input with values between 0 and quantization_channels - 1
  530. and returns a signal scaled between -1 and 1.
  531. Args:
  532. x_mu (Tensor): Input tensor
  533. quantization_channels (int): Number of channels
  534. Returns:
  535. Tensor: Input after mu-law decoding
  536. """
  537. mu = quantization_channels - 1.0
  538. if not x_mu.is_floating_point():
  539. x_mu = x_mu.to(torch.float)
  540. mu = torch.tensor(mu, dtype=x_mu.dtype)
  541. x = ((x_mu) / mu) * 2 - 1.0
  542. x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
  543. return x
  544. def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor) -> Tensor:
  545. r"""Given a STFT tensor, speed up in time without modifying pitch by a factor of ``rate``.
  546. .. devices:: CPU CUDA
  547. .. properties:: Autograd TorchScript
  548. Args:
  549. complex_specgrams (Tensor):
  550. A tensor of dimension `(..., freq, num_frame)` with complex dtype.
  551. rate (float): Speed-up factor
  552. phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`
  553. Returns:
  554. Tensor:
  555. Stretched spectrogram. The resulting tensor is of the same dtype as the input
  556. spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
  557. Example
  558. >>> freq, hop_length = 1025, 512
  559. >>> # (channel, freq, time)
  560. >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
  561. >>> rate = 1.3 # Speed up by 30%
  562. >>> phase_advance = torch.linspace(
  563. >>> 0, math.pi * hop_length, freq)[..., None]
  564. >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
  565. >>> x.shape # with 231 == ceil(300 / 1.3)
  566. torch.Size([2, 1025, 231])
  567. """
  568. if rate == 1.0:
  569. return complex_specgrams
  570. # pack batch
  571. shape = complex_specgrams.size()
  572. complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
  573. # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
  574. # Note torch.real is a view so it does not incur any memory copy.
  575. real_dtype = torch.real(complex_specgrams).dtype
  576. time_steps = torch.arange(0, complex_specgrams.size(-1), rate, device=complex_specgrams.device, dtype=real_dtype)
  577. alphas = time_steps % 1.0
  578. phase_0 = complex_specgrams[..., :1].angle()
  579. # Time Padding
  580. complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
  581. # (new_bins, freq, 2)
  582. complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
  583. complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
  584. angle_0 = complex_specgrams_0.angle()
  585. angle_1 = complex_specgrams_1.angle()
  586. norm_0 = complex_specgrams_0.abs()
  587. norm_1 = complex_specgrams_1.abs()
  588. phase = angle_1 - angle_0 - phase_advance
  589. phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))
  590. # Compute Phase Accum
  591. phase = phase + phase_advance
  592. phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
  593. phase_acc = torch.cumsum(phase, -1)
  594. mag = alphas * norm_1 + (1 - alphas) * norm_0
  595. complex_specgrams_stretch = torch.polar(mag, phase_acc)
  596. # unpack batch
  597. complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
  598. return complex_specgrams_stretch
  599. def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int:
  600. if p == 1.0:
  601. return mask_param
  602. else:
  603. return min(mask_param, int(axis_length * p))
  604. def mask_along_axis_iid(
  605. specgrams: Tensor,
  606. mask_param: int,
  607. mask_value: float,
  608. axis: int,
  609. p: float = 1.0,
  610. ) -> Tensor:
  611. r"""Apply a mask along ``axis``.
  612. .. devices:: CPU CUDA
  613. .. properties:: Autograd TorchScript
  614. Mask will be applied from indices ``[v_0, v_0 + v)``,
  615. where ``v`` is sampled from ``uniform(0, max_v)`` and
  616. ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``,
  617. with ``max_v = mask_param`` when ``p = 1.0`` and
  618. ``max_v = min(mask_param, floor(specgrams.size(axis) * p))`` otherwise.
  619. Args:
  620. specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)`
  621. mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
  622. mask_value (float): Value to assign to the masked columns
  623. axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
  624. p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
  625. Returns:
  626. Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)`
  627. """
  628. if axis not in [2, 3]:
  629. raise ValueError("Only Frequency and Time masking are supported")
  630. if not 0.0 <= p <= 1.0:
  631. raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
  632. mask_param = _get_mask_param(mask_param, p, specgrams.shape[axis])
  633. if mask_param < 1:
  634. return specgrams
  635. device = specgrams.device
  636. dtype = specgrams.dtype
  637. value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
  638. min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
  639. # Create broadcastable mask
  640. mask_start = min_value.long()[..., None, None]
  641. mask_end = (min_value.long() + value.long())[..., None, None]
  642. mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
  643. # Per batch example masking
  644. specgrams = specgrams.transpose(axis, -1)
  645. specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
  646. specgrams = specgrams.transpose(axis, -1)
  647. return specgrams
  648. def mask_along_axis(
  649. specgram: Tensor,
  650. mask_param: int,
  651. mask_value: float,
  652. axis: int,
  653. p: float = 1.0,
  654. ) -> Tensor:
  655. r"""Apply a mask along ``axis``.
  656. .. devices:: CPU CUDA
  657. .. properties:: Autograd TorchScript
  658. Mask will be applied from indices ``[v_0, v_0 + v)``,
  659. where ``v`` is sampled from ``uniform(0, max_v)`` and
  660. ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
  661. ``max_v = mask_param`` when ``p = 1.0`` and
  662. ``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
  663. otherwise.
  664. All examples will have the same mask interval.
  665. Args:
  666. specgram (Tensor): Real spectrogram `(channel, freq, time)`
  667. mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
  668. mask_value (float): Value to assign to the masked columns
  669. axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
  670. p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
  671. Returns:
  672. Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
  673. """
  674. if axis not in [1, 2]:
  675. raise ValueError("Only Frequency and Time masking are supported")
  676. if not 0.0 <= p <= 1.0:
  677. raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
  678. mask_param = _get_mask_param(mask_param, p, specgram.shape[axis])
  679. if mask_param < 1:
  680. return specgram
  681. # pack batch
  682. shape = specgram.size()
  683. specgram = specgram.reshape([-1] + list(shape[-2:]))
  684. value = torch.rand(1) * mask_param
  685. min_value = torch.rand(1) * (specgram.size(axis) - value)
  686. mask_start = (min_value.long()).squeeze()
  687. mask_end = (min_value.long() + value.long()).squeeze()
  688. mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype)
  689. mask = (mask >= mask_start) & (mask < mask_end)
  690. if axis == 1:
  691. mask = mask.unsqueeze(-1)
  692. assert mask_end - mask_start < mask_param
  693. specgram = specgram.masked_fill(mask, mask_value)
  694. # unpack batch
  695. specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
  696. return specgram
  697. def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate") -> Tensor:
  698. r"""Compute delta coefficients of a tensor, usually a spectrogram:
  699. .. devices:: CPU CUDA
  700. .. properties:: TorchScript
  701. .. math::
  702. d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
  703. where :math:`d_t` is the deltas at time :math:`t`,
  704. :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
  705. :math:`N` is ``(win_length-1)//2``.
  706. Args:
  707. specgram (Tensor): Tensor of audio of dimension `(..., freq, time)`
  708. win_length (int, optional): The window length used for computing delta (Default: ``5``)
  709. mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
  710. Returns:
  711. Tensor: Tensor of deltas of dimension `(..., freq, time)`
  712. Example
  713. >>> specgram = torch.randn(1, 40, 1000)
  714. >>> delta = compute_deltas(specgram)
  715. >>> delta2 = compute_deltas(delta)
  716. """
  717. device = specgram.device
  718. dtype = specgram.dtype
  719. # pack batch
  720. shape = specgram.size()
  721. specgram = specgram.reshape(1, -1, shape[-1])
  722. assert win_length >= 3
  723. n = (win_length - 1) // 2
  724. # twice sum of integer squared
  725. denom = n * (n + 1) * (2 * n + 1) / 3
  726. specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
  727. kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
  728. output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
  729. # unpack batch
  730. output = output.reshape(shape)
  731. return output
  732. def _compute_nccf(waveform: Tensor, sample_rate: int, frame_time: float, freq_low: int) -> Tensor:
  733. r"""
  734. Compute Normalized Cross-Correlation Function (NCCF).
  735. .. math::
  736. \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},
  737. where
  738. :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
  739. :math:`w` is the waveform,
  740. :math:`N` is the length of a frame,
  741. :math:`b_i` is the beginning of frame :math:`i`,
  742. :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
  743. """
  744. EPSILON = 10 ** (-9)
  745. # Number of lags to check
  746. lags = int(math.ceil(sample_rate / freq_low))
  747. frame_size = int(math.ceil(sample_rate * frame_time))
  748. waveform_length = waveform.size()[-1]
  749. num_of_frames = int(math.ceil(waveform_length / frame_size))
  750. p = lags + num_of_frames * frame_size - waveform_length
  751. waveform = torch.nn.functional.pad(waveform, (0, p))
  752. # Compute lags
  753. output_lag = []
  754. for lag in range(1, lags + 1):
  755. s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
  756. s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
  757. output_frames = (
  758. (s1 * s2).sum(-1)
  759. / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
  760. / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
  761. )
  762. output_lag.append(output_frames.unsqueeze(-1))
  763. nccf = torch.cat(output_lag, -1)
  764. return nccf
  765. def _combine_max(a: Tuple[Tensor, Tensor], b: Tuple[Tensor, Tensor], thresh: float = 0.99) -> Tuple[Tensor, Tensor]:
  766. """
  767. Take value from first if bigger than a multiplicative factor of the second, elementwise.
  768. """
  769. mask = a[0] > thresh * b[0]
  770. values = mask * a[0] + ~mask * b[0]
  771. indices = mask * a[1] + ~mask * b[1]
  772. return values, indices
  773. def _find_max_per_frame(nccf: Tensor, sample_rate: int, freq_high: int) -> Tensor:
  774. r"""
  775. For each frame, take the highest value of NCCF,
  776. apply centered median smoothing, and convert to frequency.
  777. Note: If the max among all the lags is very close
  778. to the first half of lags, then the latter is taken.
  779. """
  780. lag_min = int(math.ceil(sample_rate / freq_high))
  781. # Find near enough max that is smallest
  782. best = torch.max(nccf[..., lag_min:], -1)
  783. half_size = nccf.shape[-1] // 2
  784. half = torch.max(nccf[..., lag_min:half_size], -1)
  785. best = _combine_max(half, best)
  786. indices = best[1]
  787. # Add back minimal lag
  788. indices += lag_min
  789. # Add 1 empirical calibration offset
  790. indices += 1
  791. return indices
  792. def _median_smoothing(indices: Tensor, win_length: int) -> Tensor:
  793. r"""
  794. Apply median smoothing to the 1D tensor over the given window.
  795. """
  796. # Centered windowed
  797. pad_length = (win_length - 1) // 2
  798. # "replicate" padding in any dimension
  799. indices = torch.nn.functional.pad(indices, (pad_length, 0), mode="constant", value=0.0)
  800. indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
  801. roll = indices.unfold(-1, win_length, 1)
  802. values, _ = torch.median(roll, -1)
  803. return values
  804. def detect_pitch_frequency(
  805. waveform: Tensor,
  806. sample_rate: int,
  807. frame_time: float = 10 ** (-2),
  808. win_length: int = 30,
  809. freq_low: int = 85,
  810. freq_high: int = 3400,
  811. ) -> Tensor:
  812. r"""Detect pitch frequency.
  813. .. devices:: CPU CUDA
  814. .. properties:: TorchScript
  815. It is implemented using normalized cross-correlation function and median smoothing.
  816. Args:
  817. waveform (Tensor): Tensor of audio of dimension `(..., freq, time)`
  818. sample_rate (int): The sample rate of the waveform (Hz)
  819. frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
  820. win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
  821. freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
  822. freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
  823. Returns:
  824. Tensor: Tensor of freq of dimension `(..., frame)`
  825. """
  826. # pack batch
  827. shape = list(waveform.size())
  828. waveform = waveform.reshape([-1] + shape[-1:])
  829. nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
  830. indices = _find_max_per_frame(nccf, sample_rate, freq_high)
  831. indices = _median_smoothing(indices, win_length)
  832. # Convert indices to frequency
  833. EPSILON = 10 ** (-9)
  834. freq = sample_rate / (EPSILON + indices.to(torch.float))
  835. # unpack batch
  836. freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
  837. return freq
  838. def sliding_window_cmn(
  839. specgram: Tensor,
  840. cmn_window: int = 600,
  841. min_cmn_window: int = 100,
  842. center: bool = False,
  843. norm_vars: bool = False,
  844. ) -> Tensor:
  845. r"""
  846. Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
  847. .. devices:: CPU CUDA
  848. .. properties:: TorchScript
  849. Args:
  850. specgram (Tensor): Tensor of spectrogram of dimension `(..., time, freq)`
  851. cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
  852. min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
  853. Only applicable if center == false, ignored if center==true (int, default = 100)
  854. center (bool, optional): If true, use a window centered on the current frame
  855. (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
  856. norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
  857. Returns:
  858. Tensor: Tensor matching input shape `(..., freq, time)`
  859. """
  860. input_shape = specgram.shape
  861. num_frames, num_feats = input_shape[-2:]
  862. specgram = specgram.view(-1, num_frames, num_feats)
  863. num_channels = specgram.shape[0]
  864. dtype = specgram.dtype
  865. device = specgram.device
  866. last_window_start = last_window_end = -1
  867. cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
  868. cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
  869. cmn_specgram = torch.zeros(num_channels, num_frames, num_feats, dtype=dtype, device=device)
  870. for t in range(num_frames):
  871. window_start = 0
  872. window_end = 0
  873. if center:
  874. window_start = t - cmn_window // 2
  875. window_end = window_start + cmn_window
  876. else:
  877. window_start = t - cmn_window
  878. window_end = t + 1
  879. if window_start < 0:
  880. window_end -= window_start
  881. window_start = 0
  882. if not center:
  883. if window_end > t:
  884. window_end = max(t + 1, min_cmn_window)
  885. if window_end > num_frames:
  886. window_start -= window_end - num_frames
  887. window_end = num_frames
  888. if window_start < 0:
  889. window_start = 0
  890. if last_window_start == -1:
  891. input_part = specgram[:, window_start : window_end - window_start, :]
  892. cur_sum += torch.sum(input_part, 1)
  893. if norm_vars:
  894. cur_sumsq += torch.cumsum(input_part**2, 1)[:, -1, :]
  895. else:
  896. if window_start > last_window_start:
  897. frame_to_remove = specgram[:, last_window_start, :]
  898. cur_sum -= frame_to_remove
  899. if norm_vars:
  900. cur_sumsq -= frame_to_remove**2
  901. if window_end > last_window_end:
  902. frame_to_add = specgram[:, last_window_end, :]
  903. cur_sum += frame_to_add
  904. if norm_vars:
  905. cur_sumsq += frame_to_add**2
  906. window_frames = window_end - window_start
  907. last_window_start = window_start
  908. last_window_end = window_end
  909. cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
  910. if norm_vars:
  911. if window_frames == 1:
  912. cmn_specgram[:, t, :] = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
  913. else:
  914. variance = cur_sumsq
  915. variance = variance / window_frames
  916. variance -= (cur_sum**2) / (window_frames**2)
  917. variance = torch.pow(variance, -0.5)
  918. cmn_specgram[:, t, :] *= variance
  919. cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
  920. if len(input_shape) == 2:
  921. cmn_specgram = cmn_specgram.squeeze(0)
  922. return cmn_specgram
  923. def spectral_centroid(
  924. waveform: Tensor,
  925. sample_rate: int,
  926. pad: int,
  927. window: Tensor,
  928. n_fft: int,
  929. hop_length: int,
  930. win_length: int,
  931. ) -> Tensor:
  932. r"""Compute the spectral centroid for each channel along the time axis.
  933. .. devices:: CPU CUDA
  934. .. properties:: Autograd TorchScript
  935. The spectral centroid is defined as the weighted average of the
  936. frequency values, weighted by their magnitude.
  937. Args:
  938. waveform (Tensor): Tensor of audio of dimension `(..., time)`
  939. sample_rate (int): Sample rate of the audio waveform
  940. pad (int): Two sided padding of signal
  941. window (Tensor): Window tensor that is applied/multiplied to each frame/window
  942. n_fft (int): Size of FFT
  943. hop_length (int): Length of hop between STFT windows
  944. win_length (int): Window size
  945. Returns:
  946. Tensor: Dimension `(..., time)`
  947. """
  948. specgram = spectrogram(
  949. waveform,
  950. pad=pad,
  951. window=window,
  952. n_fft=n_fft,
  953. hop_length=hop_length,
  954. win_length=win_length,
  955. power=1.0,
  956. normalized=False,
  957. )
  958. freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, device=specgram.device).reshape((-1, 1))
  959. freq_dim = -2
  960. return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
  961. @_mod_utils.requires_sox()
  962. def apply_codec(
  963. waveform: Tensor,
  964. sample_rate: int,
  965. format: str,
  966. channels_first: bool = True,
  967. compression: Optional[float] = None,
  968. encoding: Optional[str] = None,
  969. bits_per_sample: Optional[int] = None,
  970. ) -> Tensor:
  971. r"""
  972. Apply codecs as a form of augmentation.
  973. .. devices:: CPU
  974. Args:
  975. waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```.
  976. sample_rate (int): Sample rate of the audio waveform.
  977. format (str): File format.
  978. channels_first (bool, optional):
  979. When True, both the input and output Tensor have dimension `(channel, time)`.
  980. Otherwise, they have dimension `(time, channel)`.
  981. compression (float or None, optional): Used for formats other than WAV.
  982. For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
  983. encoding (str or None, optional): Changes the encoding for the supported formats.
  984. For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
  985. bits_per_sample (int or None, optional): Changes the bit depth for the supported formats.
  986. For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
  987. Returns:
  988. Tensor: Resulting Tensor.
  989. If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
  990. """
  991. bytes = io.BytesIO()
  992. torchaudio.backend.sox_io_backend.save(
  993. bytes, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
  994. )
  995. bytes.seek(0)
  996. augmented, sr = torchaudio.backend.sox_io_backend.load(bytes, channels_first=channels_first, format=format)
  997. if sr != sample_rate:
  998. augmented = resample(augmented, sr, sample_rate)
  999. return augmented
  1000. @_mod_utils.requires_kaldi()
  1001. def compute_kaldi_pitch(
  1002. waveform: torch.Tensor,
  1003. sample_rate: float,
  1004. frame_length: float = 25.0,
  1005. frame_shift: float = 10.0,
  1006. min_f0: float = 50,
  1007. max_f0: float = 400,
  1008. soft_min_f0: float = 10.0,
  1009. penalty_factor: float = 0.1,
  1010. lowpass_cutoff: float = 1000,
  1011. resample_frequency: float = 4000,
  1012. delta_pitch: float = 0.005,
  1013. nccf_ballast: float = 7000,
  1014. lowpass_filter_width: int = 1,
  1015. upsample_filter_width: int = 5,
  1016. max_frames_latency: int = 0,
  1017. frames_per_chunk: int = 0,
  1018. simulate_first_pass_online: bool = False,
  1019. recompute_frame: int = 500,
  1020. snip_edges: bool = True,
  1021. ) -> torch.Tensor:
  1022. """Extract pitch based on method described in *A pitch extraction algorithm tuned
  1023. for automatic speech recognition* [:footcite:`6854049`].
  1024. .. devices:: CPU
  1025. .. properties:: TorchScript
  1026. This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.
  1027. Args:
  1028. waveform (Tensor):
  1029. The input waveform of shape `(..., time)`.
  1030. sample_rate (float):
  1031. Sample rate of `waveform`.
  1032. frame_length (float, optional):
  1033. Frame length in milliseconds. (default: 25.0)
  1034. frame_shift (float, optional):
  1035. Frame shift in milliseconds. (default: 10.0)
  1036. min_f0 (float, optional):
  1037. Minimum F0 to search for (Hz) (default: 50.0)
  1038. max_f0 (float, optional):
  1039. Maximum F0 to search for (Hz) (default: 400.0)
  1040. soft_min_f0 (float, optional):
  1041. Minimum f0, applied in soft way, must not exceed min-f0 (default: 10.0)
  1042. penalty_factor (float, optional):
  1043. Cost factor for FO change. (default: 0.1)
  1044. lowpass_cutoff (float, optional):
  1045. Cutoff frequency for LowPass filter (Hz) (default: 1000)
  1046. resample_frequency (float, optional):
  1047. Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
  1048. (default: 4000)
  1049. delta_pitch( float, optional):
  1050. Smallest relative change in pitch that our algorithm measures. (default: 0.005)
  1051. nccf_ballast (float, optional):
  1052. Increasing this factor reduces NCCF for quiet frames (default: 7000)
  1053. lowpass_filter_width (int, optional):
  1054. Integer that determines filter width of lowpass filter, more gives sharper filter.
  1055. (default: 1)
  1056. upsample_filter_width (int, optional):
  1057. Integer that determines filter width when upsampling NCCF. (default: 5)
  1058. max_frames_latency (int, optional):
  1059. Maximum number of frames of latency that we allow pitch tracking to introduce into
  1060. the feature processing (affects output only if ``frames_per_chunk > 0`` and
  1061. ``simulate_first_pass_online=True``) (default: 0)
  1062. frames_per_chunk (int, optional):
  1063. The number of frames used for energy normalization. (default: 0)
  1064. simulate_first_pass_online (bool, optional):
  1065. If true, the function will output features that correspond to what an online decoder
  1066. would see in the first pass of decoding -- not the final version of the features,
  1067. which is the default. (default: False)
  1068. Relevant if ``frames_per_chunk > 0``.
  1069. recompute_frame (int, optional):
  1070. Only relevant for compatibility with online pitch extraction.
  1071. A non-critical parameter; the frame at which we recompute some of the forward pointers,
  1072. after revising our estimate of the signal energy.
  1073. Relevant if ``frames_per_chunk > 0``. (default: 500)
  1074. snip_edges (bool, optional):
  1075. If this is set to false, the incomplete frames near the ending edge won't be snipped,
  1076. so that the number of frames is the file size divided by the frame-shift.
  1077. This makes different types of features give the same number of frames. (default: True)
  1078. Returns:
  1079. Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
  1080. corresponds to pitch and NCCF.
  1081. """
  1082. shape = waveform.shape
  1083. waveform = waveform.reshape(-1, shape[-1])
  1084. result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
  1085. waveform,
  1086. sample_rate,
  1087. frame_length,
  1088. frame_shift,
  1089. min_f0,
  1090. max_f0,
  1091. soft_min_f0,
  1092. penalty_factor,
  1093. lowpass_cutoff,
  1094. resample_frequency,
  1095. delta_pitch,
  1096. nccf_ballast,
  1097. lowpass_filter_width,
  1098. upsample_filter_width,
  1099. max_frames_latency,
  1100. frames_per_chunk,
  1101. simulate_first_pass_online,
  1102. recompute_frame,
  1103. snip_edges,
  1104. )
  1105. result = result.reshape(shape[:-1] + result.shape[-2:])
  1106. return result
  1107. def _get_sinc_resample_kernel(
  1108. orig_freq: int,
  1109. new_freq: int,
  1110. gcd: int,
  1111. lowpass_filter_width: int = 6,
  1112. rolloff: float = 0.99,
  1113. resampling_method: str = "sinc_interpolation",
  1114. beta: Optional[float] = None,
  1115. device: torch.device = torch.device("cpu"),
  1116. dtype: Optional[torch.dtype] = None,
  1117. ):
  1118. if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
  1119. raise Exception(
  1120. "Frequencies must be of integer type to ensure quality resampling computation. "
  1121. "To work around this, manually convert both frequencies to integer values "
  1122. "that maintain their resampling rate ratio before passing them into the function. "
  1123. "Example: To downsample a 44100 hz waveform by a factor of 8, use "
  1124. "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. "
  1125. "For more information, please refer to https://github.com/pytorch/audio/issues/1487."
  1126. )
  1127. if resampling_method not in ["sinc_interpolation", "kaiser_window"]:
  1128. raise ValueError("Invalid resampling method: {}".format(resampling_method))
  1129. orig_freq = int(orig_freq) // gcd
  1130. new_freq = int(new_freq) // gcd
  1131. assert lowpass_filter_width > 0
  1132. kernels = []
  1133. base_freq = min(orig_freq, new_freq)
  1134. # This will perform antialiasing filtering by removing the highest frequencies.
  1135. # At first I thought I only needed this when downsampling, but when upsampling
  1136. # you will get edge artifacts without this, as the edge is equivalent to zero padding,
  1137. # which will add high freq artifacts.
  1138. base_freq *= rolloff
  1139. # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
  1140. # using the sinc interpolation formula:
  1141. # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
  1142. # We can then sample the function x(t) with a different sample rate:
  1143. # y[j] = x(j / new_freq)
  1144. # or,
  1145. # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
  1146. # We see here that y[j] is the convolution of x[i] with a specific filter, for which
  1147. # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
  1148. # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
  1149. # Indeed:
  1150. # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
  1151. # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
  1152. # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
  1153. # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
  1154. # This will explain the F.conv1d after, with a stride of orig_freq.
  1155. width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
  1156. # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
  1157. # they will have a lot of almost zero values to the left or to the right...
  1158. # There is probably a way to evaluate those filters more efficiently, but this is kept for
  1159. # future work.
  1160. idx_dtype = dtype if dtype is not None else torch.float64
  1161. idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)
  1162. for i in range(new_freq):
  1163. t = (-i / new_freq + idx / orig_freq) * base_freq
  1164. t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
  1165. # we do not use built in torch windows here as we need to evaluate the window
  1166. # at specific positions, not over a regular grid.
  1167. if resampling_method == "sinc_interpolation":
  1168. window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
  1169. else:
  1170. # kaiser_window
  1171. if beta is None:
  1172. beta = 14.769656459379492
  1173. beta_tensor = torch.tensor(float(beta))
  1174. window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
  1175. t *= math.pi
  1176. kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t)
  1177. kernel.mul_(window)
  1178. kernels.append(kernel)
  1179. scale = base_freq / orig_freq
  1180. kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
  1181. if dtype is None:
  1182. kernels = kernels.to(dtype=torch.float32)
  1183. return kernels, width
  1184. def _apply_sinc_resample_kernel(
  1185. waveform: Tensor,
  1186. orig_freq: int,
  1187. new_freq: int,
  1188. gcd: int,
  1189. kernel: Tensor,
  1190. width: int,
  1191. ):
  1192. if not waveform.is_floating_point():
  1193. raise TypeError(f"Expected floating point type for waveform tensor, but received {waveform.dtype}.")
  1194. orig_freq = int(orig_freq) // gcd
  1195. new_freq = int(new_freq) // gcd
  1196. # pack batch
  1197. shape = waveform.size()
  1198. waveform = waveform.view(-1, shape[-1])
  1199. num_wavs, length = waveform.shape
  1200. waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
  1201. resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
  1202. resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
  1203. target_length = int(math.ceil(new_freq * length / orig_freq))
  1204. resampled = resampled[..., :target_length]
  1205. # unpack batch
  1206. resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
  1207. return resampled
  1208. def resample(
  1209. waveform: Tensor,
  1210. orig_freq: int,
  1211. new_freq: int,
  1212. lowpass_filter_width: int = 6,
  1213. rolloff: float = 0.99,
  1214. resampling_method: str = "sinc_interpolation",
  1215. beta: Optional[float] = None,
  1216. ) -> Tensor:
  1217. r"""Resamples the waveform at the new frequency using bandlimited interpolation. [:footcite:`RESAMPLE`].
  1218. .. devices:: CPU CUDA
  1219. .. properties:: Autograd TorchScript
  1220. Note:
  1221. ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
  1222. more efficient computation if resampling multiple waveforms with the same resampling parameters.
  1223. Args:
  1224. waveform (Tensor): The input signal of dimension `(..., time)`
  1225. orig_freq (int): The original frequency of the signal
  1226. new_freq (int): The desired frequency
  1227. lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
  1228. but less efficient. (Default: ``6``)
  1229. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
  1230. Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
  1231. resampling_method (str, optional): The resampling method to use.
  1232. Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
  1233. beta (float or None, optional): The shape parameter used for kaiser window.
  1234. Returns:
  1235. Tensor: The waveform at the new frequency of dimension `(..., time).`
  1236. """
  1237. assert orig_freq > 0.0 and new_freq > 0.0
  1238. if orig_freq == new_freq:
  1239. return waveform
  1240. gcd = math.gcd(int(orig_freq), int(new_freq))
  1241. kernel, width = _get_sinc_resample_kernel(
  1242. orig_freq,
  1243. new_freq,
  1244. gcd,
  1245. lowpass_filter_width,
  1246. rolloff,
  1247. resampling_method,
  1248. beta,
  1249. waveform.device,
  1250. waveform.dtype,
  1251. )
  1252. resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
  1253. return resampled
  1254. @torch.jit.unused
  1255. def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
  1256. """
  1257. Calculate the word level edit (Levenshtein) distance between two sequences.
  1258. .. devices:: CPU
  1259. The function computes an edit distance allowing deletion, insertion and
  1260. substitution. The result is an integer.
  1261. For most applications, the two input sequences should be the same type. If
  1262. two strings are given, the output is the edit distance between the two
  1263. strings (character edit distance). If two lists of strings are given, the
  1264. output is the edit distance between sentences (word edit distance). Users
  1265. may want to normalize the output by the length of the reference sequence.
  1266. Args:
  1267. seq1 (Sequence): the first sequence to compare.
  1268. seq2 (Sequence): the second sequence to compare.
  1269. Returns:
  1270. int: The distance between the first and second sequences.
  1271. """
  1272. len_sent2 = len(seq2)
  1273. dold = list(range(len_sent2 + 1))
  1274. dnew = [0 for _ in range(len_sent2 + 1)]
  1275. for i in range(1, len(seq1) + 1):
  1276. dnew[0] = i
  1277. for j in range(1, len_sent2 + 1):
  1278. if seq1[i - 1] == seq2[j - 1]:
  1279. dnew[j] = dold[j - 1]
  1280. else:
  1281. substitution = dold[j - 1] + 1
  1282. insertion = dnew[j - 1] + 1
  1283. deletion = dold[j] + 1
  1284. dnew[j] = min(substitution, insertion, deletion)
  1285. dnew, dold = dold, dnew
  1286. return int(dold[-1])
  1287. def pitch_shift(
  1288. waveform: Tensor,
  1289. sample_rate: int,
  1290. n_steps: int,
  1291. bins_per_octave: int = 12,
  1292. n_fft: int = 512,
  1293. win_length: Optional[int] = None,
  1294. hop_length: Optional[int] = None,
  1295. window: Optional[Tensor] = None,
  1296. ) -> Tensor:
  1297. """
  1298. Shift the pitch of a waveform by ``n_steps`` steps.
  1299. .. devices:: CPU CUDA
  1300. .. properties:: TorchScript
  1301. Args:
  1302. waveform (Tensor): The input waveform of shape `(..., time)`.
  1303. sample_rate (int): Sample rate of `waveform`.
  1304. n_steps (int): The (fractional) steps to shift `waveform`.
  1305. bins_per_octave (int, optional): The number of steps per octave (Default: ``12``).
  1306. n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
  1307. win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
  1308. hop_length (int or None, optional): Length of hop between STFT windows. If None, then
  1309. ``win_length // 4`` is used (Default: ``None``).
  1310. window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
  1311. If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
  1312. Returns:
  1313. Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
  1314. """
  1315. waveform_stretch = _stretch_waveform(
  1316. waveform,
  1317. n_steps,
  1318. bins_per_octave,
  1319. n_fft,
  1320. win_length,
  1321. hop_length,
  1322. window,
  1323. )
  1324. rate = 2.0 ** (-float(n_steps) / bins_per_octave)
  1325. waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
  1326. return _fix_waveform_shape(waveform_shift, waveform.size())
  1327. def _stretch_waveform(
  1328. waveform: Tensor,
  1329. n_steps: int,
  1330. bins_per_octave: int = 12,
  1331. n_fft: int = 512,
  1332. win_length: Optional[int] = None,
  1333. hop_length: Optional[int] = None,
  1334. window: Optional[Tensor] = None,
  1335. ) -> Tensor:
  1336. """
  1337. Pitch shift helper function to preprocess and stretch waveform before resampling step.
  1338. Args:
  1339. See pitch_shift arg descriptions.
  1340. Returns:
  1341. Tensor: The preprocessed waveform stretched prior to resampling.
  1342. """
  1343. if hop_length is None:
  1344. hop_length = n_fft // 4
  1345. if win_length is None:
  1346. win_length = n_fft
  1347. if window is None:
  1348. window = torch.hann_window(window_length=win_length, device=waveform.device)
  1349. # pack batch
  1350. shape = waveform.size()
  1351. waveform = waveform.reshape(-1, shape[-1])
  1352. ori_len = shape[-1]
  1353. rate = 2.0 ** (-float(n_steps) / bins_per_octave)
  1354. spec_f = torch.stft(
  1355. input=waveform,
  1356. n_fft=n_fft,
  1357. hop_length=hop_length,
  1358. win_length=win_length,
  1359. window=window,
  1360. center=True,
  1361. pad_mode="reflect",
  1362. normalized=False,
  1363. onesided=True,
  1364. return_complex=True,
  1365. )
  1366. phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None]
  1367. spec_stretch = phase_vocoder(spec_f, rate, phase_advance)
  1368. len_stretch = int(round(ori_len / rate))
  1369. waveform_stretch = torch.istft(
  1370. spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
  1371. )
  1372. return waveform_stretch
  1373. def _fix_waveform_shape(
  1374. waveform_shift: Tensor,
  1375. shape: List[int],
  1376. ) -> Tensor:
  1377. """
  1378. PitchShift helper function to process after resampling step to fix the shape back.
  1379. Args:
  1380. waveform_shift(Tensor): The waveform after stretch and resample
  1381. shape (List[int]): The shape of initial waveform
  1382. Returns:
  1383. Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
  1384. """
  1385. ori_len = shape[-1]
  1386. shift_len = waveform_shift.size()[-1]
  1387. if shift_len > ori_len:
  1388. waveform_shift = waveform_shift[..., :ori_len]
  1389. else:
  1390. waveform_shift = torch.nn.functional.pad(waveform_shift, [0, ori_len - shift_len])
  1391. # unpack batch
  1392. waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
  1393. return waveform_shift
  1394. def rnnt_loss(
  1395. logits: Tensor,
  1396. targets: Tensor,
  1397. logit_lengths: Tensor,
  1398. target_lengths: Tensor,
  1399. blank: int = -1,
  1400. clamp: float = -1,
  1401. reduction: str = "mean",
  1402. ):
  1403. """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
  1404. [:footcite:`graves2012sequence`].
  1405. .. devices:: CPU CUDA
  1406. .. properties:: Autograd TorchScript
  1407. The RNN Transducer loss extends the CTC loss by defining a distribution over output
  1408. sequences of all lengths, and by jointly modelling both input-output and output-output
  1409. dependencies.
  1410. Args:
  1411. logits (Tensor): Tensor of dimension `(batch, max seq length, max target length + 1, class)`
  1412. containing output from joiner
  1413. targets (Tensor): Tensor of dimension `(batch, max target length)` containing targets with zero padded
  1414. logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder
  1415. target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence
  1416. blank (int, optional): blank label (Default: ``-1``)
  1417. clamp (float, optional): clamp for gradients (Default: ``-1``)
  1418. reduction (string, optional): Specifies the reduction to apply to the output:
  1419. ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
  1420. Returns:
  1421. Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size `(batch)`,
  1422. otherwise scalar.
  1423. """
  1424. if reduction not in ["none", "mean", "sum"]:
  1425. raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")
  1426. if blank < 0: # reinterpret blank index if blank < 0.
  1427. blank = logits.shape[-1] + blank
  1428. costs, _ = torch.ops.torchaudio.rnnt_loss(
  1429. logits=logits,
  1430. targets=targets,
  1431. logit_lengths=logit_lengths,
  1432. target_lengths=target_lengths,
  1433. blank=blank,
  1434. clamp=clamp,
  1435. )
  1436. if reduction == "mean":
  1437. return costs.mean()
  1438. elif reduction == "sum":
  1439. return costs.sum()
  1440. return costs
  1441. def psd(
  1442. specgram: Tensor,
  1443. mask: Optional[Tensor] = None,
  1444. normalize: bool = True,
  1445. eps: float = 1e-10,
  1446. ) -> Tensor:
  1447. """Compute cross-channel power spectral density (PSD) matrix.
  1448. .. devices:: CPU CUDA
  1449. .. properties:: Autograd TorchScript
  1450. Args:
  1451. specgram (torch.Tensor): Multi-channel complex-valued spectrum.
  1452. Tensor with dimensions `(..., channel, freq, time)`.
  1453. mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
  1454. Tensor with dimensions `(..., freq, time)`. (Default: ``None``)
  1455. normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
  1456. eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
  1457. Returns:
  1458. torch.Tensor: The complex-valued PSD matrix of the input spectrum.
  1459. Tensor with dimensions `(..., freq, channel, channel)`
  1460. """
  1461. specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
  1462. # outer product:
  1463. # (..., ch_1, time) x (..., ch_2, time) -> (..., time, ch_1, ch_2)
  1464. psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])
  1465. if mask is not None:
  1466. assert (
  1467. mask.shape[:-1] == specgram.shape[:-2] and mask.shape[-1] == specgram.shape[-1]
  1468. ), "The dimensions of mask except the channel dimension should be the same as specgram."
  1469. f"Found {mask.shape} for mask and {specgram.shape} for specgram."
  1470. # Normalized mask along time dimension:
  1471. if normalize:
  1472. mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
  1473. psd = psd * mask[..., None, None]
  1474. psd = psd.sum(dim=-3)
  1475. return psd
  1476. def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
  1477. r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
  1478. Args:
  1479. input (torch.Tensor): Tensor with dimensions `(..., channel, channel)`.
  1480. dim1 (int, optional): The first dimension of the diagonal matrix.
  1481. (Default: ``-1``)
  1482. dim2 (int, optional): The second dimension of the diagonal matrix.
  1483. (Default: ``-2``)
  1484. Returns:
  1485. Tensor: The trace of the input Tensor.
  1486. """
  1487. assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
  1488. assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
  1489. input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
  1490. return input.sum(dim=-1)
  1491. def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
  1492. """Perform Tikhonov regularization (only modifying real part).
  1493. Args:
  1494. mat (torch.Tensor): Input matrix with dimensions `(..., channel, channel)`.
  1495. reg (float, optional): Regularization factor. (Default: 1e-8)
  1496. eps (float, optional): Value to avoid the correlation matrix is all-zero. (Default: ``1e-8``)
  1497. Returns:
  1498. Tensor: Regularized matrix with dimensions `(..., channel, channel)`.
  1499. """
  1500. # Add eps
  1501. C = mat.size(-1)
  1502. eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
  1503. epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
  1504. # in case that correlation_matrix is all-zero
  1505. epsilon = epsilon + eps
  1506. mat = mat + epsilon * eye[..., :, :]
  1507. return mat
  1508. def _assert_psd_matrices(psd_s: torch.Tensor, psd_n: torch.Tensor) -> None:
  1509. """Assertion checks of the PSD matrices of target speech and noise.
  1510. Args:
  1511. psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
  1512. Tensor with dimensions `(..., freq, channel, channel)`.
  1513. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
  1514. Tensor with dimensions `(..., freq, channel, channel)`.
  1515. """
  1516. assert (
  1517. psd_s.ndim >= 3 and psd_n.ndim >= 3
  1518. ), "Expected at least 3D Tensor (..., freq, channel, channel) for psd_s and psd_n."
  1519. "Found {psd_s.shape} for psd_s and {psd_n.shape} for psd_n."
  1520. assert (
  1521. psd_s.is_complex() and psd_n.is_complex()
  1522. ), "The type of psd_s and psd_n must be ``torch.cfloat`` or ``torch.cdouble``."
  1523. f"Found {psd_s.dtype} for psd_s and {psd_n.dtype} for psd_n."
  1524. assert (
  1525. psd_s.shape == psd_n.shape
  1526. ), f"The dimensions of psd_s and psd_n should be the same. Found {psd_s.shape} and {psd_n.shape}."
  1527. assert (
  1528. psd_s.shape[-1] == psd_s.shape[-2]
  1529. ), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}."
  1530. def mvdr_weights_souden(
  1531. psd_s: Tensor,
  1532. psd_n: Tensor,
  1533. reference_channel: Union[int, Tensor],
  1534. diagonal_loading: bool = True,
  1535. diag_eps: float = 1e-7,
  1536. eps: float = 1e-8,
  1537. ) -> Tensor:
  1538. r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
  1539. by the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
  1540. .. devices:: CPU CUDA
  1541. .. properties:: Autograd TorchScript
  1542. Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`,
  1543. the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
  1544. reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
  1545. :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
  1546. .. math::
  1547. \textbf{w}_{\text{MVDR}}(f) =
  1548. \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
  1549. {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
  1550. Args:
  1551. psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
  1552. Tensor with dimensions `(..., freq, channel, channel)`.
  1553. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
  1554. Tensor with dimensions `(..., freq, channel, channel)`.
  1555. reference_channel (int or torch.Tensor): Specifies the reference channel.
  1556. If the dtype is ``int``, it represents the reference channel index.
  1557. If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
  1558. is one-hot.
  1559. diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
  1560. (Default: ``True``)
  1561. diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
  1562. It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
  1563. eps (float, optional): Value to add to the denominator in the beamforming weight formula.
  1564. (Default: ``1e-8``)
  1565. Returns:
  1566. torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
  1567. """
  1568. _assert_psd_matrices(psd_s, psd_n)
  1569. if diagonal_loading:
  1570. psd_n = _tik_reg(psd_n, reg=diag_eps)
  1571. numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
  1572. # ws: (..., C, C) / (...,) -> (..., C, C)
  1573. ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
  1574. if torch.jit.isinstance(reference_channel, int):
  1575. beamform_weights = ws[..., :, reference_channel]
  1576. elif torch.jit.isinstance(reference_channel, Tensor):
  1577. reference_channel = reference_channel.to(psd_n.dtype)
  1578. # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
  1579. beamform_weights = torch.einsum("...c,...c->...", [ws, reference_channel[..., None, None, :]])
  1580. else:
  1581. raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
  1582. return beamform_weights
  1583. def mvdr_weights_rtf(
  1584. rtf: Tensor,
  1585. psd_n: Tensor,
  1586. reference_channel: Optional[Union[int, Tensor]] = None,
  1587. diagonal_loading: bool = True,
  1588. diag_eps: float = 1e-7,
  1589. eps: float = 1e-8,
  1590. ) -> Tensor:
  1591. r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
  1592. based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
  1593. .. devices:: CPU CUDA
  1594. .. properties:: Autograd TorchScript
  1595. Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`,
  1596. the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
  1597. reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
  1598. :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
  1599. .. math::
  1600. \textbf{w}_{\text{MVDR}}(f) =
  1601. \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
  1602. {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
  1603. where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
  1604. Args:
  1605. rtf (torch.Tensor): The complex-valued RTF vector of target speech.
  1606. Tensor with dimensions `(..., freq, channel)`.
  1607. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
  1608. Tensor with dimensions `(..., freq, channel, channel)`.
  1609. reference_channel (int or torch.Tensor): Specifies the reference channel.
  1610. If the dtype is ``int``, it represents the reference channel index.
  1611. If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
  1612. is one-hot.
  1613. diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
  1614. (Default: ``True``)
  1615. diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
  1616. It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
  1617. eps (float, optional): Value to add to the denominator in the beamforming weight formula.
  1618. (Default: ``1e-8``)
  1619. Returns:
  1620. torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
  1621. """
  1622. assert rtf.ndim >= 2, f"Expected at least 2D Tensor (..., freq, channel) for rtf. Found {rtf.shape}."
  1623. assert psd_n.ndim >= 3, f"Expected at least 3D Tensor (..., freq, channel, channel) for psd_n. Found {psd_n.shape}."
  1624. assert (
  1625. rtf.is_complex() and psd_n.is_complex()
  1626. ), "The type of rtf and psd_n must be ``torch.cfloat`` or ``torch.cdouble``."
  1627. f"Found {rtf.dtype} for rtf and {psd_n.dtype} for psd_n."
  1628. assert (
  1629. rtf.shape == psd_n.shape[:-1]
  1630. ), "The dimensions of rtf and the dimensions withou the last dimension of psd_n should be the same."
  1631. f"Found {rtf.shape} for rtf and {psd_n.shape} for psd_n."
  1632. assert (
  1633. psd_n.shape[-1] == psd_n.shape[-2]
  1634. ), f"The last two dimensions of psd_n should be the same. Found {psd_n.shape}."
  1635. if diagonal_loading:
  1636. psd_n = _tik_reg(psd_n, reg=diag_eps)
  1637. # numerator = psd_n.inv() @ stv
  1638. numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel)
  1639. # denominator = stv^H @ psd_n.inv() @ stv
  1640. denominator = torch.einsum("...d,...d->...", [rtf.conj(), numerator])
  1641. beamform_weights = numerator / (denominator.real.unsqueeze(-1) + eps)
  1642. # normalize the numerator
  1643. if reference_channel is not None:
  1644. if torch.jit.isinstance(reference_channel, int):
  1645. scale = rtf[..., reference_channel].conj()
  1646. elif torch.jit.isinstance(reference_channel, Tensor):
  1647. reference_channel = reference_channel.to(psd_n.dtype)
  1648. scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]])
  1649. else:
  1650. raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
  1651. beamform_weights = beamform_weights * scale[..., None]
  1652. return beamform_weights
  1653. def rtf_evd(psd_s: Tensor) -> Tensor:
  1654. r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
  1655. .. devices:: CPU CUDA
  1656. .. properties:: TorchScript
  1657. Args:
  1658. psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
  1659. Tensor of dimension `(..., freq, channel, channel)`
  1660. Returns:
  1661. Tensor: The estimated complex-valued RTF of target speech.
  1662. Tensor of dimension `(..., freq, channel)`
  1663. """
  1664. assert psd_s.is_complex(), f"The type of psd_s must be ``torch.cfloat`` or ``torch.cdouble``. Found {psd_s.dtype}."
  1665. assert (
  1666. psd_s.shape[-1] == psd_s.shape[-2]
  1667. ), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}."
  1668. _, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
  1669. rtf = v[..., -1] # choose the eigenvector with max eigenvalue
  1670. return rtf
  1671. def rtf_power(
  1672. psd_s: Tensor,
  1673. psd_n: Tensor,
  1674. reference_channel: Union[int, Tensor],
  1675. n_iter: int = 3,
  1676. diagonal_loading: bool = True,
  1677. diag_eps: float = 1e-7,
  1678. ) -> Tensor:
  1679. r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
  1680. .. devices:: CPU CUDA
  1681. .. properties:: Autograd TorchScript
  1682. Args:
  1683. psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
  1684. Tensor with dimensions `(..., freq, channel, channel)`.
  1685. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
  1686. Tensor with dimensions `(..., freq, channel, channel)`.
  1687. reference_channel (int or torch.Tensor): Specifies the reference channel.
  1688. If the dtype is ``int``, it represents the reference channel index.
  1689. If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
  1690. is one-hot.
  1691. diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
  1692. (Default: ``True``)
  1693. diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
  1694. It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
  1695. Returns:
  1696. torch.Tensor: The estimated complex-valued RTF of target speech.
  1697. Tensor of dimension `(..., freq, channel)`.
  1698. """
  1699. _assert_psd_matrices(psd_s, psd_n)
  1700. assert n_iter > 0, "The number of iteration must be greater than 0."
  1701. # Apply diagonal loading to psd_n to improve robustness.
  1702. if diagonal_loading:
  1703. psd_n = _tik_reg(psd_n, reg=diag_eps)
  1704. # phi is regarded as the first iteration
  1705. phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
  1706. if torch.jit.isinstance(reference_channel, int):
  1707. rtf = phi[..., reference_channel]
  1708. elif torch.jit.isinstance(reference_channel, Tensor):
  1709. reference_channel = reference_channel.to(psd_n.dtype)
  1710. rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]])
  1711. else:
  1712. raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
  1713. rtf = rtf.unsqueeze(-1) # (..., freq, channel, 1)
  1714. if n_iter >= 2:
  1715. # The number of iterations in the for loop is `n_iter - 2`
  1716. # because the `phi` above and `torch.matmul(psd_s, rtf)` are regarded as
  1717. # two iterations.
  1718. for _ in range(n_iter - 2):
  1719. rtf = torch.matmul(phi, rtf)
  1720. rtf = torch.matmul(psd_s, rtf)
  1721. else:
  1722. # if there is only one iteration, the rtf is the psd_s[..., referenc_channel]
  1723. # which is psd_n @ phi @ ref_channel
  1724. rtf = torch.matmul(psd_n, rtf)
  1725. return rtf.squeeze(-1)
  1726. def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
  1727. r"""Apply the beamforming weight to the multi-channel noisy spectrum to obtain the single-channel enhanced spectrum.
  1728. .. devices:: CPU CUDA
  1729. .. properties:: Autograd TorchScript
  1730. .. math::
  1731. \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
  1732. where :math:`\textbf{w}_{\text{bf}}(f)` is the beamforming weight for the :math:`f`-th frequency bin,
  1733. :math:`\textbf{Y}` is the multi-channel spectrum for the :math:`f`-th frequency bin.
  1734. Args:
  1735. beamform_weights (Tensor): The complex-valued beamforming weight matrix.
  1736. Tensor of dimension `(..., freq, channel)`
  1737. specgram (Tensor): The multi-channel complex-valued noisy spectrum.
  1738. Tensor of dimension `(..., channel, freq, time)`
  1739. Returns:
  1740. Tensor: The single-channel complex-valued enhanced spectrum.
  1741. Tensor of dimension `(..., freq, time)`
  1742. """
  1743. assert (
  1744. beamform_weights.shape[:-2] == specgram.shape[:-3]
  1745. ), "The dimensions except the last two dimensions of beamform_weights should be the same "
  1746. "as the dimensions except the last three dimensions of specgram."
  1747. f"Found {beamform_weights.shape} for beamform_weights and {specgram.shape} for specgram."
  1748. assert (
  1749. beamform_weights.is_complex() and specgram.is_complex()
  1750. ), "The type of beamform_weights and specgram must be ``torch.cfloat`` or ``torch.cdouble``."
  1751. f"Found {beamform_weights.dtype} for beamform_weights and {specgram.dtype} for specgram."
  1752. # (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
  1753. specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
  1754. return specgram_enhanced