resnet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. from functools import partial
  2. from typing import Any, Type, Union, List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor
  6. from torchvision.models.resnet import (
  7. Bottleneck,
  8. BasicBlock,
  9. ResNet,
  10. ResNet18_Weights,
  11. ResNet50_Weights,
  12. ResNeXt101_32X8D_Weights,
  13. ResNeXt101_64X4D_Weights,
  14. )
  15. from ...transforms._presets import ImageClassification
  16. from .._api import WeightsEnum, Weights
  17. from .._meta import _IMAGENET_CATEGORIES
  18. from .._utils import handle_legacy_interface, _ovewrite_named_param
  19. from .utils import _fuse_modules, _replace_relu, quantize_model
  20. __all__ = [
  21. "QuantizableResNet",
  22. "ResNet18_QuantizedWeights",
  23. "ResNet50_QuantizedWeights",
  24. "ResNeXt101_32X8D_QuantizedWeights",
  25. "ResNeXt101_64X4D_QuantizedWeights",
  26. "resnet18",
  27. "resnet50",
  28. "resnext101_32x8d",
  29. "resnext101_64x4d",
  30. ]
  31. class QuantizableBasicBlock(BasicBlock):
  32. def __init__(self, *args: Any, **kwargs: Any) -> None:
  33. super().__init__(*args, **kwargs)
  34. self.add_relu = torch.nn.quantized.FloatFunctional()
  35. def forward(self, x: Tensor) -> Tensor:
  36. identity = x
  37. out = self.conv1(x)
  38. out = self.bn1(out)
  39. out = self.relu(out)
  40. out = self.conv2(out)
  41. out = self.bn2(out)
  42. if self.downsample is not None:
  43. identity = self.downsample(x)
  44. out = self.add_relu.add_relu(out, identity)
  45. return out
  46. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  47. _fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], is_qat, inplace=True)
  48. if self.downsample:
  49. _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
  50. class QuantizableBottleneck(Bottleneck):
  51. def __init__(self, *args: Any, **kwargs: Any) -> None:
  52. super().__init__(*args, **kwargs)
  53. self.skip_add_relu = nn.quantized.FloatFunctional()
  54. self.relu1 = nn.ReLU(inplace=False)
  55. self.relu2 = nn.ReLU(inplace=False)
  56. def forward(self, x: Tensor) -> Tensor:
  57. identity = x
  58. out = self.conv1(x)
  59. out = self.bn1(out)
  60. out = self.relu1(out)
  61. out = self.conv2(out)
  62. out = self.bn2(out)
  63. out = self.relu2(out)
  64. out = self.conv3(out)
  65. out = self.bn3(out)
  66. if self.downsample is not None:
  67. identity = self.downsample(x)
  68. out = self.skip_add_relu.add_relu(out, identity)
  69. return out
  70. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  71. _fuse_modules(
  72. self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], is_qat, inplace=True
  73. )
  74. if self.downsample:
  75. _fuse_modules(self.downsample, ["0", "1"], is_qat, inplace=True)
  76. class QuantizableResNet(ResNet):
  77. def __init__(self, *args: Any, **kwargs: Any) -> None:
  78. super().__init__(*args, **kwargs)
  79. self.quant = torch.ao.quantization.QuantStub()
  80. self.dequant = torch.ao.quantization.DeQuantStub()
  81. def forward(self, x: Tensor) -> Tensor:
  82. x = self.quant(x)
  83. # Ensure scriptability
  84. # super(QuantizableResNet,self).forward(x)
  85. # is not scriptable
  86. x = self._forward_impl(x)
  87. x = self.dequant(x)
  88. return x
  89. def fuse_model(self, is_qat: Optional[bool] = None) -> None:
  90. r"""Fuse conv/bn/relu modules in resnet models
  91. Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
  92. Model is modified in place. Note that this operation does not change numerics
  93. and the model after modification is in floating point
  94. """
  95. _fuse_modules(self, ["conv1", "bn1", "relu"], is_qat, inplace=True)
  96. for m in self.modules():
  97. if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
  98. m.fuse_model(is_qat)
  99. def _resnet(
  100. block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
  101. layers: List[int],
  102. weights: Optional[WeightsEnum],
  103. progress: bool,
  104. quantize: bool,
  105. **kwargs: Any,
  106. ) -> QuantizableResNet:
  107. if weights is not None:
  108. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  109. if "backend" in weights.meta:
  110. _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
  111. backend = kwargs.pop("backend", "fbgemm")
  112. model = QuantizableResNet(block, layers, **kwargs)
  113. _replace_relu(model)
  114. if quantize:
  115. quantize_model(model, backend)
  116. if weights is not None:
  117. model.load_state_dict(weights.get_state_dict(progress=progress))
  118. return model
  119. _COMMON_META = {
  120. "min_size": (1, 1),
  121. "categories": _IMAGENET_CATEGORIES,
  122. "backend": "fbgemm",
  123. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
  124. "_docs": """
  125. These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
  126. weights listed below.
  127. """,
  128. }
  129. class ResNet18_QuantizedWeights(WeightsEnum):
  130. IMAGENET1K_FBGEMM_V1 = Weights(
  131. url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
  132. transforms=partial(ImageClassification, crop_size=224),
  133. meta={
  134. **_COMMON_META,
  135. "num_params": 11689512,
  136. "unquantized": ResNet18_Weights.IMAGENET1K_V1,
  137. "_metrics": {
  138. "ImageNet-1K": {
  139. "acc@1": 69.494,
  140. "acc@5": 88.882,
  141. }
  142. },
  143. },
  144. )
  145. DEFAULT = IMAGENET1K_FBGEMM_V1
  146. class ResNet50_QuantizedWeights(WeightsEnum):
  147. IMAGENET1K_FBGEMM_V1 = Weights(
  148. url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
  149. transforms=partial(ImageClassification, crop_size=224),
  150. meta={
  151. **_COMMON_META,
  152. "num_params": 25557032,
  153. "unquantized": ResNet50_Weights.IMAGENET1K_V1,
  154. "_metrics": {
  155. "ImageNet-1K": {
  156. "acc@1": 75.920,
  157. "acc@5": 92.814,
  158. }
  159. },
  160. },
  161. )
  162. IMAGENET1K_FBGEMM_V2 = Weights(
  163. url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
  164. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  165. meta={
  166. **_COMMON_META,
  167. "num_params": 25557032,
  168. "unquantized": ResNet50_Weights.IMAGENET1K_V2,
  169. "_metrics": {
  170. "ImageNet-1K": {
  171. "acc@1": 80.282,
  172. "acc@5": 94.976,
  173. }
  174. },
  175. },
  176. )
  177. DEFAULT = IMAGENET1K_FBGEMM_V2
  178. class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
  179. IMAGENET1K_FBGEMM_V1 = Weights(
  180. url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
  181. transforms=partial(ImageClassification, crop_size=224),
  182. meta={
  183. **_COMMON_META,
  184. "num_params": 88791336,
  185. "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
  186. "_metrics": {
  187. "ImageNet-1K": {
  188. "acc@1": 78.986,
  189. "acc@5": 94.480,
  190. }
  191. },
  192. },
  193. )
  194. IMAGENET1K_FBGEMM_V2 = Weights(
  195. url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
  196. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  197. meta={
  198. **_COMMON_META,
  199. "num_params": 88791336,
  200. "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
  201. "_metrics": {
  202. "ImageNet-1K": {
  203. "acc@1": 82.574,
  204. "acc@5": 96.132,
  205. }
  206. },
  207. },
  208. )
  209. DEFAULT = IMAGENET1K_FBGEMM_V2
  210. class ResNeXt101_64X4D_QuantizedWeights(WeightsEnum):
  211. IMAGENET1K_FBGEMM_V1 = Weights(
  212. url="https://download.pytorch.org/models/quantized/resnext101_64x4d_fbgemm-605a1cb3.pth",
  213. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  214. meta={
  215. **_COMMON_META,
  216. "num_params": 83455272,
  217. "recipe": "https://github.com/pytorch/vision/pull/5935",
  218. "unquantized": ResNeXt101_64X4D_Weights.IMAGENET1K_V1,
  219. "_metrics": {
  220. "ImageNet-1K": {
  221. "acc@1": 82.898,
  222. "acc@5": 96.326,
  223. }
  224. },
  225. },
  226. )
  227. DEFAULT = IMAGENET1K_FBGEMM_V1
  228. @handle_legacy_interface(
  229. weights=(
  230. "pretrained",
  231. lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  232. if kwargs.get("quantize", False)
  233. else ResNet18_Weights.IMAGENET1K_V1,
  234. )
  235. )
  236. def resnet18(
  237. *,
  238. weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None,
  239. progress: bool = True,
  240. quantize: bool = False,
  241. **kwargs: Any,
  242. ) -> QuantizableResNet:
  243. """ResNet-18 model from
  244. `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
  245. .. note::
  246. Note that ``quantize = True`` returns a quantized model with 8 bit
  247. weights. Quantized models only support inference and run on CPUs.
  248. GPU inference is not yet supported.
  249. Args:
  250. weights (:class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` or :class:`~torchvision.models.ResNet18_Weights`, optional): The
  251. pretrained weights for the model. See
  252. :class:`~torchvision.models.quantization.ResNet18_QuantizedWeights` below for
  253. more details, and possible values. By default, no pre-trained
  254. weights are used.
  255. progress (bool, optional): If True, displays a progress bar of the
  256. download to stderr. Default is True.
  257. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  258. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  259. base class. Please refer to the `source code
  260. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  261. for more details about this class.
  262. .. autoclass:: torchvision.models.quantization.ResNet18_QuantizedWeights
  263. :members:
  264. .. autoclass:: torchvision.models.ResNet18_Weights
  265. :members:
  266. :noindex:
  267. """
  268. weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights)
  269. return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
  270. @handle_legacy_interface(
  271. weights=(
  272. "pretrained",
  273. lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  274. if kwargs.get("quantize", False)
  275. else ResNet50_Weights.IMAGENET1K_V1,
  276. )
  277. )
  278. def resnet50(
  279. *,
  280. weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None,
  281. progress: bool = True,
  282. quantize: bool = False,
  283. **kwargs: Any,
  284. ) -> QuantizableResNet:
  285. """ResNet-50 model from
  286. `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_
  287. .. note::
  288. Note that ``quantize = True`` returns a quantized model with 8 bit
  289. weights. Quantized models only support inference and run on CPUs.
  290. GPU inference is not yet supported.
  291. Args:
  292. weights (:class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` or :class:`~torchvision.models.ResNet50_Weights`, optional): The
  293. pretrained weights for the model. See
  294. :class:`~torchvision.models.quantization.ResNet50_QuantizedWeights` below for
  295. more details, and possible values. By default, no pre-trained
  296. weights are used.
  297. progress (bool, optional): If True, displays a progress bar of the
  298. download to stderr. Default is True.
  299. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  300. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  301. base class. Please refer to the `source code
  302. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  303. for more details about this class.
  304. .. autoclass:: torchvision.models.quantization.ResNet50_QuantizedWeights
  305. :members:
  306. .. autoclass:: torchvision.models.ResNet50_Weights
  307. :members:
  308. :noindex:
  309. """
  310. weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights)
  311. return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
  312. @handle_legacy_interface(
  313. weights=(
  314. "pretrained",
  315. lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1
  316. if kwargs.get("quantize", False)
  317. else ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
  318. )
  319. )
  320. def resnext101_32x8d(
  321. *,
  322. weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None,
  323. progress: bool = True,
  324. quantize: bool = False,
  325. **kwargs: Any,
  326. ) -> QuantizableResNet:
  327. """ResNeXt-101 32x8d model from
  328. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
  329. .. note::
  330. Note that ``quantize = True`` returns a quantized model with 8 bit
  331. weights. Quantized models only support inference and run on CPUs.
  332. GPU inference is not yet supported.
  333. Args:
  334. weights (:class:`~torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
  335. pretrained weights for the model. See
  336. :class:`~torchvision.models.quantization.ResNet101_32X8D_QuantizedWeights` below for
  337. more details, and possible values. By default, no pre-trained
  338. weights are used.
  339. progress (bool, optional): If True, displays a progress bar of the
  340. download to stderr. Default is True.
  341. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  342. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  343. base class. Please refer to the `source code
  344. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  345. for more details about this class.
  346. .. autoclass:: torchvision.models.quantization.ResNeXt101_32X8D_QuantizedWeights
  347. :members:
  348. .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
  349. :members:
  350. :noindex:
  351. """
  352. weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights)
  353. _ovewrite_named_param(kwargs, "groups", 32)
  354. _ovewrite_named_param(kwargs, "width_per_group", 8)
  355. return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
  356. def resnext101_64x4d(
  357. *,
  358. weights: Optional[Union[ResNeXt101_64X4D_QuantizedWeights, ResNeXt101_64X4D_Weights]] = None,
  359. progress: bool = True,
  360. quantize: bool = False,
  361. **kwargs: Any,
  362. ) -> QuantizableResNet:
  363. """ResNeXt-101 64x4d model from
  364. `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_
  365. .. note::
  366. Note that ``quantize = True`` returns a quantized model with 8 bit
  367. weights. Quantized models only support inference and run on CPUs.
  368. GPU inference is not yet supported.
  369. Args:
  370. weights (:class:`~torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights` or :class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
  371. pretrained weights for the model. See
  372. :class:`~torchvision.models.quantization.ResNet101_64X4D_QuantizedWeights` below for
  373. more details, and possible values. By default, no pre-trained
  374. weights are used.
  375. progress (bool, optional): If True, displays a progress bar of the
  376. download to stderr. Default is True.
  377. quantize (bool, optional): If True, return a quantized version of the model. Default is False.
  378. **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableResNet``
  379. base class. Please refer to the `source code
  380. <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/resnet.py>`_
  381. for more details about this class.
  382. .. autoclass:: torchvision.models.quantization.ResNeXt101_64X4D_QuantizedWeights
  383. :members:
  384. .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
  385. :members:
  386. :noindex:
  387. """
  388. weights = (ResNeXt101_64X4D_QuantizedWeights if quantize else ResNeXt101_64X4D_Weights).verify(weights)
  389. _ovewrite_named_param(kwargs, "groups", 64)
  390. _ovewrite_named_param(kwargs, "width_per_group", 4)
  391. return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
  392. # The dictionary below is internal implementation detail and will be removed in v0.15
  393. from .._utils import _ModelURLs
  394. from ..resnet import model_urls # noqa: F401
  395. quant_model_urls = _ModelURLs(
  396. {
  397. "resnet18_fbgemm": ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
  398. "resnet50_fbgemm": ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
  399. "resnext101_32x8d_fbgemm": ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
  400. }
  401. )