emformer.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. import math
  2. from typing import List, Optional, Tuple
  3. import torch
  4. __all__ = ["Emformer"]
  5. def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
  6. batch_size = lengths.shape[0]
  7. max_length = int(torch.max(lengths).item())
  8. padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
  9. batch_size, max_length
  10. ) >= lengths.unsqueeze(1)
  11. return padding_mask
  12. def _gen_padding_mask(
  13. utterance: torch.Tensor,
  14. right_context: torch.Tensor,
  15. summary: torch.Tensor,
  16. lengths: torch.Tensor,
  17. mems: torch.Tensor,
  18. left_context_key: Optional[torch.Tensor] = None,
  19. ) -> Optional[torch.Tensor]:
  20. T = right_context.size(0) + utterance.size(0) + summary.size(0)
  21. B = right_context.size(1)
  22. if B == 1:
  23. padding_mask = None
  24. else:
  25. right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
  26. left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
  27. klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
  28. padding_mask = _lengths_to_padding_mask(lengths=klengths)
  29. return padding_mask
  30. def _get_activation_module(activation: str) -> torch.nn.Module:
  31. if activation == "relu":
  32. return torch.nn.ReLU()
  33. elif activation == "gelu":
  34. return torch.nn.GELU()
  35. elif activation == "silu":
  36. return torch.nn.SiLU()
  37. else:
  38. raise ValueError(f"Unsupported activation {activation}")
  39. def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
  40. if weight_init_scale_strategy is None:
  41. return [None for _ in range(num_layers)]
  42. elif weight_init_scale_strategy == "depthwise":
  43. return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)]
  44. elif weight_init_scale_strategy == "constant":
  45. return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
  46. else:
  47. raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
  48. def _gen_attention_mask_block(
  49. col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
  50. ) -> torch.Tensor:
  51. assert len(col_widths) == len(col_mask), "Length of col_widths must match that of col_mask"
  52. mask_block = [
  53. torch.ones(num_rows, col_width, device=device)
  54. if is_ones_col
  55. else torch.zeros(num_rows, col_width, device=device)
  56. for col_width, is_ones_col in zip(col_widths, col_mask)
  57. ]
  58. return torch.cat(mask_block, dim=1)
  59. class _EmformerAttention(torch.nn.Module):
  60. r"""Emformer layer attention module.
  61. Args:
  62. input_dim (int): input dimension.
  63. num_heads (int): number of attention heads in each Emformer layer.
  64. dropout (float, optional): dropout probability. (Default: 0.0)
  65. weight_init_gain (float or None, optional): scale factor to apply when initializing
  66. attention module parameters. (Default: ``None``)
  67. tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
  68. negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
  69. """
  70. def __init__(
  71. self,
  72. input_dim: int,
  73. num_heads: int,
  74. dropout: float = 0.0,
  75. weight_init_gain: Optional[float] = None,
  76. tanh_on_mem: bool = False,
  77. negative_inf: float = -1e8,
  78. ):
  79. super().__init__()
  80. if input_dim % num_heads != 0:
  81. raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
  82. self.input_dim = input_dim
  83. self.num_heads = num_heads
  84. self.dropout = dropout
  85. self.tanh_on_mem = tanh_on_mem
  86. self.negative_inf = negative_inf
  87. self.scaling = (self.input_dim // self.num_heads) ** -0.5
  88. self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True)
  89. self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True)
  90. self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
  91. if weight_init_gain:
  92. torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
  93. torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
  94. def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  95. T, _, _ = input.shape
  96. summary_length = mems.size(0) + 1
  97. right_ctx_utterance_block = input[: T - summary_length]
  98. mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
  99. key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
  100. return key, value
  101. def _gen_attention_probs(
  102. self,
  103. attention_weights: torch.Tensor,
  104. attention_mask: torch.Tensor,
  105. padding_mask: Optional[torch.Tensor],
  106. ) -> torch.Tensor:
  107. attention_weights_float = attention_weights.float()
  108. attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
  109. T = attention_weights.size(1)
  110. B = attention_weights.size(0) // self.num_heads
  111. if padding_mask is not None:
  112. attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
  113. attention_weights_float = attention_weights_float.masked_fill(
  114. padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
  115. )
  116. attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
  117. attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
  118. return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
  119. def _forward_impl(
  120. self,
  121. utterance: torch.Tensor,
  122. lengths: torch.Tensor,
  123. right_context: torch.Tensor,
  124. summary: torch.Tensor,
  125. mems: torch.Tensor,
  126. attention_mask: torch.Tensor,
  127. left_context_key: Optional[torch.Tensor] = None,
  128. left_context_val: Optional[torch.Tensor] = None,
  129. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  130. B = utterance.size(1)
  131. T = right_context.size(0) + utterance.size(0) + summary.size(0)
  132. # Compute query with [right context, utterance, summary].
  133. query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
  134. # Compute key and value with [mems, right context, utterance].
  135. key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
  136. if left_context_key is not None and left_context_val is not None:
  137. right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
  138. key = torch.cat(
  139. [
  140. key[: mems.size(0) + right_context_blocks_length],
  141. left_context_key,
  142. key[mems.size(0) + right_context_blocks_length :],
  143. ],
  144. )
  145. value = torch.cat(
  146. [
  147. value[: mems.size(0) + right_context_blocks_length],
  148. left_context_val,
  149. value[mems.size(0) + right_context_blocks_length :],
  150. ],
  151. )
  152. # Compute attention weights from query, key, and value.
  153. reshaped_query, reshaped_key, reshaped_value = [
  154. tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
  155. for tensor in [query, key, value]
  156. ]
  157. attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
  158. # Compute padding mask.
  159. padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
  160. # Compute attention probabilities.
  161. attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
  162. # Compute attention.
  163. attention = torch.bmm(attention_probs, reshaped_value)
  164. assert attention.shape == (
  165. B * self.num_heads,
  166. T,
  167. self.input_dim // self.num_heads,
  168. )
  169. attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
  170. # Apply output projection.
  171. output_right_context_mems = self.out_proj(attention)
  172. summary_length = summary.size(0)
  173. output_right_context = output_right_context_mems[: T - summary_length]
  174. output_mems = output_right_context_mems[T - summary_length :]
  175. if self.tanh_on_mem:
  176. output_mems = torch.tanh(output_mems)
  177. else:
  178. output_mems = torch.clamp(output_mems, min=-10, max=10)
  179. return output_right_context, output_mems, key, value
  180. def forward(
  181. self,
  182. utterance: torch.Tensor,
  183. lengths: torch.Tensor,
  184. right_context: torch.Tensor,
  185. summary: torch.Tensor,
  186. mems: torch.Tensor,
  187. attention_mask: torch.Tensor,
  188. ) -> Tuple[torch.Tensor, torch.Tensor]:
  189. r"""Forward pass for training.
  190. B: batch size;
  191. D: feature dimension of each frame;
  192. T: number of utterance frames;
  193. R: number of right context frames;
  194. S: number of summary elements;
  195. M: number of memory elements.
  196. Args:
  197. utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
  198. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  199. number of valid frames for i-th batch element in ``utterance``.
  200. right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
  201. summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
  202. mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
  203. attention_mask (torch.Tensor): attention mask for underlying attention module.
  204. Returns:
  205. (Tensor, Tensor):
  206. Tensor
  207. output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
  208. Tensor
  209. updated memory elements, with shape `(M, B, D)`.
  210. """
  211. output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
  212. return output, output_mems[:-1]
  213. @torch.jit.export
  214. def infer(
  215. self,
  216. utterance: torch.Tensor,
  217. lengths: torch.Tensor,
  218. right_context: torch.Tensor,
  219. summary: torch.Tensor,
  220. mems: torch.Tensor,
  221. left_context_key: torch.Tensor,
  222. left_context_val: torch.Tensor,
  223. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  224. r"""Forward pass for inference.
  225. B: batch size;
  226. D: feature dimension of each frame;
  227. T: number of utterance frames;
  228. R: number of right context frames;
  229. S: number of summary elements;
  230. M: number of memory elements.
  231. Args:
  232. utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
  233. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  234. number of valid frames for i-th batch element in ``utterance``.
  235. right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
  236. summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
  237. mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
  238. left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
  239. left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
  240. Returns:
  241. (Tensor, Tensor, Tensor, and Tensor):
  242. Tensor
  243. output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
  244. Tensor
  245. updated memory elements, with shape `(M, B, D)`.
  246. Tensor
  247. attention key computed for left context and utterance.
  248. Tensor
  249. attention value computed for left context and utterance.
  250. """
  251. query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
  252. key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
  253. attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
  254. attention_mask[-1, : mems.size(0)] = True
  255. output, output_mems, key, value = self._forward_impl(
  256. utterance,
  257. lengths,
  258. right_context,
  259. summary,
  260. mems,
  261. attention_mask,
  262. left_context_key=left_context_key,
  263. left_context_val=left_context_val,
  264. )
  265. return (
  266. output,
  267. output_mems,
  268. key[mems.size(0) + right_context.size(0) :],
  269. value[mems.size(0) + right_context.size(0) :],
  270. )
  271. class _EmformerLayer(torch.nn.Module):
  272. r"""Emformer layer that constitutes Emformer.
  273. Args:
  274. input_dim (int): input dimension.
  275. num_heads (int): number of attention heads.
  276. ffn_dim: (int): hidden layer dimension of feedforward network.
  277. segment_length (int): length of each input segment.
  278. dropout (float, optional): dropout probability. (Default: 0.0)
  279. activation (str, optional): activation function to use in feedforward network.
  280. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
  281. left_context_length (int, optional): length of left context. (Default: 0)
  282. max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
  283. weight_init_gain (float or None, optional): scale factor to apply when initializing
  284. attention module parameters. (Default: ``None``)
  285. tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
  286. negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
  287. """
  288. def __init__(
  289. self,
  290. input_dim: int,
  291. num_heads: int,
  292. ffn_dim: int,
  293. segment_length: int,
  294. dropout: float = 0.0,
  295. activation: str = "relu",
  296. left_context_length: int = 0,
  297. max_memory_size: int = 0,
  298. weight_init_gain: Optional[float] = None,
  299. tanh_on_mem: bool = False,
  300. negative_inf: float = -1e8,
  301. ):
  302. super().__init__()
  303. self.attention = _EmformerAttention(
  304. input_dim=input_dim,
  305. num_heads=num_heads,
  306. dropout=dropout,
  307. weight_init_gain=weight_init_gain,
  308. tanh_on_mem=tanh_on_mem,
  309. negative_inf=negative_inf,
  310. )
  311. self.dropout = torch.nn.Dropout(dropout)
  312. self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
  313. activation_module = _get_activation_module(activation)
  314. self.pos_ff = torch.nn.Sequential(
  315. torch.nn.LayerNorm(input_dim),
  316. torch.nn.Linear(input_dim, ffn_dim),
  317. activation_module,
  318. torch.nn.Dropout(dropout),
  319. torch.nn.Linear(ffn_dim, input_dim),
  320. torch.nn.Dropout(dropout),
  321. )
  322. self.layer_norm_input = torch.nn.LayerNorm(input_dim)
  323. self.layer_norm_output = torch.nn.LayerNorm(input_dim)
  324. self.left_context_length = left_context_length
  325. self.segment_length = segment_length
  326. self.max_memory_size = max_memory_size
  327. self.input_dim = input_dim
  328. self.use_mem = max_memory_size > 0
  329. def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
  330. empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
  331. left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
  332. left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
  333. past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
  334. return [empty_memory, left_context_key, left_context_val, past_length]
  335. def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  336. past_length = state[3][0][0].item()
  337. past_left_context_length = min(self.left_context_length, past_length)
  338. past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
  339. pre_mems = state[0][self.max_memory_size - past_mem_length :]
  340. lc_key = state[1][self.left_context_length - past_left_context_length :]
  341. lc_val = state[2][self.left_context_length - past_left_context_length :]
  342. return pre_mems, lc_key, lc_val
  343. def _pack_state(
  344. self,
  345. next_k: torch.Tensor,
  346. next_v: torch.Tensor,
  347. update_length: int,
  348. mems: torch.Tensor,
  349. state: List[torch.Tensor],
  350. ) -> List[torch.Tensor]:
  351. new_k = torch.cat([state[1], next_k])
  352. new_v = torch.cat([state[2], next_v])
  353. state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
  354. state[1] = new_k[new_k.shape[0] - self.left_context_length :]
  355. state[2] = new_v[new_v.shape[0] - self.left_context_length :]
  356. state[3] = state[3] + update_length
  357. return state
  358. def _process_attention_output(
  359. self,
  360. rc_output: torch.Tensor,
  361. utterance: torch.Tensor,
  362. right_context: torch.Tensor,
  363. ) -> torch.Tensor:
  364. result = self.dropout(rc_output) + torch.cat([right_context, utterance])
  365. result = self.pos_ff(result) + result
  366. result = self.layer_norm_output(result)
  367. return result
  368. def _apply_pre_attention_layer_norm(
  369. self, utterance: torch.Tensor, right_context: torch.Tensor
  370. ) -> Tuple[torch.Tensor, torch.Tensor]:
  371. layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
  372. return (
  373. layer_norm_input[right_context.size(0) :],
  374. layer_norm_input[: right_context.size(0)],
  375. )
  376. def _apply_post_attention_ffn(
  377. self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
  378. ) -> Tuple[torch.Tensor, torch.Tensor]:
  379. rc_output = self._process_attention_output(rc_output, utterance, right_context)
  380. return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
  381. def _apply_attention_forward(
  382. self,
  383. utterance: torch.Tensor,
  384. lengths: torch.Tensor,
  385. right_context: torch.Tensor,
  386. mems: torch.Tensor,
  387. attention_mask: Optional[torch.Tensor],
  388. ) -> Tuple[torch.Tensor, torch.Tensor]:
  389. if attention_mask is None:
  390. raise ValueError("attention_mask must be not None when for_inference is False")
  391. if self.use_mem:
  392. summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
  393. else:
  394. summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
  395. rc_output, next_m = self.attention(
  396. utterance=utterance,
  397. lengths=lengths,
  398. right_context=right_context,
  399. summary=summary,
  400. mems=mems,
  401. attention_mask=attention_mask,
  402. )
  403. return rc_output, next_m
  404. def _apply_attention_infer(
  405. self,
  406. utterance: torch.Tensor,
  407. lengths: torch.Tensor,
  408. right_context: torch.Tensor,
  409. mems: torch.Tensor,
  410. state: Optional[List[torch.Tensor]],
  411. ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
  412. if state is None:
  413. state = self._init_state(utterance.size(1), device=utterance.device)
  414. pre_mems, lc_key, lc_val = self._unpack_state(state)
  415. if self.use_mem:
  416. summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
  417. summary = summary[:1]
  418. else:
  419. summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
  420. rc_output, next_m, next_k, next_v = self.attention.infer(
  421. utterance=utterance,
  422. lengths=lengths,
  423. right_context=right_context,
  424. summary=summary,
  425. mems=pre_mems,
  426. left_context_key=lc_key,
  427. left_context_val=lc_val,
  428. )
  429. state = self._pack_state(next_k, next_v, utterance.size(0), mems, state)
  430. return rc_output, next_m, state
  431. def forward(
  432. self,
  433. utterance: torch.Tensor,
  434. lengths: torch.Tensor,
  435. right_context: torch.Tensor,
  436. mems: torch.Tensor,
  437. attention_mask: torch.Tensor,
  438. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  439. r"""Forward pass for training.
  440. B: batch size;
  441. D: feature dimension of each frame;
  442. T: number of utterance frames;
  443. R: number of right context frames;
  444. M: number of memory elements.
  445. Args:
  446. utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
  447. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  448. number of valid frames for i-th batch element in ``utterance``.
  449. right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
  450. mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
  451. attention_mask (torch.Tensor): attention mask for underlying attention module.
  452. Returns:
  453. (Tensor, Tensor, Tensor):
  454. Tensor
  455. encoded utterance frames, with shape `(T, B, D)`.
  456. Tensor
  457. updated right context frames, with shape `(R, B, D)`.
  458. Tensor
  459. updated memory elements, with shape `(M, B, D)`.
  460. """
  461. (
  462. layer_norm_utterance,
  463. layer_norm_right_context,
  464. ) = self._apply_pre_attention_layer_norm(utterance, right_context)
  465. rc_output, output_mems = self._apply_attention_forward(
  466. layer_norm_utterance,
  467. lengths,
  468. layer_norm_right_context,
  469. mems,
  470. attention_mask,
  471. )
  472. output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
  473. return output_utterance, output_right_context, output_mems
  474. @torch.jit.export
  475. def infer(
  476. self,
  477. utterance: torch.Tensor,
  478. lengths: torch.Tensor,
  479. right_context: torch.Tensor,
  480. state: Optional[List[torch.Tensor]],
  481. mems: torch.Tensor,
  482. ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
  483. r"""Forward pass for inference.
  484. B: batch size;
  485. D: feature dimension of each frame;
  486. T: number of utterance frames;
  487. R: number of right context frames;
  488. M: number of memory elements.
  489. Args:
  490. utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
  491. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  492. number of valid frames for i-th batch element in ``utterance``.
  493. right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
  494. state (List[torch.Tensor] or None): list of tensors representing layer internal state
  495. generated in preceding invocation of ``infer``.
  496. mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
  497. Returns:
  498. (Tensor, Tensor, List[torch.Tensor], Tensor):
  499. Tensor
  500. encoded utterance frames, with shape `(T, B, D)`.
  501. Tensor
  502. updated right context frames, with shape `(R, B, D)`.
  503. List[Tensor]
  504. list of tensors representing layer internal state
  505. generated in current invocation of ``infer``.
  506. Tensor
  507. updated memory elements, with shape `(M, B, D)`.
  508. """
  509. (
  510. layer_norm_utterance,
  511. layer_norm_right_context,
  512. ) = self._apply_pre_attention_layer_norm(utterance, right_context)
  513. rc_output, output_mems, output_state = self._apply_attention_infer(
  514. layer_norm_utterance, lengths, layer_norm_right_context, mems, state
  515. )
  516. output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
  517. return output_utterance, output_right_context, output_state, output_mems
  518. class _EmformerImpl(torch.nn.Module):
  519. def __init__(
  520. self,
  521. emformer_layers: torch.nn.ModuleList,
  522. segment_length: int,
  523. left_context_length: int = 0,
  524. right_context_length: int = 0,
  525. max_memory_size: int = 0,
  526. ):
  527. super().__init__()
  528. self.use_mem = max_memory_size > 0
  529. self.memory_op = torch.nn.AvgPool1d(
  530. kernel_size=segment_length,
  531. stride=segment_length,
  532. ceil_mode=True,
  533. )
  534. self.emformer_layers = emformer_layers
  535. self.left_context_length = left_context_length
  536. self.right_context_length = right_context_length
  537. self.segment_length = segment_length
  538. self.max_memory_size = max_memory_size
  539. def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
  540. T = input.shape[0]
  541. num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
  542. right_context_blocks = []
  543. for seg_idx in range(num_segs - 1):
  544. start = (seg_idx + 1) * self.segment_length
  545. end = start + self.right_context_length
  546. right_context_blocks.append(input[start:end])
  547. right_context_blocks.append(input[T - self.right_context_length :])
  548. return torch.cat(right_context_blocks)
  549. def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
  550. num_segs = math.ceil(utterance_length / self.segment_length)
  551. rc = self.right_context_length
  552. lc = self.left_context_length
  553. rc_start = seg_idx * rc
  554. rc_end = rc_start + rc
  555. seg_start = max(seg_idx * self.segment_length - lc, 0)
  556. seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
  557. rc_length = self.right_context_length * num_segs
  558. if self.use_mem:
  559. m_start = max(seg_idx - self.max_memory_size, 0)
  560. mem_length = num_segs - 1
  561. col_widths = [
  562. m_start, # before memory
  563. seg_idx - m_start, # memory
  564. mem_length - seg_idx, # after memory
  565. rc_start, # before right context
  566. rc, # right context
  567. rc_length - rc_end, # after right context
  568. seg_start, # before query segment
  569. seg_end - seg_start, # query segment
  570. utterance_length - seg_end, # after query segment
  571. ]
  572. else:
  573. col_widths = [
  574. rc_start, # before right context
  575. rc, # right context
  576. rc_length - rc_end, # after right context
  577. seg_start, # before query segment
  578. seg_end - seg_start, # query segment
  579. utterance_length - seg_end, # after query segment
  580. ]
  581. return col_widths
  582. def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
  583. utterance_length = input.size(0)
  584. num_segs = math.ceil(utterance_length / self.segment_length)
  585. rc_mask = []
  586. query_mask = []
  587. summary_mask = []
  588. if self.use_mem:
  589. num_cols = 9
  590. # memory, right context, query segment
  591. rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)]
  592. # right context, query segment
  593. s_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
  594. masks_to_concat = [rc_mask, query_mask, summary_mask]
  595. else:
  596. num_cols = 6
  597. # right context, query segment
  598. rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)]
  599. s_cols_mask = None
  600. masks_to_concat = [rc_mask, query_mask]
  601. for seg_idx in range(num_segs):
  602. col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length)
  603. rc_mask_block = _gen_attention_mask_block(
  604. col_widths, rc_q_cols_mask, self.right_context_length, input.device
  605. )
  606. rc_mask.append(rc_mask_block)
  607. query_mask_block = _gen_attention_mask_block(
  608. col_widths,
  609. rc_q_cols_mask,
  610. min(
  611. self.segment_length,
  612. utterance_length - seg_idx * self.segment_length,
  613. ),
  614. input.device,
  615. )
  616. query_mask.append(query_mask_block)
  617. if s_cols_mask is not None:
  618. summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
  619. summary_mask.append(summary_mask_block)
  620. attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
  621. return attention_mask
  622. def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  623. r"""Forward pass for training and non-streaming inference.
  624. B: batch size;
  625. T: max number of input frames in batch;
  626. D: feature dimension of each frame.
  627. Args:
  628. input (torch.Tensor): utterance frames right-padded with right context frames, with
  629. shape `(B, T + right_context_length, D)`.
  630. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  631. number of valid utterance frames for i-th batch element in ``input``.
  632. Returns:
  633. (Tensor, Tensor):
  634. Tensor
  635. output frames, with shape `(B, T, D)`.
  636. Tensor
  637. output lengths, with shape `(B,)` and i-th element representing
  638. number of valid frames for i-th batch element in output frames.
  639. """
  640. input = input.permute(1, 0, 2)
  641. right_context = self._gen_right_context(input)
  642. utterance = input[: input.size(0) - self.right_context_length]
  643. attention_mask = self._gen_attention_mask(utterance)
  644. mems = (
  645. self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
  646. if self.use_mem
  647. else torch.empty(0).to(dtype=input.dtype, device=input.device)
  648. )
  649. output = utterance
  650. for layer in self.emformer_layers:
  651. output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
  652. return output.permute(1, 0, 2), lengths
  653. @torch.jit.export
  654. def infer(
  655. self,
  656. input: torch.Tensor,
  657. lengths: torch.Tensor,
  658. states: Optional[List[List[torch.Tensor]]] = None,
  659. ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
  660. r"""Forward pass for streaming inference.
  661. B: batch size;
  662. D: feature dimension of each frame.
  663. Args:
  664. input (torch.Tensor): utterance frames right-padded with right context frames, with
  665. shape `(B, segment_length + right_context_length, D)`.
  666. lengths (torch.Tensor): with shape `(B,)` and i-th element representing
  667. number of valid frames for i-th batch element in ``input``.
  668. states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
  669. representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
  670. Returns:
  671. (Tensor, Tensor, List[List[Tensor]]):
  672. Tensor
  673. output frames, with shape `(B, segment_length, D)`.
  674. Tensor
  675. output lengths, with shape `(B,)` and i-th element representing
  676. number of valid frames for i-th batch element in output frames.
  677. List[List[Tensor]]
  678. output states; list of lists of tensors representing internal state
  679. generated in current invocation of ``infer``.
  680. """
  681. assert input.size(1) == self.segment_length + self.right_context_length, (
  682. "Per configured segment_length and right_context_length"
  683. f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
  684. f", but got {input.size(1)}."
  685. )
  686. input = input.permute(1, 0, 2)
  687. right_context_start_idx = input.size(0) - self.right_context_length
  688. right_context = input[right_context_start_idx:]
  689. utterance = input[:right_context_start_idx]
  690. output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
  691. mems = (
  692. self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
  693. if self.use_mem
  694. else torch.empty(0).to(dtype=input.dtype, device=input.device)
  695. )
  696. output = utterance
  697. output_states: List[List[torch.Tensor]] = []
  698. for layer_idx, layer in enumerate(self.emformer_layers):
  699. output, right_context, output_state, mems = layer.infer(
  700. output,
  701. output_lengths,
  702. right_context,
  703. None if states is None else states[layer_idx],
  704. mems,
  705. )
  706. output_states.append(output_state)
  707. return output.permute(1, 0, 2), output_lengths, output_states
  708. class Emformer(_EmformerImpl):
  709. r"""Implements the Emformer architecture introduced in
  710. *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
  711. [:footcite:`shi2021emformer`].
  712. Args:
  713. input_dim (int): input dimension.
  714. num_heads (int): number of attention heads in each Emformer layer.
  715. ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
  716. num_layers (int): number of Emformer layers to instantiate.
  717. segment_length (int): length of each input segment.
  718. dropout (float, optional): dropout probability. (Default: 0.0)
  719. activation (str, optional): activation function to use in each Emformer layer's
  720. feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
  721. left_context_length (int, optional): length of left context. (Default: 0)
  722. right_context_length (int, optional): length of right context. (Default: 0)
  723. max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
  724. weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
  725. strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
  726. tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
  727. negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
  728. Examples:
  729. >>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
  730. >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
  731. >>> lengths = torch.randint(1, 200, (128,)) # batch
  732. >>> output, lengths = emformer(input, lengths)
  733. >>> input = torch.rand(128, 5, 512)
  734. >>> lengths = torch.ones(128) * 5
  735. >>> output, lengths, states = emformer.infer(input, lengths, None)
  736. """
  737. def __init__(
  738. self,
  739. input_dim: int,
  740. num_heads: int,
  741. ffn_dim: int,
  742. num_layers: int,
  743. segment_length: int,
  744. dropout: float = 0.0,
  745. activation: str = "relu",
  746. left_context_length: int = 0,
  747. right_context_length: int = 0,
  748. max_memory_size: int = 0,
  749. weight_init_scale_strategy: Optional[str] = "depthwise",
  750. tanh_on_mem: bool = False,
  751. negative_inf: float = -1e8,
  752. ):
  753. weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
  754. emformer_layers = torch.nn.ModuleList(
  755. [
  756. _EmformerLayer(
  757. input_dim,
  758. num_heads,
  759. ffn_dim,
  760. segment_length,
  761. dropout=dropout,
  762. activation=activation,
  763. left_context_length=left_context_length,
  764. max_memory_size=max_memory_size,
  765. weight_init_gain=weight_init_gains[layer_idx],
  766. tanh_on_mem=tanh_on_mem,
  767. negative_inf=negative_inf,
  768. )
  769. for layer_idx in range(num_layers)
  770. ]
  771. )
  772. super().__init__(
  773. emformer_layers,
  774. segment_length,
  775. left_context_length=left_context_length,
  776. right_context_length=right_context_length,
  777. max_memory_size=max_memory_size,
  778. )