deeplabv3.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from functools import partial
  2. from typing import Any, List, Optional
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from ...transforms._presets import SemanticSegmentation
  7. from .._api import WeightsEnum, Weights
  8. from .._meta import _VOC_CATEGORIES
  9. from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param
  10. from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large
  11. from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights
  12. from ._utils import _SimpleSegmentationModel
  13. from .fcn import FCNHead
  14. __all__ = [
  15. "DeepLabV3",
  16. "DeepLabV3_ResNet50_Weights",
  17. "DeepLabV3_ResNet101_Weights",
  18. "DeepLabV3_MobileNet_V3_Large_Weights",
  19. "deeplabv3_mobilenet_v3_large",
  20. "deeplabv3_resnet50",
  21. "deeplabv3_resnet101",
  22. ]
  23. class DeepLabV3(_SimpleSegmentationModel):
  24. """
  25. Implements DeepLabV3 model from
  26. `"Rethinking Atrous Convolution for Semantic Image Segmentation"
  27. <https://arxiv.org/abs/1706.05587>`_.
  28. Args:
  29. backbone (nn.Module): the network used to compute the features for the model.
  30. The backbone should return an OrderedDict[Tensor], with the key being
  31. "out" for the last feature map used, and "aux" if an auxiliary classifier
  32. is used.
  33. classifier (nn.Module): module that takes the "out" element returned from
  34. the backbone and returns a dense prediction.
  35. aux_classifier (nn.Module, optional): auxiliary classifier used during training
  36. """
  37. pass
  38. class DeepLabHead(nn.Sequential):
  39. def __init__(self, in_channels: int, num_classes: int) -> None:
  40. super().__init__(
  41. ASPP(in_channels, [12, 24, 36]),
  42. nn.Conv2d(256, 256, 3, padding=1, bias=False),
  43. nn.BatchNorm2d(256),
  44. nn.ReLU(),
  45. nn.Conv2d(256, num_classes, 1),
  46. )
  47. class ASPPConv(nn.Sequential):
  48. def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
  49. modules = [
  50. nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
  51. nn.BatchNorm2d(out_channels),
  52. nn.ReLU(),
  53. ]
  54. super().__init__(*modules)
  55. class ASPPPooling(nn.Sequential):
  56. def __init__(self, in_channels: int, out_channels: int) -> None:
  57. super().__init__(
  58. nn.AdaptiveAvgPool2d(1),
  59. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  60. nn.BatchNorm2d(out_channels),
  61. nn.ReLU(),
  62. )
  63. def forward(self, x: torch.Tensor) -> torch.Tensor:
  64. size = x.shape[-2:]
  65. for mod in self:
  66. x = mod(x)
  67. return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
  68. class ASPP(nn.Module):
  69. def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
  70. super().__init__()
  71. modules = []
  72. modules.append(
  73. nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
  74. )
  75. rates = tuple(atrous_rates)
  76. for rate in rates:
  77. modules.append(ASPPConv(in_channels, out_channels, rate))
  78. modules.append(ASPPPooling(in_channels, out_channels))
  79. self.convs = nn.ModuleList(modules)
  80. self.project = nn.Sequential(
  81. nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
  82. nn.BatchNorm2d(out_channels),
  83. nn.ReLU(),
  84. nn.Dropout(0.5),
  85. )
  86. def forward(self, x: torch.Tensor) -> torch.Tensor:
  87. _res = []
  88. for conv in self.convs:
  89. _res.append(conv(x))
  90. res = torch.cat(_res, dim=1)
  91. return self.project(res)
  92. def _deeplabv3_resnet(
  93. backbone: ResNet,
  94. num_classes: int,
  95. aux: Optional[bool],
  96. ) -> DeepLabV3:
  97. return_layers = {"layer4": "out"}
  98. if aux:
  99. return_layers["layer3"] = "aux"
  100. backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
  101. aux_classifier = FCNHead(1024, num_classes) if aux else None
  102. classifier = DeepLabHead(2048, num_classes)
  103. return DeepLabV3(backbone, classifier, aux_classifier)
  104. _COMMON_META = {
  105. "categories": _VOC_CATEGORIES,
  106. "min_size": (1, 1),
  107. "_docs": """
  108. These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC
  109. dataset.
  110. """,
  111. }
  112. class DeepLabV3_ResNet50_Weights(WeightsEnum):
  113. COCO_WITH_VOC_LABELS_V1 = Weights(
  114. url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
  115. transforms=partial(SemanticSegmentation, resize_size=520),
  116. meta={
  117. **_COMMON_META,
  118. "num_params": 42004074,
  119. "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
  120. "_metrics": {
  121. "COCO-val2017-VOC-labels": {
  122. "miou": 66.4,
  123. "pixel_acc": 92.4,
  124. }
  125. },
  126. },
  127. )
  128. DEFAULT = COCO_WITH_VOC_LABELS_V1
  129. class DeepLabV3_ResNet101_Weights(WeightsEnum):
  130. COCO_WITH_VOC_LABELS_V1 = Weights(
  131. url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
  132. transforms=partial(SemanticSegmentation, resize_size=520),
  133. meta={
  134. **_COMMON_META,
  135. "num_params": 60996202,
  136. "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
  137. "_metrics": {
  138. "COCO-val2017-VOC-labels": {
  139. "miou": 67.4,
  140. "pixel_acc": 92.4,
  141. }
  142. },
  143. },
  144. )
  145. DEFAULT = COCO_WITH_VOC_LABELS_V1
  146. class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
  147. COCO_WITH_VOC_LABELS_V1 = Weights(
  148. url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
  149. transforms=partial(SemanticSegmentation, resize_size=520),
  150. meta={
  151. **_COMMON_META,
  152. "num_params": 11029328,
  153. "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
  154. "_metrics": {
  155. "COCO-val2017-VOC-labels": {
  156. "miou": 60.3,
  157. "pixel_acc": 91.2,
  158. }
  159. },
  160. },
  161. )
  162. DEFAULT = COCO_WITH_VOC_LABELS_V1
  163. def _deeplabv3_mobilenetv3(
  164. backbone: MobileNetV3,
  165. num_classes: int,
  166. aux: Optional[bool],
  167. ) -> DeepLabV3:
  168. backbone = backbone.features
  169. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
  170. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
  171. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
  172. out_pos = stage_indices[-1] # use C5 which has output_stride = 16
  173. out_inplanes = backbone[out_pos].out_channels
  174. aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
  175. aux_inplanes = backbone[aux_pos].out_channels
  176. return_layers = {str(out_pos): "out"}
  177. if aux:
  178. return_layers[str(aux_pos)] = "aux"
  179. backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
  180. aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
  181. classifier = DeepLabHead(out_inplanes, num_classes)
  182. return DeepLabV3(backbone, classifier, aux_classifier)
  183. @handle_legacy_interface(
  184. weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1),
  185. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  186. )
  187. def deeplabv3_resnet50(
  188. *,
  189. weights: Optional[DeepLabV3_ResNet50_Weights] = None,
  190. progress: bool = True,
  191. num_classes: Optional[int] = None,
  192. aux_loss: Optional[bool] = None,
  193. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  194. **kwargs: Any,
  195. ) -> DeepLabV3:
  196. """Constructs a DeepLabV3 model with a ResNet-50 backbone.
  197. .. betastatus:: segmentation module
  198. Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
  199. Args:
  200. weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The
  201. pretrained weights to use. See
  202. :class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for
  203. more details, and possible values. By default, no pre-trained
  204. weights are used.
  205. progress (bool, optional): If True, displays a progress bar of the
  206. download to stderr. Default is True.
  207. num_classes (int, optional): number of output classes of the model (including the background)
  208. aux_loss (bool, optional): If True, it uses an auxiliary loss
  209. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the
  210. backbone
  211. **kwargs: unused
  212. .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights
  213. :members:
  214. """
  215. weights = DeepLabV3_ResNet50_Weights.verify(weights)
  216. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  217. if weights is not None:
  218. weights_backbone = None
  219. num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
  220. aux_loss = _ovewrite_value_param(aux_loss, True)
  221. elif num_classes is None:
  222. num_classes = 21
  223. backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
  224. model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
  225. if weights is not None:
  226. model.load_state_dict(weights.get_state_dict(progress=progress))
  227. return model
  228. @handle_legacy_interface(
  229. weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1),
  230. weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1),
  231. )
  232. def deeplabv3_resnet101(
  233. *,
  234. weights: Optional[DeepLabV3_ResNet101_Weights] = None,
  235. progress: bool = True,
  236. num_classes: Optional[int] = None,
  237. aux_loss: Optional[bool] = None,
  238. weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1,
  239. **kwargs: Any,
  240. ) -> DeepLabV3:
  241. """Constructs a DeepLabV3 model with a ResNet-101 backbone.
  242. .. betastatus:: segmentation module
  243. Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
  244. Args:
  245. weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The
  246. pretrained weights to use. See
  247. :class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for
  248. more details, and possible values. By default, no pre-trained
  249. weights are used.
  250. progress (bool, optional): If True, displays a progress bar of the
  251. download to stderr. Default is True.
  252. num_classes (int, optional): number of output classes of the model (including the background)
  253. aux_loss (bool, optional): If True, it uses an auxiliary loss
  254. weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the
  255. backbone
  256. **kwargs: unused
  257. .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights
  258. :members:
  259. """
  260. weights = DeepLabV3_ResNet101_Weights.verify(weights)
  261. weights_backbone = ResNet101_Weights.verify(weights_backbone)
  262. if weights is not None:
  263. weights_backbone = None
  264. num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
  265. aux_loss = _ovewrite_value_param(aux_loss, True)
  266. elif num_classes is None:
  267. num_classes = 21
  268. backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
  269. model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
  270. if weights is not None:
  271. model.load_state_dict(weights.get_state_dict(progress=progress))
  272. return model
  273. @handle_legacy_interface(
  274. weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
  275. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  276. )
  277. def deeplabv3_mobilenet_v3_large(
  278. *,
  279. weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None,
  280. progress: bool = True,
  281. num_classes: Optional[int] = None,
  282. aux_loss: Optional[bool] = None,
  283. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  284. **kwargs: Any,
  285. ) -> DeepLabV3:
  286. """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
  287. Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.
  288. Args:
  289. weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The
  290. pretrained weights to use. See
  291. :class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for
  292. more details, and possible values. By default, no pre-trained
  293. weights are used.
  294. progress (bool, optional): If True, displays a progress bar of the
  295. download to stderr. Default is True.
  296. num_classes (int, optional): number of output classes of the model (including the background)
  297. aux_loss (bool, optional): If True, it uses an auxiliary loss
  298. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights
  299. for the backbone
  300. **kwargs: unused
  301. .. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights
  302. :members:
  303. """
  304. weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights)
  305. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  306. if weights is not None:
  307. weights_backbone = None
  308. num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
  309. aux_loss = _ovewrite_value_param(aux_loss, True)
  310. elif num_classes is None:
  311. num_classes = 21
  312. backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
  313. model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
  314. if weights is not None:
  315. model.load_state_dict(weights.get_state_dict(progress=progress))
  316. return model
  317. # The dictionary below is internal implementation detail and will be removed in v0.15
  318. from .._utils import _ModelURLs
  319. model_urls = _ModelURLs(
  320. {
  321. "deeplabv3_resnet50_coco": DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.url,
  322. "deeplabv3_resnet101_coco": DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1.url,
  323. "deeplabv3_mobilenet_v3_large_coco": DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1.url,
  324. }
  325. )