components.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051
  1. import logging
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from torch import nn, Tensor
  5. from torch.nn import Module, Parameter
  6. _LG = logging.getLogger(__name__)
  7. class LayerNorm(nn.LayerNorm):
  8. """Layer norm with transpose"""
  9. def forward(self, input: Tensor) -> Tensor:
  10. x = input.transpose(-2, -1)
  11. x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  12. x = x.transpose(-2, -1)
  13. return x
  14. class ConvLayerBlock(Module):
  15. """Convolution unit of FeatureExtractor"""
  16. def __init__(
  17. self,
  18. in_channels: int,
  19. out_channels: int,
  20. kernel_size: int,
  21. stride: int,
  22. bias: bool,
  23. layer_norm: Optional[Module],
  24. ):
  25. super().__init__()
  26. self.kernel_size = kernel_size
  27. self.stride = stride
  28. self.layer_norm = layer_norm
  29. self.conv = nn.Conv1d(
  30. in_channels=in_channels,
  31. out_channels=out_channels,
  32. kernel_size=kernel_size,
  33. stride=stride,
  34. bias=bias,
  35. )
  36. def forward(
  37. self,
  38. x: Tensor,
  39. length: Optional[Tensor],
  40. ) -> Tuple[Tensor, Optional[Tensor]]:
  41. """
  42. Args:
  43. x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
  44. length (Tensor or None, optional): Shape ``[batch, ]``.
  45. Returns:
  46. Tensor: Shape ``[batch, out_channels, out_frames]``.
  47. Optional[Tensor]: Shape ``[batch, ]``.
  48. """
  49. x = self.conv(x)
  50. if self.layer_norm is not None:
  51. x = self.layer_norm(x)
  52. x = nn.functional.gelu(x)
  53. if length is not None:
  54. length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
  55. # When input length is 0, the resulting length can be negative. So fix it here.
  56. length = torch.max(torch.zeros_like(length), length)
  57. return x, length
  58. class FeatureExtractor(Module):
  59. """Extract features from audio
  60. Args:
  61. conv_layers (nn.ModuleList):
  62. convolution layers
  63. """
  64. def __init__(
  65. self,
  66. conv_layers: nn.ModuleList,
  67. ):
  68. super().__init__()
  69. self.conv_layers = conv_layers
  70. def forward(
  71. self,
  72. x: Tensor,
  73. length: Optional[Tensor],
  74. ) -> Tuple[Tensor, Optional[Tensor]]:
  75. """
  76. Args:
  77. x (Tensor):
  78. Input Tensor representing a batch of audio,
  79. shape: ``[batch, time]``.
  80. length (Tensor or None, optional):
  81. Valid length of each input sample. shape: ``[batch, ]``.
  82. Returns:
  83. Tensor:
  84. The resulting feature, shape: ``[batch, frame, feature]``
  85. Optional[Tensor]:
  86. Valid length of each output sample. shape: ``[batch, ]``.
  87. """
  88. if x.ndim != 2:
  89. raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}")
  90. x = x.unsqueeze(1) # (batch, channel==1, frame)
  91. for layer in self.conv_layers:
  92. x, length = layer(x, length) # (batch, feature, frame)
  93. x = x.transpose(1, 2) # (batch, frame, feature)
  94. return x, length
  95. class FeatureProjection(Module):
  96. """Layer that connects FeatureExtractor and Encoder
  97. Projects features to encoder dimension.
  98. Args:
  99. in_features (int): Input feature dim.
  100. out_features (int): Output feature dim.
  101. dropout (float): Dropout probability.
  102. """
  103. def __init__(
  104. self,
  105. in_features: int,
  106. out_features: int,
  107. dropout: float,
  108. ):
  109. super().__init__()
  110. self.layer_norm = nn.LayerNorm(in_features)
  111. self.projection = nn.Linear(
  112. in_features,
  113. out_features,
  114. )
  115. self.dropout = nn.Dropout(dropout)
  116. def forward(self, x):
  117. """
  118. Args:
  119. x (Tensor):
  120. Feature Tensor. shape: ``[batch, frame, in_feature]``
  121. Returns:
  122. Tensor: Projected features. ``[batch, frame, out_feature]``.
  123. """
  124. x = self.layer_norm(x)
  125. x = self.projection(x)
  126. x = self.dropout(x)
  127. return x
  128. class ConvolutionalPositionalEmbedding(Module):
  129. """Positional embedding which is placed at the beginning of Transformer.
  130. Args:
  131. embed_dim (int): Feature dimension of the input Tensor.
  132. kernel_size (int): The number of frames to be use.
  133. groups (int): The number of groups in feature dimensions.
  134. """
  135. def __init__(
  136. self,
  137. embed_dim: int,
  138. kernel_size: int,
  139. groups: int,
  140. ):
  141. super().__init__()
  142. self.embed_dim = embed_dim
  143. self.conv = nn.Conv1d(
  144. in_channels=embed_dim,
  145. out_channels=embed_dim,
  146. kernel_size=kernel_size,
  147. padding=kernel_size // 2,
  148. groups=groups,
  149. )
  150. self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
  151. self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
  152. def __prepare_scriptable__(self):
  153. for hook in self.conv._forward_pre_hooks.values():
  154. # The hook we want to remove is an instance of WeightNorm class, so
  155. # normally we would do `if isinstance(...)` but this class is not accessible
  156. # because of shadowing, so we check the module name directly.
  157. # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
  158. if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
  159. _LG.warning("Removing weight_norm from %s", self.__class__.__name__)
  160. torch.nn.utils.remove_weight_norm(self.conv)
  161. return self
  162. def forward(self, x):
  163. """
  164. Args:
  165. x (Tensor): shape ``[batch, frame, feature]``.
  166. Returns:
  167. Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
  168. """
  169. x = x.transpose(-2, -1)
  170. x = self.conv(x)
  171. if self.num_remove > 0:
  172. x = x[..., : -self.num_remove]
  173. x = torch.nn.functional.gelu(x)
  174. x = x.transpose(-2, -1)
  175. return x
  176. class SelfAttention(Module):
  177. """Multihead Self Attention module
  178. Args:
  179. embed_dim (int): Total dimension of the model.
  180. num_heads (int): The number of heads.
  181. dropout (float, optional):
  182. Dropout probabiliry on attn_output_weights. Default: ``0.0``
  183. """
  184. def __init__(
  185. self,
  186. embed_dim: int,
  187. num_heads: int,
  188. dropout: float = 0.0,
  189. ):
  190. super().__init__()
  191. head_dim = embed_dim // num_heads
  192. if head_dim * num_heads != embed_dim:
  193. raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`")
  194. self.embed_dim = embed_dim
  195. self.num_heads = num_heads
  196. self.dropout = torch.nn.Dropout(dropout)
  197. self.head_dim = head_dim
  198. self.scaling = self.head_dim**-0.5
  199. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
  200. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
  201. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
  202. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
  203. def forward(
  204. self,
  205. x: Tensor,
  206. attention_mask: Optional[Tensor] = None,
  207. ) -> Tensor:
  208. """
  209. Args:
  210. x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
  211. attention_mask (Tensor or None, optional):
  212. shape: ``[batch_size, 1, sequence_length, sequence_length]``
  213. Returns:
  214. Tensor: The resulting tensor. shape: ``[batch, sequence_length, embed_dim]``
  215. """
  216. if x.ndim != 3 or x.shape[2] != self.embed_dim:
  217. raise ValueError(
  218. f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}."
  219. )
  220. batch_size, length, embed_dim = x.size()
  221. if attention_mask is not None:
  222. shape_ = (batch_size, 1, length, length)
  223. if attention_mask.size() != shape_:
  224. raise ValueError(f"The expected attention mask shape is {shape_}. " f"Found {attention_mask.size()}.")
  225. shape = (batch_size, length, self.num_heads, self.head_dim)
  226. q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
  227. k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
  228. v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
  229. weights = self.scaling * (q @ k) # B, nH, L, L
  230. if attention_mask is not None:
  231. weights += attention_mask
  232. weights = torch.nn.functional.softmax(weights, dim=-1)
  233. weights = self.dropout(weights)
  234. output = weights @ v # B, nH, L, Hd
  235. output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
  236. output = self.out_proj(output)
  237. return output
  238. class FeedForward(Module):
  239. """Layer that follows attention layer in encoder layer."""
  240. def __init__(
  241. self,
  242. io_features: int,
  243. intermediate_features: int,
  244. intermediate_dropout: float,
  245. output_dropout: float,
  246. ):
  247. super().__init__()
  248. self.intermediate_dense = nn.Linear(io_features, intermediate_features)
  249. self.intermediate_dropout = nn.Dropout(intermediate_dropout)
  250. self.output_dense = nn.Linear(intermediate_features, io_features)
  251. self.output_dropout = nn.Dropout(output_dropout)
  252. def forward(self, x):
  253. """
  254. Args:
  255. x (Tensor): shape: `(batch, sequence_length, io_features)`
  256. Returns:
  257. x (Tensor): shape: `(batch, sequence_length, io_features)`
  258. """
  259. x = self.intermediate_dense(x)
  260. x = torch.nn.functional.gelu(x)
  261. x = self.intermediate_dropout(x)
  262. x = self.output_dense(x)
  263. x = self.output_dropout(x)
  264. return x
  265. class EncoderLayer(Module):
  266. """A layer unit in encoder. Combines multihead self attention and feed forward."""
  267. def __init__(
  268. self,
  269. attention: Module,
  270. dropout: float,
  271. layer_norm_first: bool,
  272. feed_forward: Module,
  273. ):
  274. super().__init__()
  275. self.attention = attention
  276. self.dropout = nn.Dropout(dropout)
  277. self.layer_norm = nn.LayerNorm(attention.embed_dim)
  278. self.layer_norm_first = layer_norm_first
  279. self.feed_forward = feed_forward
  280. self.final_layer_norm = nn.LayerNorm(attention.embed_dim)
  281. def forward(
  282. self,
  283. x: Tensor,
  284. attention_mask: Optional[Tensor] = None,
  285. ):
  286. """
  287. Args:
  288. x (Tensor): shape: `(batch, sequence_length, embed_dim)`
  289. attention_mask (Tensor or None, optional):
  290. shape: `(batch, 1, sequence_length, sequence_length)`
  291. """
  292. residual = x
  293. if self.layer_norm_first:
  294. x = self.layer_norm(x)
  295. x = self.attention(x, attention_mask)
  296. x = self.dropout(x)
  297. x = residual + x
  298. if self.layer_norm_first:
  299. x = x + self.feed_forward(self.final_layer_norm(x))
  300. else:
  301. x = self.layer_norm(x)
  302. x = self.final_layer_norm(x + self.feed_forward(x))
  303. return x
  304. class Transformer(Module):
  305. def __init__(
  306. self,
  307. pos_conv_embed: Module,
  308. dropout: float,
  309. layers: Module,
  310. layer_norm_first: bool,
  311. layer_drop: float,
  312. ):
  313. super().__init__()
  314. self.pos_conv_embed = pos_conv_embed
  315. self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)
  316. self.layer_norm_first = layer_norm_first
  317. self.layer_drop = layer_drop
  318. self.dropout = nn.Dropout(dropout)
  319. self.layers = layers
  320. def _preprocess(self, x: Tensor):
  321. x = x + self.pos_conv_embed(x)
  322. if self.layer_norm_first:
  323. x = self.layer_norm(x)
  324. x = self.dropout(x)
  325. return x
  326. def forward(
  327. self,
  328. x: Tensor,
  329. attention_mask: Optional[Tensor] = None,
  330. ):
  331. x = self._preprocess(x)
  332. for layer in self.layers:
  333. if not (self.training and torch.rand(1).item() <= self.layer_drop):
  334. x = layer(x, attention_mask)
  335. if not self.layer_norm_first:
  336. x = self.layer_norm(x)
  337. return x
  338. def get_intermediate_outputs(
  339. self,
  340. x: Tensor,
  341. attention_mask: Optional[Tensor] = None,
  342. num_layers: Optional[int] = None,
  343. ) -> List[Tensor]:
  344. if num_layers is not None:
  345. if not 0 < num_layers <= len(self.layers):
  346. raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
  347. ret: List[Tensor] = []
  348. x = self._preprocess(x)
  349. for layer in self.layers:
  350. x = layer(x, attention_mask)
  351. ret.append(x)
  352. if num_layers is not None and len(ret) >= num_layers:
  353. return ret
  354. return ret
  355. class Encoder(Module):
  356. def __init__(
  357. self,
  358. feature_projection: Module,
  359. transformer: Module,
  360. ):
  361. super().__init__()
  362. self.feature_projection = feature_projection
  363. self.transformer = transformer
  364. def _preprocess(
  365. self,
  366. features: Tensor,
  367. lengths: Optional[Tensor] = None,
  368. ) -> Tuple[Tensor, Optional[Tensor]]:
  369. x = self.feature_projection(features)
  370. mask: Optional[Tensor] = None
  371. if lengths is not None:
  372. batch_size, max_len, _ = x.shape
  373. # create mask for padded elements and zero-out them
  374. mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
  375. x[mask] = 0.0
  376. # extend the mask to attention shape and set weight
  377. mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
  378. mask = mask.expand(batch_size, 1, max_len, max_len)
  379. return x, mask
  380. def forward(
  381. self,
  382. features: Tensor,
  383. lengths: Optional[Tensor] = None,
  384. ) -> Tensor:
  385. x, mask = self._preprocess(features, lengths)
  386. x = self.transformer(x, attention_mask=mask)
  387. return x
  388. def extract_features(
  389. self,
  390. features: Tensor,
  391. lengths: Optional[Tensor] = None,
  392. num_layers: Optional[int] = None,
  393. ) -> List[Tensor]:
  394. x, masks = self._preprocess(features, lengths)
  395. return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
  396. ################################################################################
  397. def _get_feature_extractor(
  398. norm_mode: str,
  399. shapes: List[Tuple[int, int, int]],
  400. bias: bool,
  401. ) -> FeatureExtractor:
  402. """
  403. Args:
  404. norm_mode (str):
  405. Either "group_norm" or "layer_norm".
  406. If "group_norm", then a single normalization is applied
  407. in the first convolution block. Otherwise, all the convolution
  408. blocks will have layer normalization.
  409. This option corresponds to "extractor_mode" from fairseq.
  410. Expected values are "group_norm" for Base arch, and
  411. "layer_norm" for Large arch.
  412. shapes (list of tuple of int):
  413. Configuration of convolution layers. List of convolution configuration,
  414. i.e. ``[(output_channel, kernel_size, stride), ...]``
  415. This option corresponds to "conv_feature_layers" from fairseq.
  416. Expected values are
  417. ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2``
  418. for all the architectures.
  419. bias (bool):
  420. Whether to include bias term to each convolution operation.
  421. This option corresponds to "conv_bias" from fairseq.
  422. Expected values are False for Base arch, and True for Large arch.
  423. See Also:
  424. * Original implementation
  425. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733
  426. * "extractor_mode"
  427. - Def and base:
  428. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45
  429. - Large:
  430. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52
  431. * "conv_feature_layers"
  432. - Def, base and large:
  433. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100
  434. * "conv_bias"
  435. - Def and base:
  436. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103
  437. - Large:
  438. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
  439. """
  440. assert norm_mode in ["group_norm", "layer_norm"]
  441. blocks = []
  442. in_channels = 1
  443. for i, (out_channels, kernel_size, stride) in enumerate(shapes):
  444. normalization = None
  445. if norm_mode == "group_norm" and i == 0:
  446. normalization = nn.GroupNorm(
  447. num_groups=out_channels,
  448. num_channels=out_channels,
  449. affine=True,
  450. )
  451. elif norm_mode == "layer_norm":
  452. normalization = LayerNorm(
  453. normalized_shape=out_channels,
  454. elementwise_affine=True,
  455. )
  456. blocks.append(
  457. ConvLayerBlock(
  458. in_channels=in_channels,
  459. out_channels=out_channels,
  460. kernel_size=kernel_size,
  461. stride=stride,
  462. bias=bias,
  463. layer_norm=normalization,
  464. )
  465. )
  466. in_channels = out_channels
  467. return FeatureExtractor(nn.ModuleList(blocks))
  468. def _get_encoder(
  469. in_features: int,
  470. embed_dim: int,
  471. dropout_input: float,
  472. pos_conv_kernel: int,
  473. pos_conv_groups: int,
  474. num_layers: int,
  475. num_heads: int,
  476. attention_dropout: float,
  477. ff_interm_features: int,
  478. ff_interm_dropout: float,
  479. dropout: float,
  480. layer_norm_first: bool,
  481. layer_drop: float,
  482. ) -> Encoder:
  483. """
  484. Args:
  485. in_features (int): The number of input features.
  486. embed_dim (int):
  487. The dimension of embedding.
  488. This option corresponds to "encoder_embed_dim" from fairseq.
  489. Expected values are 768 for Base arch, and 1024 for Large arch.
  490. dropout_input (float):
  491. The dropout probability applied after the input feature is projected
  492. to ``embed_dim``.
  493. This option corresponds to "dropout_input" from fairseq.
  494. Expected values are 0.1 for both Base and Large arch.
  495. pos_conv_kernel (int):
  496. The kernel size of convolutional positional embeddings.
  497. This option corresponds to "conv_pos" from fairseq.
  498. Expected values are 128 for both Base and Large arch.
  499. pos_conv_groups (int):
  500. The number of groups of convolutional positional embeddings.
  501. This option corresponds to "conv_pos_groups" from fairseq.
  502. Expected values are 16 for both Base and Large arch.
  503. num_layers (int):
  504. The number of self attention layers in transformer block.
  505. This option corresponds to "encoder_layers" from fairseq.
  506. Expected values are 12 for Base and 24 for Large arch.
  507. num_heads (int):
  508. The number of heads in self attention layers.
  509. This option corresponds to "encoder_attention_heads" from fairseq.
  510. Expected values are 12 for Base and 16 for Large arch.
  511. attention_dropout (float):
  512. The dropout probability applied after softmax in self-attention layer.
  513. This option corresponds to "attention_dropout" from fairseq.
  514. Expected values are 0.1 for Base and 0.0 for Large arch.
  515. ff_interm_features (int):
  516. The dimension of hidden features in feed forward layer.
  517. This option corresponds to "encoder_ffn_embed_dim" from fairseq.
  518. Expected values are 3072 for Base and 4096 for Large arch.
  519. ff_interm_dropout (float):
  520. The dropout probability applied in feedforward layer.
  521. This option correspinds to "activation_dropout" from fairseq.
  522. Expected values are 0.1 for both Base and Large arch.
  523. dropout (float):
  524. The dropout probability applied at the end of feed forward layer.
  525. This option corresponds to "dropout" from fairseq.
  526. Expected values are 0.1 for Base and 0.0 for Large arch.
  527. layer_norm_first (bool):
  528. Control the order of layer norm in transformer layer and each encoder layer.
  529. If True, in transformer layer, layer norm is applied before features are fed
  530. to encoder layers. In encoder layer, two layer norms are applied before and after
  531. self attention.
  532. If False, in transformer layer, layer norm is applied after features are fed
  533. to encoder layers. In encoder layer, two layer norms are applied after self
  534. attention, before and after feed forward.
  535. This option corresponds to "layer_norm_first" from fairseq.
  536. Expected values are False for Base and True for Large arch.
  537. layer_drop (float):
  538. Probability to drop each encoder layer during training.
  539. This option corresponds to "layerdrop" from fairseq.
  540. Expected values are 0.1 for both Base and Large arch.
  541. See Also:
  542. * "encoder_embed_dim"
  543. - Def and base
  544. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51
  545. - Large
  546. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64
  547. * "dropout_input"
  548. - Def, base and large
  549. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78
  550. * "conv_pos"
  551. - Def, base and large
  552. NOTE: The description is wrong.
  553. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207
  554. - Usage
  555. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756
  556. * "conv_pos_groups"
  557. - Def, base and large
  558. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211
  559. * "encoder_layers"
  560. - Def and base
  561. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48
  562. - Large
  563. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63
  564. * "encoder_attention_heads"
  565. - Def and base
  566. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57
  567. - Large
  568. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66
  569. * "attention_dropout"
  570. - Def and base
  571. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68
  572. - Large
  573. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60
  574. * "encoder_ffn_embed_dim"
  575. - Def and base
  576. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54
  577. - Large
  578. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65
  579. * "activation_dropout"
  580. - Def
  581. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71
  582. - Base
  583. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55
  584. - Large
  585. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55
  586. * "dropout"
  587. - Def and base
  588. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65
  589. - Large
  590. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59
  591. * "layer_norm_first"
  592. - Def and base
  593. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93
  594. - Large
  595. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53
  596. * "layerdrop"
  597. - Def
  598. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74
  599. - Base
  600. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54
  601. - Large
  602. https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54
  603. """
  604. feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
  605. pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
  606. # Original impl
  607. # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
  608. encoder_layers = nn.ModuleList()
  609. for _ in range(num_layers):
  610. attention = SelfAttention(
  611. embed_dim=embed_dim,
  612. num_heads=num_heads,
  613. dropout=attention_dropout,
  614. )
  615. feed_forward = FeedForward(
  616. io_features=embed_dim,
  617. intermediate_features=ff_interm_features,
  618. intermediate_dropout=ff_interm_dropout,
  619. output_dropout=dropout,
  620. )
  621. encoder_layers.append(
  622. EncoderLayer(
  623. attention=attention,
  624. dropout=dropout,
  625. layer_norm_first=layer_norm_first,
  626. feed_forward=feed_forward,
  627. )
  628. )
  629. transformer = Transformer(
  630. pos_conv_embed=pos_conv,
  631. dropout=dropout,
  632. layers=encoder_layers,
  633. layer_norm_first=not layer_norm_first,
  634. layer_drop=layer_drop,
  635. )
  636. return Encoder(feature_projection, transformer)
  637. def _compute_mask_indices(
  638. shape: Tuple[int, int],
  639. padding_mask: Optional[Tensor],
  640. mask_prob: float,
  641. mask_length: int,
  642. mask_type: str = "static",
  643. mask_other: float = 0.0,
  644. min_masks: int = 0,
  645. no_overlap: bool = False,
  646. min_space: int = 0,
  647. ) -> Tensor:
  648. """Computes random mask spans for a given shape.
  649. Args:
  650. shape (int, int): The shape for which to compute masks.
  651. The first element is batch size and second is the number of frames.
  652. padding_mask (Tensor or None): The padding mask of the same dimension as shape,
  653. which will prevent masking padded elements.
  654. mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
  655. This will be multiplied by number of timesteps divided by length of mask span to mask
  656. approximately this percentage of all elements. However due to overlaps, the actual number
  657. will be smaller (unless no_overlap is True).
  658. mask_type (str): How to compute mask lengths. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
  659. ``static``: Fixed size
  660. ``uniform``: Sample from uniform distribution [mask_other, mask_length*2]
  661. ``normal``: Sample from normal distribution with mean ``mask_length`` and stdev ``mask_other``.
  662. ``poisson``: Sample from possion distribution with lambda = ``mask_length``.
  663. min_masks (int): Minimum number of masked spans.
  664. no_overlap (bool): If false, will switch to an alternative recursive algorithm
  665. that prevents spans from overlapping.
  666. min_space (int): How many frames to keep unmasked between spans (Only used if no_overlap is True).
  667. Returns:
  668. (Tensor): The mask indices of dimension `[batch, frame]`.
  669. """
  670. batch_size, frame = shape
  671. mask = torch.full((batch_size, frame), False)
  672. # add a random number for probabilistic rounding
  673. all_num_mask = int(mask_prob * frame / float(mask_length) + torch.rand(1))
  674. all_num_mask = max(min_masks, all_num_mask)
  675. mask_idcs = []
  676. for i in range(batch_size):
  677. if padding_mask is not None:
  678. sz = frame - padding_mask[i].long().sum().item()
  679. # add a random number for probabilistic rounding
  680. num_mask = int(mask_prob * sz / float(mask_length) + torch.rand(1))
  681. num_mask = max(min_masks, num_mask)
  682. else:
  683. sz = frame
  684. num_mask = all_num_mask
  685. if mask_type == "static":
  686. lengths = torch.full((num_mask,), mask_length)
  687. elif mask_type == "uniform":
  688. lengths = torch.randint(mask_other, mask_length * 2 + 1, size=(num_mask,))
  689. elif mask_type == "normal":
  690. lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
  691. lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
  692. elif mask_type == "poisson":
  693. lengths = torch.poisson(mask_length, size=(num_mask,))
  694. lengths = torch.round(lengths).int()
  695. else:
  696. raise Exception(f"unknown mask selection: {mask_type}")
  697. if sum(lengths) == 0:
  698. lengths[0] = min(mask_length, sz - 1)
  699. if no_overlap:
  700. mask_idc = []
  701. def arrange(s, e, length, keep_length):
  702. span_start = torch.randint(s, e - length, size=(1,))
  703. mask_idc.extend(span_start + i for i in range(length))
  704. new_parts = []
  705. if span_start - s - min_space >= keep_length:
  706. new_parts.append((s, span_start - min_space + 1))
  707. if e - span_start - keep_length - min_space > keep_length:
  708. new_parts.append((span_start + length + min_space, e))
  709. return new_parts
  710. parts = [(0, sz)]
  711. min_length = min(lengths)
  712. for length in sorted(lengths, reverse=True):
  713. lens = torch.tensor([e - s for s, e in parts], dtype=torch.int)
  714. lens[lens < length + min_space] = 0
  715. l_sum = lens.sum()
  716. if l_sum == 0:
  717. break
  718. probs = lens / l_sum
  719. c = torch.distributions.categorical.Categorical(probs).sample()
  720. s, e = parts.pop(c)
  721. parts.extend(arrange(s, e, length, min_length))
  722. mask_idc = torch.tensor(mask_idc)
  723. else:
  724. min_len = min(lengths)
  725. if sz - min_len <= num_mask:
  726. min_len = sz - num_mask - 1
  727. mask_idc = torch.multinomial(torch.ones((sz - min_len,)), num_samples=num_mask, replacement=False)
  728. mask_idc = torch.tensor(
  729. [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
  730. )
  731. mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
  732. min_len = min([len(m) for m in mask_idcs])
  733. for i, mask_idc in enumerate(mask_idcs):
  734. if len(mask_idc) > min_len:
  735. mask_idc = torch.index_select(
  736. mask_idc,
  737. 0,
  738. torch.multinomial(
  739. torch.ones((mask_idc.shape[0],)),
  740. num_samples=min_len,
  741. replacement=False,
  742. ),
  743. )
  744. mask[i, mask_idc] = True
  745. return mask
  746. def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor:
  747. """Generate the padding mask given the padded input and the lengths Tensors.
  748. Args:
  749. input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`.
  750. lengths (Tensor): The lengths Tensor of dimension `[batch,]`.
  751. Returns:
  752. (Tensor): The padding mask.
  753. """
  754. batch_size, max_len, _ = input.shape
  755. mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
  756. return mask
  757. class MaskGenerator(Module):
  758. """Generate the masks for masked prediction.
  759. Args:
  760. encoder_embed_dim (int): The dimension of the transformer embedding output.
  761. mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
  762. This will be multiplied by number of timesteps divided by length of mask span to mask
  763. approximately this percentage of all elements. However due to overlaps, the actual number
  764. will be smaller (unless no_overlap is True).
  765. mask_selection (str): How to choose the mask length.
  766. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
  767. mask_other (float): Secondary mask argument (used for more complex distributions).
  768. mask_length (int): The lengths of the mask.
  769. no_mask_overlap (bool): Whether to allow masks to overlap.
  770. mask_min_space (int): Minimum space between spans (if no overlap is enabled).
  771. mask_channel_prob (float): The probability of replacing a feature with 0.
  772. mask_channel_selection (str): How to choose the mask length for channel masking.
  773. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
  774. mask_channel_other (float): Secondary mask argument for channel masking(used for more complex distributions).
  775. mask_channel_length (int): Minimum space between spans (if no overlap is enabled) for channel masking.
  776. no_mask_channel_overlap (bool): Whether to allow channel masks to overlap.
  777. mask_channel_min_space (int): Minimum space between spans for channel masking(if no overlap is enabled).
  778. """
  779. def __init__(
  780. self,
  781. encoder_embed_dim: int,
  782. mask_prob: float,
  783. mask_selection: str,
  784. mask_other: float,
  785. mask_length: int,
  786. no_mask_overlap: bool,
  787. mask_min_space: int,
  788. mask_channel_prob: float,
  789. mask_channel_selection: str,
  790. mask_channel_other: float,
  791. mask_channel_length: int,
  792. no_mask_channel_overlap: bool,
  793. mask_channel_min_space: int,
  794. ):
  795. super().__init__()
  796. self.mask_prob = mask_prob
  797. self.mask_selection = mask_selection
  798. self.mask_other = mask_other
  799. self.mask_length = mask_length
  800. self.no_mask_overlap = no_mask_overlap
  801. self.mask_min_space = mask_min_space
  802. self.mask_channel_prob = mask_channel_prob
  803. self.mask_channel_selection = mask_channel_selection
  804. self.mask_channel_other = mask_channel_other
  805. self.mask_channel_length = mask_channel_length
  806. self.no_mask_channel_overlap = no_mask_channel_overlap
  807. self.mask_channel_min_space = mask_channel_min_space
  808. self.mask_embedding = Parameter(torch.FloatTensor(encoder_embed_dim))
  809. torch.nn.init.uniform_(self.mask_embedding)
  810. def forward(self, x: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
  811. """
  812. Args:
  813. x (Tensor): The encoded representations after feature extraction module.
  814. padding_mask (Tensor or None): The padding mask of the same dimension as shape,
  815. which will prevent masking padded elements.
  816. Returns:
  817. Tensor: The feature representations after masking.
  818. Tensor: The generated mask indices.
  819. """
  820. B, T, C = x.shape
  821. if self.mask_prob > 0:
  822. mask_indices = _compute_mask_indices(
  823. (B, T),
  824. padding_mask,
  825. self.mask_prob,
  826. self.mask_length,
  827. self.mask_selection,
  828. self.mask_other,
  829. min_masks=2,
  830. no_overlap=self.no_mask_overlap,
  831. min_space=self.mask_min_space,
  832. )
  833. mask_indices = mask_indices.to(x.device)
  834. x[mask_indices] = self.mask_embedding
  835. else:
  836. mask_indices = None
  837. if self.mask_channel_prob > 0:
  838. mask_channel_indices = _compute_mask_indices(
  839. (B, C),
  840. None,
  841. self.mask_channel_prob,
  842. self.mask_channel_length,
  843. self.mask_channel_selection,
  844. self.mask_channel_other,
  845. no_overlap=self.no_mask_channel_overlap,
  846. min_space=self.mask_channel_min_space,
  847. )
  848. mask_channel_indices = mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
  849. x[mask_channel_indices] = 0
  850. return x, mask_indices
  851. def _compute_logits(
  852. proj_x: Tensor,
  853. target: Tensor,
  854. label_embeddings: Parameter,
  855. ) -> Tensor:
  856. """Compute the logits of the embeddings.
  857. Args:
  858. proj_x (Tensor): The projected masked representations of dimension `[batch, frame, final_dim]`.
  859. target (Tensor): The target Tensor of dimension `[batch, frame, final_dim]`.
  860. label_embeddings (Parameter): The trainable embeddings of target of dimension `[num_class, final_dim]`.
  861. Returns:
  862. (Tensor): The logits of the inputs.
  863. """
  864. logit_temp = 0.1
  865. pos = torch.index_select(label_embeddings, 0, target.long())
  866. negs = label_embeddings.unsqueeze(1).expand(-1, proj_x.size(0), -1)
  867. neg_is_pos = (pos == negs).all(-1)
  868. pos = pos.unsqueeze(0)
  869. targets = torch.cat([pos, negs], dim=0)
  870. logits = torch.cosine_similarity(proj_x.float(), targets.float(), dim=-1).type_as(proj_x)
  871. logits /= logit_temp
  872. if neg_is_pos.any():
  873. logits[1:][neg_is_pos] = float("-inf")
  874. logits = logits.transpose(0, 1) # (num_x, num_cls+1)
  875. return logits
  876. class LogitGenerator(Module):
  877. """Generate the logits of masked and unmasked inputs.
  878. Args:
  879. encoder_embed_dim (int): The dimension of the transformer embedding output.
  880. num_classes (int): The number of classes in the labels.
  881. final_dim (int): Project final representations and targets to `final_dim`.
  882. skip_masked (bool): If True, skip computing losses over masked frames.
  883. skip_nomask (bool): If True, skip computing losses over unmasked frames.
  884. """
  885. def __init__(
  886. self,
  887. encoder_embed_dim: int,
  888. num_classes: int,
  889. final_dim: int,
  890. skip_masked: bool,
  891. skip_nomask: bool,
  892. ):
  893. super().__init__()
  894. self.label_embeddings = Parameter(torch.FloatTensor(num_classes, final_dim))
  895. torch.nn.init.uniform_(self.label_embeddings)
  896. self.final_proj = torch.nn.Linear(encoder_embed_dim, final_dim)
  897. self.skip_masked = skip_masked
  898. self.skip_nomask = skip_nomask
  899. def forward(self, x: Tensor, label: Tensor, mask_m: Tensor, mask_u: Tensor) -> Tuple[Tensor, Tensor]:
  900. """
  901. Args:
  902. x (Tensor): The feature representation of the last transformer layer.
  903. label (Tensor): The label Tensor of dimension `[batch, frame]`.
  904. mask_m (Tensor): The masked indices of dimension `[batch, frame]`.
  905. mask_u (Tensor): The unmasked indices of dimension `[batch, frame]`.
  906. Returns:
  907. Tensor: The logits of masked frames. Tensor of dimension `[masked_frame, final_dim]`.
  908. Tensor: The logits of unmasked frames. Tensor of dimension `[unmasked_frame, final_dim]`.
  909. """
  910. proj_x = self.final_proj(x)
  911. if self.skip_masked:
  912. logit_m = None
  913. else:
  914. proj_x_m = proj_x[mask_m]
  915. label_m = label[mask_m]
  916. logit_m = _compute_logits(proj_x_m, label_m, self.label_embeddings)
  917. if self.skip_nomask:
  918. logit_u = None
  919. else:
  920. proj_x_u = proj_x[mask_u]
  921. label_u = label[mask_u]
  922. logit_u = _compute_logits(proj_x_u, label_u, self.label_embeddings)
  923. return logit_m, logit_u
  924. class GradMultiply(torch.autograd.Function):
  925. @staticmethod
  926. def forward(ctx, x, scale):
  927. ctx.scale = scale
  928. res = x.new(x)
  929. return res
  930. @staticmethod
  931. def backward(ctx, grad):
  932. return grad * ctx.scale, None