| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876 |
- import math
- from typing import List, Optional, Tuple
- import torch
- __all__ = ["Emformer"]
- def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
- batch_size = lengths.shape[0]
- max_length = int(torch.max(lengths).item())
- padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
- batch_size, max_length
- ) >= lengths.unsqueeze(1)
- return padding_mask
- def _gen_padding_mask(
- utterance: torch.Tensor,
- right_context: torch.Tensor,
- summary: torch.Tensor,
- lengths: torch.Tensor,
- mems: torch.Tensor,
- left_context_key: Optional[torch.Tensor] = None,
- ) -> Optional[torch.Tensor]:
- T = right_context.size(0) + utterance.size(0) + summary.size(0)
- B = right_context.size(1)
- if B == 1:
- padding_mask = None
- else:
- right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
- left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
- klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
- padding_mask = _lengths_to_padding_mask(lengths=klengths)
- return padding_mask
- def _get_activation_module(activation: str) -> torch.nn.Module:
- if activation == "relu":
- return torch.nn.ReLU()
- elif activation == "gelu":
- return torch.nn.GELU()
- elif activation == "silu":
- return torch.nn.SiLU()
- else:
- raise ValueError(f"Unsupported activation {activation}")
- def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
- if weight_init_scale_strategy is None:
- return [None for _ in range(num_layers)]
- elif weight_init_scale_strategy == "depthwise":
- return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)]
- elif weight_init_scale_strategy == "constant":
- return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
- else:
- raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
- def _gen_attention_mask_block(
- col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
- ) -> torch.Tensor:
- assert len(col_widths) == len(col_mask), "Length of col_widths must match that of col_mask"
- mask_block = [
- torch.ones(num_rows, col_width, device=device)
- if is_ones_col
- else torch.zeros(num_rows, col_width, device=device)
- for col_width, is_ones_col in zip(col_widths, col_mask)
- ]
- return torch.cat(mask_block, dim=1)
- class _EmformerAttention(torch.nn.Module):
- r"""Emformer layer attention module.
- Args:
- input_dim (int): input dimension.
- num_heads (int): number of attention heads in each Emformer layer.
- dropout (float, optional): dropout probability. (Default: 0.0)
- weight_init_gain (float or None, optional): scale factor to apply when initializing
- attention module parameters. (Default: ``None``)
- tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
- negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
- """
- def __init__(
- self,
- input_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- weight_init_gain: Optional[float] = None,
- tanh_on_mem: bool = False,
- negative_inf: float = -1e8,
- ):
- super().__init__()
- if input_dim % num_heads != 0:
- raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
- self.input_dim = input_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.tanh_on_mem = tanh_on_mem
- self.negative_inf = negative_inf
- self.scaling = (self.input_dim // self.num_heads) ** -0.5
- self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True)
- self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True)
- self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
- if weight_init_gain:
- torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
- torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
- def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- T, _, _ = input.shape
- summary_length = mems.size(0) + 1
- right_ctx_utterance_block = input[: T - summary_length]
- mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
- key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
- return key, value
- def _gen_attention_probs(
- self,
- attention_weights: torch.Tensor,
- attention_mask: torch.Tensor,
- padding_mask: Optional[torch.Tensor],
- ) -> torch.Tensor:
- attention_weights_float = attention_weights.float()
- attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
- T = attention_weights.size(1)
- B = attention_weights.size(0) // self.num_heads
- if padding_mask is not None:
- attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
- attention_weights_float = attention_weights_float.masked_fill(
- padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
- )
- attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
- attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
- return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
- def _forward_impl(
- self,
- utterance: torch.Tensor,
- lengths: torch.Tensor,
- right_context: torch.Tensor,
- summary: torch.Tensor,
- mems: torch.Tensor,
- attention_mask: torch.Tensor,
- left_context_key: Optional[torch.Tensor] = None,
- left_context_val: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- B = utterance.size(1)
- T = right_context.size(0) + utterance.size(0) + summary.size(0)
- # Compute query with [right context, utterance, summary].
- query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
- # Compute key and value with [mems, right context, utterance].
- key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
- if left_context_key is not None and left_context_val is not None:
- right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
- key = torch.cat(
- [
- key[: mems.size(0) + right_context_blocks_length],
- left_context_key,
- key[mems.size(0) + right_context_blocks_length :],
- ],
- )
- value = torch.cat(
- [
- value[: mems.size(0) + right_context_blocks_length],
- left_context_val,
- value[mems.size(0) + right_context_blocks_length :],
- ],
- )
- # Compute attention weights from query, key, and value.
- reshaped_query, reshaped_key, reshaped_value = [
- tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
- for tensor in [query, key, value]
- ]
- attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
- # Compute padding mask.
- padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
- # Compute attention probabilities.
- attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
- # Compute attention.
- attention = torch.bmm(attention_probs, reshaped_value)
- assert attention.shape == (
- B * self.num_heads,
- T,
- self.input_dim // self.num_heads,
- )
- attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
- # Apply output projection.
- output_right_context_mems = self.out_proj(attention)
- summary_length = summary.size(0)
- output_right_context = output_right_context_mems[: T - summary_length]
- output_mems = output_right_context_mems[T - summary_length :]
- if self.tanh_on_mem:
- output_mems = torch.tanh(output_mems)
- else:
- output_mems = torch.clamp(output_mems, min=-10, max=10)
- return output_right_context, output_mems, key, value
- def forward(
- self,
- utterance: torch.Tensor,
- lengths: torch.Tensor,
- right_context: torch.Tensor,
- summary: torch.Tensor,
- mems: torch.Tensor,
- attention_mask: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""Forward pass for training.
- B: batch size;
- D: feature dimension of each frame;
- T: number of utterance frames;
- R: number of right context frames;
- S: number of summary elements;
- M: number of memory elements.
- Args:
- utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``utterance``.
- right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
- summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
- mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
- attention_mask (torch.Tensor): attention mask for underlying attention module.
- Returns:
- (Tensor, Tensor):
- Tensor
- output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
- Tensor
- updated memory elements, with shape `(M, B, D)`.
- """
- output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
- return output, output_mems[:-1]
- @torch.jit.export
- def infer(
- self,
- utterance: torch.Tensor,
- lengths: torch.Tensor,
- right_context: torch.Tensor,
- summary: torch.Tensor,
- mems: torch.Tensor,
- left_context_key: torch.Tensor,
- left_context_val: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- r"""Forward pass for inference.
- B: batch size;
- D: feature dimension of each frame;
- T: number of utterance frames;
- R: number of right context frames;
- S: number of summary elements;
- M: number of memory elements.
- Args:
- utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``utterance``.
- right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
- summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
- mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
- left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
- left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
- Returns:
- (Tensor, Tensor, Tensor, and Tensor):
- Tensor
- output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
- Tensor
- updated memory elements, with shape `(M, B, D)`.
- Tensor
- attention key computed for left context and utterance.
- Tensor
- attention value computed for left context and utterance.
- """
- query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
- key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
- attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
- attention_mask[-1, : mems.size(0)] = True
- output, output_mems, key, value = self._forward_impl(
- utterance,
- lengths,
- right_context,
- summary,
- mems,
- attention_mask,
- left_context_key=left_context_key,
- left_context_val=left_context_val,
- )
- return (
- output,
- output_mems,
- key[mems.size(0) + right_context.size(0) :],
- value[mems.size(0) + right_context.size(0) :],
- )
- class _EmformerLayer(torch.nn.Module):
- r"""Emformer layer that constitutes Emformer.
- Args:
- input_dim (int): input dimension.
- num_heads (int): number of attention heads.
- ffn_dim: (int): hidden layer dimension of feedforward network.
- segment_length (int): length of each input segment.
- dropout (float, optional): dropout probability. (Default: 0.0)
- activation (str, optional): activation function to use in feedforward network.
- Must be one of ("relu", "gelu", "silu"). (Default: "relu")
- left_context_length (int, optional): length of left context. (Default: 0)
- max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
- weight_init_gain (float or None, optional): scale factor to apply when initializing
- attention module parameters. (Default: ``None``)
- tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
- negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
- """
- def __init__(
- self,
- input_dim: int,
- num_heads: int,
- ffn_dim: int,
- segment_length: int,
- dropout: float = 0.0,
- activation: str = "relu",
- left_context_length: int = 0,
- max_memory_size: int = 0,
- weight_init_gain: Optional[float] = None,
- tanh_on_mem: bool = False,
- negative_inf: float = -1e8,
- ):
- super().__init__()
- self.attention = _EmformerAttention(
- input_dim=input_dim,
- num_heads=num_heads,
- dropout=dropout,
- weight_init_gain=weight_init_gain,
- tanh_on_mem=tanh_on_mem,
- negative_inf=negative_inf,
- )
- self.dropout = torch.nn.Dropout(dropout)
- self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
- activation_module = _get_activation_module(activation)
- self.pos_ff = torch.nn.Sequential(
- torch.nn.LayerNorm(input_dim),
- torch.nn.Linear(input_dim, ffn_dim),
- activation_module,
- torch.nn.Dropout(dropout),
- torch.nn.Linear(ffn_dim, input_dim),
- torch.nn.Dropout(dropout),
- )
- self.layer_norm_input = torch.nn.LayerNorm(input_dim)
- self.layer_norm_output = torch.nn.LayerNorm(input_dim)
- self.left_context_length = left_context_length
- self.segment_length = segment_length
- self.max_memory_size = max_memory_size
- self.input_dim = input_dim
- self.use_mem = max_memory_size > 0
- def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
- empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
- left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
- left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
- past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
- return [empty_memory, left_context_key, left_context_val, past_length]
- def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- past_length = state[3][0][0].item()
- past_left_context_length = min(self.left_context_length, past_length)
- past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
- pre_mems = state[0][self.max_memory_size - past_mem_length :]
- lc_key = state[1][self.left_context_length - past_left_context_length :]
- lc_val = state[2][self.left_context_length - past_left_context_length :]
- return pre_mems, lc_key, lc_val
- def _pack_state(
- self,
- next_k: torch.Tensor,
- next_v: torch.Tensor,
- update_length: int,
- mems: torch.Tensor,
- state: List[torch.Tensor],
- ) -> List[torch.Tensor]:
- new_k = torch.cat([state[1], next_k])
- new_v = torch.cat([state[2], next_v])
- state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
- state[1] = new_k[new_k.shape[0] - self.left_context_length :]
- state[2] = new_v[new_v.shape[0] - self.left_context_length :]
- state[3] = state[3] + update_length
- return state
- def _process_attention_output(
- self,
- rc_output: torch.Tensor,
- utterance: torch.Tensor,
- right_context: torch.Tensor,
- ) -> torch.Tensor:
- result = self.dropout(rc_output) + torch.cat([right_context, utterance])
- result = self.pos_ff(result) + result
- result = self.layer_norm_output(result)
- return result
- def _apply_pre_attention_layer_norm(
- self, utterance: torch.Tensor, right_context: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
- return (
- layer_norm_input[right_context.size(0) :],
- layer_norm_input[: right_context.size(0)],
- )
- def _apply_post_attention_ffn(
- self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- rc_output = self._process_attention_output(rc_output, utterance, right_context)
- return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
- def _apply_attention_forward(
- self,
- utterance: torch.Tensor,
- lengths: torch.Tensor,
- right_context: torch.Tensor,
- mems: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- if attention_mask is None:
- raise ValueError("attention_mask must be not None when for_inference is False")
- if self.use_mem:
- summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
- else:
- summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
- rc_output, next_m = self.attention(
- utterance=utterance,
- lengths=lengths,
- right_context=right_context,
- summary=summary,
- mems=mems,
- attention_mask=attention_mask,
- )
- return rc_output, next_m
- def _apply_attention_infer(
- self,
- utterance: torch.Tensor,
- lengths: torch.Tensor,
- right_context: torch.Tensor,
- mems: torch.Tensor,
- state: Optional[List[torch.Tensor]],
- ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
- if state is None:
- state = self._init_state(utterance.size(1), device=utterance.device)
- pre_mems, lc_key, lc_val = self._unpack_state(state)
- if self.use_mem:
- summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
- summary = summary[:1]
- else:
- summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
- rc_output, next_m, next_k, next_v = self.attention.infer(
- utterance=utterance,
- lengths=lengths,
- right_context=right_context,
- summary=summary,
- mems=pre_mems,
- left_context_key=lc_key,
- left_context_val=lc_val,
- )
- state = self._pack_state(next_k, next_v, utterance.size(0), mems, state)
- return rc_output, next_m, state
- def forward(
- self,
- utterance: torch.Tensor,
- lengths: torch.Tensor,
- right_context: torch.Tensor,
- mems: torch.Tensor,
- attention_mask: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- r"""Forward pass for training.
- B: batch size;
- D: feature dimension of each frame;
- T: number of utterance frames;
- R: number of right context frames;
- M: number of memory elements.
- Args:
- utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``utterance``.
- right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
- mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
- attention_mask (torch.Tensor): attention mask for underlying attention module.
- Returns:
- (Tensor, Tensor, Tensor):
- Tensor
- encoded utterance frames, with shape `(T, B, D)`.
- Tensor
- updated right context frames, with shape `(R, B, D)`.
- Tensor
- updated memory elements, with shape `(M, B, D)`.
- """
- (
- layer_norm_utterance,
- layer_norm_right_context,
- ) = self._apply_pre_attention_layer_norm(utterance, right_context)
- rc_output, output_mems = self._apply_attention_forward(
- layer_norm_utterance,
- lengths,
- layer_norm_right_context,
- mems,
- attention_mask,
- )
- output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
- return output_utterance, output_right_context, output_mems
- @torch.jit.export
- def infer(
- self,
- utterance: torch.Tensor,
- lengths: torch.Tensor,
- right_context: torch.Tensor,
- state: Optional[List[torch.Tensor]],
- mems: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
- r"""Forward pass for inference.
- B: batch size;
- D: feature dimension of each frame;
- T: number of utterance frames;
- R: number of right context frames;
- M: number of memory elements.
- Args:
- utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``utterance``.
- right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
- state (List[torch.Tensor] or None): list of tensors representing layer internal state
- generated in preceding invocation of ``infer``.
- mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
- Returns:
- (Tensor, Tensor, List[torch.Tensor], Tensor):
- Tensor
- encoded utterance frames, with shape `(T, B, D)`.
- Tensor
- updated right context frames, with shape `(R, B, D)`.
- List[Tensor]
- list of tensors representing layer internal state
- generated in current invocation of ``infer``.
- Tensor
- updated memory elements, with shape `(M, B, D)`.
- """
- (
- layer_norm_utterance,
- layer_norm_right_context,
- ) = self._apply_pre_attention_layer_norm(utterance, right_context)
- rc_output, output_mems, output_state = self._apply_attention_infer(
- layer_norm_utterance, lengths, layer_norm_right_context, mems, state
- )
- output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
- return output_utterance, output_right_context, output_state, output_mems
- class _EmformerImpl(torch.nn.Module):
- def __init__(
- self,
- emformer_layers: torch.nn.ModuleList,
- segment_length: int,
- left_context_length: int = 0,
- right_context_length: int = 0,
- max_memory_size: int = 0,
- ):
- super().__init__()
- self.use_mem = max_memory_size > 0
- self.memory_op = torch.nn.AvgPool1d(
- kernel_size=segment_length,
- stride=segment_length,
- ceil_mode=True,
- )
- self.emformer_layers = emformer_layers
- self.left_context_length = left_context_length
- self.right_context_length = right_context_length
- self.segment_length = segment_length
- self.max_memory_size = max_memory_size
- def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
- T = input.shape[0]
- num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
- right_context_blocks = []
- for seg_idx in range(num_segs - 1):
- start = (seg_idx + 1) * self.segment_length
- end = start + self.right_context_length
- right_context_blocks.append(input[start:end])
- right_context_blocks.append(input[T - self.right_context_length :])
- return torch.cat(right_context_blocks)
- def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
- num_segs = math.ceil(utterance_length / self.segment_length)
- rc = self.right_context_length
- lc = self.left_context_length
- rc_start = seg_idx * rc
- rc_end = rc_start + rc
- seg_start = max(seg_idx * self.segment_length - lc, 0)
- seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
- rc_length = self.right_context_length * num_segs
- if self.use_mem:
- m_start = max(seg_idx - self.max_memory_size, 0)
- mem_length = num_segs - 1
- col_widths = [
- m_start, # before memory
- seg_idx - m_start, # memory
- mem_length - seg_idx, # after memory
- rc_start, # before right context
- rc, # right context
- rc_length - rc_end, # after right context
- seg_start, # before query segment
- seg_end - seg_start, # query segment
- utterance_length - seg_end, # after query segment
- ]
- else:
- col_widths = [
- rc_start, # before right context
- rc, # right context
- rc_length - rc_end, # after right context
- seg_start, # before query segment
- seg_end - seg_start, # query segment
- utterance_length - seg_end, # after query segment
- ]
- return col_widths
- def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
- utterance_length = input.size(0)
- num_segs = math.ceil(utterance_length / self.segment_length)
- rc_mask = []
- query_mask = []
- summary_mask = []
- if self.use_mem:
- num_cols = 9
- # memory, right context, query segment
- rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)]
- # right context, query segment
- s_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
- masks_to_concat = [rc_mask, query_mask, summary_mask]
- else:
- num_cols = 6
- # right context, query segment
- rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)]
- s_cols_mask = None
- masks_to_concat = [rc_mask, query_mask]
- for seg_idx in range(num_segs):
- col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length)
- rc_mask_block = _gen_attention_mask_block(
- col_widths, rc_q_cols_mask, self.right_context_length, input.device
- )
- rc_mask.append(rc_mask_block)
- query_mask_block = _gen_attention_mask_block(
- col_widths,
- rc_q_cols_mask,
- min(
- self.segment_length,
- utterance_length - seg_idx * self.segment_length,
- ),
- input.device,
- )
- query_mask.append(query_mask_block)
- if s_cols_mask is not None:
- summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
- summary_mask.append(summary_mask_block)
- attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
- return attention_mask
- def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""Forward pass for training and non-streaming inference.
- B: batch size;
- T: max number of input frames in batch;
- D: feature dimension of each frame.
- Args:
- input (torch.Tensor): utterance frames right-padded with right context frames, with
- shape `(B, T + right_context_length, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid utterance frames for i-th batch element in ``input``.
- Returns:
- (Tensor, Tensor):
- Tensor
- output frames, with shape `(B, T, D)`.
- Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in output frames.
- """
- input = input.permute(1, 0, 2)
- right_context = self._gen_right_context(input)
- utterance = input[: input.size(0) - self.right_context_length]
- attention_mask = self._gen_attention_mask(utterance)
- mems = (
- self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
- if self.use_mem
- else torch.empty(0).to(dtype=input.dtype, device=input.device)
- )
- output = utterance
- for layer in self.emformer_layers:
- output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
- return output.permute(1, 0, 2), lengths
- @torch.jit.export
- def infer(
- self,
- input: torch.Tensor,
- lengths: torch.Tensor,
- states: Optional[List[List[torch.Tensor]]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
- r"""Forward pass for streaming inference.
- B: batch size;
- D: feature dimension of each frame.
- Args:
- input (torch.Tensor): utterance frames right-padded with right context frames, with
- shape `(B, segment_length + right_context_length, D)`.
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in ``input``.
- states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
- representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
- Returns:
- (Tensor, Tensor, List[List[Tensor]]):
- Tensor
- output frames, with shape `(B, segment_length, D)`.
- Tensor
- output lengths, with shape `(B,)` and i-th element representing
- number of valid frames for i-th batch element in output frames.
- List[List[Tensor]]
- output states; list of lists of tensors representing internal state
- generated in current invocation of ``infer``.
- """
- assert input.size(1) == self.segment_length + self.right_context_length, (
- "Per configured segment_length and right_context_length"
- f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
- f", but got {input.size(1)}."
- )
- input = input.permute(1, 0, 2)
- right_context_start_idx = input.size(0) - self.right_context_length
- right_context = input[right_context_start_idx:]
- utterance = input[:right_context_start_idx]
- output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
- mems = (
- self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
- if self.use_mem
- else torch.empty(0).to(dtype=input.dtype, device=input.device)
- )
- output = utterance
- output_states: List[List[torch.Tensor]] = []
- for layer_idx, layer in enumerate(self.emformer_layers):
- output, right_context, output_state, mems = layer.infer(
- output,
- output_lengths,
- right_context,
- None if states is None else states[layer_idx],
- mems,
- )
- output_states.append(output_state)
- return output.permute(1, 0, 2), output_lengths, output_states
- class Emformer(_EmformerImpl):
- r"""Implements the Emformer architecture introduced in
- *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
- [:footcite:`shi2021emformer`].
- Args:
- input_dim (int): input dimension.
- num_heads (int): number of attention heads in each Emformer layer.
- ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
- num_layers (int): number of Emformer layers to instantiate.
- segment_length (int): length of each input segment.
- dropout (float, optional): dropout probability. (Default: 0.0)
- activation (str, optional): activation function to use in each Emformer layer's
- feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
- left_context_length (int, optional): length of left context. (Default: 0)
- right_context_length (int, optional): length of right context. (Default: 0)
- max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
- weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
- strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
- tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
- negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
- Examples:
- >>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
- >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
- >>> lengths = torch.randint(1, 200, (128,)) # batch
- >>> output, lengths = emformer(input, lengths)
- >>> input = torch.rand(128, 5, 512)
- >>> lengths = torch.ones(128) * 5
- >>> output, lengths, states = emformer.infer(input, lengths, None)
- """
- def __init__(
- self,
- input_dim: int,
- num_heads: int,
- ffn_dim: int,
- num_layers: int,
- segment_length: int,
- dropout: float = 0.0,
- activation: str = "relu",
- left_context_length: int = 0,
- right_context_length: int = 0,
- max_memory_size: int = 0,
- weight_init_scale_strategy: Optional[str] = "depthwise",
- tanh_on_mem: bool = False,
- negative_inf: float = -1e8,
- ):
- weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
- emformer_layers = torch.nn.ModuleList(
- [
- _EmformerLayer(
- input_dim,
- num_heads,
- ffn_dim,
- segment_length,
- dropout=dropout,
- activation=activation,
- left_context_length=left_context_length,
- max_memory_size=max_memory_size,
- weight_init_gain=weight_init_gains[layer_idx],
- tanh_on_mem=tanh_on_mem,
- negative_inf=negative_inf,
- )
- for layer_idx in range(num_layers)
- ]
- )
- super().__init__(
- emformer_layers,
- segment_length,
- left_context_length=left_context_length,
- right_context_length=right_context_length,
- max_memory_size=max_memory_size,
- )
|