conv_tasnet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """Implements Conv-TasNet with building blocks of it.
  2. Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c
  3. """
  4. from typing import Optional, Tuple
  5. import torch
  6. class ConvBlock(torch.nn.Module):
  7. """1D Convolutional block.
  8. Args:
  9. io_channels (int): The number of input/output channels, <B, Sc>
  10. hidden_channels (int): The number of channels in the internal layers, <H>.
  11. kernel_size (int): The convolution kernel size of the middle layer, <P>.
  12. padding (int): Padding value of the convolution in the middle layer.
  13. dilation (int, optional): Dilation value of the convolution in the middle layer.
  14. no_redisual (bool, optional): Disable residual block/output.
  15. Note:
  16. This implementation corresponds to the "non-causal" setting in the paper.
  17. """
  18. def __init__(
  19. self,
  20. io_channels: int,
  21. hidden_channels: int,
  22. kernel_size: int,
  23. padding: int,
  24. dilation: int = 1,
  25. no_residual: bool = False,
  26. ):
  27. super().__init__()
  28. self.conv_layers = torch.nn.Sequential(
  29. torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
  30. torch.nn.PReLU(),
  31. torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
  32. torch.nn.Conv1d(
  33. in_channels=hidden_channels,
  34. out_channels=hidden_channels,
  35. kernel_size=kernel_size,
  36. padding=padding,
  37. dilation=dilation,
  38. groups=hidden_channels,
  39. ),
  40. torch.nn.PReLU(),
  41. torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
  42. )
  43. self.res_out = (
  44. None
  45. if no_residual
  46. else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
  47. )
  48. self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
  49. def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
  50. feature = self.conv_layers(input)
  51. if self.res_out is None:
  52. residual = None
  53. else:
  54. residual = self.res_out(feature)
  55. skip_out = self.skip_out(feature)
  56. return residual, skip_out
  57. class MaskGenerator(torch.nn.Module):
  58. """TCN (Temporal Convolution Network) Separation Module
  59. Generates masks for separation.
  60. Args:
  61. input_dim (int): Input feature dimension, <N>.
  62. num_sources (int): The number of sources to separate.
  63. kernel_size (int): The convolution kernel size of conv blocks, <P>.
  64. num_featrs (int): Input/output feature dimenstion of conv blocks, <B, Sc>.
  65. num_hidden (int): Intermediate feature dimention of conv blocks, <H>
  66. num_layers (int): The number of conv blocks in one stack, <X>.
  67. num_stacks (int): The number of conv block stacks, <R>.
  68. msk_activate (str): The activation function of the mask output.
  69. Note:
  70. This implementation corresponds to the "non-causal" setting in the paper.
  71. """
  72. def __init__(
  73. self,
  74. input_dim: int,
  75. num_sources: int,
  76. kernel_size: int,
  77. num_feats: int,
  78. num_hidden: int,
  79. num_layers: int,
  80. num_stacks: int,
  81. msk_activate: str,
  82. ):
  83. super().__init__()
  84. self.input_dim = input_dim
  85. self.num_sources = num_sources
  86. self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
  87. self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
  88. self.receptive_field = 0
  89. self.conv_layers = torch.nn.ModuleList([])
  90. for s in range(num_stacks):
  91. for l in range(num_layers):
  92. multi = 2**l
  93. self.conv_layers.append(
  94. ConvBlock(
  95. io_channels=num_feats,
  96. hidden_channels=num_hidden,
  97. kernel_size=kernel_size,
  98. dilation=multi,
  99. padding=multi,
  100. # The last ConvBlock does not need residual
  101. no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
  102. )
  103. )
  104. self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
  105. self.output_prelu = torch.nn.PReLU()
  106. self.output_conv = torch.nn.Conv1d(
  107. in_channels=num_feats,
  108. out_channels=input_dim * num_sources,
  109. kernel_size=1,
  110. )
  111. if msk_activate == "sigmoid":
  112. self.mask_activate = torch.nn.Sigmoid()
  113. elif msk_activate == "relu":
  114. self.mask_activate = torch.nn.ReLU()
  115. else:
  116. raise ValueError(f"Unsupported activation {msk_activate}")
  117. def forward(self, input: torch.Tensor) -> torch.Tensor:
  118. """Generate separation mask.
  119. Args:
  120. input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
  121. Returns:
  122. Tensor: shape [batch, num_sources, features, frames]
  123. """
  124. batch_size = input.shape[0]
  125. feats = self.input_norm(input)
  126. feats = self.input_conv(feats)
  127. output = 0.0
  128. for layer in self.conv_layers:
  129. residual, skip = layer(feats)
  130. if residual is not None: # the last conv layer does not produce residual
  131. feats = feats + residual
  132. output = output + skip
  133. output = self.output_prelu(output)
  134. output = self.output_conv(output)
  135. output = self.mask_activate(output)
  136. return output.view(batch_size, self.num_sources, self.input_dim, -1)
  137. class ConvTasNet(torch.nn.Module):
  138. """Conv-TasNet: a fully-convolutional time-domain audio separation network
  139. *Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
  140. [:footcite:`Luo_2019`].
  141. Args:
  142. num_sources (int, optional): The number of sources to split.
  143. enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>.
  144. enc_num_feats (int, optional): The feature dimensions passed to mask generator, <N>.
  145. msk_kernel_size (int, optional): The convolution kernel size of the mask generator, <P>.
  146. msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, <B, Sc>.
  147. msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, <H>.
  148. msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
  149. msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
  150. msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
  151. Note:
  152. This implementation corresponds to the "non-causal" setting in the paper.
  153. """
  154. def __init__(
  155. self,
  156. num_sources: int = 2,
  157. # encoder/decoder parameters
  158. enc_kernel_size: int = 16,
  159. enc_num_feats: int = 512,
  160. # mask generator parameters
  161. msk_kernel_size: int = 3,
  162. msk_num_feats: int = 128,
  163. msk_num_hidden_feats: int = 512,
  164. msk_num_layers: int = 8,
  165. msk_num_stacks: int = 3,
  166. msk_activate: str = "sigmoid",
  167. ):
  168. super().__init__()
  169. self.num_sources = num_sources
  170. self.enc_num_feats = enc_num_feats
  171. self.enc_kernel_size = enc_kernel_size
  172. self.enc_stride = enc_kernel_size // 2
  173. self.encoder = torch.nn.Conv1d(
  174. in_channels=1,
  175. out_channels=enc_num_feats,
  176. kernel_size=enc_kernel_size,
  177. stride=self.enc_stride,
  178. padding=self.enc_stride,
  179. bias=False,
  180. )
  181. self.mask_generator = MaskGenerator(
  182. input_dim=enc_num_feats,
  183. num_sources=num_sources,
  184. kernel_size=msk_kernel_size,
  185. num_feats=msk_num_feats,
  186. num_hidden=msk_num_hidden_feats,
  187. num_layers=msk_num_layers,
  188. num_stacks=msk_num_stacks,
  189. msk_activate=msk_activate,
  190. )
  191. self.decoder = torch.nn.ConvTranspose1d(
  192. in_channels=enc_num_feats,
  193. out_channels=1,
  194. kernel_size=enc_kernel_size,
  195. stride=self.enc_stride,
  196. padding=self.enc_stride,
  197. bias=False,
  198. )
  199. def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
  200. """Pad input Tensor so that the end of the input tensor corresponds with
  201. 1. (if kernel size is odd) the center of the last convolution kernel
  202. or 2. (if kernel size is even) the end of the first half of the last convolution kernel
  203. Assumption:
  204. The resulting Tensor will be padded with the size of stride (== kernel_width // 2)
  205. on the both ends in Conv1D
  206. |<--- k_1 --->|
  207. | | |<-- k_n-1 -->|
  208. | | | |<--- k_n --->|
  209. | | | | |
  210. | | | | |
  211. | v v v |
  212. |<---->|<--- input signal --->|<--->|<---->|
  213. stride PAD stride
  214. Args:
  215. input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
  216. Returns:
  217. Tensor: Padded Tensor
  218. int: Number of paddings performed
  219. """
  220. batch_size, num_channels, num_frames = input.shape
  221. is_odd = self.enc_kernel_size % 2
  222. num_strides = (num_frames - is_odd) // self.enc_stride
  223. num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
  224. if num_remainings == 0:
  225. return input, 0
  226. num_paddings = self.enc_stride - num_remainings
  227. pad = torch.zeros(
  228. batch_size,
  229. num_channels,
  230. num_paddings,
  231. dtype=input.dtype,
  232. device=input.device,
  233. )
  234. return torch.cat([input, pad], 2), num_paddings
  235. def forward(self, input: torch.Tensor) -> torch.Tensor:
  236. """Perform source separation. Generate audio source waveforms.
  237. Args:
  238. input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
  239. Returns:
  240. Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
  241. """
  242. if input.ndim != 3 or input.shape[1] != 1:
  243. raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
  244. # B: batch size
  245. # L: input frame length
  246. # L': padded input frame length
  247. # F: feature dimension
  248. # M: feature frame length
  249. # S: number of sources
  250. padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L'
  251. batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
  252. feats = self.encoder(padded) # B, F, M
  253. masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
  254. masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
  255. decoded = self.decoder(masked) # B*S, 1, L'
  256. output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
  257. if num_pads > 0:
  258. output = output[..., :-num_pads] # B, S, L
  259. return output