| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051 |
- import logging
- from typing import List, Optional, Tuple
- import torch
- from torch import nn, Tensor
- from torch.nn import Module, Parameter
- _LG = logging.getLogger(__name__)
- class LayerNorm(nn.LayerNorm):
- """Layer norm with transpose"""
- def forward(self, input: Tensor) -> Tensor:
- x = input.transpose(-2, -1)
- x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- x = x.transpose(-2, -1)
- return x
- class ConvLayerBlock(Module):
- """Convolution unit of FeatureExtractor"""
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int,
- bias: bool,
- layer_norm: Optional[Module],
- ):
- super().__init__()
- self.kernel_size = kernel_size
- self.stride = stride
- self.layer_norm = layer_norm
- self.conv = nn.Conv1d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- bias=bias,
- )
- def forward(
- self,
- x: Tensor,
- length: Optional[Tensor],
- ) -> Tuple[Tensor, Optional[Tensor]]:
- """
- Args:
- x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
- length (Tensor or None, optional): Shape ``[batch, ]``.
- Returns:
- Tensor: Shape ``[batch, out_channels, out_frames]``.
- Optional[Tensor]: Shape ``[batch, ]``.
- """
- x = self.conv(x)
- if self.layer_norm is not None:
- x = self.layer_norm(x)
- x = nn.functional.gelu(x)
- if length is not None:
- length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
- # When input length is 0, the resulting length can be negative. So fix it here.
- length = torch.max(torch.zeros_like(length), length)
- return x, length
- class FeatureExtractor(Module):
- """Extract features from audio
- Args:
- conv_layers (nn.ModuleList):
- convolution layers
- """
- def __init__(
- self,
- conv_layers: nn.ModuleList,
- ):
- super().__init__()
- self.conv_layers = conv_layers
- def forward(
- self,
- x: Tensor,
- length: Optional[Tensor],
- ) -> Tuple[Tensor, Optional[Tensor]]:
- """
- Args:
- x (Tensor):
- Input Tensor representing a batch of audio,
- shape: ``[batch, time]``.
- length (Tensor or None, optional):
- Valid length of each input sample. shape: ``[batch, ]``.
- Returns:
- Tensor:
- The resulting feature, shape: ``[batch, frame, feature]``
- Optional[Tensor]:
- Valid length of each output sample. shape: ``[batch, ]``.
- """
- if x.ndim != 2:
- raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}")
- x = x.unsqueeze(1) # (batch, channel==1, frame)
- for layer in self.conv_layers:
- x, length = layer(x, length) # (batch, feature, frame)
- x = x.transpose(1, 2) # (batch, frame, feature)
- return x, length
- class FeatureProjection(Module):
- """Layer that connects FeatureExtractor and Encoder
- Projects features to encoder dimension.
- Args:
- in_features (int): Input feature dim.
- out_features (int): Output feature dim.
- dropout (float): Dropout probability.
- """
- def __init__(
- self,
- in_features: int,
- out_features: int,
- dropout: float,
- ):
- super().__init__()
- self.layer_norm = nn.LayerNorm(in_features)
- self.projection = nn.Linear(
- in_features,
- out_features,
- )
- self.dropout = nn.Dropout(dropout)
- def forward(self, x):
- """
- Args:
- x (Tensor):
- Feature Tensor. shape: ``[batch, frame, in_feature]``
- Returns:
- Tensor: Projected features. ``[batch, frame, out_feature]``.
- """
- x = self.layer_norm(x)
- x = self.projection(x)
- x = self.dropout(x)
- return x
- class ConvolutionalPositionalEmbedding(Module):
- """Positional embedding which is placed at the beginning of Transformer.
- Args:
- embed_dim (int): Feature dimension of the input Tensor.
- kernel_size (int): The number of frames to be use.
- groups (int): The number of groups in feature dimensions.
- """
- def __init__(
- self,
- embed_dim: int,
- kernel_size: int,
- groups: int,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.conv = nn.Conv1d(
- in_channels=embed_dim,
- out_channels=embed_dim,
- kernel_size=kernel_size,
- padding=kernel_size // 2,
- groups=groups,
- )
- self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
- self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
- def __prepare_scriptable__(self):
- for hook in self.conv._forward_pre_hooks.values():
- # The hook we want to remove is an instance of WeightNorm class, so
- # normally we would do `if isinstance(...)` but this class is not accessible
- # because of shadowing, so we check the module name directly.
- # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
- if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
- _LG.warning("Removing weight_norm from %s", self.__class__.__name__)
- torch.nn.utils.remove_weight_norm(self.conv)
- return self
- def forward(self, x):
- """
- Args:
- x (Tensor): shape ``[batch, frame, feature]``.
- Returns:
- Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
- """
- x = x.transpose(-2, -1)
- x = self.conv(x)
- if self.num_remove > 0:
- x = x[..., : -self.num_remove]
- x = torch.nn.functional.gelu(x)
- x = x.transpose(-2, -1)
- return x
- class SelfAttention(Module):
- """Multihead Self Attention module
- Args:
- embed_dim (int): Total dimension of the model.
- num_heads (int): The number of heads.
- dropout (float, optional):
- Dropout probabiliry on attn_output_weights. Default: ``0.0``
- """
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- ):
- super().__init__()
- head_dim = embed_dim // num_heads
- if head_dim * num_heads != embed_dim:
- raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`")
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = torch.nn.Dropout(dropout)
- self.head_dim = head_dim
- self.scaling = self.head_dim**-0.5
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
- def forward(
- self,
- x: Tensor,
- attention_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """
- Args:
- x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
- attention_mask (Tensor or None, optional):
- shape: ``[batch_size, 1, sequence_length, sequence_length]``
- Returns:
- Tensor: The resulting tensor. shape: ``[batch, sequence_length, embed_dim]``
- """
- if x.ndim != 3 or x.shape[2] != self.embed_dim:
- raise ValueError(
- f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}."
- )
- batch_size, length, embed_dim = x.size()
- if attention_mask is not None:
- shape_ = (batch_size, 1, length, length)
- if attention_mask.size() != shape_:
- raise ValueError(f"The expected attention mask shape is {shape_}. " f"Found {attention_mask.size()}.")
- shape = (batch_size, length, self.num_heads, self.head_dim)
- q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
- k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
- v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
- weights = self.scaling * (q @ k) # B, nH, L, L
- if attention_mask is not None:
- weights += attention_mask
- weights = torch.nn.functional.softmax(weights, dim=-1)
- weights = self.dropout(weights)
- output = weights @ v # B, nH, L, Hd
- output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
- output = self.out_proj(output)
- return output
- class FeedForward(Module):
- """Layer that follows attention layer in encoder layer."""
- def __init__(
- self,
- io_features: int,
- intermediate_features: int,
- intermediate_dropout: float,
- output_dropout: float,
- ):
- super().__init__()
- self.intermediate_dense = nn.Linear(io_features, intermediate_features)
- self.intermediate_dropout = nn.Dropout(intermediate_dropout)
- self.output_dense = nn.Linear(intermediate_features, io_features)
- self.output_dropout = nn.Dropout(output_dropout)
- def forward(self, x):
- """
- Args:
- x (Tensor): shape: `(batch, sequence_length, io_features)`
- Returns:
- x (Tensor): shape: `(batch, sequence_length, io_features)`
- """
- x = self.intermediate_dense(x)
- x = torch.nn.functional.gelu(x)
- x = self.intermediate_dropout(x)
- x = self.output_dense(x)
- x = self.output_dropout(x)
- return x
- class EncoderLayer(Module):
- """A layer unit in encoder. Combines multihead self attention and feed forward."""
- def __init__(
- self,
- attention: Module,
- dropout: float,
- layer_norm_first: bool,
- feed_forward: Module,
- ):
- super().__init__()
- self.attention = attention
- self.dropout = nn.Dropout(dropout)
- self.layer_norm = nn.LayerNorm(attention.embed_dim)
- self.layer_norm_first = layer_norm_first
- self.feed_forward = feed_forward
- self.final_layer_norm = nn.LayerNorm(attention.embed_dim)
- def forward(
- self,
- x: Tensor,
- attention_mask: Optional[Tensor] = None,
- ):
- """
- Args:
- x (Tensor): shape: `(batch, sequence_length, embed_dim)`
- attention_mask (Tensor or None, optional):
- shape: `(batch, 1, sequence_length, sequence_length)`
- """
- residual = x
- if self.layer_norm_first:
- x = self.layer_norm(x)
- x = self.attention(x, attention_mask)
- x = self.dropout(x)
- x = residual + x
- if self.layer_norm_first:
- x = x + self.feed_forward(self.final_layer_norm(x))
- else:
- x = self.layer_norm(x)
- x = self.final_layer_norm(x + self.feed_forward(x))
- return x
- class Transformer(Module):
- def __init__(
- self,
- pos_conv_embed: Module,
- dropout: float,
- layers: Module,
- layer_norm_first: bool,
- layer_drop: float,
- ):
- super().__init__()
- self.pos_conv_embed = pos_conv_embed
- self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)
- self.layer_norm_first = layer_norm_first
- self.layer_drop = layer_drop
- self.dropout = nn.Dropout(dropout)
- self.layers = layers
- def _preprocess(self, x: Tensor):
- x = x + self.pos_conv_embed(x)
- if self.layer_norm_first:
- x = self.layer_norm(x)
- x = self.dropout(x)
- return x
- def forward(
- self,
- x: Tensor,
- attention_mask: Optional[Tensor] = None,
- ):
- x = self._preprocess(x)
- for layer in self.layers:
- if not (self.training and torch.rand(1).item() <= self.layer_drop):
- x = layer(x, attention_mask)
- if not self.layer_norm_first:
- x = self.layer_norm(x)
- return x
- def get_intermediate_outputs(
- self,
- x: Tensor,
- attention_mask: Optional[Tensor] = None,
- num_layers: Optional[int] = None,
- ) -> List[Tensor]:
- if num_layers is not None:
- if not 0 < num_layers <= len(self.layers):
- raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
- ret: List[Tensor] = []
- x = self._preprocess(x)
- for layer in self.layers:
- x = layer(x, attention_mask)
- ret.append(x)
- if num_layers is not None and len(ret) >= num_layers:
- return ret
- return ret
- class Encoder(Module):
- def __init__(
- self,
- feature_projection: Module,
- transformer: Module,
- ):
- super().__init__()
- self.feature_projection = feature_projection
- self.transformer = transformer
- def _preprocess(
- self,
- features: Tensor,
- lengths: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Optional[Tensor]]:
- x = self.feature_projection(features)
- mask: Optional[Tensor] = None
- if lengths is not None:
- batch_size, max_len, _ = x.shape
- # create mask for padded elements and zero-out them
- mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
- x[mask] = 0.0
- # extend the mask to attention shape and set weight
- mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
- mask = mask.expand(batch_size, 1, max_len, max_len)
- return x, mask
- def forward(
- self,
- features: Tensor,
- lengths: Optional[Tensor] = None,
- ) -> Tensor:
- x, mask = self._preprocess(features, lengths)
- x = self.transformer(x, attention_mask=mask)
- return x
- def extract_features(
- self,
- features: Tensor,
- lengths: Optional[Tensor] = None,
- num_layers: Optional[int] = None,
- ) -> List[Tensor]:
- x, masks = self._preprocess(features, lengths)
- return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
- ################################################################################
- def _get_feature_extractor(
- norm_mode: str,
- shapes: List[Tuple[int, int, int]],
- bias: bool,
- ) -> FeatureExtractor:
- """
- Args:
- norm_mode (str):
- Either "group_norm" or "layer_norm".
- If "group_norm", then a single normalization is applied
- in the first convolution block. Otherwise, all the convolution
- blocks will have layer normalization.
- This option corresponds to "extractor_mode" from fairseq.
- Expected values are "group_norm" for Base arch, and
- "layer_norm" for Large arch.
- shapes (list of tuple of int):
- Configuration of convolution layers. List of convolution configuration,
- i.e. ``[(output_channel, kernel_size, stride), ...]``
- This option corresponds to "conv_feature_layers" from fairseq.
- Expected values are
- ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2``
- for all the architectures.
- bias (bool):
- Whether to include bias term to each convolution operation.
- This option corresponds to "conv_bias" from fairseq.
- Expected values are False for Base arch, and True for Large arch.
- See Also:
- * Original implementation
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733
- * "extractor_mode"
- - Def and base:
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45
- - Large:
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52
- * "conv_feature_layers"
- - Def, base and large:
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100
- * "conv_bias"
- - Def and base:
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103
- - Large:
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
- """
- assert norm_mode in ["group_norm", "layer_norm"]
- blocks = []
- in_channels = 1
- for i, (out_channels, kernel_size, stride) in enumerate(shapes):
- normalization = None
- if norm_mode == "group_norm" and i == 0:
- normalization = nn.GroupNorm(
- num_groups=out_channels,
- num_channels=out_channels,
- affine=True,
- )
- elif norm_mode == "layer_norm":
- normalization = LayerNorm(
- normalized_shape=out_channels,
- elementwise_affine=True,
- )
- blocks.append(
- ConvLayerBlock(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- bias=bias,
- layer_norm=normalization,
- )
- )
- in_channels = out_channels
- return FeatureExtractor(nn.ModuleList(blocks))
- def _get_encoder(
- in_features: int,
- embed_dim: int,
- dropout_input: float,
- pos_conv_kernel: int,
- pos_conv_groups: int,
- num_layers: int,
- num_heads: int,
- attention_dropout: float,
- ff_interm_features: int,
- ff_interm_dropout: float,
- dropout: float,
- layer_norm_first: bool,
- layer_drop: float,
- ) -> Encoder:
- """
- Args:
- in_features (int): The number of input features.
- embed_dim (int):
- The dimension of embedding.
- This option corresponds to "encoder_embed_dim" from fairseq.
- Expected values are 768 for Base arch, and 1024 for Large arch.
- dropout_input (float):
- The dropout probability applied after the input feature is projected
- to ``embed_dim``.
- This option corresponds to "dropout_input" from fairseq.
- Expected values are 0.1 for both Base and Large arch.
- pos_conv_kernel (int):
- The kernel size of convolutional positional embeddings.
- This option corresponds to "conv_pos" from fairseq.
- Expected values are 128 for both Base and Large arch.
- pos_conv_groups (int):
- The number of groups of convolutional positional embeddings.
- This option corresponds to "conv_pos_groups" from fairseq.
- Expected values are 16 for both Base and Large arch.
- num_layers (int):
- The number of self attention layers in transformer block.
- This option corresponds to "encoder_layers" from fairseq.
- Expected values are 12 for Base and 24 for Large arch.
- num_heads (int):
- The number of heads in self attention layers.
- This option corresponds to "encoder_attention_heads" from fairseq.
- Expected values are 12 for Base and 16 for Large arch.
- attention_dropout (float):
- The dropout probability applied after softmax in self-attention layer.
- This option corresponds to "attention_dropout" from fairseq.
- Expected values are 0.1 for Base and 0.0 for Large arch.
- ff_interm_features (int):
- The dimension of hidden features in feed forward layer.
- This option corresponds to "encoder_ffn_embed_dim" from fairseq.
- Expected values are 3072 for Base and 4096 for Large arch.
- ff_interm_dropout (float):
- The dropout probability applied in feedforward layer.
- This option correspinds to "activation_dropout" from fairseq.
- Expected values are 0.1 for both Base and Large arch.
- dropout (float):
- The dropout probability applied at the end of feed forward layer.
- This option corresponds to "dropout" from fairseq.
- Expected values are 0.1 for Base and 0.0 for Large arch.
- layer_norm_first (bool):
- Control the order of layer norm in transformer layer and each encoder layer.
- If True, in transformer layer, layer norm is applied before features are fed
- to encoder layers. In encoder layer, two layer norms are applied before and after
- self attention.
- If False, in transformer layer, layer norm is applied after features are fed
- to encoder layers. In encoder layer, two layer norms are applied after self
- attention, before and after feed forward.
- This option corresponds to "layer_norm_first" from fairseq.
- Expected values are False for Base and True for Large arch.
- layer_drop (float):
- Probability to drop each encoder layer during training.
- This option corresponds to "layerdrop" from fairseq.
- Expected values are 0.1 for both Base and Large arch.
- See Also:
- * "encoder_embed_dim"
- - Def and base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64
- * "dropout_input"
- - Def, base and large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78
- * "conv_pos"
- - Def, base and large
- NOTE: The description is wrong.
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207
- - Usage
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756
- * "conv_pos_groups"
- - Def, base and large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211
- * "encoder_layers"
- - Def and base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63
- * "encoder_attention_heads"
- - Def and base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66
- * "attention_dropout"
- - Def and base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60
- * "encoder_ffn_embed_dim"
- - Def and base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65
- * "activation_dropout"
- - Def
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71
- - Base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55
- * "dropout"
- - Def and base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59
- * "layer_norm_first"
- - Def and base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53
- * "layerdrop"
- - Def
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74
- - Base
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54
- - Large
- https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54
- """
- feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
- pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
- # Original impl
- # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
- encoder_layers = nn.ModuleList()
- for _ in range(num_layers):
- attention = SelfAttention(
- embed_dim=embed_dim,
- num_heads=num_heads,
- dropout=attention_dropout,
- )
- feed_forward = FeedForward(
- io_features=embed_dim,
- intermediate_features=ff_interm_features,
- intermediate_dropout=ff_interm_dropout,
- output_dropout=dropout,
- )
- encoder_layers.append(
- EncoderLayer(
- attention=attention,
- dropout=dropout,
- layer_norm_first=layer_norm_first,
- feed_forward=feed_forward,
- )
- )
- transformer = Transformer(
- pos_conv_embed=pos_conv,
- dropout=dropout,
- layers=encoder_layers,
- layer_norm_first=not layer_norm_first,
- layer_drop=layer_drop,
- )
- return Encoder(feature_projection, transformer)
- def _compute_mask_indices(
- shape: Tuple[int, int],
- padding_mask: Optional[Tensor],
- mask_prob: float,
- mask_length: int,
- mask_type: str = "static",
- mask_other: float = 0.0,
- min_masks: int = 0,
- no_overlap: bool = False,
- min_space: int = 0,
- ) -> Tensor:
- """Computes random mask spans for a given shape.
- Args:
- shape (int, int): The shape for which to compute masks.
- The first element is batch size and second is the number of frames.
- padding_mask (Tensor or None): The padding mask of the same dimension as shape,
- which will prevent masking padded elements.
- mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
- This will be multiplied by number of timesteps divided by length of mask span to mask
- approximately this percentage of all elements. However due to overlaps, the actual number
- will be smaller (unless no_overlap is True).
- mask_type (str): How to compute mask lengths. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
- ``static``: Fixed size
- ``uniform``: Sample from uniform distribution [mask_other, mask_length*2]
- ``normal``: Sample from normal distribution with mean ``mask_length`` and stdev ``mask_other``.
- ``poisson``: Sample from possion distribution with lambda = ``mask_length``.
- min_masks (int): Minimum number of masked spans.
- no_overlap (bool): If false, will switch to an alternative recursive algorithm
- that prevents spans from overlapping.
- min_space (int): How many frames to keep unmasked between spans (Only used if no_overlap is True).
- Returns:
- (Tensor): The mask indices of dimension `[batch, frame]`.
- """
- batch_size, frame = shape
- mask = torch.full((batch_size, frame), False)
- # add a random number for probabilistic rounding
- all_num_mask = int(mask_prob * frame / float(mask_length) + torch.rand(1))
- all_num_mask = max(min_masks, all_num_mask)
- mask_idcs = []
- for i in range(batch_size):
- if padding_mask is not None:
- sz = frame - padding_mask[i].long().sum().item()
- # add a random number for probabilistic rounding
- num_mask = int(mask_prob * sz / float(mask_length) + torch.rand(1))
- num_mask = max(min_masks, num_mask)
- else:
- sz = frame
- num_mask = all_num_mask
- if mask_type == "static":
- lengths = torch.full((num_mask,), mask_length)
- elif mask_type == "uniform":
- lengths = torch.randint(mask_other, mask_length * 2 + 1, size=(num_mask,))
- elif mask_type == "normal":
- lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
- lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
- elif mask_type == "poisson":
- lengths = torch.poisson(mask_length, size=(num_mask,))
- lengths = torch.round(lengths).int()
- else:
- raise Exception(f"unknown mask selection: {mask_type}")
- if sum(lengths) == 0:
- lengths[0] = min(mask_length, sz - 1)
- if no_overlap:
- mask_idc = []
- def arrange(s, e, length, keep_length):
- span_start = torch.randint(s, e - length, size=(1,))
- mask_idc.extend(span_start + i for i in range(length))
- new_parts = []
- if span_start - s - min_space >= keep_length:
- new_parts.append((s, span_start - min_space + 1))
- if e - span_start - keep_length - min_space > keep_length:
- new_parts.append((span_start + length + min_space, e))
- return new_parts
- parts = [(0, sz)]
- min_length = min(lengths)
- for length in sorted(lengths, reverse=True):
- lens = torch.tensor([e - s for s, e in parts], dtype=torch.int)
- lens[lens < length + min_space] = 0
- l_sum = lens.sum()
- if l_sum == 0:
- break
- probs = lens / l_sum
- c = torch.distributions.categorical.Categorical(probs).sample()
- s, e = parts.pop(c)
- parts.extend(arrange(s, e, length, min_length))
- mask_idc = torch.tensor(mask_idc)
- else:
- min_len = min(lengths)
- if sz - min_len <= num_mask:
- min_len = sz - num_mask - 1
- mask_idc = torch.multinomial(torch.ones((sz - min_len,)), num_samples=num_mask, replacement=False)
- mask_idc = torch.tensor(
- [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
- )
- mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
- min_len = min([len(m) for m in mask_idcs])
- for i, mask_idc in enumerate(mask_idcs):
- if len(mask_idc) > min_len:
- mask_idc = torch.index_select(
- mask_idc,
- 0,
- torch.multinomial(
- torch.ones((mask_idc.shape[0],)),
- num_samples=min_len,
- replacement=False,
- ),
- )
- mask[i, mask_idc] = True
- return mask
- def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor:
- """Generate the padding mask given the padded input and the lengths Tensors.
- Args:
- input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`.
- lengths (Tensor): The lengths Tensor of dimension `[batch,]`.
- Returns:
- (Tensor): The padding mask.
- """
- batch_size, max_len, _ = input.shape
- mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
- return mask
- class MaskGenerator(Module):
- """Generate the masks for masked prediction.
- Args:
- encoder_embed_dim (int): The dimension of the transformer embedding output.
- mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
- This will be multiplied by number of timesteps divided by length of mask span to mask
- approximately this percentage of all elements. However due to overlaps, the actual number
- will be smaller (unless no_overlap is True).
- mask_selection (str): How to choose the mask length.
- Options: [``static``, ``uniform``, ``normal``, ``poisson``].
- mask_other (float): Secondary mask argument (used for more complex distributions).
- mask_length (int): The lengths of the mask.
- no_mask_overlap (bool): Whether to allow masks to overlap.
- mask_min_space (int): Minimum space between spans (if no overlap is enabled).
- mask_channel_prob (float): The probability of replacing a feature with 0.
- mask_channel_selection (str): How to choose the mask length for channel masking.
- Options: [``static``, ``uniform``, ``normal``, ``poisson``].
- mask_channel_other (float): Secondary mask argument for channel masking(used for more complex distributions).
- mask_channel_length (int): Minimum space between spans (if no overlap is enabled) for channel masking.
- no_mask_channel_overlap (bool): Whether to allow channel masks to overlap.
- mask_channel_min_space (int): Minimum space between spans for channel masking(if no overlap is enabled).
- """
- def __init__(
- self,
- encoder_embed_dim: int,
- mask_prob: float,
- mask_selection: str,
- mask_other: float,
- mask_length: int,
- no_mask_overlap: bool,
- mask_min_space: int,
- mask_channel_prob: float,
- mask_channel_selection: str,
- mask_channel_other: float,
- mask_channel_length: int,
- no_mask_channel_overlap: bool,
- mask_channel_min_space: int,
- ):
- super().__init__()
- self.mask_prob = mask_prob
- self.mask_selection = mask_selection
- self.mask_other = mask_other
- self.mask_length = mask_length
- self.no_mask_overlap = no_mask_overlap
- self.mask_min_space = mask_min_space
- self.mask_channel_prob = mask_channel_prob
- self.mask_channel_selection = mask_channel_selection
- self.mask_channel_other = mask_channel_other
- self.mask_channel_length = mask_channel_length
- self.no_mask_channel_overlap = no_mask_channel_overlap
- self.mask_channel_min_space = mask_channel_min_space
- self.mask_embedding = Parameter(torch.FloatTensor(encoder_embed_dim))
- torch.nn.init.uniform_(self.mask_embedding)
- def forward(self, x: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
- """
- Args:
- x (Tensor): The encoded representations after feature extraction module.
- padding_mask (Tensor or None): The padding mask of the same dimension as shape,
- which will prevent masking padded elements.
- Returns:
- Tensor: The feature representations after masking.
- Tensor: The generated mask indices.
- """
- B, T, C = x.shape
- if self.mask_prob > 0:
- mask_indices = _compute_mask_indices(
- (B, T),
- padding_mask,
- self.mask_prob,
- self.mask_length,
- self.mask_selection,
- self.mask_other,
- min_masks=2,
- no_overlap=self.no_mask_overlap,
- min_space=self.mask_min_space,
- )
- mask_indices = mask_indices.to(x.device)
- x[mask_indices] = self.mask_embedding
- else:
- mask_indices = None
- if self.mask_channel_prob > 0:
- mask_channel_indices = _compute_mask_indices(
- (B, C),
- None,
- self.mask_channel_prob,
- self.mask_channel_length,
- self.mask_channel_selection,
- self.mask_channel_other,
- no_overlap=self.no_mask_channel_overlap,
- min_space=self.mask_channel_min_space,
- )
- mask_channel_indices = mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
- x[mask_channel_indices] = 0
- return x, mask_indices
- def _compute_logits(
- proj_x: Tensor,
- target: Tensor,
- label_embeddings: Parameter,
- ) -> Tensor:
- """Compute the logits of the embeddings.
- Args:
- proj_x (Tensor): The projected masked representations of dimension `[batch, frame, final_dim]`.
- target (Tensor): The target Tensor of dimension `[batch, frame, final_dim]`.
- label_embeddings (Parameter): The trainable embeddings of target of dimension `[num_class, final_dim]`.
- Returns:
- (Tensor): The logits of the inputs.
- """
- logit_temp = 0.1
- pos = torch.index_select(label_embeddings, 0, target.long())
- negs = label_embeddings.unsqueeze(1).expand(-1, proj_x.size(0), -1)
- neg_is_pos = (pos == negs).all(-1)
- pos = pos.unsqueeze(0)
- targets = torch.cat([pos, negs], dim=0)
- logits = torch.cosine_similarity(proj_x.float(), targets.float(), dim=-1).type_as(proj_x)
- logits /= logit_temp
- if neg_is_pos.any():
- logits[1:][neg_is_pos] = float("-inf")
- logits = logits.transpose(0, 1) # (num_x, num_cls+1)
- return logits
- class LogitGenerator(Module):
- """Generate the logits of masked and unmasked inputs.
- Args:
- encoder_embed_dim (int): The dimension of the transformer embedding output.
- num_classes (int): The number of classes in the labels.
- final_dim (int): Project final representations and targets to `final_dim`.
- skip_masked (bool): If True, skip computing losses over masked frames.
- skip_nomask (bool): If True, skip computing losses over unmasked frames.
- """
- def __init__(
- self,
- encoder_embed_dim: int,
- num_classes: int,
- final_dim: int,
- skip_masked: bool,
- skip_nomask: bool,
- ):
- super().__init__()
- self.label_embeddings = Parameter(torch.FloatTensor(num_classes, final_dim))
- torch.nn.init.uniform_(self.label_embeddings)
- self.final_proj = torch.nn.Linear(encoder_embed_dim, final_dim)
- self.skip_masked = skip_masked
- self.skip_nomask = skip_nomask
- def forward(self, x: Tensor, label: Tensor, mask_m: Tensor, mask_u: Tensor) -> Tuple[Tensor, Tensor]:
- """
- Args:
- x (Tensor): The feature representation of the last transformer layer.
- label (Tensor): The label Tensor of dimension `[batch, frame]`.
- mask_m (Tensor): The masked indices of dimension `[batch, frame]`.
- mask_u (Tensor): The unmasked indices of dimension `[batch, frame]`.
- Returns:
- Tensor: The logits of masked frames. Tensor of dimension `[masked_frame, final_dim]`.
- Tensor: The logits of unmasked frames. Tensor of dimension `[unmasked_frame, final_dim]`.
- """
- proj_x = self.final_proj(x)
- if self.skip_masked:
- logit_m = None
- else:
- proj_x_m = proj_x[mask_m]
- label_m = label[mask_m]
- logit_m = _compute_logits(proj_x_m, label_m, self.label_embeddings)
- if self.skip_nomask:
- logit_u = None
- else:
- proj_x_u = proj_x[mask_u]
- label_u = label[mask_u]
- logit_u = _compute_logits(proj_x_u, label_u, self.label_embeddings)
- return logit_m, logit_u
- class GradMultiply(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x, scale):
- ctx.scale = scale
- res = x.new(x)
- return res
- @staticmethod
- def backward(ctx, grad):
- return grad * ctx.scale, None
|