retinanet.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. import math
  2. import warnings
  3. from collections import OrderedDict
  4. from functools import partial
  5. from typing import Any, Callable, Dict, List, Tuple, Optional
  6. import torch
  7. from torch import nn, Tensor
  8. from ...ops import sigmoid_focal_loss
  9. from ...ops import boxes as box_ops
  10. from ...ops import misc as misc_nn_ops
  11. from ...ops.feature_pyramid_network import LastLevelP6P7
  12. from ...transforms._presets import ObjectDetection
  13. from ...utils import _log_api_usage_once
  14. from .._api import WeightsEnum, Weights
  15. from .._meta import _COCO_CATEGORIES
  16. from .._utils import handle_legacy_interface, _ovewrite_value_param
  17. from ..resnet import ResNet50_Weights, resnet50
  18. from . import _utils as det_utils
  19. from ._utils import overwrite_eps, _box_loss
  20. from .anchor_utils import AnchorGenerator
  21. from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  22. from .transform import GeneralizedRCNNTransform
  23. __all__ = [
  24. "RetinaNet",
  25. "RetinaNet_ResNet50_FPN_Weights",
  26. "RetinaNet_ResNet50_FPN_V2_Weights",
  27. "retinanet_resnet50_fpn",
  28. "retinanet_resnet50_fpn_v2",
  29. ]
  30. def _sum(x: List[Tensor]) -> Tensor:
  31. res = x[0]
  32. for i in x[1:]:
  33. res = res + i
  34. return res
  35. def _v1_to_v2_weights(state_dict, prefix):
  36. for i in range(4):
  37. for type in ["weight", "bias"]:
  38. old_key = f"{prefix}conv.{2*i}.{type}"
  39. new_key = f"{prefix}conv.{i}.0.{type}"
  40. if old_key in state_dict:
  41. state_dict[new_key] = state_dict.pop(old_key)
  42. def _default_anchorgen():
  43. anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
  44. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  45. anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  46. return anchor_generator
  47. class RetinaNetHead(nn.Module):
  48. """
  49. A regression and classification head for use in RetinaNet.
  50. Args:
  51. in_channels (int): number of channels of the input feature
  52. num_anchors (int): number of anchors to be predicted
  53. num_classes (int): number of classes to be predicted
  54. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  55. """
  56. def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
  57. super().__init__()
  58. self.classification_head = RetinaNetClassificationHead(
  59. in_channels, num_anchors, num_classes, norm_layer=norm_layer
  60. )
  61. self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
  62. def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
  63. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
  64. return {
  65. "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
  66. "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
  67. }
  68. def forward(self, x):
  69. # type: (List[Tensor]) -> Dict[str, Tensor]
  70. return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
  71. class RetinaNetClassificationHead(nn.Module):
  72. """
  73. A classification head for use in RetinaNet.
  74. Args:
  75. in_channels (int): number of channels of the input feature
  76. num_anchors (int): number of anchors to be predicted
  77. num_classes (int): number of classes to be predicted
  78. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  79. """
  80. _version = 2
  81. def __init__(
  82. self,
  83. in_channels,
  84. num_anchors,
  85. num_classes,
  86. prior_probability=0.01,
  87. norm_layer: Optional[Callable[..., nn.Module]] = None,
  88. ):
  89. super().__init__()
  90. conv = []
  91. for _ in range(4):
  92. conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
  93. self.conv = nn.Sequential(*conv)
  94. for layer in self.conv.modules():
  95. if isinstance(layer, nn.Conv2d):
  96. torch.nn.init.normal_(layer.weight, std=0.01)
  97. if layer.bias is not None:
  98. torch.nn.init.constant_(layer.bias, 0)
  99. self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
  100. torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
  101. torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
  102. self.num_classes = num_classes
  103. self.num_anchors = num_anchors
  104. # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
  105. # TorchScript doesn't support class attributes.
  106. # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
  107. self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
  108. def _load_from_state_dict(
  109. self,
  110. state_dict,
  111. prefix,
  112. local_metadata,
  113. strict,
  114. missing_keys,
  115. unexpected_keys,
  116. error_msgs,
  117. ):
  118. version = local_metadata.get("version", None)
  119. if version is None or version < 2:
  120. _v1_to_v2_weights(state_dict, prefix)
  121. super()._load_from_state_dict(
  122. state_dict,
  123. prefix,
  124. local_metadata,
  125. strict,
  126. missing_keys,
  127. unexpected_keys,
  128. error_msgs,
  129. )
  130. def compute_loss(self, targets, head_outputs, matched_idxs):
  131. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
  132. losses = []
  133. cls_logits = head_outputs["cls_logits"]
  134. for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
  135. # determine only the foreground
  136. foreground_idxs_per_image = matched_idxs_per_image >= 0
  137. num_foreground = foreground_idxs_per_image.sum()
  138. # create the target classification
  139. gt_classes_target = torch.zeros_like(cls_logits_per_image)
  140. gt_classes_target[
  141. foreground_idxs_per_image,
  142. targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
  143. ] = 1.0
  144. # find indices for which anchors should be ignored
  145. valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
  146. # compute the classification loss
  147. losses.append(
  148. sigmoid_focal_loss(
  149. cls_logits_per_image[valid_idxs_per_image],
  150. gt_classes_target[valid_idxs_per_image],
  151. reduction="sum",
  152. )
  153. / max(1, num_foreground)
  154. )
  155. return _sum(losses) / len(targets)
  156. def forward(self, x):
  157. # type: (List[Tensor]) -> Tensor
  158. all_cls_logits = []
  159. for features in x:
  160. cls_logits = self.conv(features)
  161. cls_logits = self.cls_logits(cls_logits)
  162. # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
  163. N, _, H, W = cls_logits.shape
  164. cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
  165. cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
  166. cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
  167. all_cls_logits.append(cls_logits)
  168. return torch.cat(all_cls_logits, dim=1)
  169. class RetinaNetRegressionHead(nn.Module):
  170. """
  171. A regression head for use in RetinaNet.
  172. Args:
  173. in_channels (int): number of channels of the input feature
  174. num_anchors (int): number of anchors to be predicted
  175. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  176. """
  177. _version = 2
  178. __annotations__ = {
  179. "box_coder": det_utils.BoxCoder,
  180. }
  181. def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
  182. super().__init__()
  183. conv = []
  184. for _ in range(4):
  185. conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
  186. self.conv = nn.Sequential(*conv)
  187. self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
  188. torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
  189. torch.nn.init.zeros_(self.bbox_reg.bias)
  190. for layer in self.conv.modules():
  191. if isinstance(layer, nn.Conv2d):
  192. torch.nn.init.normal_(layer.weight, std=0.01)
  193. if layer.bias is not None:
  194. torch.nn.init.zeros_(layer.bias)
  195. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  196. self._loss_type = "l1"
  197. def _load_from_state_dict(
  198. self,
  199. state_dict,
  200. prefix,
  201. local_metadata,
  202. strict,
  203. missing_keys,
  204. unexpected_keys,
  205. error_msgs,
  206. ):
  207. version = local_metadata.get("version", None)
  208. if version is None or version < 2:
  209. _v1_to_v2_weights(state_dict, prefix)
  210. super()._load_from_state_dict(
  211. state_dict,
  212. prefix,
  213. local_metadata,
  214. strict,
  215. missing_keys,
  216. unexpected_keys,
  217. error_msgs,
  218. )
  219. def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
  220. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
  221. losses = []
  222. bbox_regression = head_outputs["bbox_regression"]
  223. for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
  224. targets, bbox_regression, anchors, matched_idxs
  225. ):
  226. # determine only the foreground indices, ignore the rest
  227. foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
  228. num_foreground = foreground_idxs_per_image.numel()
  229. # select only the foreground boxes
  230. matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
  231. bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
  232. anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
  233. # compute the loss
  234. losses.append(
  235. _box_loss(
  236. self._loss_type,
  237. self.box_coder,
  238. anchors_per_image,
  239. matched_gt_boxes_per_image,
  240. bbox_regression_per_image,
  241. )
  242. / max(1, num_foreground)
  243. )
  244. return _sum(losses) / max(1, len(targets))
  245. def forward(self, x):
  246. # type: (List[Tensor]) -> Tensor
  247. all_bbox_regression = []
  248. for features in x:
  249. bbox_regression = self.conv(features)
  250. bbox_regression = self.bbox_reg(bbox_regression)
  251. # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
  252. N, _, H, W = bbox_regression.shape
  253. bbox_regression = bbox_regression.view(N, -1, 4, H, W)
  254. bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
  255. bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
  256. all_bbox_regression.append(bbox_regression)
  257. return torch.cat(all_bbox_regression, dim=1)
  258. class RetinaNet(nn.Module):
  259. """
  260. Implements RetinaNet.
  261. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  262. image, and should be in 0-1 range. Different images can have different sizes.
  263. The behavior of the model changes depending if it is in training or evaluation mode.
  264. During training, the model expects both the input tensors, as well as a targets (list of dictionary),
  265. containing:
  266. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  267. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  268. - labels (Int64Tensor[N]): the class label for each ground-truth box
  269. The model returns a Dict[Tensor] during training, containing the classification and regression
  270. losses.
  271. During inference, the model requires only the input tensors, and returns the post-processed
  272. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  273. follows:
  274. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  275. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  276. - labels (Int64Tensor[N]): the predicted labels for each image
  277. - scores (Tensor[N]): the scores for each prediction
  278. Args:
  279. backbone (nn.Module): the network used to compute the features for the model.
  280. It should contain an out_channels attribute, which indicates the number of output
  281. channels that each feature map has (and it should be the same for all feature maps).
  282. The backbone should return a single Tensor or an OrderedDict[Tensor].
  283. num_classes (int): number of output classes of the model (including the background).
  284. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  285. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  286. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  287. They are generally the mean values of the dataset on which the backbone has been trained
  288. on
  289. image_std (Tuple[float, float, float]): std values used for input normalization.
  290. They are generally the std values of the dataset on which the backbone has been trained on
  291. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  292. maps.
  293. head (nn.Module): Module run on top of the feature pyramid.
  294. Defaults to a module containing a classification and regression module.
  295. score_thresh (float): Score threshold used for postprocessing the detections.
  296. nms_thresh (float): NMS threshold used for postprocessing the detections.
  297. detections_per_img (int): Number of best detections to keep after NMS.
  298. fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  299. considered as positive during training.
  300. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  301. considered as negative during training.
  302. topk_candidates (int): Number of best detections to keep before NMS.
  303. Example:
  304. >>> import torch
  305. >>> import torchvision
  306. >>> from torchvision.models.detection import RetinaNet
  307. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  308. >>> # load a pre-trained model for classification and return
  309. >>> # only the features
  310. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  311. >>> # RetinaNet needs to know the number of
  312. >>> # output channels in a backbone. For mobilenet_v2, it's 1280
  313. >>> # so we need to add it here
  314. >>> backbone.out_channels = 1280
  315. >>>
  316. >>> # let's make the network generate 5 x 3 anchors per spatial
  317. >>> # location, with 5 different sizes and 3 different aspect
  318. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  319. >>> # map could potentially have different sizes and
  320. >>> # aspect ratios
  321. >>> anchor_generator = AnchorGenerator(
  322. >>> sizes=((32, 64, 128, 256, 512),),
  323. >>> aspect_ratios=((0.5, 1.0, 2.0),)
  324. >>> )
  325. >>>
  326. >>> # put the pieces together inside a RetinaNet model
  327. >>> model = RetinaNet(backbone,
  328. >>> num_classes=2,
  329. >>> anchor_generator=anchor_generator)
  330. >>> model.eval()
  331. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  332. >>> predictions = model(x)
  333. """
  334. __annotations__ = {
  335. "box_coder": det_utils.BoxCoder,
  336. "proposal_matcher": det_utils.Matcher,
  337. }
  338. def __init__(
  339. self,
  340. backbone,
  341. num_classes,
  342. # transform parameters
  343. min_size=800,
  344. max_size=1333,
  345. image_mean=None,
  346. image_std=None,
  347. # Anchor parameters
  348. anchor_generator=None,
  349. head=None,
  350. proposal_matcher=None,
  351. score_thresh=0.05,
  352. nms_thresh=0.5,
  353. detections_per_img=300,
  354. fg_iou_thresh=0.5,
  355. bg_iou_thresh=0.4,
  356. topk_candidates=1000,
  357. **kwargs,
  358. ):
  359. super().__init__()
  360. _log_api_usage_once(self)
  361. if not hasattr(backbone, "out_channels"):
  362. raise ValueError(
  363. "backbone should contain an attribute out_channels "
  364. "specifying the number of output channels (assumed to be the "
  365. "same for all the levels)"
  366. )
  367. self.backbone = backbone
  368. if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
  369. raise TypeError(
  370. f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
  371. )
  372. if anchor_generator is None:
  373. anchor_generator = _default_anchorgen()
  374. self.anchor_generator = anchor_generator
  375. if head is None:
  376. head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
  377. self.head = head
  378. if proposal_matcher is None:
  379. proposal_matcher = det_utils.Matcher(
  380. fg_iou_thresh,
  381. bg_iou_thresh,
  382. allow_low_quality_matches=True,
  383. )
  384. self.proposal_matcher = proposal_matcher
  385. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  386. if image_mean is None:
  387. image_mean = [0.485, 0.456, 0.406]
  388. if image_std is None:
  389. image_std = [0.229, 0.224, 0.225]
  390. self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  391. self.score_thresh = score_thresh
  392. self.nms_thresh = nms_thresh
  393. self.detections_per_img = detections_per_img
  394. self.topk_candidates = topk_candidates
  395. # used only on torchscript mode
  396. self._has_warned = False
  397. @torch.jit.unused
  398. def eager_outputs(self, losses, detections):
  399. # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  400. if self.training:
  401. return losses
  402. return detections
  403. def compute_loss(self, targets, head_outputs, anchors):
  404. # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
  405. matched_idxs = []
  406. for anchors_per_image, targets_per_image in zip(anchors, targets):
  407. if targets_per_image["boxes"].numel() == 0:
  408. matched_idxs.append(
  409. torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
  410. )
  411. continue
  412. match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
  413. matched_idxs.append(self.proposal_matcher(match_quality_matrix))
  414. return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
  415. def postprocess_detections(self, head_outputs, anchors, image_shapes):
  416. # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
  417. class_logits = head_outputs["cls_logits"]
  418. box_regression = head_outputs["bbox_regression"]
  419. num_images = len(image_shapes)
  420. detections: List[Dict[str, Tensor]] = []
  421. for index in range(num_images):
  422. box_regression_per_image = [br[index] for br in box_regression]
  423. logits_per_image = [cl[index] for cl in class_logits]
  424. anchors_per_image, image_shape = anchors[index], image_shapes[index]
  425. image_boxes = []
  426. image_scores = []
  427. image_labels = []
  428. for box_regression_per_level, logits_per_level, anchors_per_level in zip(
  429. box_regression_per_image, logits_per_image, anchors_per_image
  430. ):
  431. num_classes = logits_per_level.shape[-1]
  432. # remove low scoring boxes
  433. scores_per_level = torch.sigmoid(logits_per_level).flatten()
  434. keep_idxs = scores_per_level > self.score_thresh
  435. scores_per_level = scores_per_level[keep_idxs]
  436. topk_idxs = torch.where(keep_idxs)[0]
  437. # keep only topk scoring predictions
  438. num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
  439. scores_per_level, idxs = scores_per_level.topk(num_topk)
  440. topk_idxs = topk_idxs[idxs]
  441. anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
  442. labels_per_level = topk_idxs % num_classes
  443. boxes_per_level = self.box_coder.decode_single(
  444. box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
  445. )
  446. boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
  447. image_boxes.append(boxes_per_level)
  448. image_scores.append(scores_per_level)
  449. image_labels.append(labels_per_level)
  450. image_boxes = torch.cat(image_boxes, dim=0)
  451. image_scores = torch.cat(image_scores, dim=0)
  452. image_labels = torch.cat(image_labels, dim=0)
  453. # non-maximum suppression
  454. keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
  455. keep = keep[: self.detections_per_img]
  456. detections.append(
  457. {
  458. "boxes": image_boxes[keep],
  459. "scores": image_scores[keep],
  460. "labels": image_labels[keep],
  461. }
  462. )
  463. return detections
  464. def forward(self, images, targets=None):
  465. # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  466. """
  467. Args:
  468. images (list[Tensor]): images to be processed
  469. targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
  470. Returns:
  471. result (list[BoxList] or dict[Tensor]): the output from the model.
  472. During training, it returns a dict[Tensor] which contains the losses.
  473. During testing, it returns list[BoxList] contains additional fields
  474. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  475. """
  476. if self.training:
  477. if targets is None:
  478. torch._assert(False, "targets should not be none when in training mode")
  479. else:
  480. for target in targets:
  481. boxes = target["boxes"]
  482. torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
  483. torch._assert(
  484. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  485. "Expected target boxes to be a tensor of shape [N, 4].",
  486. )
  487. # get the original image sizes
  488. original_image_sizes: List[Tuple[int, int]] = []
  489. for img in images:
  490. val = img.shape[-2:]
  491. torch._assert(
  492. len(val) == 2,
  493. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  494. )
  495. original_image_sizes.append((val[0], val[1]))
  496. # transform the input
  497. images, targets = self.transform(images, targets)
  498. # Check for degenerate boxes
  499. # TODO: Move this to a function
  500. if targets is not None:
  501. for target_idx, target in enumerate(targets):
  502. boxes = target["boxes"]
  503. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  504. if degenerate_boxes.any():
  505. # print the first degenerate box
  506. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  507. degen_bb: List[float] = boxes[bb_idx].tolist()
  508. torch._assert(
  509. False,
  510. "All bounding boxes should have positive height and width."
  511. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  512. )
  513. # get the features from the backbone
  514. features = self.backbone(images.tensors)
  515. if isinstance(features, torch.Tensor):
  516. features = OrderedDict([("0", features)])
  517. # TODO: Do we want a list or a dict?
  518. features = list(features.values())
  519. # compute the retinanet heads outputs using the features
  520. head_outputs = self.head(features)
  521. # create the set of anchors
  522. anchors = self.anchor_generator(images, features)
  523. losses = {}
  524. detections: List[Dict[str, Tensor]] = []
  525. if self.training:
  526. if targets is None:
  527. torch._assert(False, "targets should not be none when in training mode")
  528. else:
  529. # compute the losses
  530. losses = self.compute_loss(targets, head_outputs, anchors)
  531. else:
  532. # recover level sizes
  533. num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
  534. HW = 0
  535. for v in num_anchors_per_level:
  536. HW += v
  537. HWA = head_outputs["cls_logits"].size(1)
  538. A = HWA // HW
  539. num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
  540. # split outputs per level
  541. split_head_outputs: Dict[str, List[Tensor]] = {}
  542. for k in head_outputs:
  543. split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
  544. split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
  545. # compute the detections
  546. detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
  547. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  548. if torch.jit.is_scripting():
  549. if not self._has_warned:
  550. warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
  551. self._has_warned = True
  552. return losses, detections
  553. return self.eager_outputs(losses, detections)
  554. _COMMON_META = {
  555. "categories": _COCO_CATEGORIES,
  556. "min_size": (1, 1),
  557. }
  558. class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
  559. COCO_V1 = Weights(
  560. url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
  561. transforms=ObjectDetection,
  562. meta={
  563. **_COMMON_META,
  564. "num_params": 34014999,
  565. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
  566. "_metrics": {
  567. "COCO-val2017": {
  568. "box_map": 36.4,
  569. }
  570. },
  571. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  572. },
  573. )
  574. DEFAULT = COCO_V1
  575. class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  576. COCO_V1 = Weights(
  577. url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
  578. transforms=ObjectDetection,
  579. meta={
  580. **_COMMON_META,
  581. "num_params": 38198935,
  582. "recipe": "https://github.com/pytorch/vision/pull/5756",
  583. "_metrics": {
  584. "COCO-val2017": {
  585. "box_map": 41.5,
  586. }
  587. },
  588. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  589. },
  590. )
  591. DEFAULT = COCO_V1
  592. @handle_legacy_interface(
  593. weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
  594. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  595. )
  596. def retinanet_resnet50_fpn(
  597. *,
  598. weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
  599. progress: bool = True,
  600. num_classes: Optional[int] = None,
  601. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  602. trainable_backbone_layers: Optional[int] = None,
  603. **kwargs: Any,
  604. ) -> RetinaNet:
  605. """
  606. Constructs a RetinaNet model with a ResNet-50-FPN backbone.
  607. .. betastatus:: detection module
  608. Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
  609. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  610. image, and should be in ``0-1`` range. Different images can have different sizes.
  611. The behavior of the model changes depending if it is in training or evaluation mode.
  612. During training, the model expects both the input tensors, as well as a targets (list of dictionary),
  613. containing:
  614. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  615. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  616. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  617. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  618. losses.
  619. During inference, the model requires only the input tensors, and returns the post-processed
  620. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  621. follows, where ``N`` is the number of detections:
  622. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  623. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  624. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  625. - scores (``Tensor[N]``): the scores of each detection
  626. For more details on the output, you may refer to :ref:`instance_seg_output`.
  627. Example::
  628. >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
  629. >>> model.eval()
  630. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  631. >>> predictions = model(x)
  632. Args:
  633. weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
  634. pretrained weights to use. See
  635. :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
  636. below for more details, and possible values. By default, no
  637. pre-trained weights are used.
  638. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  639. num_classes (int, optional): number of output classes of the model (including the background)
  640. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  641. the backbone.
  642. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  643. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  644. passed (the default) this value is set to 3.
  645. **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
  646. base class. Please refer to the `source code
  647. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
  648. for more details about this class.
  649. .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
  650. :members:
  651. """
  652. weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
  653. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  654. if weights is not None:
  655. weights_backbone = None
  656. num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
  657. elif num_classes is None:
  658. num_classes = 91
  659. is_trained = weights is not None or weights_backbone is not None
  660. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  661. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  662. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  663. # skip P2 because it generates too many anchors (according to their paper)
  664. backbone = _resnet_fpn_extractor(
  665. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
  666. )
  667. model = RetinaNet(backbone, num_classes, **kwargs)
  668. if weights is not None:
  669. model.load_state_dict(weights.get_state_dict(progress=progress))
  670. if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
  671. overwrite_eps(model, 0.0)
  672. return model
  673. def retinanet_resnet50_fpn_v2(
  674. *,
  675. weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
  676. progress: bool = True,
  677. num_classes: Optional[int] = None,
  678. weights_backbone: Optional[ResNet50_Weights] = None,
  679. trainable_backbone_layers: Optional[int] = None,
  680. **kwargs: Any,
  681. ) -> RetinaNet:
  682. """
  683. Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
  684. .. betastatus:: detection module
  685. Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
  686. <https://arxiv.org/abs/1912.02424>`_.
  687. :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
  688. Args:
  689. weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
  690. pretrained weights to use. See
  691. :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
  692. below for more details, and possible values. By default, no
  693. pre-trained weights are used.
  694. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  695. num_classes (int, optional): number of output classes of the model (including the background)
  696. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  697. the backbone.
  698. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  699. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  700. passed (the default) this value is set to 3.
  701. **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
  702. base class. Please refer to the `source code
  703. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
  704. for more details about this class.
  705. .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
  706. :members:
  707. """
  708. weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
  709. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  710. if weights is not None:
  711. weights_backbone = None
  712. num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
  713. elif num_classes is None:
  714. num_classes = 91
  715. is_trained = weights is not None or weights_backbone is not None
  716. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  717. backbone = resnet50(weights=weights_backbone, progress=progress)
  718. backbone = _resnet_fpn_extractor(
  719. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
  720. )
  721. anchor_generator = _default_anchorgen()
  722. head = RetinaNetHead(
  723. backbone.out_channels,
  724. anchor_generator.num_anchors_per_location()[0],
  725. num_classes,
  726. norm_layer=partial(nn.GroupNorm, 32),
  727. )
  728. head.regression_head._loss_type = "giou"
  729. model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
  730. if weights is not None:
  731. model.load_state_dict(weights.get_state_dict(progress=progress))
  732. return model
  733. # The dictionary below is internal implementation detail and will be removed in v0.15
  734. from .._utils import _ModelURLs
  735. model_urls = _ModelURLs(
  736. {
  737. "retinanet_resnet50_fpn_coco": RetinaNet_ResNet50_FPN_Weights.COCO_V1.url,
  738. }
  739. )