misc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import warnings
  2. from typing import Callable, List, Optional
  3. import torch
  4. from torch import Tensor
  5. from ..utils import _log_api_usage_once
  6. interpolate = torch.nn.functional.interpolate
  7. class FrozenBatchNorm2d(torch.nn.Module):
  8. """
  9. BatchNorm2d where the batch statistics and the affine parameters are fixed
  10. Args:
  11. num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
  12. eps (float): a value added to the denominator for numerical stability. Default: 1e-5
  13. """
  14. def __init__(
  15. self,
  16. num_features: int,
  17. eps: float = 1e-5,
  18. ):
  19. super().__init__()
  20. _log_api_usage_once(self)
  21. self.eps = eps
  22. self.register_buffer("weight", torch.ones(num_features))
  23. self.register_buffer("bias", torch.zeros(num_features))
  24. self.register_buffer("running_mean", torch.zeros(num_features))
  25. self.register_buffer("running_var", torch.ones(num_features))
  26. def _load_from_state_dict(
  27. self,
  28. state_dict: dict,
  29. prefix: str,
  30. local_metadata: dict,
  31. strict: bool,
  32. missing_keys: List[str],
  33. unexpected_keys: List[str],
  34. error_msgs: List[str],
  35. ):
  36. num_batches_tracked_key = prefix + "num_batches_tracked"
  37. if num_batches_tracked_key in state_dict:
  38. del state_dict[num_batches_tracked_key]
  39. super()._load_from_state_dict(
  40. state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
  41. )
  42. def forward(self, x: Tensor) -> Tensor:
  43. # move reshapes to the beginning
  44. # to make it fuser-friendly
  45. w = self.weight.reshape(1, -1, 1, 1)
  46. b = self.bias.reshape(1, -1, 1, 1)
  47. rv = self.running_var.reshape(1, -1, 1, 1)
  48. rm = self.running_mean.reshape(1, -1, 1, 1)
  49. scale = w * (rv + self.eps).rsqrt()
  50. bias = b - rm * scale
  51. return x * scale + bias
  52. def __repr__(self) -> str:
  53. return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
  54. class ConvNormActivation(torch.nn.Sequential):
  55. def __init__(
  56. self,
  57. in_channels: int,
  58. out_channels: int,
  59. kernel_size: int = 3,
  60. stride: int = 1,
  61. padding: Optional[int] = None,
  62. groups: int = 1,
  63. norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
  64. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  65. dilation: int = 1,
  66. inplace: Optional[bool] = True,
  67. bias: Optional[bool] = None,
  68. conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
  69. ) -> None:
  70. if padding is None:
  71. padding = (kernel_size - 1) // 2 * dilation
  72. if bias is None:
  73. bias = norm_layer is None
  74. layers = [
  75. conv_layer(
  76. in_channels,
  77. out_channels,
  78. kernel_size,
  79. stride,
  80. padding,
  81. dilation=dilation,
  82. groups=groups,
  83. bias=bias,
  84. )
  85. ]
  86. if norm_layer is not None:
  87. layers.append(norm_layer(out_channels))
  88. if activation_layer is not None:
  89. params = {} if inplace is None else {"inplace": inplace}
  90. layers.append(activation_layer(**params))
  91. super().__init__(*layers)
  92. _log_api_usage_once(self)
  93. self.out_channels = out_channels
  94. if self.__class__ == ConvNormActivation:
  95. warnings.warn(
  96. "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
  97. )
  98. class Conv2dNormActivation(ConvNormActivation):
  99. """
  100. Configurable block used for Convolution2d-Normalization-Activation blocks.
  101. Args:
  102. in_channels (int): Number of channels in the input image
  103. out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
  104. kernel_size: (int, optional): Size of the convolving kernel. Default: 3
  105. stride (int, optional): Stride of the convolution. Default: 1
  106. padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
  107. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  108. norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
  109. activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
  110. dilation (int): Spacing between kernel elements. Default: 1
  111. inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
  112. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
  113. """
  114. def __init__(
  115. self,
  116. in_channels: int,
  117. out_channels: int,
  118. kernel_size: int = 3,
  119. stride: int = 1,
  120. padding: Optional[int] = None,
  121. groups: int = 1,
  122. norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
  123. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  124. dilation: int = 1,
  125. inplace: Optional[bool] = True,
  126. bias: Optional[bool] = None,
  127. ) -> None:
  128. super().__init__(
  129. in_channels,
  130. out_channels,
  131. kernel_size,
  132. stride,
  133. padding,
  134. groups,
  135. norm_layer,
  136. activation_layer,
  137. dilation,
  138. inplace,
  139. bias,
  140. torch.nn.Conv2d,
  141. )
  142. class Conv3dNormActivation(ConvNormActivation):
  143. """
  144. Configurable block used for Convolution3d-Normalization-Activation blocks.
  145. Args:
  146. in_channels (int): Number of channels in the input video.
  147. out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
  148. kernel_size: (int, optional): Size of the convolving kernel. Default: 3
  149. stride (int, optional): Stride of the convolution. Default: 1
  150. padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
  151. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  152. norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
  153. activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
  154. dilation (int): Spacing between kernel elements. Default: 1
  155. inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
  156. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
  157. """
  158. def __init__(
  159. self,
  160. in_channels: int,
  161. out_channels: int,
  162. kernel_size: int = 3,
  163. stride: int = 1,
  164. padding: Optional[int] = None,
  165. groups: int = 1,
  166. norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
  167. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  168. dilation: int = 1,
  169. inplace: Optional[bool] = True,
  170. bias: Optional[bool] = None,
  171. ) -> None:
  172. super().__init__(
  173. in_channels,
  174. out_channels,
  175. kernel_size,
  176. stride,
  177. padding,
  178. groups,
  179. norm_layer,
  180. activation_layer,
  181. dilation,
  182. inplace,
  183. bias,
  184. torch.nn.Conv3d,
  185. )
  186. class SqueezeExcitation(torch.nn.Module):
  187. """
  188. This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
  189. Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
  190. Args:
  191. input_channels (int): Number of channels in the input image
  192. squeeze_channels (int): Number of squeeze channels
  193. activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
  194. scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
  195. """
  196. def __init__(
  197. self,
  198. input_channels: int,
  199. squeeze_channels: int,
  200. activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
  201. scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
  202. ) -> None:
  203. super().__init__()
  204. _log_api_usage_once(self)
  205. self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
  206. self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
  207. self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
  208. self.activation = activation()
  209. self.scale_activation = scale_activation()
  210. def _scale(self, input: Tensor) -> Tensor:
  211. scale = self.avgpool(input)
  212. scale = self.fc1(scale)
  213. scale = self.activation(scale)
  214. scale = self.fc2(scale)
  215. return self.scale_activation(scale)
  216. def forward(self, input: Tensor) -> Tensor:
  217. scale = self._scale(input)
  218. return scale * input
  219. class MLP(torch.nn.Sequential):
  220. """This block implements the multi-layer perceptron (MLP) module.
  221. Args:
  222. in_channels (int): Number of channels of the input
  223. hidden_channels (List[int]): List of the hidden channel dimensions
  224. norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
  225. activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
  226. inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
  227. bias (bool): Whether to use bias in the linear layer. Default ``True``
  228. dropout (float): The probability for the dropout layer. Default: 0.0
  229. """
  230. def __init__(
  231. self,
  232. in_channels: int,
  233. hidden_channels: List[int],
  234. norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
  235. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  236. inplace: Optional[bool] = True,
  237. bias: bool = True,
  238. dropout: float = 0.0,
  239. ):
  240. # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
  241. # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
  242. params = {} if inplace is None else {"inplace": inplace}
  243. layers = []
  244. in_dim = in_channels
  245. for hidden_dim in hidden_channels[:-1]:
  246. layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
  247. if norm_layer is not None:
  248. layers.append(norm_layer(hidden_dim))
  249. layers.append(activation_layer(**params))
  250. layers.append(torch.nn.Dropout(dropout, **params))
  251. in_dim = hidden_dim
  252. layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
  253. layers.append(torch.nn.Dropout(dropout, **params))
  254. super().__init__(*layers)
  255. _log_api_usage_once(self)
  256. class Permute(torch.nn.Module):
  257. """This module returns a view of the tensor input with its dimensions permuted.
  258. Args:
  259. dims (List[int]): The desired ordering of dimensions
  260. """
  261. def __init__(self, dims: List[int]):
  262. super().__init__()
  263. self.dims = dims
  264. def forward(self, x: Tensor) -> Tensor:
  265. return torch.permute(x, self.dims)