squeezenet.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. from functools import partial
  2. from typing import Any, Optional
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.init as init
  6. from ..transforms._presets import ImageClassification
  7. from ..utils import _log_api_usage_once
  8. from ._api import WeightsEnum, Weights
  9. from ._meta import _IMAGENET_CATEGORIES
  10. from ._utils import handle_legacy_interface, _ovewrite_named_param
  11. __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
  12. class Fire(nn.Module):
  13. def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
  14. super().__init__()
  15. self.inplanes = inplanes
  16. self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
  17. self.squeeze_activation = nn.ReLU(inplace=True)
  18. self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
  19. self.expand1x1_activation = nn.ReLU(inplace=True)
  20. self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
  21. self.expand3x3_activation = nn.ReLU(inplace=True)
  22. def forward(self, x: torch.Tensor) -> torch.Tensor:
  23. x = self.squeeze_activation(self.squeeze(x))
  24. return torch.cat(
  25. [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
  26. )
  27. class SqueezeNet(nn.Module):
  28. def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
  29. super().__init__()
  30. _log_api_usage_once(self)
  31. self.num_classes = num_classes
  32. if version == "1_0":
  33. self.features = nn.Sequential(
  34. nn.Conv2d(3, 96, kernel_size=7, stride=2),
  35. nn.ReLU(inplace=True),
  36. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  37. Fire(96, 16, 64, 64),
  38. Fire(128, 16, 64, 64),
  39. Fire(128, 32, 128, 128),
  40. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  41. Fire(256, 32, 128, 128),
  42. Fire(256, 48, 192, 192),
  43. Fire(384, 48, 192, 192),
  44. Fire(384, 64, 256, 256),
  45. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  46. Fire(512, 64, 256, 256),
  47. )
  48. elif version == "1_1":
  49. self.features = nn.Sequential(
  50. nn.Conv2d(3, 64, kernel_size=3, stride=2),
  51. nn.ReLU(inplace=True),
  52. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  53. Fire(64, 16, 64, 64),
  54. Fire(128, 16, 64, 64),
  55. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  56. Fire(128, 32, 128, 128),
  57. Fire(256, 32, 128, 128),
  58. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  59. Fire(256, 48, 192, 192),
  60. Fire(384, 48, 192, 192),
  61. Fire(384, 64, 256, 256),
  62. Fire(512, 64, 256, 256),
  63. )
  64. else:
  65. # FIXME: Is this needed? SqueezeNet should only be called from the
  66. # FIXME: squeezenet1_x() functions
  67. # FIXME: This checking is not done for the other models
  68. raise ValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")
  69. # Final convolution is initialized differently from the rest
  70. final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
  71. self.classifier = nn.Sequential(
  72. nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
  73. )
  74. for m in self.modules():
  75. if isinstance(m, nn.Conv2d):
  76. if m is final_conv:
  77. init.normal_(m.weight, mean=0.0, std=0.01)
  78. else:
  79. init.kaiming_uniform_(m.weight)
  80. if m.bias is not None:
  81. init.constant_(m.bias, 0)
  82. def forward(self, x: torch.Tensor) -> torch.Tensor:
  83. x = self.features(x)
  84. x = self.classifier(x)
  85. return torch.flatten(x, 1)
  86. def _squeezenet(
  87. version: str,
  88. weights: Optional[WeightsEnum],
  89. progress: bool,
  90. **kwargs: Any,
  91. ) -> SqueezeNet:
  92. if weights is not None:
  93. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  94. model = SqueezeNet(version, **kwargs)
  95. if weights is not None:
  96. model.load_state_dict(weights.get_state_dict(progress=progress))
  97. return model
  98. _COMMON_META = {
  99. "categories": _IMAGENET_CATEGORIES,
  100. "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
  101. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  102. }
  103. class SqueezeNet1_0_Weights(WeightsEnum):
  104. IMAGENET1K_V1 = Weights(
  105. url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
  106. transforms=partial(ImageClassification, crop_size=224),
  107. meta={
  108. **_COMMON_META,
  109. "min_size": (21, 21),
  110. "num_params": 1248424,
  111. "_metrics": {
  112. "ImageNet-1K": {
  113. "acc@1": 58.092,
  114. "acc@5": 80.420,
  115. }
  116. },
  117. },
  118. )
  119. DEFAULT = IMAGENET1K_V1
  120. class SqueezeNet1_1_Weights(WeightsEnum):
  121. IMAGENET1K_V1 = Weights(
  122. url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
  123. transforms=partial(ImageClassification, crop_size=224),
  124. meta={
  125. **_COMMON_META,
  126. "min_size": (17, 17),
  127. "num_params": 1235496,
  128. "_metrics": {
  129. "ImageNet-1K": {
  130. "acc@1": 58.178,
  131. "acc@5": 80.624,
  132. }
  133. },
  134. },
  135. )
  136. DEFAULT = IMAGENET1K_V1
  137. @handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
  138. def squeezenet1_0(
  139. *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
  140. ) -> SqueezeNet:
  141. """SqueezeNet model architecture from the `SqueezeNet: AlexNet-level
  142. accuracy with 50x fewer parameters and <0.5MB model size
  143. <https://arxiv.org/abs/1602.07360>`_ paper.
  144. Args:
  145. weights (:class:`~torchvision.models.SqueezeNet1_0_Weights`, optional): The
  146. pretrained weights to use. See
  147. :class:`~torchvision.models.SqueezeNet1_0_Weights` below for
  148. more details, and possible values. By default, no pre-trained
  149. weights are used.
  150. progress (bool, optional): If True, displays a progress bar of the
  151. download to stderr. Default is True.
  152. **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
  153. base class. Please refer to the `source code
  154. <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
  155. for more details about this class.
  156. .. autoclass:: torchvision.models.SqueezeNet1_0_Weights
  157. :members:
  158. """
  159. weights = SqueezeNet1_0_Weights.verify(weights)
  160. return _squeezenet("1_0", weights, progress, **kwargs)
  161. @handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
  162. def squeezenet1_1(
  163. *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
  164. ) -> SqueezeNet:
  165. """SqueezeNet 1.1 model from the `official SqueezeNet repo
  166. <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
  167. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
  168. than SqueezeNet 1.0, without sacrificing accuracy.
  169. Args:
  170. weights (:class:`~torchvision.models.SqueezeNet1_1_Weights`, optional): The
  171. pretrained weights to use. See
  172. :class:`~torchvision.models.SqueezeNet1_1_Weights` below for
  173. more details, and possible values. By default, no pre-trained
  174. weights are used.
  175. progress (bool, optional): If True, displays a progress bar of the
  176. download to stderr. Default is True.
  177. **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
  178. base class. Please refer to the `source code
  179. <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
  180. for more details about this class.
  181. .. autoclass:: torchvision.models.SqueezeNet1_1_Weights
  182. :members:
  183. """
  184. weights = SqueezeNet1_1_Weights.verify(weights)
  185. return _squeezenet("1_1", weights, progress, **kwargs)
  186. # The dictionary below is internal implementation detail and will be removed in v0.15
  187. from ._utils import _ModelURLs
  188. model_urls = _ModelURLs(
  189. {
  190. "squeezenet1_0": SqueezeNet1_0_Weights.IMAGENET1K_V1.url,
  191. "squeezenet1_1": SqueezeNet1_1_Weights.IMAGENET1K_V1.url,
  192. }
  193. )