mnasnet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. import warnings
  2. from functools import partial
  3. from typing import Any, Dict, List, Optional
  4. import torch
  5. import torch.nn as nn
  6. from torch import Tensor
  7. from ..transforms._presets import ImageClassification
  8. from ..utils import _log_api_usage_once
  9. from ._api import WeightsEnum, Weights
  10. from ._meta import _IMAGENET_CATEGORIES
  11. from ._utils import handle_legacy_interface, _ovewrite_named_param
  12. __all__ = [
  13. "MNASNet",
  14. "MNASNet0_5_Weights",
  15. "MNASNet0_75_Weights",
  16. "MNASNet1_0_Weights",
  17. "MNASNet1_3_Weights",
  18. "mnasnet0_5",
  19. "mnasnet0_75",
  20. "mnasnet1_0",
  21. "mnasnet1_3",
  22. ]
  23. # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
  24. # 1.0 - tensorflow.
  25. _BN_MOMENTUM = 1 - 0.9997
  26. class _InvertedResidual(nn.Module):
  27. def __init__(
  28. self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
  29. ) -> None:
  30. super().__init__()
  31. if stride not in [1, 2]:
  32. raise ValueError(f"stride should be 1 or 2 instead of {stride}")
  33. if kernel_size not in [3, 5]:
  34. raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
  35. mid_ch = in_ch * expansion_factor
  36. self.apply_residual = in_ch == out_ch and stride == 1
  37. self.layers = nn.Sequential(
  38. # Pointwise
  39. nn.Conv2d(in_ch, mid_ch, 1, bias=False),
  40. nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
  41. nn.ReLU(inplace=True),
  42. # Depthwise
  43. nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
  44. nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
  45. nn.ReLU(inplace=True),
  46. # Linear pointwise. Note that there's no activation.
  47. nn.Conv2d(mid_ch, out_ch, 1, bias=False),
  48. nn.BatchNorm2d(out_ch, momentum=bn_momentum),
  49. )
  50. def forward(self, input: Tensor) -> Tensor:
  51. if self.apply_residual:
  52. return self.layers(input) + input
  53. else:
  54. return self.layers(input)
  55. def _stack(
  56. in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
  57. ) -> nn.Sequential:
  58. """Creates a stack of inverted residuals."""
  59. if repeats < 1:
  60. raise ValueError(f"repeats should be >= 1, instead got {repeats}")
  61. # First one has no skip, because feature map size changes.
  62. first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
  63. remaining = []
  64. for _ in range(1, repeats):
  65. remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
  66. return nn.Sequential(first, *remaining)
  67. def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
  68. """Asymmetric rounding to make `val` divisible by `divisor`. With default
  69. bias, will round up, unless the number is no more than 10% greater than the
  70. smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
  71. if not 0.0 < round_up_bias < 1.0:
  72. raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
  73. new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
  74. return new_val if new_val >= round_up_bias * val else new_val + divisor
  75. def _get_depths(alpha: float) -> List[int]:
  76. """Scales tensor depths as in reference MobileNet code, prefers rouding up
  77. rather than down."""
  78. depths = [32, 16, 24, 40, 80, 96, 192, 320]
  79. return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
  80. class MNASNet(torch.nn.Module):
  81. """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
  82. implements the B1 variant of the model.
  83. >>> model = MNASNet(1.0, num_classes=1000)
  84. >>> x = torch.rand(1, 3, 224, 224)
  85. >>> y = model(x)
  86. >>> y.dim()
  87. 2
  88. >>> y.nelement()
  89. 1000
  90. """
  91. # Version 2 adds depth scaling in the initial stages of the network.
  92. _version = 2
  93. def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
  94. super().__init__()
  95. _log_api_usage_once(self)
  96. if alpha <= 0.0:
  97. raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
  98. self.alpha = alpha
  99. self.num_classes = num_classes
  100. depths = _get_depths(alpha)
  101. layers = [
  102. # First layer: regular conv.
  103. nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
  104. nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
  105. nn.ReLU(inplace=True),
  106. # Depthwise separable, no skip.
  107. nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
  108. nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
  109. nn.ReLU(inplace=True),
  110. nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
  111. nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
  112. # MNASNet blocks: stacks of inverted residuals.
  113. _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
  114. _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
  115. _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
  116. _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
  117. _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
  118. _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
  119. # Final mapping to classifier input.
  120. nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
  121. nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
  122. nn.ReLU(inplace=True),
  123. ]
  124. self.layers = nn.Sequential(*layers)
  125. self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
  126. for m in self.modules():
  127. if isinstance(m, nn.Conv2d):
  128. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  129. if m.bias is not None:
  130. nn.init.zeros_(m.bias)
  131. elif isinstance(m, nn.BatchNorm2d):
  132. nn.init.ones_(m.weight)
  133. nn.init.zeros_(m.bias)
  134. elif isinstance(m, nn.Linear):
  135. nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
  136. nn.init.zeros_(m.bias)
  137. def forward(self, x: Tensor) -> Tensor:
  138. x = self.layers(x)
  139. # Equivalent to global avgpool and removing H and W dimensions.
  140. x = x.mean([2, 3])
  141. return self.classifier(x)
  142. def _load_from_state_dict(
  143. self,
  144. state_dict: Dict,
  145. prefix: str,
  146. local_metadata: Dict,
  147. strict: bool,
  148. missing_keys: List[str],
  149. unexpected_keys: List[str],
  150. error_msgs: List[str],
  151. ) -> None:
  152. version = local_metadata.get("version", None)
  153. if version not in [1, 2]:
  154. raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
  155. if version == 1 and not self.alpha == 1.0:
  156. # In the initial version of the model (v1), stem was fixed-size.
  157. # All other layer configurations were the same. This will patch
  158. # the model so that it's identical to v1. Model with alpha 1.0 is
  159. # unaffected.
  160. depths = _get_depths(self.alpha)
  161. v1_stem = [
  162. nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
  163. nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
  164. nn.ReLU(inplace=True),
  165. nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
  166. nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
  167. nn.ReLU(inplace=True),
  168. nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
  169. nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
  170. _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
  171. ]
  172. for idx, layer in enumerate(v1_stem):
  173. self.layers[idx] = layer
  174. # The model is now identical to v1, and must be saved as such.
  175. self._version = 1
  176. warnings.warn(
  177. "A new version of MNASNet model has been implemented. "
  178. "Your checkpoint was saved using the previous version. "
  179. "This checkpoint will load and work as before, but "
  180. "you may want to upgrade by training a newer model or "
  181. "transfer learning from an updated ImageNet checkpoint.",
  182. UserWarning,
  183. )
  184. super()._load_from_state_dict(
  185. state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
  186. )
  187. _COMMON_META = {
  188. "min_size": (1, 1),
  189. "categories": _IMAGENET_CATEGORIES,
  190. "recipe": "https://github.com/1e100/mnasnet_trainer",
  191. }
  192. class MNASNet0_5_Weights(WeightsEnum):
  193. IMAGENET1K_V1 = Weights(
  194. url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
  195. transforms=partial(ImageClassification, crop_size=224),
  196. meta={
  197. **_COMMON_META,
  198. "num_params": 2218512,
  199. "_metrics": {
  200. "ImageNet-1K": {
  201. "acc@1": 67.734,
  202. "acc@5": 87.490,
  203. }
  204. },
  205. "_docs": """These weights reproduce closely the results of the paper.""",
  206. },
  207. )
  208. DEFAULT = IMAGENET1K_V1
  209. class MNASNet0_75_Weights(WeightsEnum):
  210. IMAGENET1K_V1 = Weights(
  211. url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
  212. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  213. meta={
  214. **_COMMON_META,
  215. "recipe": "https://github.com/pytorch/vision/pull/6019",
  216. "num_params": 3170208,
  217. "_metrics": {
  218. "ImageNet-1K": {
  219. "acc@1": 71.180,
  220. "acc@5": 90.496,
  221. }
  222. },
  223. "_docs": """
  224. These weights were trained from scratch by using TorchVision's `new training recipe
  225. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  226. """,
  227. },
  228. )
  229. DEFAULT = IMAGENET1K_V1
  230. class MNASNet1_0_Weights(WeightsEnum):
  231. IMAGENET1K_V1 = Weights(
  232. url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
  233. transforms=partial(ImageClassification, crop_size=224),
  234. meta={
  235. **_COMMON_META,
  236. "num_params": 4383312,
  237. "_metrics": {
  238. "ImageNet-1K": {
  239. "acc@1": 73.456,
  240. "acc@5": 91.510,
  241. }
  242. },
  243. "_docs": """These weights reproduce closely the results of the paper.""",
  244. },
  245. )
  246. DEFAULT = IMAGENET1K_V1
  247. class MNASNet1_3_Weights(WeightsEnum):
  248. IMAGENET1K_V1 = Weights(
  249. url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
  250. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  251. meta={
  252. **_COMMON_META,
  253. "recipe": "https://github.com/pytorch/vision/pull/6019",
  254. "num_params": 6282256,
  255. "_metrics": {
  256. "ImageNet-1K": {
  257. "acc@1": 76.506,
  258. "acc@5": 93.522,
  259. }
  260. },
  261. "_docs": """
  262. These weights were trained from scratch by using TorchVision's `new training recipe
  263. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  264. """,
  265. },
  266. )
  267. DEFAULT = IMAGENET1K_V1
  268. def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
  269. if weights is not None:
  270. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  271. model = MNASNet(alpha, **kwargs)
  272. if weights:
  273. model.load_state_dict(weights.get_state_dict(progress=progress))
  274. return model
  275. @handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
  276. def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  277. """MNASNet with depth multiplier of 0.5 from
  278. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  279. <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
  280. Args:
  281. weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The
  282. pretrained weights to use. See
  283. :class:`~torchvision.models.MNASNet0_5_Weights` below for
  284. more details, and possible values. By default, no pre-trained
  285. weights are used.
  286. progress (bool, optional): If True, displays a progress bar of the
  287. download to stderr. Default is True.
  288. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  289. base class. Please refer to the `source code
  290. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  291. for more details about this class.
  292. .. autoclass:: torchvision.models.MNASNet0_5_Weights
  293. :members:
  294. """
  295. weights = MNASNet0_5_Weights.verify(weights)
  296. return _mnasnet(0.5, weights, progress, **kwargs)
  297. @handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
  298. def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  299. """MNASNet with depth multiplier of 0.75 from
  300. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  301. <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
  302. Args:
  303. weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
  304. pretrained weights to use. See
  305. :class:`~torchvision.models.MNASNet0_75_Weights` below for
  306. more details, and possible values. By default, no pre-trained
  307. weights are used.
  308. progress (bool, optional): If True, displays a progress bar of the
  309. download to stderr. Default is True.
  310. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  311. base class. Please refer to the `source code
  312. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  313. for more details about this class.
  314. .. autoclass:: torchvision.models.MNASNet0_75_Weights
  315. :members:
  316. """
  317. weights = MNASNet0_75_Weights.verify(weights)
  318. return _mnasnet(0.75, weights, progress, **kwargs)
  319. @handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
  320. def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  321. """MNASNet with depth multiplier of 1.0 from
  322. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  323. <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
  324. Args:
  325. weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The
  326. pretrained weights to use. See
  327. :class:`~torchvision.models.MNASNet1_0_Weights` below for
  328. more details, and possible values. By default, no pre-trained
  329. weights are used.
  330. progress (bool, optional): If True, displays a progress bar of the
  331. download to stderr. Default is True.
  332. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  333. base class. Please refer to the `source code
  334. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  335. for more details about this class.
  336. .. autoclass:: torchvision.models.MNASNet1_0_Weights
  337. :members:
  338. """
  339. weights = MNASNet1_0_Weights.verify(weights)
  340. return _mnasnet(1.0, weights, progress, **kwargs)
  341. @handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
  342. def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
  343. """MNASNet with depth multiplier of 1.3 from
  344. `MnasNet: Platform-Aware Neural Architecture Search for Mobile
  345. <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
  346. Args:
  347. weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
  348. pretrained weights to use. See
  349. :class:`~torchvision.models.MNASNet1_3_Weights` below for
  350. more details, and possible values. By default, no pre-trained
  351. weights are used.
  352. progress (bool, optional): If True, displays a progress bar of the
  353. download to stderr. Default is True.
  354. **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
  355. base class. Please refer to the `source code
  356. <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
  357. for more details about this class.
  358. .. autoclass:: torchvision.models.MNASNet1_3_Weights
  359. :members:
  360. """
  361. weights = MNASNet1_3_Weights.verify(weights)
  362. return _mnasnet(1.3, weights, progress, **kwargs)