resnet.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959
  1. from functools import partial
  2. from typing import Type, Any, Callable, Union, List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from ..transforms._presets import ImageClassification
  7. from ..utils import _log_api_usage_once
  8. from ._api import WeightsEnum, Weights
  9. from ._meta import _IMAGENET_CATEGORIES
  10. from ._utils import handle_legacy_interface, _ovewrite_named_param
  11. __all__ = [
  12. "ResNet",
  13. "ResNet18_Weights",
  14. "ResNet34_Weights",
  15. "ResNet50_Weights",
  16. "ResNet101_Weights",
  17. "ResNet152_Weights",
  18. "ResNeXt50_32X4D_Weights",
  19. "ResNeXt101_32X8D_Weights",
  20. "ResNeXt101_64X4D_Weights",
  21. "Wide_ResNet50_2_Weights",
  22. "Wide_ResNet101_2_Weights",
  23. "resnet18",
  24. "resnet34",
  25. "resnet50",
  26. "resnet101",
  27. "resnet152",
  28. "resnext50_32x4d",
  29. "resnext101_32x8d",
  30. "resnext101_64x4d",
  31. "wide_resnet50_2",
  32. "wide_resnet101_2",
  33. ]
  34. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  35. """3x3 convolution with padding"""
  36. return nn.Conv2d(
  37. in_planes,
  38. out_planes,
  39. kernel_size=3,
  40. stride=stride,
  41. padding=dilation,
  42. groups=groups,
  43. bias=False,
  44. dilation=dilation,
  45. )
  46. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  47. """1x1 convolution"""
  48. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  49. class BasicBlock(nn.Module):
  50. expansion: int = 1
  51. def __init__(
  52. self,
  53. inplanes: int,
  54. planes: int,
  55. stride: int = 1,
  56. downsample: Optional[nn.Module] = None,
  57. groups: int = 1,
  58. base_width: int = 64,
  59. dilation: int = 1,
  60. norm_layer: Optional[Callable[..., nn.Module]] = None,
  61. ) -> None:
  62. super().__init__()
  63. if norm_layer is None:
  64. norm_layer = nn.BatchNorm2d
  65. if groups != 1 or base_width != 64:
  66. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  67. if dilation > 1:
  68. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  69. # Both self.conv1 and self.downsample layers downsample the input when stride != 1
  70. self.conv1 = conv3x3(inplanes, planes, stride)
  71. self.bn1 = norm_layer(planes)
  72. self.relu = nn.ReLU(inplace=True)
  73. self.conv2 = conv3x3(planes, planes)
  74. self.bn2 = norm_layer(planes)
  75. self.downsample = downsample
  76. self.stride = stride
  77. def forward(self, x: Tensor) -> Tensor:
  78. identity = x
  79. out = self.conv1(x)
  80. out = self.bn1(out)
  81. out = self.relu(out)
  82. out = self.conv2(out)
  83. out = self.bn2(out)
  84. if self.downsample is not None:
  85. identity = self.downsample(x)
  86. out += identity
  87. out = self.relu(out)
  88. return out
  89. class Bottleneck(nn.Module):
  90. # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
  91. # while original implementation places the stride at the first 1x1 convolution(self.conv1)
  92. # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
  93. # This variant is also known as ResNet V1.5 and improves accuracy according to
  94. # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
  95. expansion: int = 4
  96. def __init__(
  97. self,
  98. inplanes: int,
  99. planes: int,
  100. stride: int = 1,
  101. downsample: Optional[nn.Module] = None,
  102. groups: int = 1,
  103. base_width: int = 64,
  104. dilation: int = 1,
  105. norm_layer: Optional[Callable[..., nn.Module]] = None,
  106. ) -> None:
  107. super().__init__()
  108. if norm_layer is None:
  109. norm_layer = nn.BatchNorm2d
  110. width = int(planes * (base_width / 64.0)) * groups
  111. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  112. self.conv1 = conv1x1(inplanes, width)
  113. self.bn1 = norm_layer(width)
  114. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  115. self.bn2 = norm_layer(width)
  116. self.conv3 = conv1x1(width, planes * self.expansion)
  117. self.bn3 = norm_layer(planes * self.expansion)
  118. self.relu = nn.ReLU(inplace=True)
  119. self.downsample = downsample
  120. self.stride = stride
  121. def forward(self, x: Tensor) -> Tensor:
  122. identity = x
  123. out = self.conv1(x)
  124. out = self.bn1(out)
  125. out = self.relu(out)
  126. out = self.conv2(out)
  127. out = self.bn2(out)
  128. out = self.relu(out)
  129. out = self.conv3(out)
  130. out = self.bn3(out)
  131. if self.downsample is not None:
  132. identity = self.downsample(x)
  133. out += identity
  134. out = self.relu(out)
  135. return out
  136. class ResNet(nn.Module):
  137. def __init__(
  138. self,
  139. block: Type[Union[BasicBlock, Bottleneck]],
  140. layers: List[int],
  141. num_classes: int = 1000,
  142. zero_init_residual: bool = False,
  143. groups: int = 1,
  144. width_per_group: int = 64,
  145. replace_stride_with_dilation: Optional[List[bool]] = None,
  146. norm_layer: Optional[Callable[..., nn.Module]] = None,
  147. ) -> None:
  148. super().__init__()
  149. _log_api_usage_once(self)
  150. if norm_layer is None:
  151. norm_layer = nn.BatchNorm2d
  152. self._norm_layer = norm_layer
  153. self.inplanes = 64
  154. self.dilation = 1
  155. if replace_stride_with_dilation is None:
  156. # each element in the tuple indicates if we should replace
  157. # the 2x2 stride with a dilated convolution instead
  158. replace_stride_with_dilation = [False, False, False]
  159. if len(replace_stride_with_dilation) != 3:
  160. raise ValueError(
  161. "replace_stride_with_dilation should be None "
  162. f"or a 3-element tuple, got {replace_stride_with_dilation}"
  163. )
  164. self.groups = groups
  165. self.base_width = width_per_group
  166. self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
  167. self.bn1 = norm_layer(self.inplanes)
  168. self.relu = nn.ReLU(inplace=True)
  169. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  170. self.layer1 = self._make_layer(block, 64, layers[0])
  171. self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
  172. self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
  173. self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
  174. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  175. self.fc = nn.Linear(512 * block.expansion, num_classes)
  176. for m in self.modules():
  177. if isinstance(m, nn.Conv2d):
  178. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  179. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  180. nn.init.constant_(m.weight, 1)
  181. nn.init.constant_(m.bias, 0)
  182. # Zero-initialize the last BN in each residual branch,
  183. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
  184. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
  185. if zero_init_residual:
  186. for m in self.modules():
  187. if isinstance(m, Bottleneck) and m.bn3.weight is not None:
  188. nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
  189. elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
  190. nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
  191. def _make_layer(
  192. self,
  193. block: Type[Union[BasicBlock, Bottleneck]],
  194. planes: int,
  195. blocks: int,
  196. stride: int = 1,
  197. dilate: bool = False,
  198. ) -> nn.Sequential:
  199. norm_layer = self._norm_layer
  200. downsample = None
  201. previous_dilation = self.dilation
  202. if dilate:
  203. self.dilation *= stride
  204. stride = 1
  205. if stride != 1 or self.inplanes != planes * block.expansion:
  206. downsample = nn.Sequential(
  207. conv1x1(self.inplanes, planes * block.expansion, stride),
  208. norm_layer(planes * block.expansion),
  209. )
  210. layers = []
  211. layers.append(
  212. block(
  213. self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
  214. )
  215. )
  216. self.inplanes = planes * block.expansion
  217. for _ in range(1, blocks):
  218. layers.append(
  219. block(
  220. self.inplanes,
  221. planes,
  222. groups=self.groups,
  223. base_width=self.base_width,
  224. dilation=self.dilation,
  225. norm_layer=norm_layer,
  226. )
  227. )
  228. return nn.Sequential(*layers)
  229. def _forward_impl(self, x: Tensor) -> Tensor:
  230. # See note [TorchScript super()]
  231. x = self.conv1(x)
  232. x = self.bn1(x)
  233. x = self.relu(x)
  234. x = self.maxpool(x)
  235. x = self.layer1(x)
  236. x = self.layer2(x)
  237. x = self.layer3(x)
  238. x = self.layer4(x)
  239. x = self.avgpool(x)
  240. x = torch.flatten(x, 1)
  241. x = self.fc(x)
  242. return x
  243. def forward(self, x: Tensor) -> Tensor:
  244. return self._forward_impl(x)
  245. def _resnet(
  246. block: Type[Union[BasicBlock, Bottleneck]],
  247. layers: List[int],
  248. weights: Optional[WeightsEnum],
  249. progress: bool,
  250. **kwargs: Any,
  251. ) -> ResNet:
  252. if weights is not None:
  253. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  254. model = ResNet(block, layers, **kwargs)
  255. if weights is not None:
  256. model.load_state_dict(weights.get_state_dict(progress=progress))
  257. return model
  258. _COMMON_META = {
  259. "min_size": (1, 1),
  260. "categories": _IMAGENET_CATEGORIES,
  261. }
  262. class ResNet18_Weights(WeightsEnum):
  263. IMAGENET1K_V1 = Weights(
  264. url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
  265. transforms=partial(ImageClassification, crop_size=224),
  266. meta={
  267. **_COMMON_META,
  268. "num_params": 11689512,
  269. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  270. "_metrics": {
  271. "ImageNet-1K": {
  272. "acc@1": 69.758,
  273. "acc@5": 89.078,
  274. }
  275. },
  276. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  277. },
  278. )
  279. DEFAULT = IMAGENET1K_V1
  280. class ResNet34_Weights(WeightsEnum):
  281. IMAGENET1K_V1 = Weights(
  282. url="https://download.pytorch.org/models/resnet34-b627a593.pth",
  283. transforms=partial(ImageClassification, crop_size=224),
  284. meta={
  285. **_COMMON_META,
  286. "num_params": 21797672,
  287. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  288. "_metrics": {
  289. "ImageNet-1K": {
  290. "acc@1": 73.314,
  291. "acc@5": 91.420,
  292. }
  293. },
  294. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  295. },
  296. )
  297. DEFAULT = IMAGENET1K_V1
  298. class ResNet50_Weights(WeightsEnum):
  299. IMAGENET1K_V1 = Weights(
  300. url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
  301. transforms=partial(ImageClassification, crop_size=224),
  302. meta={
  303. **_COMMON_META,
  304. "num_params": 25557032,
  305. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  306. "_metrics": {
  307. "ImageNet-1K": {
  308. "acc@1": 76.130,
  309. "acc@5": 92.862,
  310. }
  311. },
  312. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  313. },
  314. )
  315. IMAGENET1K_V2 = Weights(
  316. url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
  317. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  318. meta={
  319. **_COMMON_META,
  320. "num_params": 25557032,
  321. "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
  322. "_metrics": {
  323. "ImageNet-1K": {
  324. "acc@1": 80.858,
  325. "acc@5": 95.434,
  326. }
  327. },
  328. "_docs": """
  329. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  330. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  331. """,
  332. },
  333. )
  334. DEFAULT = IMAGENET1K_V2
  335. class ResNet101_Weights(WeightsEnum):
  336. IMAGENET1K_V1 = Weights(
  337. url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
  338. transforms=partial(ImageClassification, crop_size=224),
  339. meta={
  340. **_COMMON_META,
  341. "num_params": 44549160,
  342. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  343. "_metrics": {
  344. "ImageNet-1K": {
  345. "acc@1": 77.374,
  346. "acc@5": 93.546,
  347. }
  348. },
  349. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  350. },
  351. )
  352. IMAGENET1K_V2 = Weights(
  353. url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
  354. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  355. meta={
  356. **_COMMON_META,
  357. "num_params": 44549160,
  358. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  359. "_metrics": {
  360. "ImageNet-1K": {
  361. "acc@1": 81.886,
  362. "acc@5": 95.780,
  363. }
  364. },
  365. "_docs": """
  366. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  367. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  368. """,
  369. },
  370. )
  371. DEFAULT = IMAGENET1K_V2
  372. class ResNet152_Weights(WeightsEnum):
  373. IMAGENET1K_V1 = Weights(
  374. url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
  375. transforms=partial(ImageClassification, crop_size=224),
  376. meta={
  377. **_COMMON_META,
  378. "num_params": 60192808,
  379. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
  380. "_metrics": {
  381. "ImageNet-1K": {
  382. "acc@1": 78.312,
  383. "acc@5": 94.046,
  384. }
  385. },
  386. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  387. },
  388. )
  389. IMAGENET1K_V2 = Weights(
  390. url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
  391. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  392. meta={
  393. **_COMMON_META,
  394. "num_params": 60192808,
  395. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  396. "_metrics": {
  397. "ImageNet-1K": {
  398. "acc@1": 82.284,
  399. "acc@5": 96.002,
  400. }
  401. },
  402. "_docs": """
  403. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  404. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  405. """,
  406. },
  407. )
  408. DEFAULT = IMAGENET1K_V2
  409. class ResNeXt50_32X4D_Weights(WeightsEnum):
  410. IMAGENET1K_V1 = Weights(
  411. url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
  412. transforms=partial(ImageClassification, crop_size=224),
  413. meta={
  414. **_COMMON_META,
  415. "num_params": 25028904,
  416. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
  417. "_metrics": {
  418. "ImageNet-1K": {
  419. "acc@1": 77.618,
  420. "acc@5": 93.698,
  421. }
  422. },
  423. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  424. },
  425. )
  426. IMAGENET1K_V2 = Weights(
  427. url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
  428. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  429. meta={
  430. **_COMMON_META,
  431. "num_params": 25028904,
  432. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  433. "_metrics": {
  434. "ImageNet-1K": {
  435. "acc@1": 81.198,
  436. "acc@5": 95.340,
  437. }
  438. },
  439. "_docs": """
  440. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  441. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  442. """,
  443. },
  444. )
  445. DEFAULT = IMAGENET1K_V2
  446. class ResNeXt101_32X8D_Weights(WeightsEnum):
  447. IMAGENET1K_V1 = Weights(
  448. url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
  449. transforms=partial(ImageClassification, crop_size=224),
  450. meta={
  451. **_COMMON_META,
  452. "num_params": 88791336,
  453. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
  454. "_metrics": {
  455. "ImageNet-1K": {
  456. "acc@1": 79.312,
  457. "acc@5": 94.526,
  458. }
  459. },
  460. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  461. },
  462. )
  463. IMAGENET1K_V2 = Weights(
  464. url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
  465. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  466. meta={
  467. **_COMMON_META,
  468. "num_params": 88791336,
  469. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  470. "_metrics": {
  471. "ImageNet-1K": {
  472. "acc@1": 82.834,
  473. "acc@5": 96.228,
  474. }
  475. },
  476. "_docs": """
  477. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  478. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  479. """,
  480. },
  481. )
  482. DEFAULT = IMAGENET1K_V2
  483. class ResNeXt101_64X4D_Weights(WeightsEnum):
  484. IMAGENET1K_V1 = Weights(
  485. url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
  486. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  487. meta={
  488. **_COMMON_META,
  489. "num_params": 83455272,
  490. "recipe": "https://github.com/pytorch/vision/pull/5935",
  491. "_metrics": {
  492. "ImageNet-1K": {
  493. "acc@1": 83.246,
  494. "acc@5": 96.454,
  495. }
  496. },
  497. "_docs": """
  498. These weights were trained from scratch by using TorchVision's `new training recipe
  499. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  500. """,
  501. },
  502. )
  503. DEFAULT = IMAGENET1K_V1
  504. class Wide_ResNet50_2_Weights(WeightsEnum):
  505. IMAGENET1K_V1 = Weights(
  506. url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
  507. transforms=partial(ImageClassification, crop_size=224),
  508. meta={
  509. **_COMMON_META,
  510. "num_params": 68883240,
  511. "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
  512. "_metrics": {
  513. "ImageNet-1K": {
  514. "acc@1": 78.468,
  515. "acc@5": 94.086,
  516. }
  517. },
  518. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  519. },
  520. )
  521. IMAGENET1K_V2 = Weights(
  522. url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
  523. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  524. meta={
  525. **_COMMON_META,
  526. "num_params": 68883240,
  527. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  528. "_metrics": {
  529. "ImageNet-1K": {
  530. "acc@1": 81.602,
  531. "acc@5": 95.758,
  532. }
  533. },
  534. "_docs": """
  535. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  536. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  537. """,
  538. },
  539. )
  540. DEFAULT = IMAGENET1K_V2
  541. class Wide_ResNet101_2_Weights(WeightsEnum):
  542. IMAGENET1K_V1 = Weights(
  543. url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
  544. transforms=partial(ImageClassification, crop_size=224),
  545. meta={
  546. **_COMMON_META,
  547. "num_params": 126886696,
  548. "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
  549. "_metrics": {
  550. "ImageNet-1K": {
  551. "acc@1": 78.848,
  552. "acc@5": 94.284,
  553. }
  554. },
  555. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  556. },
  557. )
  558. IMAGENET1K_V2 = Weights(
  559. url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
  560. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  561. meta={
  562. **_COMMON_META,
  563. "num_params": 126886696,
  564. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  565. "_metrics": {
  566. "ImageNet-1K": {
  567. "acc@1": 82.510,
  568. "acc@5": 96.020,
  569. }
  570. },
  571. "_docs": """
  572. These weights improve upon the results of the original paper by using TorchVision's `new training recipe
  573. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  574. """,
  575. },
  576. )
  577. DEFAULT = IMAGENET1K_V2
  578. @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
  579. def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  580. """ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
  581. Args:
  582. weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
  583. pretrained weights to use. See
  584. :class:`~torchvision.models.ResNet18_Weights` below for
  585. more details, and possible values. By default, no pre-trained
  586. weights are used.
  587. progress (bool, optional): If True, displays a progress bar of the
  588. download to stderr. Default is True.
  589. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  590. base class. Please refer to the `source code
  591. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  592. for more details about this class.
  593. .. autoclass:: torchvision.models.ResNet18_Weights
  594. :members:
  595. """
  596. weights = ResNet18_Weights.verify(weights)
  597. return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
  598. @handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
  599. def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  600. """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
  601. Args:
  602. weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
  603. pretrained weights to use. See
  604. :class:`~torchvision.models.ResNet34_Weights` below for
  605. more details, and possible values. By default, no pre-trained
  606. weights are used.
  607. progress (bool, optional): If True, displays a progress bar of the
  608. download to stderr. Default is True.
  609. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  610. base class. Please refer to the `source code
  611. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  612. for more details about this class.
  613. .. autoclass:: torchvision.models.ResNet34_Weights
  614. :members:
  615. """
  616. weights = ResNet34_Weights.verify(weights)
  617. return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
  618. @handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
  619. def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  620. """ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
  621. .. note::
  622. The bottleneck of TorchVision places the stride for downsampling to the second 3x3
  623. convolution while the original paper places it to the first 1x1 convolution.
  624. This variant improves the accuracy and is known as `ResNet V1.5
  625. <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.
  626. Args:
  627. weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  628. pretrained weights to use. See
  629. :class:`~torchvision.models.ResNet50_Weights` below for
  630. more details, and possible values. By default, no pre-trained
  631. weights are used.
  632. progress (bool, optional): If True, displays a progress bar of the
  633. download to stderr. Default is True.
  634. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  635. base class. Please refer to the `source code
  636. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  637. for more details about this class.
  638. .. autoclass:: torchvision.models.ResNet50_Weights
  639. :members:
  640. """
  641. weights = ResNet50_Weights.verify(weights)
  642. return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
  643. @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
  644. def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  645. """ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
  646. .. note::
  647. The bottleneck of TorchVision places the stride for downsampling to the second 3x3
  648. convolution while the original paper places it to the first 1x1 convolution.
  649. This variant improves the accuracy and is known as `ResNet V1.5
  650. <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.
  651. Args:
  652. weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
  653. pretrained weights to use. See
  654. :class:`~torchvision.models.ResNet101_Weights` below for
  655. more details, and possible values. By default, no pre-trained
  656. weights are used.
  657. progress (bool, optional): If True, displays a progress bar of the
  658. download to stderr. Default is True.
  659. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  660. base class. Please refer to the `source code
  661. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  662. for more details about this class.
  663. .. autoclass:: torchvision.models.ResNet101_Weights
  664. :members:
  665. """
  666. weights = ResNet101_Weights.verify(weights)
  667. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
  668. @handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
  669. def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
  670. """ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
  671. .. note::
  672. The bottleneck of TorchVision places the stride for downsampling to the second 3x3
  673. convolution while the original paper places it to the first 1x1 convolution.
  674. This variant improves the accuracy and is known as `ResNet V1.5
  675. <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.
  676. Args:
  677. weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
  678. pretrained weights to use. See
  679. :class:`~torchvision.models.ResNet152_Weights` below for
  680. more details, and possible values. By default, no pre-trained
  681. weights are used.
  682. progress (bool, optional): If True, displays a progress bar of the
  683. download to stderr. Default is True.
  684. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  685. base class. Please refer to the `source code
  686. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  687. for more details about this class.
  688. .. autoclass:: torchvision.models.ResNet152_Weights
  689. :members:
  690. """
  691. weights = ResNet152_Weights.verify(weights)
  692. return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
  693. @handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
  694. def resnext50_32x4d(
  695. *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
  696. ) -> ResNet:
  697. """ResNeXt-50 32x4d model from
  698. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
  699. Args:
  700. weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The
  701. pretrained weights to use. See
  702. :class:`~torchvision.models.ResNext50_32X4D_Weights` below for
  703. more details, and possible values. By default, no pre-trained
  704. weights are used.
  705. progress (bool, optional): If True, displays a progress bar of the
  706. download to stderr. Default is True.
  707. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  708. base class. Please refer to the `source code
  709. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  710. for more details about this class.
  711. .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights
  712. :members:
  713. """
  714. weights = ResNeXt50_32X4D_Weights.verify(weights)
  715. _ovewrite_named_param(kwargs, "groups", 32)
  716. _ovewrite_named_param(kwargs, "width_per_group", 4)
  717. return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
  718. @handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
  719. def resnext101_32x8d(
  720. *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
  721. ) -> ResNet:
  722. """ResNeXt-101 32x8d model from
  723. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
  724. Args:
  725. weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
  726. pretrained weights to use. See
  727. :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for
  728. more details, and possible values. By default, no pre-trained
  729. weights are used.
  730. progress (bool, optional): If True, displays a progress bar of the
  731. download to stderr. Default is True.
  732. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  733. base class. Please refer to the `source code
  734. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  735. for more details about this class.
  736. .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
  737. :members:
  738. """
  739. weights = ResNeXt101_32X8D_Weights.verify(weights)
  740. _ovewrite_named_param(kwargs, "groups", 32)
  741. _ovewrite_named_param(kwargs, "width_per_group", 8)
  742. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
  743. def resnext101_64x4d(
  744. *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
  745. ) -> ResNet:
  746. """ResNeXt-101 64x4d model from
  747. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
  748. Args:
  749. weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
  750. pretrained weights to use. See
  751. :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for
  752. more details, and possible values. By default, no pre-trained
  753. weights are used.
  754. progress (bool, optional): If True, displays a progress bar of the
  755. download to stderr. Default is True.
  756. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  757. base class. Please refer to the `source code
  758. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  759. for more details about this class.
  760. .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
  761. :members:
  762. """
  763. weights = ResNeXt101_64X4D_Weights.verify(weights)
  764. _ovewrite_named_param(kwargs, "groups", 64)
  765. _ovewrite_named_param(kwargs, "width_per_group", 4)
  766. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
  767. @handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
  768. def wide_resnet50_2(
  769. *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
  770. ) -> ResNet:
  771. """Wide ResNet-50-2 model from
  772. `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.
  773. The model is the same as ResNet except for the bottleneck number of channels
  774. which is twice larger in every block. The number of channels in outer 1x1
  775. convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
  776. channels, and in Wide ResNet-50-2 has 2048-1024-2048.
  777. Args:
  778. weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The
  779. pretrained weights to use. See
  780. :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for
  781. more details, and possible values. By default, no pre-trained
  782. weights are used.
  783. progress (bool, optional): If True, displays a progress bar of the
  784. download to stderr. Default is True.
  785. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  786. base class. Please refer to the `source code
  787. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  788. for more details about this class.
  789. .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights
  790. :members:
  791. """
  792. weights = Wide_ResNet50_2_Weights.verify(weights)
  793. _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
  794. return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
  795. @handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
  796. def wide_resnet101_2(
  797. *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
  798. ) -> ResNet:
  799. """Wide ResNet-101-2 model from
  800. `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.
  801. The model is the same as ResNet except for the bottleneck number of channels
  802. which is twice larger in every block. The number of channels in outer 1x1
  803. convolutions is the same, e.g. last block in ResNet-101 has 2048-512-2048
  804. channels, and in Wide ResNet-101-2 has 2048-1024-2048.
  805. Args:
  806. weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The
  807. pretrained weights to use. See
  808. :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for
  809. more details, and possible values. By default, no pre-trained
  810. weights are used.
  811. progress (bool, optional): If True, displays a progress bar of the
  812. download to stderr. Default is True.
  813. **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
  814. base class. Please refer to the `source code
  815. <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
  816. for more details about this class.
  817. .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights
  818. :members:
  819. """
  820. weights = Wide_ResNet101_2_Weights.verify(weights)
  821. _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
  822. return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
  823. # The dictionary below is internal implementation detail and will be removed in v0.15
  824. from ._utils import _ModelURLs
  825. model_urls = _ModelURLs(
  826. {
  827. "resnet18": ResNet18_Weights.IMAGENET1K_V1.url,
  828. "resnet34": ResNet34_Weights.IMAGENET1K_V1.url,
  829. "resnet50": ResNet50_Weights.IMAGENET1K_V1.url,
  830. "resnet101": ResNet101_Weights.IMAGENET1K_V1.url,
  831. "resnet152": ResNet152_Weights.IMAGENET1K_V1.url,
  832. "resnext50_32x4d": ResNeXt50_32X4D_Weights.IMAGENET1K_V1.url,
  833. "resnext101_32x8d": ResNeXt101_32X8D_Weights.IMAGENET1K_V1.url,
  834. "wide_resnet50_2": Wide_ResNet50_2_Weights.IMAGENET1K_V1.url,
  835. "wide_resnet101_2": Wide_ResNet101_2_Weights.IMAGENET1K_V1.url,
  836. }
  837. )