vgg.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. from functools import partial
  2. from typing import Union, List, Dict, Any, Optional, cast
  3. import torch
  4. import torch.nn as nn
  5. from ..transforms._presets import ImageClassification
  6. from ..utils import _log_api_usage_once
  7. from ._api import WeightsEnum, Weights
  8. from ._meta import _IMAGENET_CATEGORIES
  9. from ._utils import handle_legacy_interface, _ovewrite_named_param
  10. __all__ = [
  11. "VGG",
  12. "VGG11_Weights",
  13. "VGG11_BN_Weights",
  14. "VGG13_Weights",
  15. "VGG13_BN_Weights",
  16. "VGG16_Weights",
  17. "VGG16_BN_Weights",
  18. "VGG19_Weights",
  19. "VGG19_BN_Weights",
  20. "vgg11",
  21. "vgg11_bn",
  22. "vgg13",
  23. "vgg13_bn",
  24. "vgg16",
  25. "vgg16_bn",
  26. "vgg19",
  27. "vgg19_bn",
  28. ]
  29. class VGG(nn.Module):
  30. def __init__(
  31. self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
  32. ) -> None:
  33. super().__init__()
  34. _log_api_usage_once(self)
  35. self.features = features
  36. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  37. self.classifier = nn.Sequential(
  38. nn.Linear(512 * 7 * 7, 4096),
  39. nn.ReLU(True),
  40. nn.Dropout(p=dropout),
  41. nn.Linear(4096, 4096),
  42. nn.ReLU(True),
  43. nn.Dropout(p=dropout),
  44. nn.Linear(4096, num_classes),
  45. )
  46. if init_weights:
  47. for m in self.modules():
  48. if isinstance(m, nn.Conv2d):
  49. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  50. if m.bias is not None:
  51. nn.init.constant_(m.bias, 0)
  52. elif isinstance(m, nn.BatchNorm2d):
  53. nn.init.constant_(m.weight, 1)
  54. nn.init.constant_(m.bias, 0)
  55. elif isinstance(m, nn.Linear):
  56. nn.init.normal_(m.weight, 0, 0.01)
  57. nn.init.constant_(m.bias, 0)
  58. def forward(self, x: torch.Tensor) -> torch.Tensor:
  59. x = self.features(x)
  60. x = self.avgpool(x)
  61. x = torch.flatten(x, 1)
  62. x = self.classifier(x)
  63. return x
  64. def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
  65. layers: List[nn.Module] = []
  66. in_channels = 3
  67. for v in cfg:
  68. if v == "M":
  69. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  70. else:
  71. v = cast(int, v)
  72. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  73. if batch_norm:
  74. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  75. else:
  76. layers += [conv2d, nn.ReLU(inplace=True)]
  77. in_channels = v
  78. return nn.Sequential(*layers)
  79. cfgs: Dict[str, List[Union[str, int]]] = {
  80. "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
  81. "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
  82. "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
  83. "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
  84. }
  85. def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
  86. if weights is not None:
  87. kwargs["init_weights"] = False
  88. if weights.meta["categories"] is not None:
  89. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  90. model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
  91. if weights is not None:
  92. model.load_state_dict(weights.get_state_dict(progress=progress))
  93. return model
  94. _COMMON_META = {
  95. "min_size": (32, 32),
  96. "categories": _IMAGENET_CATEGORIES,
  97. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
  98. "_docs": """These weights were trained from scratch by using a simplified training recipe.""",
  99. }
  100. class VGG11_Weights(WeightsEnum):
  101. IMAGENET1K_V1 = Weights(
  102. url="https://download.pytorch.org/models/vgg11-8a719046.pth",
  103. transforms=partial(ImageClassification, crop_size=224),
  104. meta={
  105. **_COMMON_META,
  106. "num_params": 132863336,
  107. "_metrics": {
  108. "ImageNet-1K": {
  109. "acc@1": 69.020,
  110. "acc@5": 88.628,
  111. }
  112. },
  113. },
  114. )
  115. DEFAULT = IMAGENET1K_V1
  116. class VGG11_BN_Weights(WeightsEnum):
  117. IMAGENET1K_V1 = Weights(
  118. url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
  119. transforms=partial(ImageClassification, crop_size=224),
  120. meta={
  121. **_COMMON_META,
  122. "num_params": 132868840,
  123. "_metrics": {
  124. "ImageNet-1K": {
  125. "acc@1": 70.370,
  126. "acc@5": 89.810,
  127. }
  128. },
  129. },
  130. )
  131. DEFAULT = IMAGENET1K_V1
  132. class VGG13_Weights(WeightsEnum):
  133. IMAGENET1K_V1 = Weights(
  134. url="https://download.pytorch.org/models/vgg13-19584684.pth",
  135. transforms=partial(ImageClassification, crop_size=224),
  136. meta={
  137. **_COMMON_META,
  138. "num_params": 133047848,
  139. "_metrics": {
  140. "ImageNet-1K": {
  141. "acc@1": 69.928,
  142. "acc@5": 89.246,
  143. }
  144. },
  145. },
  146. )
  147. DEFAULT = IMAGENET1K_V1
  148. class VGG13_BN_Weights(WeightsEnum):
  149. IMAGENET1K_V1 = Weights(
  150. url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
  151. transforms=partial(ImageClassification, crop_size=224),
  152. meta={
  153. **_COMMON_META,
  154. "num_params": 133053736,
  155. "_metrics": {
  156. "ImageNet-1K": {
  157. "acc@1": 71.586,
  158. "acc@5": 90.374,
  159. }
  160. },
  161. },
  162. )
  163. DEFAULT = IMAGENET1K_V1
  164. class VGG16_Weights(WeightsEnum):
  165. IMAGENET1K_V1 = Weights(
  166. url="https://download.pytorch.org/models/vgg16-397923af.pth",
  167. transforms=partial(ImageClassification, crop_size=224),
  168. meta={
  169. **_COMMON_META,
  170. "num_params": 138357544,
  171. "_metrics": {
  172. "ImageNet-1K": {
  173. "acc@1": 71.592,
  174. "acc@5": 90.382,
  175. }
  176. },
  177. },
  178. )
  179. IMAGENET1K_FEATURES = Weights(
  180. # Weights ported from https://github.com/amdegroot/ssd.pytorch/
  181. url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
  182. transforms=partial(
  183. ImageClassification,
  184. crop_size=224,
  185. mean=(0.48235, 0.45882, 0.40784),
  186. std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
  187. ),
  188. meta={
  189. **_COMMON_META,
  190. "num_params": 138357544,
  191. "categories": None,
  192. "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
  193. "_metrics": {
  194. "ImageNet-1K": {
  195. "acc@1": float("nan"),
  196. "acc@5": float("nan"),
  197. }
  198. },
  199. "_docs": """
  200. These weights can't be used for classification because they are missing values in the `classifier`
  201. module. Only the `features` module has valid values and can be used for feature extraction. The weights
  202. were trained using the original input standardization method as described in the paper.
  203. """,
  204. },
  205. )
  206. DEFAULT = IMAGENET1K_V1
  207. class VGG16_BN_Weights(WeightsEnum):
  208. IMAGENET1K_V1 = Weights(
  209. url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
  210. transforms=partial(ImageClassification, crop_size=224),
  211. meta={
  212. **_COMMON_META,
  213. "num_params": 138365992,
  214. "_metrics": {
  215. "ImageNet-1K": {
  216. "acc@1": 73.360,
  217. "acc@5": 91.516,
  218. }
  219. },
  220. },
  221. )
  222. DEFAULT = IMAGENET1K_V1
  223. class VGG19_Weights(WeightsEnum):
  224. IMAGENET1K_V1 = Weights(
  225. url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
  226. transforms=partial(ImageClassification, crop_size=224),
  227. meta={
  228. **_COMMON_META,
  229. "num_params": 143667240,
  230. "_metrics": {
  231. "ImageNet-1K": {
  232. "acc@1": 72.376,
  233. "acc@5": 90.876,
  234. }
  235. },
  236. },
  237. )
  238. DEFAULT = IMAGENET1K_V1
  239. class VGG19_BN_Weights(WeightsEnum):
  240. IMAGENET1K_V1 = Weights(
  241. url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
  242. transforms=partial(ImageClassification, crop_size=224),
  243. meta={
  244. **_COMMON_META,
  245. "num_params": 143678248,
  246. "_metrics": {
  247. "ImageNet-1K": {
  248. "acc@1": 74.218,
  249. "acc@5": 91.842,
  250. }
  251. },
  252. },
  253. )
  254. DEFAULT = IMAGENET1K_V1
  255. @handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
  256. def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  257. """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  258. Args:
  259. weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
  260. pretrained weights to use. See
  261. :class:`~torchvision.models.VGG11_Weights` below for
  262. more details, and possible values. By default, no pre-trained
  263. weights are used.
  264. progress (bool, optional): If True, displays a progress bar of the
  265. download to stderr. Default is True.
  266. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  267. base class. Please refer to the `source code
  268. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  269. for more details about this class.
  270. .. autoclass:: torchvision.models.VGG11_Weights
  271. :members:
  272. """
  273. weights = VGG11_Weights.verify(weights)
  274. return _vgg("A", False, weights, progress, **kwargs)
  275. @handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
  276. def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  277. """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  278. Args:
  279. weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
  280. pretrained weights to use. See
  281. :class:`~torchvision.models.VGG11_BN_Weights` below for
  282. more details, and possible values. By default, no pre-trained
  283. weights are used.
  284. progress (bool, optional): If True, displays a progress bar of the
  285. download to stderr. Default is True.
  286. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  287. base class. Please refer to the `source code
  288. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  289. for more details about this class.
  290. .. autoclass:: torchvision.models.VGG11_BN_Weights
  291. :members:
  292. """
  293. weights = VGG11_BN_Weights.verify(weights)
  294. return _vgg("A", True, weights, progress, **kwargs)
  295. @handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
  296. def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  297. """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  298. Args:
  299. weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
  300. pretrained weights to use. See
  301. :class:`~torchvision.models.VGG13_Weights` below for
  302. more details, and possible values. By default, no pre-trained
  303. weights are used.
  304. progress (bool, optional): If True, displays a progress bar of the
  305. download to stderr. Default is True.
  306. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  307. base class. Please refer to the `source code
  308. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  309. for more details about this class.
  310. .. autoclass:: torchvision.models.VGG13_Weights
  311. :members:
  312. """
  313. weights = VGG13_Weights.verify(weights)
  314. return _vgg("B", False, weights, progress, **kwargs)
  315. @handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
  316. def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  317. """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  318. Args:
  319. weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
  320. pretrained weights to use. See
  321. :class:`~torchvision.models.VGG13_BN_Weights` below for
  322. more details, and possible values. By default, no pre-trained
  323. weights are used.
  324. progress (bool, optional): If True, displays a progress bar of the
  325. download to stderr. Default is True.
  326. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  327. base class. Please refer to the `source code
  328. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  329. for more details about this class.
  330. .. autoclass:: torchvision.models.VGG13_BN_Weights
  331. :members:
  332. """
  333. weights = VGG13_BN_Weights.verify(weights)
  334. return _vgg("B", True, weights, progress, **kwargs)
  335. @handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
  336. def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  337. """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  338. Args:
  339. weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
  340. pretrained weights to use. See
  341. :class:`~torchvision.models.VGG16_Weights` below for
  342. more details, and possible values. By default, no pre-trained
  343. weights are used.
  344. progress (bool, optional): If True, displays a progress bar of the
  345. download to stderr. Default is True.
  346. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  347. base class. Please refer to the `source code
  348. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  349. for more details about this class.
  350. .. autoclass:: torchvision.models.VGG16_Weights
  351. :members:
  352. """
  353. weights = VGG16_Weights.verify(weights)
  354. return _vgg("D", False, weights, progress, **kwargs)
  355. @handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
  356. def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  357. """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  358. Args:
  359. weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
  360. pretrained weights to use. See
  361. :class:`~torchvision.models.VGG16_BN_Weights` below for
  362. more details, and possible values. By default, no pre-trained
  363. weights are used.
  364. progress (bool, optional): If True, displays a progress bar of the
  365. download to stderr. Default is True.
  366. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  367. base class. Please refer to the `source code
  368. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  369. for more details about this class.
  370. .. autoclass:: torchvision.models.VGG16_BN_Weights
  371. :members:
  372. """
  373. weights = VGG16_BN_Weights.verify(weights)
  374. return _vgg("D", True, weights, progress, **kwargs)
  375. @handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
  376. def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  377. """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  378. Args:
  379. weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
  380. pretrained weights to use. See
  381. :class:`~torchvision.models.VGG19_Weights` below for
  382. more details, and possible values. By default, no pre-trained
  383. weights are used.
  384. progress (bool, optional): If True, displays a progress bar of the
  385. download to stderr. Default is True.
  386. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  387. base class. Please refer to the `source code
  388. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  389. for more details about this class.
  390. .. autoclass:: torchvision.models.VGG19_Weights
  391. :members:
  392. """
  393. weights = VGG19_Weights.verify(weights)
  394. return _vgg("E", False, weights, progress, **kwargs)
  395. @handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
  396. def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
  397. """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  398. Args:
  399. weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
  400. pretrained weights to use. See
  401. :class:`~torchvision.models.VGG19_BN_Weights` below for
  402. more details, and possible values. By default, no pre-trained
  403. weights are used.
  404. progress (bool, optional): If True, displays a progress bar of the
  405. download to stderr. Default is True.
  406. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  407. base class. Please refer to the `source code
  408. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  409. for more details about this class.
  410. .. autoclass:: torchvision.models.VGG19_BN_Weights
  411. :members:
  412. """
  413. weights = VGG19_BN_Weights.verify(weights)
  414. return _vgg("E", True, weights, progress, **kwargs)
  415. # The dictionary below is internal implementation detail and will be removed in v0.15
  416. from ._utils import _ModelURLs
  417. model_urls = _ModelURLs(
  418. {
  419. "vgg11": VGG11_Weights.IMAGENET1K_V1.url,
  420. "vgg13": VGG13_Weights.IMAGENET1K_V1.url,
  421. "vgg16": VGG16_Weights.IMAGENET1K_V1.url,
  422. "vgg19": VGG19_Weights.IMAGENET1K_V1.url,
  423. "vgg11_bn": VGG11_BN_Weights.IMAGENET1K_V1.url,
  424. "vgg13_bn": VGG13_BN_Weights.IMAGENET1K_V1.url,
  425. "vgg16_bn": VGG16_BN_Weights.IMAGENET1K_V1.url,
  426. "vgg19_bn": VGG19_BN_Weights.IMAGENET1K_V1.url,
  427. }
  428. )