resnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. from functools import partial
  2. from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union
  3. import torch.nn as nn
  4. from torch import Tensor
  5. from ...transforms._presets import VideoClassification
  6. from ...utils import _log_api_usage_once
  7. from .._api import WeightsEnum, Weights
  8. from .._meta import _KINETICS400_CATEGORIES
  9. from .._utils import handle_legacy_interface, _ovewrite_named_param
  10. __all__ = [
  11. "VideoResNet",
  12. "R3D_18_Weights",
  13. "MC3_18_Weights",
  14. "R2Plus1D_18_Weights",
  15. "r3d_18",
  16. "mc3_18",
  17. "r2plus1d_18",
  18. ]
  19. class Conv3DSimple(nn.Conv3d):
  20. def __init__(
  21. self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
  22. ) -> None:
  23. super().__init__(
  24. in_channels=in_planes,
  25. out_channels=out_planes,
  26. kernel_size=(3, 3, 3),
  27. stride=stride,
  28. padding=padding,
  29. bias=False,
  30. )
  31. @staticmethod
  32. def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
  33. return stride, stride, stride
  34. class Conv2Plus1D(nn.Sequential):
  35. def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
  36. super().__init__(
  37. nn.Conv3d(
  38. in_planes,
  39. midplanes,
  40. kernel_size=(1, 3, 3),
  41. stride=(1, stride, stride),
  42. padding=(0, padding, padding),
  43. bias=False,
  44. ),
  45. nn.BatchNorm3d(midplanes),
  46. nn.ReLU(inplace=True),
  47. nn.Conv3d(
  48. midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
  49. ),
  50. )
  51. @staticmethod
  52. def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
  53. return stride, stride, stride
  54. class Conv3DNoTemporal(nn.Conv3d):
  55. def __init__(
  56. self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
  57. ) -> None:
  58. super().__init__(
  59. in_channels=in_planes,
  60. out_channels=out_planes,
  61. kernel_size=(1, 3, 3),
  62. stride=(1, stride, stride),
  63. padding=(0, padding, padding),
  64. bias=False,
  65. )
  66. @staticmethod
  67. def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
  68. return 1, stride, stride
  69. class BasicBlock(nn.Module):
  70. expansion = 1
  71. def __init__(
  72. self,
  73. inplanes: int,
  74. planes: int,
  75. conv_builder: Callable[..., nn.Module],
  76. stride: int = 1,
  77. downsample: Optional[nn.Module] = None,
  78. ) -> None:
  79. midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
  80. super().__init__()
  81. self.conv1 = nn.Sequential(
  82. conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  83. )
  84. self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
  85. self.relu = nn.ReLU(inplace=True)
  86. self.downsample = downsample
  87. self.stride = stride
  88. def forward(self, x: Tensor) -> Tensor:
  89. residual = x
  90. out = self.conv1(x)
  91. out = self.conv2(out)
  92. if self.downsample is not None:
  93. residual = self.downsample(x)
  94. out += residual
  95. out = self.relu(out)
  96. return out
  97. class Bottleneck(nn.Module):
  98. expansion = 4
  99. def __init__(
  100. self,
  101. inplanes: int,
  102. planes: int,
  103. conv_builder: Callable[..., nn.Module],
  104. stride: int = 1,
  105. downsample: Optional[nn.Module] = None,
  106. ) -> None:
  107. super().__init__()
  108. midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
  109. # 1x1x1
  110. self.conv1 = nn.Sequential(
  111. nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  112. )
  113. # Second kernel
  114. self.conv2 = nn.Sequential(
  115. conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
  116. )
  117. # 1x1x1
  118. self.conv3 = nn.Sequential(
  119. nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
  120. nn.BatchNorm3d(planes * self.expansion),
  121. )
  122. self.relu = nn.ReLU(inplace=True)
  123. self.downsample = downsample
  124. self.stride = stride
  125. def forward(self, x: Tensor) -> Tensor:
  126. residual = x
  127. out = self.conv1(x)
  128. out = self.conv2(out)
  129. out = self.conv3(out)
  130. if self.downsample is not None:
  131. residual = self.downsample(x)
  132. out += residual
  133. out = self.relu(out)
  134. return out
  135. class BasicStem(nn.Sequential):
  136. """The default conv-batchnorm-relu stem"""
  137. def __init__(self) -> None:
  138. super().__init__(
  139. nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
  140. nn.BatchNorm3d(64),
  141. nn.ReLU(inplace=True),
  142. )
  143. class R2Plus1dStem(nn.Sequential):
  144. """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""
  145. def __init__(self) -> None:
  146. super().__init__(
  147. nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
  148. nn.BatchNorm3d(45),
  149. nn.ReLU(inplace=True),
  150. nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
  151. nn.BatchNorm3d(64),
  152. nn.ReLU(inplace=True),
  153. )
  154. class VideoResNet(nn.Module):
  155. def __init__(
  156. self,
  157. block: Type[Union[BasicBlock, Bottleneck]],
  158. conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
  159. layers: List[int],
  160. stem: Callable[..., nn.Module],
  161. num_classes: int = 400,
  162. zero_init_residual: bool = False,
  163. ) -> None:
  164. """Generic resnet video generator.
  165. Args:
  166. block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
  167. conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
  168. function for each layer
  169. layers (List[int]): number of blocks per layer
  170. stem (Callable[..., nn.Module]): module specifying the ResNet stem.
  171. num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
  172. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
  173. """
  174. super().__init__()
  175. _log_api_usage_once(self)
  176. self.inplanes = 64
  177. self.stem = stem()
  178. self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
  179. self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
  180. self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
  181. self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)
  182. self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
  183. self.fc = nn.Linear(512 * block.expansion, num_classes)
  184. # init weights
  185. for m in self.modules():
  186. if isinstance(m, nn.Conv3d):
  187. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  188. if m.bias is not None:
  189. nn.init.constant_(m.bias, 0)
  190. elif isinstance(m, nn.BatchNorm3d):
  191. nn.init.constant_(m.weight, 1)
  192. nn.init.constant_(m.bias, 0)
  193. elif isinstance(m, nn.Linear):
  194. nn.init.normal_(m.weight, 0, 0.01)
  195. nn.init.constant_(m.bias, 0)
  196. if zero_init_residual:
  197. for m in self.modules():
  198. if isinstance(m, Bottleneck):
  199. nn.init.constant_(m.bn3.weight, 0) # type: ignore[union-attr, arg-type]
  200. def forward(self, x: Tensor) -> Tensor:
  201. x = self.stem(x)
  202. x = self.layer1(x)
  203. x = self.layer2(x)
  204. x = self.layer3(x)
  205. x = self.layer4(x)
  206. x = self.avgpool(x)
  207. # Flatten the layer to fc
  208. x = x.flatten(1)
  209. x = self.fc(x)
  210. return x
  211. def _make_layer(
  212. self,
  213. block: Type[Union[BasicBlock, Bottleneck]],
  214. conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
  215. planes: int,
  216. blocks: int,
  217. stride: int = 1,
  218. ) -> nn.Sequential:
  219. downsample = None
  220. if stride != 1 or self.inplanes != planes * block.expansion:
  221. ds_stride = conv_builder.get_downsample_stride(stride)
  222. downsample = nn.Sequential(
  223. nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
  224. nn.BatchNorm3d(planes * block.expansion),
  225. )
  226. layers = []
  227. layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
  228. self.inplanes = planes * block.expansion
  229. for i in range(1, blocks):
  230. layers.append(block(self.inplanes, planes, conv_builder))
  231. return nn.Sequential(*layers)
  232. def _video_resnet(
  233. block: Type[Union[BasicBlock, Bottleneck]],
  234. conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
  235. layers: List[int],
  236. stem: Callable[..., nn.Module],
  237. weights: Optional[WeightsEnum],
  238. progress: bool,
  239. **kwargs: Any,
  240. ) -> VideoResNet:
  241. if weights is not None:
  242. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  243. model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
  244. if weights is not None:
  245. model.load_state_dict(weights.get_state_dict(progress=progress))
  246. return model
  247. _COMMON_META = {
  248. "min_size": (1, 1),
  249. "categories": _KINETICS400_CATEGORIES,
  250. "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
  251. "_docs": """These weights reproduce closely the accuracy of the paper for 16-frame clip inputs.""",
  252. }
  253. class R3D_18_Weights(WeightsEnum):
  254. KINETICS400_V1 = Weights(
  255. url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
  256. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  257. meta={
  258. **_COMMON_META,
  259. "num_params": 33371472,
  260. "_metrics": {
  261. "Kinetics-400": {
  262. "acc@1": 52.75,
  263. "acc@5": 75.45,
  264. }
  265. },
  266. },
  267. )
  268. DEFAULT = KINETICS400_V1
  269. class MC3_18_Weights(WeightsEnum):
  270. KINETICS400_V1 = Weights(
  271. url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
  272. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  273. meta={
  274. **_COMMON_META,
  275. "num_params": 11695440,
  276. "_metrics": {
  277. "Kinetics-400": {
  278. "acc@1": 53.90,
  279. "acc@5": 76.29,
  280. }
  281. },
  282. },
  283. )
  284. DEFAULT = KINETICS400_V1
  285. class R2Plus1D_18_Weights(WeightsEnum):
  286. KINETICS400_V1 = Weights(
  287. url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
  288. transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
  289. meta={
  290. **_COMMON_META,
  291. "num_params": 31505325,
  292. "_metrics": {
  293. "Kinetics-400": {
  294. "acc@1": 57.50,
  295. "acc@5": 78.81,
  296. }
  297. },
  298. },
  299. )
  300. DEFAULT = KINETICS400_V1
  301. @handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
  302. def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  303. """Construct 18 layer Resnet3D model.
  304. .. betastatus:: video module
  305. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  306. Args:
  307. weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The
  308. pretrained weights to use. See
  309. :class:`~torchvision.models.video.R3D_18_Weights`
  310. below for more details, and possible values. By default, no
  311. pre-trained weights are used.
  312. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  313. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  314. Please refer to the `source code
  315. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  316. for more details about this class.
  317. .. autoclass:: torchvision.models.video.R3D_18_Weights
  318. :members:
  319. """
  320. weights = R3D_18_Weights.verify(weights)
  321. return _video_resnet(
  322. BasicBlock,
  323. [Conv3DSimple] * 4,
  324. [2, 2, 2, 2],
  325. BasicStem,
  326. weights,
  327. progress,
  328. **kwargs,
  329. )
  330. @handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
  331. def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  332. """Construct 18 layer Mixed Convolution network as in
  333. .. betastatus:: video module
  334. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  335. Args:
  336. weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The
  337. pretrained weights to use. See
  338. :class:`~torchvision.models.video.MC3_18_Weights`
  339. below for more details, and possible values. By default, no
  340. pre-trained weights are used.
  341. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  342. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  343. Please refer to the `source code
  344. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  345. for more details about this class.
  346. .. autoclass:: torchvision.models.video.MC3_18_Weights
  347. :members:
  348. """
  349. weights = MC3_18_Weights.verify(weights)
  350. return _video_resnet(
  351. BasicBlock,
  352. [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
  353. [2, 2, 2, 2],
  354. BasicStem,
  355. weights,
  356. progress,
  357. **kwargs,
  358. )
  359. @handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
  360. def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
  361. """Construct 18 layer deep R(2+1)D network as in
  362. .. betastatus:: video module
  363. Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
  364. Args:
  365. weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The
  366. pretrained weights to use. See
  367. :class:`~torchvision.models.video.R2Plus1D_18_Weights`
  368. below for more details, and possible values. By default, no
  369. pre-trained weights are used.
  370. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  371. **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
  372. Please refer to the `source code
  373. <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
  374. for more details about this class.
  375. .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights
  376. :members:
  377. """
  378. weights = R2Plus1D_18_Weights.verify(weights)
  379. return _video_resnet(
  380. BasicBlock,
  381. [Conv2Plus1D] * 4,
  382. [2, 2, 2, 2],
  383. R2Plus1dStem,
  384. weights,
  385. progress,
  386. **kwargs,
  387. )
  388. # The dictionary below is internal implementation detail and will be removed in v0.15
  389. from .._utils import _ModelURLs
  390. model_urls = _ModelURLs(
  391. {
  392. "r3d_18": R3D_18_Weights.KINETICS400_V1.url,
  393. "mc3_18": MC3_18_Weights.KINETICS400_V1.url,
  394. "r2plus1d_18": R2Plus1D_18_Weights.KINETICS400_V1.url,
  395. }
  396. )