mobilenetv2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import warnings
  2. from functools import partial
  3. from typing import Callable, Any, Optional, List
  4. import torch
  5. from torch import Tensor
  6. from torch import nn
  7. from ..ops.misc import Conv2dNormActivation
  8. from ..transforms._presets import ImageClassification
  9. from ..utils import _log_api_usage_once
  10. from ._api import WeightsEnum, Weights
  11. from ._meta import _IMAGENET_CATEGORIES
  12. from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible
  13. __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
  14. # necessary for backwards compatibility
  15. class _DeprecatedConvBNAct(Conv2dNormActivation):
  16. def __init__(self, *args, **kwargs):
  17. warnings.warn(
  18. "The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. "
  19. "Use torchvision.ops.misc.Conv2dNormActivation instead.",
  20. FutureWarning,
  21. )
  22. if kwargs.get("norm_layer", None) is None:
  23. kwargs["norm_layer"] = nn.BatchNorm2d
  24. if kwargs.get("activation_layer", None) is None:
  25. kwargs["activation_layer"] = nn.ReLU6
  26. super().__init__(*args, **kwargs)
  27. ConvBNReLU = _DeprecatedConvBNAct
  28. ConvBNActivation = _DeprecatedConvBNAct
  29. class InvertedResidual(nn.Module):
  30. def __init__(
  31. self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
  32. ) -> None:
  33. super().__init__()
  34. self.stride = stride
  35. if stride not in [1, 2]:
  36. raise ValueError(f"stride should be 1 or 2 insted of {stride}")
  37. if norm_layer is None:
  38. norm_layer = nn.BatchNorm2d
  39. hidden_dim = int(round(inp * expand_ratio))
  40. self.use_res_connect = self.stride == 1 and inp == oup
  41. layers: List[nn.Module] = []
  42. if expand_ratio != 1:
  43. # pw
  44. layers.append(
  45. Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
  46. )
  47. layers.extend(
  48. [
  49. # dw
  50. Conv2dNormActivation(
  51. hidden_dim,
  52. hidden_dim,
  53. stride=stride,
  54. groups=hidden_dim,
  55. norm_layer=norm_layer,
  56. activation_layer=nn.ReLU6,
  57. ),
  58. # pw-linear
  59. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  60. norm_layer(oup),
  61. ]
  62. )
  63. self.conv = nn.Sequential(*layers)
  64. self.out_channels = oup
  65. self._is_cn = stride > 1
  66. def forward(self, x: Tensor) -> Tensor:
  67. if self.use_res_connect:
  68. return x + self.conv(x)
  69. else:
  70. return self.conv(x)
  71. class MobileNetV2(nn.Module):
  72. def __init__(
  73. self,
  74. num_classes: int = 1000,
  75. width_mult: float = 1.0,
  76. inverted_residual_setting: Optional[List[List[int]]] = None,
  77. round_nearest: int = 8,
  78. block: Optional[Callable[..., nn.Module]] = None,
  79. norm_layer: Optional[Callable[..., nn.Module]] = None,
  80. dropout: float = 0.2,
  81. ) -> None:
  82. """
  83. MobileNet V2 main class
  84. Args:
  85. num_classes (int): Number of classes
  86. width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
  87. inverted_residual_setting: Network structure
  88. round_nearest (int): Round the number of channels in each layer to be a multiple of this number
  89. Set to 1 to turn off rounding
  90. block: Module specifying inverted residual building block for mobilenet
  91. norm_layer: Module specifying the normalization layer to use
  92. dropout (float): The droupout probability
  93. """
  94. super().__init__()
  95. _log_api_usage_once(self)
  96. if block is None:
  97. block = InvertedResidual
  98. if norm_layer is None:
  99. norm_layer = nn.BatchNorm2d
  100. input_channel = 32
  101. last_channel = 1280
  102. if inverted_residual_setting is None:
  103. inverted_residual_setting = [
  104. # t, c, n, s
  105. [1, 16, 1, 1],
  106. [6, 24, 2, 2],
  107. [6, 32, 3, 2],
  108. [6, 64, 4, 2],
  109. [6, 96, 3, 1],
  110. [6, 160, 3, 2],
  111. [6, 320, 1, 1],
  112. ]
  113. # only check the first element, assuming user knows t,c,n,s are required
  114. if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
  115. raise ValueError(
  116. f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
  117. )
  118. # building first layer
  119. input_channel = _make_divisible(input_channel * width_mult, round_nearest)
  120. self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
  121. features: List[nn.Module] = [
  122. Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
  123. ]
  124. # building inverted residual blocks
  125. for t, c, n, s in inverted_residual_setting:
  126. output_channel = _make_divisible(c * width_mult, round_nearest)
  127. for i in range(n):
  128. stride = s if i == 0 else 1
  129. features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
  130. input_channel = output_channel
  131. # building last several layers
  132. features.append(
  133. Conv2dNormActivation(
  134. input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
  135. )
  136. )
  137. # make it nn.Sequential
  138. self.features = nn.Sequential(*features)
  139. # building classifier
  140. self.classifier = nn.Sequential(
  141. nn.Dropout(p=dropout),
  142. nn.Linear(self.last_channel, num_classes),
  143. )
  144. # weight initialization
  145. for m in self.modules():
  146. if isinstance(m, nn.Conv2d):
  147. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  148. if m.bias is not None:
  149. nn.init.zeros_(m.bias)
  150. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  151. nn.init.ones_(m.weight)
  152. nn.init.zeros_(m.bias)
  153. elif isinstance(m, nn.Linear):
  154. nn.init.normal_(m.weight, 0, 0.01)
  155. nn.init.zeros_(m.bias)
  156. def _forward_impl(self, x: Tensor) -> Tensor:
  157. # This exists since TorchScript doesn't support inheritance, so the superclass method
  158. # (this one) needs to have a name other than `forward` that can be accessed in a subclass
  159. x = self.features(x)
  160. # Cannot use "squeeze" as batch-size can be 1
  161. x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
  162. x = torch.flatten(x, 1)
  163. x = self.classifier(x)
  164. return x
  165. def forward(self, x: Tensor) -> Tensor:
  166. return self._forward_impl(x)
  167. _COMMON_META = {
  168. "num_params": 3504872,
  169. "min_size": (1, 1),
  170. "categories": _IMAGENET_CATEGORIES,
  171. }
  172. class MobileNet_V2_Weights(WeightsEnum):
  173. IMAGENET1K_V1 = Weights(
  174. url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
  175. transforms=partial(ImageClassification, crop_size=224),
  176. meta={
  177. **_COMMON_META,
  178. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
  179. "_metrics": {
  180. "ImageNet-1K": {
  181. "acc@1": 71.878,
  182. "acc@5": 90.286,
  183. }
  184. },
  185. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  186. },
  187. )
  188. IMAGENET1K_V2 = Weights(
  189. url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
  190. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  191. meta={
  192. **_COMMON_META,
  193. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
  194. "_metrics": {
  195. "ImageNet-1K": {
  196. "acc@1": 72.154,
  197. "acc@5": 90.822,
  198. }
  199. },
  200. "_docs": """
  201. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  202. `new training recipe
  203. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  204. """,
  205. },
  206. )
  207. DEFAULT = IMAGENET1K_V2
  208. @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
  209. def mobilenet_v2(
  210. *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
  211. ) -> MobileNetV2:
  212. """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
  213. Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
  214. Args:
  215. weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
  216. pretrained weights to use. See
  217. :class:`~torchvision.models.MobileNet_V2_Weights` below for
  218. more details, and possible values. By default, no pre-trained
  219. weights are used.
  220. progress (bool, optional): If True, displays a progress bar of the
  221. download to stderr. Default is True.
  222. **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
  223. base class. Please refer to the `source code
  224. <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
  225. for more details about this class.
  226. .. autoclass:: torchvision.models.MobileNet_V2_Weights
  227. :members:
  228. """
  229. weights = MobileNet_V2_Weights.verify(weights)
  230. if weights is not None:
  231. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  232. model = MobileNetV2(**kwargs)
  233. if weights is not None:
  234. model.load_state_dict(weights.get_state_dict(progress=progress))
  235. return model
  236. # The dictionary below is internal implementation detail and will be removed in v0.15
  237. from ._utils import _ModelURLs
  238. model_urls = _ModelURLs(
  239. {
  240. "mobilenet_v2": MobileNet_V2_Weights.IMAGENET1K_V1.url,
  241. }
  242. )