convnext.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. from functools import partial
  2. from typing import Any, Callable, List, Optional, Sequence
  3. import torch
  4. from torch import nn, Tensor
  5. from torch.nn import functional as F
  6. from ..ops.misc import Conv2dNormActivation, Permute
  7. from ..ops.stochastic_depth import StochasticDepth
  8. from ..transforms._presets import ImageClassification
  9. from ..utils import _log_api_usage_once
  10. from ._api import WeightsEnum, Weights
  11. from ._meta import _IMAGENET_CATEGORIES
  12. from ._utils import handle_legacy_interface, _ovewrite_named_param
  13. __all__ = [
  14. "ConvNeXt",
  15. "ConvNeXt_Tiny_Weights",
  16. "ConvNeXt_Small_Weights",
  17. "ConvNeXt_Base_Weights",
  18. "ConvNeXt_Large_Weights",
  19. "convnext_tiny",
  20. "convnext_small",
  21. "convnext_base",
  22. "convnext_large",
  23. ]
  24. class LayerNorm2d(nn.LayerNorm):
  25. def forward(self, x: Tensor) -> Tensor:
  26. x = x.permute(0, 2, 3, 1)
  27. x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  28. x = x.permute(0, 3, 1, 2)
  29. return x
  30. class CNBlock(nn.Module):
  31. def __init__(
  32. self,
  33. dim,
  34. layer_scale: float,
  35. stochastic_depth_prob: float,
  36. norm_layer: Optional[Callable[..., nn.Module]] = None,
  37. ) -> None:
  38. super().__init__()
  39. if norm_layer is None:
  40. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  41. self.block = nn.Sequential(
  42. nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
  43. Permute([0, 2, 3, 1]),
  44. norm_layer(dim),
  45. nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
  46. nn.GELU(),
  47. nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
  48. Permute([0, 3, 1, 2]),
  49. )
  50. self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
  51. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  52. def forward(self, input: Tensor) -> Tensor:
  53. result = self.layer_scale * self.block(input)
  54. result = self.stochastic_depth(result)
  55. result += input
  56. return result
  57. class CNBlockConfig:
  58. # Stores information listed at Section 3 of the ConvNeXt paper
  59. def __init__(
  60. self,
  61. input_channels: int,
  62. out_channels: Optional[int],
  63. num_layers: int,
  64. ) -> None:
  65. self.input_channels = input_channels
  66. self.out_channels = out_channels
  67. self.num_layers = num_layers
  68. def __repr__(self) -> str:
  69. s = self.__class__.__name__ + "("
  70. s += "input_channels={input_channels}"
  71. s += ", out_channels={out_channels}"
  72. s += ", num_layers={num_layers}"
  73. s += ")"
  74. return s.format(**self.__dict__)
  75. class ConvNeXt(nn.Module):
  76. def __init__(
  77. self,
  78. block_setting: List[CNBlockConfig],
  79. stochastic_depth_prob: float = 0.0,
  80. layer_scale: float = 1e-6,
  81. num_classes: int = 1000,
  82. block: Optional[Callable[..., nn.Module]] = None,
  83. norm_layer: Optional[Callable[..., nn.Module]] = None,
  84. **kwargs: Any,
  85. ) -> None:
  86. super().__init__()
  87. _log_api_usage_once(self)
  88. if not block_setting:
  89. raise ValueError("The block_setting should not be empty")
  90. elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
  91. raise TypeError("The block_setting should be List[CNBlockConfig]")
  92. if block is None:
  93. block = CNBlock
  94. if norm_layer is None:
  95. norm_layer = partial(LayerNorm2d, eps=1e-6)
  96. layers: List[nn.Module] = []
  97. # Stem
  98. firstconv_output_channels = block_setting[0].input_channels
  99. layers.append(
  100. Conv2dNormActivation(
  101. 3,
  102. firstconv_output_channels,
  103. kernel_size=4,
  104. stride=4,
  105. padding=0,
  106. norm_layer=norm_layer,
  107. activation_layer=None,
  108. bias=True,
  109. )
  110. )
  111. total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
  112. stage_block_id = 0
  113. for cnf in block_setting:
  114. # Bottlenecks
  115. stage: List[nn.Module] = []
  116. for _ in range(cnf.num_layers):
  117. # adjust stochastic depth probability based on the depth of the stage block
  118. sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
  119. stage.append(block(cnf.input_channels, layer_scale, sd_prob))
  120. stage_block_id += 1
  121. layers.append(nn.Sequential(*stage))
  122. if cnf.out_channels is not None:
  123. # Downsampling
  124. layers.append(
  125. nn.Sequential(
  126. norm_layer(cnf.input_channels),
  127. nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
  128. )
  129. )
  130. self.features = nn.Sequential(*layers)
  131. self.avgpool = nn.AdaptiveAvgPool2d(1)
  132. lastblock = block_setting[-1]
  133. lastconv_output_channels = (
  134. lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
  135. )
  136. self.classifier = nn.Sequential(
  137. norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
  138. )
  139. for m in self.modules():
  140. if isinstance(m, (nn.Conv2d, nn.Linear)):
  141. nn.init.trunc_normal_(m.weight, std=0.02)
  142. if m.bias is not None:
  143. nn.init.zeros_(m.bias)
  144. def _forward_impl(self, x: Tensor) -> Tensor:
  145. x = self.features(x)
  146. x = self.avgpool(x)
  147. x = self.classifier(x)
  148. return x
  149. def forward(self, x: Tensor) -> Tensor:
  150. return self._forward_impl(x)
  151. def _convnext(
  152. block_setting: List[CNBlockConfig],
  153. stochastic_depth_prob: float,
  154. weights: Optional[WeightsEnum],
  155. progress: bool,
  156. **kwargs: Any,
  157. ) -> ConvNeXt:
  158. if weights is not None:
  159. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  160. model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
  161. if weights is not None:
  162. model.load_state_dict(weights.get_state_dict(progress=progress))
  163. return model
  164. _COMMON_META = {
  165. "min_size": (32, 32),
  166. "categories": _IMAGENET_CATEGORIES,
  167. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
  168. "_docs": """
  169. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  170. `new training recipe
  171. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  172. """,
  173. }
  174. class ConvNeXt_Tiny_Weights(WeightsEnum):
  175. IMAGENET1K_V1 = Weights(
  176. url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
  177. transforms=partial(ImageClassification, crop_size=224, resize_size=236),
  178. meta={
  179. **_COMMON_META,
  180. "num_params": 28589128,
  181. "_metrics": {
  182. "ImageNet-1K": {
  183. "acc@1": 82.520,
  184. "acc@5": 96.146,
  185. }
  186. },
  187. },
  188. )
  189. DEFAULT = IMAGENET1K_V1
  190. class ConvNeXt_Small_Weights(WeightsEnum):
  191. IMAGENET1K_V1 = Weights(
  192. url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
  193. transforms=partial(ImageClassification, crop_size=224, resize_size=230),
  194. meta={
  195. **_COMMON_META,
  196. "num_params": 50223688,
  197. "_metrics": {
  198. "ImageNet-1K": {
  199. "acc@1": 83.616,
  200. "acc@5": 96.650,
  201. }
  202. },
  203. },
  204. )
  205. DEFAULT = IMAGENET1K_V1
  206. class ConvNeXt_Base_Weights(WeightsEnum):
  207. IMAGENET1K_V1 = Weights(
  208. url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
  209. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  210. meta={
  211. **_COMMON_META,
  212. "num_params": 88591464,
  213. "_metrics": {
  214. "ImageNet-1K": {
  215. "acc@1": 84.062,
  216. "acc@5": 96.870,
  217. }
  218. },
  219. },
  220. )
  221. DEFAULT = IMAGENET1K_V1
  222. class ConvNeXt_Large_Weights(WeightsEnum):
  223. IMAGENET1K_V1 = Weights(
  224. url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
  225. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  226. meta={
  227. **_COMMON_META,
  228. "num_params": 197767336,
  229. "_metrics": {
  230. "ImageNet-1K": {
  231. "acc@1": 84.414,
  232. "acc@5": 96.976,
  233. }
  234. },
  235. },
  236. )
  237. DEFAULT = IMAGENET1K_V1
  238. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
  239. def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
  240. """ConvNeXt Tiny model architecture from the
  241. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  242. Args:
  243. weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
  244. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
  245. below for more details and possible values. By default, no pre-trained weights are used.
  246. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  247. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  248. base class. Please refer to the `source code
  249. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  250. for more details about this class.
  251. .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
  252. :members:
  253. """
  254. weights = ConvNeXt_Tiny_Weights.verify(weights)
  255. block_setting = [
  256. CNBlockConfig(96, 192, 3),
  257. CNBlockConfig(192, 384, 3),
  258. CNBlockConfig(384, 768, 9),
  259. CNBlockConfig(768, None, 3),
  260. ]
  261. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
  262. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  263. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
  264. def convnext_small(
  265. *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
  266. ) -> ConvNeXt:
  267. """ConvNeXt Small model architecture from the
  268. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  269. Args:
  270. weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
  271. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
  272. below for more details and possible values. By default, no pre-trained weights are used.
  273. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  274. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  275. base class. Please refer to the `source code
  276. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  277. for more details about this class.
  278. .. autoclass:: torchvision.models.ConvNeXt_Small_Weights
  279. :members:
  280. """
  281. weights = ConvNeXt_Small_Weights.verify(weights)
  282. block_setting = [
  283. CNBlockConfig(96, 192, 3),
  284. CNBlockConfig(192, 384, 3),
  285. CNBlockConfig(384, 768, 27),
  286. CNBlockConfig(768, None, 3),
  287. ]
  288. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
  289. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  290. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
  291. def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
  292. """ConvNeXt Base model architecture from the
  293. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  294. Args:
  295. weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
  296. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
  297. below for more details and possible values. By default, no pre-trained weights are used.
  298. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  299. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  300. base class. Please refer to the `source code
  301. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  302. for more details about this class.
  303. .. autoclass:: torchvision.models.ConvNeXt_Base_Weights
  304. :members:
  305. """
  306. weights = ConvNeXt_Base_Weights.verify(weights)
  307. block_setting = [
  308. CNBlockConfig(128, 256, 3),
  309. CNBlockConfig(256, 512, 3),
  310. CNBlockConfig(512, 1024, 27),
  311. CNBlockConfig(1024, None, 3),
  312. ]
  313. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
  314. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  315. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
  316. def convnext_large(
  317. *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
  318. ) -> ConvNeXt:
  319. """ConvNeXt Large model architecture from the
  320. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  321. Args:
  322. weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
  323. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
  324. below for more details and possible values. By default, no pre-trained weights are used.
  325. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  326. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  327. base class. Please refer to the `source code
  328. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  329. for more details about this class.
  330. .. autoclass:: torchvision.models.ConvNeXt_Large_Weights
  331. :members:
  332. """
  333. weights = ConvNeXt_Large_Weights.verify(weights)
  334. block_setting = [
  335. CNBlockConfig(192, 384, 3),
  336. CNBlockConfig(384, 768, 3),
  337. CNBlockConfig(768, 1536, 27),
  338. CNBlockConfig(1536, None, 3),
  339. ]
  340. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
  341. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)