model.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from torch import Tensor
  4. from torch.nn import Module
  5. from . import components
  6. class Wav2Vec2Model(Module):
  7. """torchaudio.models.Wav2Vec2Model(feature_extractor: torch.nn.Module, encoder: torch.nn.Module, aux: Optional[torch.nn.Module] = None)
  8. Encoder model used in *wav2vec 2.0* [:footcite:`baevski2020wav2vec`].
  9. Note:
  10. To build the model, please use one of the factory functions.
  11. Args:
  12. feature_extractor (torch.nn.Module):
  13. Feature extractor that extracts feature vectors from raw audio Tensor.
  14. encoder (torch.nn.Module):
  15. Encoder that converts the audio features into the sequence of probability
  16. distribution (in negative log-likelihood) over labels.
  17. aux (torch.nn.Module or None, optional):
  18. Auxiliary module. If provided, the output from encoder is passed to this module.
  19. """ # noqa: E501
  20. def __init__(
  21. self,
  22. feature_extractor: Module,
  23. encoder: Module,
  24. aux: Optional[Module] = None,
  25. ):
  26. super().__init__()
  27. self.feature_extractor = feature_extractor
  28. self.encoder = encoder
  29. self.aux = aux
  30. @torch.jit.export
  31. def extract_features(
  32. self,
  33. waveforms: Tensor,
  34. lengths: Optional[Tensor] = None,
  35. num_layers: Optional[int] = None,
  36. ) -> Tuple[List[Tensor], Optional[Tensor]]:
  37. """Extract feature vectors from raw waveforms
  38. This returns the list of outputs from the intermediate layers of
  39. transformer block in encoder.
  40. Args:
  41. waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
  42. lengths (Tensor or None, optional):
  43. Indicates the valid length of each audio in the batch.
  44. Shape: `(batch, )`.
  45. When the ``waveforms`` contains audios with different durations,
  46. by providing ``lengths`` argument, the model will compute
  47. the corresponding valid output lengths and apply proper mask in
  48. transformer attention layer.
  49. If ``None``, it is assumed that the entire audio waveform
  50. length is valid.
  51. num_layers (int or None, optional):
  52. If given, limit the number of intermediate layers to go through.
  53. Providing `1` will stop the computation after going through one
  54. intermediate layers. If not given, the outputs from all the
  55. intermediate layers are returned.
  56. Returns:
  57. (List[Tensor], Optional[Tensor]):
  58. List of Tensors
  59. Features from requested layers.
  60. Each Tensor is of shape: `(batch, time frame, feature dimension)`
  61. Tensor or None
  62. If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
  63. is returned.
  64. It indicates the valid length in time axis of each feature Tensor.
  65. """
  66. x, lengths = self.feature_extractor(waveforms, lengths)
  67. x = self.encoder.extract_features(x, lengths, num_layers)
  68. return x, lengths
  69. def forward(
  70. self,
  71. waveforms: Tensor,
  72. lengths: Optional[Tensor] = None,
  73. ) -> Tuple[Tensor, Optional[Tensor]]:
  74. """Compute the sequence of probability distribution over labels.
  75. Args:
  76. waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
  77. lengths (Tensor or None, optional):
  78. Indicates the valid length of each audio in the batch.
  79. Shape: `(batch, )`.
  80. When the ``waveforms`` contains audios with different durations,
  81. by providing ``lengths`` argument, the model will compute
  82. the corresponding valid output lengths and apply proper mask in
  83. transformer attention layer.
  84. If ``None``, it is assumed that all the audio in ``waveforms``
  85. have valid length. Default: ``None``.
  86. Returns:
  87. (Tensor, Optional[Tensor]):
  88. Tensor
  89. The sequences of probability distribution (in logit) over labels.
  90. Shape: `(batch, frames, num labels)`.
  91. Tensor or None
  92. If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
  93. is returned.
  94. It indicates the valid length in time axis of the output Tensor.
  95. """
  96. x, lengths = self.feature_extractor(waveforms, lengths)
  97. x = self.encoder(x, lengths)
  98. if self.aux is not None:
  99. x = self.aux(x)
  100. return x, lengths
  101. class HuBERTPretrainModel(Module):
  102. """HuBERT pre-train model for training from scratch.
  103. Note:
  104. To build the model, please use one of the factory functions in
  105. `[hubert_pretrain_base, hubert_pretrain_large, hubert_pretrain_xlarge]`.
  106. Args:
  107. feature_extractor (torch.nn.Module):
  108. Feature extractor that extracts feature vectors from raw audio Tensor.
  109. encoder (torch.nn.Module):
  110. Encoder that converts the audio features into the sequence of probability
  111. distribution (in negative log-likelihood) over labels.
  112. mask_generator (torch.nn.Module):
  113. Mask generator that generates the mask for masked prediction during the training.
  114. logit_generator (torch.nn.Module):
  115. Logit generator that predicts the logits of the masked and unmasked inputs.
  116. feature_grad_mult (float or None):
  117. The factor to scale the convolutional feature extraction layer gradients by.
  118. If ``None``, the gradients of feature extraction layers are not affected.
  119. The scale factor will not affect the forward pass.
  120. """
  121. def __init__(
  122. self,
  123. wav2vec2: Wav2Vec2Model,
  124. mask_generator: Module,
  125. logit_generator: Module,
  126. feature_grad_mult: Optional[float],
  127. ):
  128. super().__init__()
  129. self.wav2vec2 = wav2vec2
  130. self.mask_generator = mask_generator
  131. self.logit_generator = logit_generator
  132. assert (
  133. feature_grad_mult is None or 0.0 < feature_grad_mult < 1.0
  134. ), f"The value of `feature_grad_mult` must be ``None`` or between (0, 1). Found {feature_grad_mult}"
  135. self.feature_grad_mult = feature_grad_mult
  136. def forward(
  137. self,
  138. waveforms: Tensor,
  139. labels: Tensor,
  140. audio_lengths: Optional[Tensor] = None,
  141. ) -> Tuple[Tensor, Optional[Tensor]]:
  142. """Compute the sequence of probability distribution over labels.
  143. Args:
  144. waveforms (Tensor): Audio tensor of dimension `[batch, frames]`.
  145. labels (Tensor): Label for pre-training. A Tensor of dimension `[batch, frames]`.
  146. audio_lengths (Tensor or None, optional):
  147. Indicates the valid length of each audio in the batch.
  148. Shape: `[batch, ]`.
  149. When the ``waveforms`` contains audios with different durations,
  150. by providing ``lengths`` argument, the model will compute
  151. the corresponding valid output lengths and apply proper mask in
  152. transformer attention layer.
  153. If ``None``, it is assumed that all the audio in ``waveforms``
  154. have valid length. Default: ``None``.
  155. Returns:
  156. (Tensor, Tensor, Tensor):
  157. Tensor
  158. The masked sequences of probability distribution (in logit).
  159. Shape: `(masked_frames, num labels)`.
  160. Tensor
  161. The unmasked sequence of probability distribution (in logit).
  162. Shape: `(unmasked_frames, num labels)`.
  163. Tensor
  164. The feature mean value for additional penalty loss.
  165. Shape: `(1,)`.
  166. """
  167. x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
  168. if self.feature_grad_mult is not None and self.feature_grad_mult < 1.0:
  169. x = components.GradMultiply.apply(x, self.feature_grad_mult)
  170. features_pen = x.float().pow(2).mean()
  171. if lengths is not None:
  172. padding_mask = components._get_padding_mask(x, lengths)
  173. else:
  174. padding_mask = None
  175. x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
  176. x, mask = self.mask_generator(x, padding_mask)
  177. x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
  178. assert x.shape[1] == labels.shape[1], "The length of label must match that of HuBERT model output"
  179. if padding_mask is not None:
  180. mask_m = torch.logical_and(~padding_mask, mask)
  181. mask_u = torch.logical_and(~padding_mask, ~mask_m)
  182. else:
  183. mask_m = mask
  184. mask_u = ~mask_m
  185. logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u)
  186. return logit_m, logit_u, features_pen
  187. def wav2vec2_model(
  188. extractor_mode: str,
  189. extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
  190. extractor_conv_bias: bool,
  191. encoder_embed_dim: int,
  192. encoder_projection_dropout: float,
  193. encoder_pos_conv_kernel: int,
  194. encoder_pos_conv_groups: int,
  195. encoder_num_layers: int,
  196. encoder_num_heads: int,
  197. encoder_attention_dropout: float,
  198. encoder_ff_interm_features: int,
  199. encoder_ff_interm_dropout: float,
  200. encoder_dropout: float,
  201. encoder_layer_norm_first: bool,
  202. encoder_layer_drop: float,
  203. aux_num_out: Optional[int],
  204. ) -> Wav2Vec2Model:
  205. # Overriding the signature so that the return type is correct on Sphinx
  206. """wav2vec2_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, aux_num_out: Optional[int]) -> torchaudio.models.Wav2Vec2Model
  207. Build a custom Wav2Vec2Model
  208. Note:
  209. The "feature extractor" below corresponds to
  210. `ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
  211. in the original ``fairseq`` implementation.
  212. This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
  213. [:footcite:`baevski2020wav2vec`] paper.
  214. The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
  215. and this is referred as "Transformer" in the paper.
  216. Args:
  217. extractor_mode (str): Operation mode of feature extractor.
  218. Valid values are ``"group_norm"`` or ``"layer_norm"``.
  219. If ``"group_norm"``, then a single normalization is applied
  220. in the first convolution block. Otherwise, all the convolution
  221. blocks will have layer normalization.
  222. This option corresponds to ``extractor_mode`` from ``fairseq``.
  223. extractor_conv_layer_config (list of integer tuples or None):
  224. Configuration of convolution layers in feature extractor.
  225. List of convolution configuration,
  226. i.e. ``[(output_channel, kernel_size, stride), ...]``
  227. If ``None`` is provided, then the following default value is used.
  228. .. code-block:: python
  229. [
  230. (512, 10, 5),
  231. (512, 3, 2),
  232. (512, 3, 2),
  233. (512, 3, 2),
  234. (512, 3, 2),
  235. (512, 2, 2),
  236. (512, 2, 2),
  237. ]
  238. This option corresponds to ``conv_feature_layers`` from ``fairseq``.
  239. extractor_conv_bias (bool):
  240. Whether to include bias term to each convolution operation.
  241. This option corresponds to ``conv_bias`` from ``fairseq``.
  242. encoder_embed_dim (int):
  243. The dimension of embedding in encoder.
  244. This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
  245. encoder_projection_dropout (float):
  246. The dropout probability applied after the input feature is projected
  247. to ``encoder_embed_dim``.
  248. This option corresponds to ``dropout_input`` from ``fairseq``.
  249. encoder_pos_conv_kernel (int):
  250. The kernel size of convolutional positional embeddings.
  251. This option corresponds to ``conv_pos`` from ``fairseq``.
  252. encoder_pos_conv_groups (int):
  253. The number of groups of convolutional positional embeddings.
  254. This option corresponds to ``conv_pos_groups`` from ``fairseq``.
  255. encoder_num_layers (int):
  256. The number of self attention layers in transformer block.
  257. This option corresponds to ``encoder_layers`` from ``fairseq``.
  258. encoder_num_heads (int):
  259. The number of heads in self attention layers.
  260. This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
  261. encoder_attention_dropout (float):
  262. The dropout probability applied after softmax in self-attention layer.
  263. This option corresponds to ``attention_dropout`` from ``fairseq``.
  264. encoder_ff_interm_features (int):
  265. The dimension of hidden features in feed forward layer.
  266. This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
  267. encoder_ff_interm_dropout (float):
  268. The dropout probability applied in feedforward layer.
  269. This option correspinds to ``activation_dropout`` from ``fairseq``.
  270. encoder_dropout (float):
  271. The dropout probability applied at the end of feed forward layer.
  272. This option corresponds to ``dropout`` from ``fairseq``.
  273. encoder_layer_norm_first (bool):
  274. Control the order of layer norm in transformer layer and each encoder layer.
  275. If True, in transformer layer, layer norm is applied before features are fed
  276. to encoder layers. In encoder layer, two layer norms are applied before and after
  277. self attention.
  278. If False, in transformer layer, layer norm is applied after features are fed
  279. to encoder layers. In encoder layer, two layer norms are applied after self
  280. attention, before and after feed forward.
  281. This option corresponds to ``layer_norm_first`` from ``fairseq``.
  282. encoder_layer_drop (float):
  283. Probability to drop each encoder layer during training.
  284. This option corresponds to ``layerdrop`` from ``fairseq``.
  285. aux_num_out (int or None):
  286. When provided, attach an extra linear layer on top of encoder, which can be
  287. used for fine-tuning.
  288. Returns:
  289. Wav2Vec2Model:
  290. The resulting model.
  291. """ # noqa: E501
  292. if extractor_conv_layer_config is None:
  293. extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
  294. feature_extractor = components._get_feature_extractor(
  295. extractor_mode, extractor_conv_layer_config, extractor_conv_bias
  296. )
  297. encoder = components._get_encoder(
  298. in_features=extractor_conv_layer_config[-1][0],
  299. embed_dim=encoder_embed_dim,
  300. dropout_input=encoder_projection_dropout,
  301. pos_conv_kernel=encoder_pos_conv_kernel,
  302. pos_conv_groups=encoder_pos_conv_groups,
  303. num_layers=encoder_num_layers,
  304. num_heads=encoder_num_heads,
  305. attention_dropout=encoder_attention_dropout,
  306. ff_interm_features=encoder_ff_interm_features,
  307. ff_interm_dropout=encoder_ff_interm_dropout,
  308. dropout=encoder_dropout,
  309. layer_norm_first=encoder_layer_norm_first,
  310. layer_drop=encoder_layer_drop,
  311. )
  312. aux = None
  313. if aux_num_out is not None:
  314. aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out)
  315. return Wav2Vec2Model(feature_extractor, encoder, aux)
  316. def wav2vec2_base(
  317. encoder_projection_dropout: float = 0.1,
  318. encoder_attention_dropout: float = 0.1,
  319. encoder_ff_interm_dropout: float = 0.1,
  320. encoder_dropout: float = 0.1,
  321. encoder_layer_drop: float = 0.1,
  322. aux_num_out: Optional[int] = None,
  323. ) -> Wav2Vec2Model:
  324. # Overriding the signature so that the return type is correct on Sphinx
  325. """wav2vec2_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
  326. Build Wav2Vec2Model with "base" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
  327. Args:
  328. encoder_projection_dropout (float):
  329. See :py:func:`wav2vec2_model`.
  330. encoder_attention_dropout (float):
  331. See :py:func:`wav2vec2_model`.
  332. encoder_ff_interm_dropout (float):
  333. See :py:func:`wav2vec2_model`.
  334. encoder_dropout (float):
  335. See :py:func:`wav2vec2_model`.
  336. encoder_layer_drop (float):
  337. See :py:func:`wav2vec2_model`.
  338. aux_num_out (int or None, optional):
  339. See :py:func:`wav2vec2_model`.
  340. Returns:
  341. Wav2Vec2Model:
  342. The resulting model.
  343. """ # noqa: E501
  344. return wav2vec2_model(
  345. extractor_mode="group_norm",
  346. extractor_conv_layer_config=None,
  347. extractor_conv_bias=False,
  348. encoder_embed_dim=768,
  349. encoder_projection_dropout=encoder_projection_dropout,
  350. encoder_pos_conv_kernel=128,
  351. encoder_pos_conv_groups=16,
  352. encoder_num_layers=12,
  353. encoder_num_heads=12,
  354. encoder_attention_dropout=encoder_attention_dropout,
  355. encoder_ff_interm_features=3072,
  356. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  357. encoder_dropout=encoder_dropout,
  358. encoder_layer_norm_first=False,
  359. encoder_layer_drop=encoder_layer_drop,
  360. aux_num_out=aux_num_out,
  361. )
  362. def wav2vec2_large(
  363. encoder_projection_dropout: float = 0.1,
  364. encoder_attention_dropout: float = 0.1,
  365. encoder_ff_interm_dropout: float = 0.1,
  366. encoder_dropout: float = 0.1,
  367. encoder_layer_drop: float = 0.1,
  368. aux_num_out: Optional[int] = None,
  369. ) -> Wav2Vec2Model:
  370. # Overriding the signature so that the return type is correct on Sphinx
  371. """wav2vec2_large(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
  372. Build Wav2Vec2Model with "large" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
  373. Args:
  374. encoder_projection_dropout (float):
  375. See :py:func:`wav2vec2_model`.
  376. encoder_attention_dropout (float):
  377. See :py:func:`wav2vec2_model`.
  378. encoder_ff_interm_dropout (float):
  379. See :py:func:`wav2vec2_model`.
  380. encoder_dropout (float):
  381. See :py:func:`wav2vec2_model`.
  382. encoder_layer_drop (float):
  383. See :py:func:`wav2vec2_model`.
  384. aux_num_out (int or None, optional):
  385. See :py:func:`wav2vec2_model`.
  386. Returns:
  387. Wav2Vec2Model:
  388. The resulting model.
  389. """ # noqa: E501
  390. return wav2vec2_model(
  391. extractor_mode="group_norm",
  392. extractor_conv_layer_config=None,
  393. extractor_conv_bias=False,
  394. encoder_embed_dim=1024,
  395. encoder_projection_dropout=encoder_projection_dropout,
  396. encoder_pos_conv_kernel=128,
  397. encoder_pos_conv_groups=16,
  398. encoder_num_layers=24,
  399. encoder_num_heads=16,
  400. encoder_attention_dropout=encoder_attention_dropout,
  401. encoder_ff_interm_features=4096,
  402. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  403. encoder_dropout=encoder_dropout,
  404. encoder_layer_norm_first=False,
  405. encoder_layer_drop=encoder_layer_drop,
  406. aux_num_out=aux_num_out,
  407. )
  408. def wav2vec2_large_lv60k(
  409. encoder_projection_dropout: float = 0.1,
  410. encoder_attention_dropout: float = 0.0,
  411. encoder_ff_interm_dropout: float = 0.1,
  412. encoder_dropout: float = 0.0,
  413. encoder_layer_drop: float = 0.1,
  414. aux_num_out: Optional[int] = None,
  415. ) -> Wav2Vec2Model:
  416. # Overriding the signature so that the return type is correct on Sphinx
  417. """wav2vec2_large_lv60k( encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
  418. Build Wav2Vec2Model with "large lv-60k" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]
  419. Args:
  420. encoder_projection_dropout (float):
  421. See :py:func:`wav2vec2_model`.
  422. encoder_attention_dropout (float):
  423. See :py:func:`wav2vec2_model`.
  424. encoder_ff_interm_dropout (float):
  425. See :py:func:`wav2vec2_model`.
  426. encoder_dropout (float):
  427. See :py:func:`wav2vec2_model`.
  428. encoder_layer_drop (float):
  429. See :py:func:`wav2vec2_model`.
  430. aux_num_out (int or None, optional):
  431. See :py:func:`wav2vec2_model`.
  432. Returns:
  433. Wav2Vec2Model:
  434. The resulting model.
  435. """ # noqa: E501
  436. return wav2vec2_model(
  437. extractor_mode="layer_norm",
  438. extractor_conv_layer_config=None,
  439. extractor_conv_bias=True,
  440. encoder_embed_dim=1024,
  441. encoder_projection_dropout=encoder_projection_dropout,
  442. encoder_pos_conv_kernel=128,
  443. encoder_pos_conv_groups=16,
  444. encoder_num_layers=24,
  445. encoder_num_heads=16,
  446. encoder_attention_dropout=encoder_attention_dropout,
  447. encoder_ff_interm_features=4096,
  448. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  449. encoder_dropout=encoder_dropout,
  450. encoder_layer_norm_first=True,
  451. encoder_layer_drop=encoder_layer_drop,
  452. aux_num_out=aux_num_out,
  453. )
  454. def hubert_base(
  455. encoder_projection_dropout: float = 0.1,
  456. encoder_attention_dropout: float = 0.1,
  457. encoder_ff_interm_dropout: float = 0.0,
  458. encoder_dropout: float = 0.1,
  459. encoder_layer_drop: float = 0.05,
  460. aux_num_out: Optional[int] = None,
  461. ) -> Wav2Vec2Model:
  462. # Overriding the signature so that the return type is correct on Sphinx
  463. """hubert_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
  464. Build HuBERT model with "base" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
  465. Args:
  466. encoder_projection_dropout (float):
  467. See :py:func:`wav2vec2_model`.
  468. encoder_attention_dropout (float):
  469. See :py:func:`wav2vec2_model`.
  470. encoder_ff_interm_dropout (float):
  471. See :py:func:`wav2vec2_model`.
  472. encoder_dropout (float):
  473. See :py:func:`wav2vec2_model`.
  474. encoder_layer_drop (float):
  475. See :py:func:`wav2vec2_model`.
  476. aux_num_out (int or None, optional):
  477. See :py:func:`wav2vec2_model`.
  478. Returns:
  479. Wav2Vec2Model:
  480. The resulting model.
  481. """ # noqa: E501
  482. return wav2vec2_model(
  483. extractor_mode="group_norm",
  484. extractor_conv_layer_config=None,
  485. extractor_conv_bias=False,
  486. encoder_embed_dim=768,
  487. encoder_projection_dropout=encoder_projection_dropout,
  488. encoder_pos_conv_kernel=128,
  489. encoder_pos_conv_groups=16,
  490. encoder_num_layers=12,
  491. encoder_num_heads=12,
  492. encoder_attention_dropout=encoder_attention_dropout,
  493. encoder_ff_interm_features=3072,
  494. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  495. encoder_dropout=encoder_dropout,
  496. encoder_layer_norm_first=False,
  497. encoder_layer_drop=encoder_layer_drop,
  498. aux_num_out=aux_num_out,
  499. )
  500. def hubert_large(
  501. encoder_projection_dropout: float = 0.0,
  502. encoder_attention_dropout: float = 0.0,
  503. encoder_ff_interm_dropout: float = 0.0,
  504. encoder_dropout: float = 0.0,
  505. encoder_layer_drop: float = 0.0,
  506. aux_num_out: Optional[int] = None,
  507. ) -> Wav2Vec2Model:
  508. # Overriding the signature so that the return type is correct on Sphinx
  509. """hubert_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
  510. Build HuBERT model with "large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
  511. Args:
  512. encoder_projection_dropout (float):
  513. See :py:func:`wav2vec2_model`.
  514. encoder_attention_dropout (float):
  515. See :py:func:`wav2vec2_model`.
  516. encoder_ff_interm_dropout (float):
  517. See :py:func:`wav2vec2_model`.
  518. encoder_dropout (float):
  519. See :py:func:`wav2vec2_model`.
  520. encoder_layer_drop (float):
  521. See :py:func:`wav2vec2_model`.
  522. aux_num_out (int or None, optional):
  523. See :py:func:`wav2vec2_model`.
  524. Returns:
  525. Wav2Vec2Model:
  526. The resulting model.
  527. """ # noqa: E501
  528. return wav2vec2_model(
  529. extractor_mode="layer_norm",
  530. extractor_conv_layer_config=None,
  531. extractor_conv_bias=False,
  532. encoder_embed_dim=1024,
  533. encoder_projection_dropout=encoder_projection_dropout,
  534. encoder_pos_conv_kernel=128,
  535. encoder_pos_conv_groups=16,
  536. encoder_num_layers=24,
  537. encoder_num_heads=16,
  538. encoder_attention_dropout=encoder_attention_dropout,
  539. encoder_ff_interm_features=4096,
  540. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  541. encoder_dropout=encoder_dropout,
  542. encoder_layer_norm_first=True,
  543. encoder_layer_drop=encoder_layer_drop,
  544. aux_num_out=aux_num_out,
  545. )
  546. def hubert_xlarge(
  547. encoder_projection_dropout: float = 0.0,
  548. encoder_attention_dropout: float = 0.0,
  549. encoder_ff_interm_dropout: float = 0.0,
  550. encoder_dropout: float = 0.0,
  551. encoder_layer_drop: float = 0.0,
  552. aux_num_out: Optional[int] = None,
  553. ) -> Wav2Vec2Model:
  554. # Overriding the signature so that the return type is correct on Sphinx
  555. """hubert_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model
  556. Build HuBERT model with "extra large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
  557. Args:
  558. encoder_projection_dropout (float):
  559. See :py:func:`wav2vec2_model`.
  560. encoder_attention_dropout (float):
  561. See :py:func:`wav2vec2_model`.
  562. encoder_ff_interm_dropout (float):
  563. See :py:func:`wav2vec2_model`.
  564. encoder_dropout (float):
  565. See :py:func:`wav2vec2_model`.
  566. encoder_layer_drop (float):
  567. See :py:func:`wav2vec2_model`.
  568. aux_num_out (int or None, optional):
  569. See :py:func:`wav2vec2_model`.
  570. Returns:
  571. Wav2Vec2Model:
  572. The resulting model.
  573. """ # noqa: E501
  574. return wav2vec2_model(
  575. extractor_mode="layer_norm",
  576. extractor_conv_layer_config=None,
  577. extractor_conv_bias=False,
  578. encoder_embed_dim=1280,
  579. encoder_projection_dropout=encoder_projection_dropout,
  580. encoder_pos_conv_kernel=128,
  581. encoder_pos_conv_groups=16,
  582. encoder_num_layers=48,
  583. encoder_num_heads=16,
  584. encoder_attention_dropout=encoder_attention_dropout,
  585. encoder_ff_interm_features=5120,
  586. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  587. encoder_dropout=encoder_dropout,
  588. encoder_layer_norm_first=True,
  589. encoder_layer_drop=encoder_layer_drop,
  590. aux_num_out=aux_num_out,
  591. )
  592. def hubert_pretrain_model(
  593. extractor_mode: str,
  594. extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
  595. extractor_conv_bias: bool,
  596. encoder_embed_dim: int,
  597. encoder_projection_dropout: float,
  598. encoder_pos_conv_kernel: int,
  599. encoder_pos_conv_groups: int,
  600. encoder_num_layers: int,
  601. encoder_num_heads: int,
  602. encoder_attention_dropout: float,
  603. encoder_ff_interm_features: int,
  604. encoder_ff_interm_dropout: float,
  605. encoder_dropout: float,
  606. encoder_layer_norm_first: bool,
  607. encoder_layer_drop: float,
  608. mask_prob: float,
  609. mask_selection: str,
  610. mask_other: float,
  611. mask_length: int,
  612. no_mask_overlap: bool,
  613. mask_min_space: int,
  614. mask_channel_prob: float,
  615. mask_channel_selection: str,
  616. mask_channel_other: float,
  617. mask_channel_length: int,
  618. no_mask_channel_overlap: bool,
  619. mask_channel_min_space: int,
  620. skip_masked: bool,
  621. skip_nomask: bool,
  622. num_classes: int,
  623. final_dim: int,
  624. feature_grad_mult: Optional[float],
  625. ) -> HuBERTPretrainModel:
  626. # Overriding the signature so that the return type is correct on Sphinx
  627. """hubert_pretrain_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, mask_prob: float, mask_selection: str, mask_other: float, mask_length: int, no_mask_overlap: bool, mask_min_space: int, mask_channel_prob: float, mask_channel_selection: str, mask_channel_other: float, mask_channel_length: int, no_mask_channel_overlap: bool, mask_channel_min_space: int, skip_masked: bool, skip_nomask: bool, num_classes: int, final_dim: int) -> torchaudio.models.HuBERTPretrainModel
  628. Build a custom HuBERTPretrainModel for training from scratch
  629. Note:
  630. The "feature extractor" below corresponds to
  631. `ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
  632. in the original ``fairseq`` implementation.
  633. This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
  634. [:footcite:`baevski2020wav2vec`] paper.
  635. The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
  636. and this is referred as "Transformer" in the paper.
  637. Args:
  638. extractor_mode (str): Operation mode of feature extractor.
  639. Valid values are ``"group_norm"`` or ``"layer_norm"``.
  640. If ``"group_norm"``, then a single normalization is applied
  641. in the first convolution block. Otherwise, all the convolution
  642. blocks will have layer normalization.
  643. This option corresponds to ``extractor_mode`` from ``fairseq``.
  644. extractor_conv_layer_config (list of integer tuples or None):
  645. Configuration of convolution layers in feature extractor.
  646. List of convolution configuration,
  647. i.e. ``[(output_channel, kernel_size, stride), ...]``
  648. If ``None`` is provided, then the following default value is used.
  649. .. code-block:: python
  650. [
  651. (512, 10, 5),
  652. (512, 3, 2),
  653. (512, 3, 2),
  654. (512, 3, 2),
  655. (512, 3, 2),
  656. (512, 2, 2),
  657. (512, 2, 2),
  658. ]
  659. This option corresponds to ``conv_feature_layers`` from ``fairseq``.
  660. extractor_conv_bias (bool):
  661. Whether to include bias term to each convolution operation.
  662. This option corresponds to ``conv_bias`` from ``fairseq``.
  663. encoder_embed_dim (int):
  664. The dimension of embedding in encoder.
  665. This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
  666. encoder_projection_dropout (float):
  667. The dropout probability applied after the input feature is projected
  668. to ``encoder_embed_dim``.
  669. This option corresponds to ``dropout_input`` from ``fairseq``.
  670. encoder_pos_conv_kernel (int):
  671. The kernel size of convolutional positional embeddings.
  672. This option corresponds to ``conv_pos`` from ``fairseq``.
  673. encoder_pos_conv_groups (int):
  674. The number of groups of convolutional positional embeddings.
  675. This option corresponds to ``conv_pos_groups`` from ``fairseq``.
  676. encoder_num_layers (int):
  677. The number of self attention layers in transformer block.
  678. This option corresponds to ``encoder_layers`` from ``fairseq``.
  679. encoder_num_heads (int):
  680. The number of heads in self attention layers.
  681. This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
  682. encoder_attention_dropout (float):
  683. The dropout probability applied after softmax in self-attention layer.
  684. This option corresponds to ``attention_dropout`` from ``fairseq``.
  685. encoder_ff_interm_features (int):
  686. The dimension of hidden features in feed forward layer.
  687. This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
  688. encoder_ff_interm_dropout (float):
  689. The dropout probability applied in feedforward layer.
  690. This option correspinds to ``activation_dropout`` from ``fairseq``.
  691. encoder_dropout (float):
  692. The dropout probability applied at the end of feed forward layer.
  693. This option corresponds to ``dropout`` from ``fairseq``.
  694. encoder_layer_norm_first (bool):
  695. Control the order of layer norm in transformer layer and each encoder layer.
  696. If True, in transformer layer, layer norm is applied before features are fed
  697. to encoder layers. In encoder layer, two layer norms are applied before and after
  698. self attention.
  699. If False, in transformer layer, layer norm is applied after features are fed
  700. to encoder layers. In encoder layer, two layer norms are applied after self
  701. attention, before and after feed forward.
  702. This option corresponds to ``layer_norm_first`` from ``fairseq``.
  703. encoder_layer_drop (float):
  704. Probability to drop each encoder layer during training.
  705. This option corresponds to ``layerdrop`` from ``fairseq``.
  706. mask_prob (float):
  707. Probability for each token to be chosen as start of the span to be masked. this will be multiplied by
  708. number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
  709. However due to overlaps, the actual number will be smaller (unless no_overlap is True).
  710. This option corresponds to ``mask_prob`` from ``fairseq``.
  711. mask_selection (str):
  712. How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
  713. This option corresponds to ``mask_selection`` from ``fairseq``.
  714. mask_other (float):
  715. Secondary mask argument (used for more complex distributions).
  716. This option corresponds to ``mask_other`` from ``fairseq``.
  717. mask_length (int):
  718. The lengths of the mask.
  719. This option corresponds to ``mask_length`` from ``fairseq``.
  720. no_mask_overlap (bool):
  721. Whether to allow masks to overlap.
  722. This option corresponds to ``no_mask_overlap`` from ``fairseq``.
  723. mask_min_space (int):
  724. Minimum space between spans (if no overlap is enabled).
  725. This option corresponds to ``mask_min_space`` from ``fairseq``.
  726. mask_channel_prob: (float):
  727. The probability of replacing a feature with 0.
  728. This option corresponds to ``mask_channel_prob`` from ``fairseq``.
  729. mask_channel_selection (str):
  730. How to choose the mask length for channel masking. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
  731. This option corresponds to ``mask_channel_selection`` from ``fairseq``.
  732. mask_channel_other (float):
  733. Secondary mask argument for channel masking(used for more complex distributions).
  734. This option corresponds to ``mask_channel_other`` from ``fairseq``.
  735. mask_channel_length (int):
  736. Minimum space between spans (if no overlap is enabled) for channel masking.
  737. This option corresponds to ``mask_channel_length`` from ``fairseq``.
  738. no_mask_channel_overlap (bool):
  739. Whether to allow channel masks to overlap.
  740. This option corresponds to ``no_mask_channel_overlap`` from ``fairseq``.
  741. mask_channel_min_space (int):
  742. Minimum space between spans for channel masking(if no overlap is enabled).
  743. This option corresponds to ``mask_channel_min_space`` from ``fairseq``.
  744. skip_masked (bool):
  745. If True, skip computing losses over masked frames.
  746. This option corresponds to ``skip_masked`` from ``fairseq``.
  747. skip_nomask (bool):
  748. If True, skip computing losses over unmasked frames.
  749. This option corresponds to ``skip_nomask`` from ``fairseq``.
  750. num_classes (int):
  751. The number of classes in the labels.
  752. final_dim (int):
  753. Project final representations and targets to `final_dim`.
  754. This option corresponds to ``final_dim`` from ``fairseq``.
  755. feature_grad_mult (float or None):
  756. The factor to scale the convolutional feature extraction layer gradients by.
  757. The scale factor will not affect the forward pass.
  758. This option corresponds to ``feature_grad_mult`` from ``fairseq``.
  759. Returns:
  760. HuBERTPretrainModel:
  761. The resulting model.
  762. """ # noqa: E501
  763. if extractor_conv_layer_config is None:
  764. extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
  765. feature_extractor = components._get_feature_extractor(
  766. extractor_mode, extractor_conv_layer_config, extractor_conv_bias
  767. )
  768. encoder = components._get_encoder(
  769. in_features=extractor_conv_layer_config[-1][0],
  770. embed_dim=encoder_embed_dim,
  771. dropout_input=encoder_projection_dropout,
  772. pos_conv_kernel=encoder_pos_conv_kernel,
  773. pos_conv_groups=encoder_pos_conv_groups,
  774. num_layers=encoder_num_layers,
  775. num_heads=encoder_num_heads,
  776. attention_dropout=encoder_attention_dropout,
  777. ff_interm_features=encoder_ff_interm_features,
  778. ff_interm_dropout=encoder_ff_interm_dropout,
  779. dropout=encoder_dropout,
  780. layer_norm_first=encoder_layer_norm_first,
  781. layer_drop=encoder_layer_drop,
  782. )
  783. wav2vec2 = Wav2Vec2Model(feature_extractor, encoder)
  784. mask_generator = components.MaskGenerator(
  785. encoder_embed_dim,
  786. mask_prob,
  787. mask_selection,
  788. mask_other,
  789. mask_length,
  790. no_mask_overlap,
  791. mask_min_space,
  792. mask_channel_prob,
  793. mask_channel_selection,
  794. mask_channel_other,
  795. mask_channel_length,
  796. no_mask_channel_overlap,
  797. mask_channel_min_space,
  798. )
  799. logit_generator = components.LogitGenerator(
  800. encoder_embed_dim,
  801. num_classes,
  802. final_dim,
  803. skip_masked,
  804. skip_nomask,
  805. )
  806. return HuBERTPretrainModel(
  807. wav2vec2=wav2vec2,
  808. mask_generator=mask_generator,
  809. logit_generator=logit_generator,
  810. feature_grad_mult=feature_grad_mult,
  811. )
  812. def hubert_pretrain_base(
  813. encoder_projection_dropout: float = 0.1,
  814. encoder_attention_dropout: float = 0.1,
  815. encoder_ff_interm_dropout: float = 0.0,
  816. encoder_dropout: float = 0.1,
  817. encoder_layer_drop: float = 0.05,
  818. mask_prob: float = 0.8,
  819. mask_channel_prob: float = 0.0,
  820. mask_channel_length: int = 10,
  821. feature_grad_mult: Optional[float] = 0.1,
  822. num_classes: int = 100,
  823. ) -> HuBERTPretrainModel:
  824. # Overriding the signature so that the return type is correct on Sphinx
  825. """hubert_pretrain_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10, feature_grad_mult: Optional[float] = 0.1, num_classes: int = 100) -> torchaudio.models.HuBERTPretrainModel
  826. Build HuBERTPretrainModel model with "base" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
  827. Args:
  828. encoder_projection_dropout (float):
  829. See :py:func:`hubert_pretrain_model`.
  830. encoder_attention_dropout (float):
  831. See :py:func:`hubert_pretrain_model`.
  832. encoder_ff_interm_dropout (float):
  833. See :py:func:`hubert_pretrain_model`.
  834. encoder_dropout (float):
  835. See :py:func:`hubert_pretrain_model`.
  836. encoder_layer_drop (float):
  837. See :py:func:`hubert_pretrain_model`.
  838. mask_prob (float):
  839. See :py:func:`hubert_pretrain_model`.
  840. mask_channel_prob (float):
  841. See :py:func:`hubert_pretrain_model`.
  842. mask_channel_length (int):
  843. See :py:func:`hubert_pretrain_model`.
  844. feature_grad_mult (float or None):
  845. See :py:func:`hubert_pretrain_model`.
  846. num_classes (int, optional):
  847. See :py:func:`hubert_pretrain_model`.
  848. Returns:
  849. HuBERTPretrainModel:
  850. The resulting model.
  851. """ # noqa: E501
  852. return hubert_pretrain_model(
  853. extractor_mode="group_norm",
  854. extractor_conv_layer_config=None,
  855. extractor_conv_bias=False,
  856. encoder_embed_dim=768,
  857. encoder_projection_dropout=encoder_projection_dropout,
  858. encoder_pos_conv_kernel=128,
  859. encoder_pos_conv_groups=16,
  860. encoder_num_layers=12,
  861. encoder_num_heads=12,
  862. encoder_attention_dropout=encoder_attention_dropout,
  863. encoder_ff_interm_features=3072,
  864. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  865. encoder_dropout=encoder_dropout,
  866. encoder_layer_norm_first=False,
  867. encoder_layer_drop=encoder_layer_drop,
  868. mask_prob=mask_prob,
  869. mask_selection="static",
  870. mask_other=0.0,
  871. mask_length=10,
  872. no_mask_overlap=False,
  873. mask_min_space=1,
  874. mask_channel_prob=mask_channel_prob,
  875. mask_channel_selection="static",
  876. mask_channel_other=0.0,
  877. mask_channel_length=mask_channel_length,
  878. no_mask_channel_overlap=False,
  879. mask_channel_min_space=1,
  880. skip_masked=False,
  881. skip_nomask=False,
  882. num_classes=num_classes,
  883. final_dim=256,
  884. feature_grad_mult=feature_grad_mult,
  885. )
  886. def hubert_pretrain_large(
  887. encoder_projection_dropout: float = 0.0,
  888. encoder_attention_dropout: float = 0.0,
  889. encoder_ff_interm_dropout: float = 0.0,
  890. encoder_dropout: float = 0.0,
  891. encoder_layer_drop: float = 0.0,
  892. mask_prob: float = 0.8,
  893. mask_channel_prob: float = 0.0,
  894. mask_channel_length: int = 10,
  895. feature_grad_mult: Optional[float] = None,
  896. ) -> HuBERTPretrainModel:
  897. # Overriding the signature so that the return type is correct on Sphinx
  898. """hubert_pretrain_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10, feature_grad_mult: Optional[float] = None) -> torchaudio.models.HuBERTPretrainModel
  899. Build HuBERTPretrainModel model for pre-training with "large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
  900. Args:
  901. encoder_projection_dropout (float):
  902. See :py:func:`hubert_pretrain_model`.
  903. encoder_attention_dropout (float):
  904. See :py:func:`hubert_pretrain_model`.
  905. encoder_ff_interm_dropout (float):
  906. See :py:func:`hubert_pretrain_model`.
  907. encoder_dropout (float):
  908. See :py:func:`hubert_pretrain_model`.
  909. encoder_layer_drop (float):
  910. See :py:func:`hubert_pretrain_model`.
  911. mask_prob (float):
  912. See :py:func:`hubert_pretrain_model`.
  913. mask_channel_prob (float):
  914. See :py:func:`hubert_pretrain_model`.
  915. mask_channel_length (int):
  916. See :py:func:`hubert_pretrain_model`.
  917. feature_grad_mult (float or None):
  918. See :py:func:`hubert_pretrain_model`.
  919. Returns:
  920. HuBERTPretrainModel:
  921. The resulting model.
  922. """ # noqa: E501
  923. return hubert_pretrain_model(
  924. extractor_mode="layer_norm",
  925. extractor_conv_layer_config=None,
  926. extractor_conv_bias=False,
  927. encoder_embed_dim=1024,
  928. encoder_projection_dropout=encoder_projection_dropout,
  929. encoder_pos_conv_kernel=128,
  930. encoder_pos_conv_groups=16,
  931. encoder_num_layers=24,
  932. encoder_num_heads=16,
  933. encoder_attention_dropout=encoder_attention_dropout,
  934. encoder_ff_interm_features=4096,
  935. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  936. encoder_dropout=encoder_dropout,
  937. encoder_layer_norm_first=True,
  938. encoder_layer_drop=encoder_layer_drop,
  939. mask_prob=mask_prob,
  940. mask_selection="static",
  941. mask_other=0.0,
  942. mask_length=10,
  943. no_mask_overlap=False,
  944. mask_min_space=1,
  945. mask_channel_prob=mask_channel_prob,
  946. mask_channel_selection="static",
  947. mask_channel_other=0.0,
  948. mask_channel_length=mask_channel_length,
  949. no_mask_channel_overlap=False,
  950. mask_channel_min_space=1,
  951. skip_masked=False,
  952. skip_nomask=False,
  953. num_classes=500,
  954. final_dim=768,
  955. feature_grad_mult=feature_grad_mult,
  956. )
  957. def hubert_pretrain_xlarge(
  958. encoder_projection_dropout: float = 0.0,
  959. encoder_attention_dropout: float = 0.0,
  960. encoder_ff_interm_dropout: float = 0.0,
  961. encoder_dropout: float = 0.0,
  962. encoder_layer_drop: float = 0.0,
  963. mask_prob: float = 0.8,
  964. mask_channel_prob: float = 0.0,
  965. mask_channel_length: int = 10,
  966. feature_grad_mult: Optional[float] = None,
  967. ) -> HuBERTPretrainModel:
  968. # Overriding the signature so that the return type is correct on Sphinx
  969. """hubert_pretrain_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, mask_prob: float = 0.8, mask_channel_prob: float = 0.0, mask_channel_length: int = 10, feature_grad_mult: Optional[float] = None) -> torchaudio.models.HuBERTPretrainModel
  970. Build HuBERTPretrainModel model for pre-training with "extra large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
  971. Args:
  972. encoder_projection_dropout (float):
  973. See :py:func:`hubert_pretrain_model`.
  974. encoder_attention_dropout (float):
  975. See :py:func:`hubert_pretrain_model`.
  976. encoder_ff_interm_dropout (float):
  977. See :py:func:`hubert_pretrain_model`.
  978. encoder_dropout (float):
  979. See :py:func:`hubert_pretrain_model`.
  980. encoder_layer_drop (float):
  981. See :py:func:`hubert_pretrain_model`.
  982. mask_prob (float):
  983. See :py:func:`hubert_pretrain_model`.
  984. mask_channel_prob (float):
  985. See :py:func:`hubert_pretrain_model`.
  986. mask_channel_length (int):
  987. See :py:func:`hubert_pretrain_model`.
  988. feature_grad_mult (float or None):
  989. See :py:func:`hubert_pretrain_model`.
  990. Returns:
  991. HuBERTPretrainModel:
  992. The resulting model.
  993. """ # noqa: E501
  994. return hubert_pretrain_model(
  995. extractor_mode="layer_norm",
  996. extractor_conv_layer_config=None,
  997. extractor_conv_bias=False,
  998. encoder_embed_dim=1280,
  999. encoder_projection_dropout=encoder_projection_dropout,
  1000. encoder_pos_conv_kernel=128,
  1001. encoder_pos_conv_groups=16,
  1002. encoder_num_layers=48,
  1003. encoder_num_heads=16,
  1004. encoder_attention_dropout=encoder_attention_dropout,
  1005. encoder_ff_interm_features=5120,
  1006. encoder_ff_interm_dropout=encoder_ff_interm_dropout,
  1007. encoder_dropout=encoder_dropout,
  1008. encoder_layer_norm_first=True,
  1009. encoder_layer_drop=encoder_layer_drop,
  1010. mask_prob=mask_prob,
  1011. mask_selection="static",
  1012. mask_other=0.0,
  1013. mask_length=10,
  1014. no_mask_overlap=False,
  1015. mask_min_space=1,
  1016. mask_channel_prob=mask_channel_prob,
  1017. mask_channel_selection="static",
  1018. mask_channel_other=0.0,
  1019. mask_channel_length=mask_channel_length,
  1020. no_mask_channel_overlap=False,
  1021. mask_channel_min_space=1,
  1022. skip_masked=False,
  1023. skip_nomask=False,
  1024. num_classes=500,
  1025. final_dim=1024,
  1026. feature_grad_mult=feature_grad_mult,
  1027. )