fcos.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. import math
  2. import warnings
  3. from collections import OrderedDict
  4. from functools import partial
  5. from typing import Any, Callable, Dict, List, Tuple, Optional
  6. import torch
  7. from torch import nn, Tensor
  8. from ...ops import sigmoid_focal_loss, generalized_box_iou_loss
  9. from ...ops import boxes as box_ops
  10. from ...ops import misc as misc_nn_ops
  11. from ...ops.feature_pyramid_network import LastLevelP6P7
  12. from ...transforms._presets import ObjectDetection
  13. from ...utils import _log_api_usage_once
  14. from .._api import WeightsEnum, Weights
  15. from .._meta import _COCO_CATEGORIES
  16. from .._utils import handle_legacy_interface, _ovewrite_value_param
  17. from ..resnet import ResNet50_Weights, resnet50
  18. from . import _utils as det_utils
  19. from .anchor_utils import AnchorGenerator
  20. from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  21. from .transform import GeneralizedRCNNTransform
  22. __all__ = [
  23. "FCOS",
  24. "FCOS_ResNet50_FPN_Weights",
  25. "fcos_resnet50_fpn",
  26. ]
  27. class FCOSHead(nn.Module):
  28. """
  29. A regression and classification head for use in FCOS.
  30. Args:
  31. in_channels (int): number of channels of the input feature
  32. num_anchors (int): number of anchors to be predicted
  33. num_classes (int): number of classes to be predicted
  34. num_convs (Optional[int]): number of conv layer of head. Default: 4.
  35. """
  36. __annotations__ = {
  37. "box_coder": det_utils.BoxLinearCoder,
  38. }
  39. def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None:
  40. super().__init__()
  41. self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
  42. self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
  43. self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)
  44. def compute_loss(
  45. self,
  46. targets: List[Dict[str, Tensor]],
  47. head_outputs: Dict[str, Tensor],
  48. anchors: List[Tensor],
  49. matched_idxs: List[Tensor],
  50. ) -> Dict[str, Tensor]:
  51. cls_logits = head_outputs["cls_logits"] # [N, HWA, C]
  52. bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4]
  53. bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1]
  54. all_gt_classes_targets = []
  55. all_gt_boxes_targets = []
  56. for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
  57. if len(targets_per_image["labels"]) == 0:
  58. gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
  59. gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
  60. else:
  61. gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
  62. gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
  63. gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
  64. all_gt_classes_targets.append(gt_classes_targets)
  65. all_gt_boxes_targets.append(gt_boxes_targets)
  66. all_gt_classes_targets = torch.stack(all_gt_classes_targets)
  67. # compute foregroud
  68. foregroud_mask = all_gt_classes_targets >= 0
  69. num_foreground = foregroud_mask.sum().item()
  70. # classification loss
  71. gt_classes_targets = torch.zeros_like(cls_logits)
  72. gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
  73. loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
  74. # regression loss: GIoU loss
  75. # TODO: vectorize this instead of using a for loop
  76. pred_boxes = [
  77. self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
  78. for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression)
  79. ]
  80. # amp issue: pred_boxes need to convert float
  81. loss_bbox_reg = generalized_box_iou_loss(
  82. torch.stack(pred_boxes)[foregroud_mask].float(),
  83. torch.stack(all_gt_boxes_targets)[foregroud_mask],
  84. reduction="sum",
  85. )
  86. # ctrness loss
  87. bbox_reg_targets = [
  88. self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
  89. for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)
  90. ]
  91. bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
  92. if len(bbox_reg_targets) == 0:
  93. gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
  94. else:
  95. left_right = bbox_reg_targets[:, :, [0, 2]]
  96. top_bottom = bbox_reg_targets[:, :, [1, 3]]
  97. gt_ctrness_targets = torch.sqrt(
  98. (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
  99. * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
  100. )
  101. pred_centerness = bbox_ctrness.squeeze(dim=2)
  102. loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
  103. pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
  104. )
  105. return {
  106. "classification": loss_cls / max(1, num_foreground),
  107. "bbox_regression": loss_bbox_reg / max(1, num_foreground),
  108. "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
  109. }
  110. def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
  111. cls_logits = self.classification_head(x)
  112. bbox_regression, bbox_ctrness = self.regression_head(x)
  113. return {
  114. "cls_logits": cls_logits,
  115. "bbox_regression": bbox_regression,
  116. "bbox_ctrness": bbox_ctrness,
  117. }
  118. class FCOSClassificationHead(nn.Module):
  119. """
  120. A classification head for use in FCOS.
  121. Args:
  122. in_channels (int): number of channels of the input feature.
  123. num_anchors (int): number of anchors to be predicted.
  124. num_classes (int): number of classes to be predicted.
  125. num_convs (Optional[int]): number of conv layer. Default: 4.
  126. prior_probability (Optional[float]): probability of prior. Default: 0.01.
  127. norm_layer: Module specifying the normalization layer to use.
  128. """
  129. def __init__(
  130. self,
  131. in_channels: int,
  132. num_anchors: int,
  133. num_classes: int,
  134. num_convs: int = 4,
  135. prior_probability: float = 0.01,
  136. norm_layer: Optional[Callable[..., nn.Module]] = None,
  137. ) -> None:
  138. super().__init__()
  139. self.num_classes = num_classes
  140. self.num_anchors = num_anchors
  141. if norm_layer is None:
  142. norm_layer = partial(nn.GroupNorm, 32)
  143. conv = []
  144. for _ in range(num_convs):
  145. conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
  146. conv.append(norm_layer(in_channels))
  147. conv.append(nn.ReLU())
  148. self.conv = nn.Sequential(*conv)
  149. for layer in self.conv.children():
  150. if isinstance(layer, nn.Conv2d):
  151. torch.nn.init.normal_(layer.weight, std=0.01)
  152. torch.nn.init.constant_(layer.bias, 0)
  153. self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
  154. torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
  155. torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
  156. def forward(self, x: List[Tensor]) -> Tensor:
  157. all_cls_logits = []
  158. for features in x:
  159. cls_logits = self.conv(features)
  160. cls_logits = self.cls_logits(cls_logits)
  161. # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
  162. N, _, H, W = cls_logits.shape
  163. cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
  164. cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
  165. cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
  166. all_cls_logits.append(cls_logits)
  167. return torch.cat(all_cls_logits, dim=1)
  168. class FCOSRegressionHead(nn.Module):
  169. """
  170. A regression head for use in FCOS, which combines regression branch and center-ness branch.
  171. This can obtain better performance.
  172. Reference: `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
  173. Args:
  174. in_channels (int): number of channels of the input feature
  175. num_anchors (int): number of anchors to be predicted
  176. num_convs (Optional[int]): number of conv layer. Default: 4.
  177. norm_layer: Module specifying the normalization layer to use.
  178. """
  179. def __init__(
  180. self,
  181. in_channels: int,
  182. num_anchors: int,
  183. num_convs: int = 4,
  184. norm_layer: Optional[Callable[..., nn.Module]] = None,
  185. ):
  186. super().__init__()
  187. if norm_layer is None:
  188. norm_layer = partial(nn.GroupNorm, 32)
  189. conv = []
  190. for _ in range(num_convs):
  191. conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
  192. conv.append(norm_layer(in_channels))
  193. conv.append(nn.ReLU())
  194. self.conv = nn.Sequential(*conv)
  195. self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
  196. self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
  197. for layer in [self.bbox_reg, self.bbox_ctrness]:
  198. torch.nn.init.normal_(layer.weight, std=0.01)
  199. torch.nn.init.zeros_(layer.bias)
  200. for layer in self.conv.children():
  201. if isinstance(layer, nn.Conv2d):
  202. torch.nn.init.normal_(layer.weight, std=0.01)
  203. torch.nn.init.zeros_(layer.bias)
  204. def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
  205. all_bbox_regression = []
  206. all_bbox_ctrness = []
  207. for features in x:
  208. bbox_feature = self.conv(features)
  209. bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
  210. bbox_ctrness = self.bbox_ctrness(bbox_feature)
  211. # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
  212. N, _, H, W = bbox_regression.shape
  213. bbox_regression = bbox_regression.view(N, -1, 4, H, W)
  214. bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
  215. bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
  216. all_bbox_regression.append(bbox_regression)
  217. # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
  218. bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
  219. bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
  220. bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
  221. all_bbox_ctrness.append(bbox_ctrness)
  222. return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)
  223. class FCOS(nn.Module):
  224. """
  225. Implements FCOS.
  226. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  227. image, and should be in 0-1 range. Different images can have different sizes.
  228. The behavior of the model changes depending if it is in training or evaluation mode.
  229. During training, the model expects both the input tensors, as well as a targets (list of dictionary),
  230. containing:
  231. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  232. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  233. - labels (Int64Tensor[N]): the class label for each ground-truth box
  234. The model returns a Dict[Tensor] during training, containing the classification, regression
  235. and centerness losses.
  236. During inference, the model requires only the input tensors, and returns the post-processed
  237. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  238. follows:
  239. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  240. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  241. - labels (Int64Tensor[N]): the predicted labels for each image
  242. - scores (Tensor[N]): the scores for each prediction
  243. Args:
  244. backbone (nn.Module): the network used to compute the features for the model.
  245. It should contain an out_channels attribute, which indicates the number of output
  246. channels that each feature map has (and it should be the same for all feature maps).
  247. The backbone should return a single Tensor or an OrderedDict[Tensor].
  248. num_classes (int): number of output classes of the model (including the background).
  249. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  250. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  251. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  252. They are generally the mean values of the dataset on which the backbone has been trained
  253. on
  254. image_std (Tuple[float, float, float]): std values used for input normalization.
  255. They are generally the std values of the dataset on which the backbone has been trained on
  256. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  257. maps. For FCOS, only set one anchor for per position of each level, the width and height equal to
  258. the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point
  259. in FCOS paper.
  260. head (nn.Module): Module run on top of the feature pyramid.
  261. Defaults to a module containing a classification and regression module.
  262. center_sampling_radius (int): radius of the "center" of a groundtruth box,
  263. within which all anchor points are labeled positive.
  264. score_thresh (float): Score threshold used for postprocessing the detections.
  265. nms_thresh (float): NMS threshold used for postprocessing the detections.
  266. detections_per_img (int): Number of best detections to keep after NMS.
  267. topk_candidates (int): Number of best detections to keep before NMS.
  268. Example:
  269. >>> import torch
  270. >>> import torchvision
  271. >>> from torchvision.models.detection import FCOS
  272. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  273. >>> # load a pre-trained model for classification and return
  274. >>> # only the features
  275. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  276. >>> # FCOS needs to know the number of
  277. >>> # output channels in a backbone. For mobilenet_v2, it's 1280
  278. >>> # so we need to add it here
  279. >>> backbone.out_channels = 1280
  280. >>>
  281. >>> # let's make the network generate 5 x 3 anchors per spatial
  282. >>> # location, with 5 different sizes and 3 different aspect
  283. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  284. >>> # map could potentially have different sizes and
  285. >>> # aspect ratios
  286. >>> anchor_generator = AnchorGenerator(
  287. >>> sizes=((8,), (16,), (32,), (64,), (128,)),
  288. >>> aspect_ratios=((1.0,),)
  289. >>> )
  290. >>>
  291. >>> # put the pieces together inside a FCOS model
  292. >>> model = FCOS(
  293. >>> backbone,
  294. >>> num_classes=80,
  295. >>> anchor_generator=anchor_generator,
  296. >>> )
  297. >>> model.eval()
  298. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  299. >>> predictions = model(x)
  300. """
  301. __annotations__ = {
  302. "box_coder": det_utils.BoxLinearCoder,
  303. }
  304. def __init__(
  305. self,
  306. backbone: nn.Module,
  307. num_classes: int,
  308. # transform parameters
  309. min_size: int = 800,
  310. max_size: int = 1333,
  311. image_mean: Optional[List[float]] = None,
  312. image_std: Optional[List[float]] = None,
  313. # Anchor parameters
  314. anchor_generator: Optional[AnchorGenerator] = None,
  315. head: Optional[nn.Module] = None,
  316. center_sampling_radius: float = 1.5,
  317. score_thresh: float = 0.2,
  318. nms_thresh: float = 0.6,
  319. detections_per_img: int = 100,
  320. topk_candidates: int = 1000,
  321. **kwargs,
  322. ):
  323. super().__init__()
  324. _log_api_usage_once(self)
  325. if not hasattr(backbone, "out_channels"):
  326. raise ValueError(
  327. "backbone should contain an attribute out_channels "
  328. "specifying the number of output channels (assumed to be the "
  329. "same for all the levels)"
  330. )
  331. self.backbone = backbone
  332. if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
  333. raise TypeError(
  334. f"anchor_generator should be of type AnchorGenerator or None, instead got {type(anchor_generator)}"
  335. )
  336. if anchor_generator is None:
  337. anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
  338. aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
  339. anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  340. self.anchor_generator = anchor_generator
  341. if self.anchor_generator.num_anchors_per_location()[0] != 1:
  342. raise ValueError(
  343. f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}"
  344. )
  345. if head is None:
  346. head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
  347. self.head = head
  348. self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
  349. if image_mean is None:
  350. image_mean = [0.485, 0.456, 0.406]
  351. if image_std is None:
  352. image_std = [0.229, 0.224, 0.225]
  353. self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  354. self.center_sampling_radius = center_sampling_radius
  355. self.score_thresh = score_thresh
  356. self.nms_thresh = nms_thresh
  357. self.detections_per_img = detections_per_img
  358. self.topk_candidates = topk_candidates
  359. # used only on torchscript mode
  360. self._has_warned = False
  361. @torch.jit.unused
  362. def eager_outputs(
  363. self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
  364. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  365. if self.training:
  366. return losses
  367. return detections
  368. def compute_loss(
  369. self,
  370. targets: List[Dict[str, Tensor]],
  371. head_outputs: Dict[str, Tensor],
  372. anchors: List[Tensor],
  373. num_anchors_per_level: List[int],
  374. ) -> Dict[str, Tensor]:
  375. matched_idxs = []
  376. for anchors_per_image, targets_per_image in zip(anchors, targets):
  377. if targets_per_image["boxes"].numel() == 0:
  378. matched_idxs.append(
  379. torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
  380. )
  381. continue
  382. gt_boxes = targets_per_image["boxes"]
  383. gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2
  384. anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
  385. anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
  386. # center sampling: anchor point must be close enough to gt center.
  387. pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
  388. dim=2
  389. ).values < self.center_sampling_radius * anchor_sizes[:, None]
  390. # compute pairwise distance between N points and M boxes
  391. x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
  392. x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
  393. pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
  394. # anchor point must be inside gt
  395. pairwise_match &= pairwise_dist.min(dim=2).values > 0
  396. # each anchor is only responsible for certain scale range.
  397. lower_bound = anchor_sizes * 4
  398. lower_bound[: num_anchors_per_level[0]] = 0
  399. upper_bound = anchor_sizes * 8
  400. upper_bound[-num_anchors_per_level[-1] :] = float("inf")
  401. pairwise_dist = pairwise_dist.max(dim=2).values
  402. pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])
  403. # match the GT box with minimum area, if there are multiple GT matches
  404. gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
  405. pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
  406. min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
  407. matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1
  408. matched_idxs.append(matched_idx)
  409. return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
  410. def postprocess_detections(
  411. self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
  412. ) -> List[Dict[str, Tensor]]:
  413. class_logits = head_outputs["cls_logits"]
  414. box_regression = head_outputs["bbox_regression"]
  415. box_ctrness = head_outputs["bbox_ctrness"]
  416. num_images = len(image_shapes)
  417. detections: List[Dict[str, Tensor]] = []
  418. for index in range(num_images):
  419. box_regression_per_image = [br[index] for br in box_regression]
  420. logits_per_image = [cl[index] for cl in class_logits]
  421. box_ctrness_per_image = [bc[index] for bc in box_ctrness]
  422. anchors_per_image, image_shape = anchors[index], image_shapes[index]
  423. image_boxes = []
  424. image_scores = []
  425. image_labels = []
  426. for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
  427. box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
  428. ):
  429. num_classes = logits_per_level.shape[-1]
  430. # remove low scoring boxes
  431. scores_per_level = torch.sqrt(
  432. torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
  433. ).flatten()
  434. keep_idxs = scores_per_level > self.score_thresh
  435. scores_per_level = scores_per_level[keep_idxs]
  436. topk_idxs = torch.where(keep_idxs)[0]
  437. # keep only topk scoring predictions
  438. num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
  439. scores_per_level, idxs = scores_per_level.topk(num_topk)
  440. topk_idxs = topk_idxs[idxs]
  441. anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
  442. labels_per_level = topk_idxs % num_classes
  443. boxes_per_level = self.box_coder.decode_single(
  444. box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
  445. )
  446. boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
  447. image_boxes.append(boxes_per_level)
  448. image_scores.append(scores_per_level)
  449. image_labels.append(labels_per_level)
  450. image_boxes = torch.cat(image_boxes, dim=0)
  451. image_scores = torch.cat(image_scores, dim=0)
  452. image_labels = torch.cat(image_labels, dim=0)
  453. # non-maximum suppression
  454. keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
  455. keep = keep[: self.detections_per_img]
  456. detections.append(
  457. {
  458. "boxes": image_boxes[keep],
  459. "scores": image_scores[keep],
  460. "labels": image_labels[keep],
  461. }
  462. )
  463. return detections
  464. def forward(
  465. self,
  466. images: List[Tensor],
  467. targets: Optional[List[Dict[str, Tensor]]] = None,
  468. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  469. """
  470. Args:
  471. images (list[Tensor]): images to be processed
  472. targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
  473. Returns:
  474. result (list[BoxList] or dict[Tensor]): the output from the model.
  475. During training, it returns a dict[Tensor] which contains the losses.
  476. During testing, it returns list[BoxList] contains additional fields
  477. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  478. """
  479. if self.training:
  480. if targets is None:
  481. torch._assert(False, "targets should not be none when in training mode")
  482. else:
  483. for target in targets:
  484. boxes = target["boxes"]
  485. torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
  486. torch._assert(
  487. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  488. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  489. )
  490. original_image_sizes: List[Tuple[int, int]] = []
  491. for img in images:
  492. val = img.shape[-2:]
  493. torch._assert(
  494. len(val) == 2,
  495. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  496. )
  497. original_image_sizes.append((val[0], val[1]))
  498. # transform the input
  499. images, targets = self.transform(images, targets)
  500. # Check for degenerate boxes
  501. if targets is not None:
  502. for target_idx, target in enumerate(targets):
  503. boxes = target["boxes"]
  504. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  505. if degenerate_boxes.any():
  506. # print the first degenerate box
  507. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  508. degen_bb: List[float] = boxes[bb_idx].tolist()
  509. torch._assert(
  510. False,
  511. f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
  512. )
  513. # get the features from the backbone
  514. features = self.backbone(images.tensors)
  515. if isinstance(features, torch.Tensor):
  516. features = OrderedDict([("0", features)])
  517. features = list(features.values())
  518. # compute the fcos heads outputs using the features
  519. head_outputs = self.head(features)
  520. # create the set of anchors
  521. anchors = self.anchor_generator(images, features)
  522. # recover level sizes
  523. num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
  524. losses = {}
  525. detections: List[Dict[str, Tensor]] = []
  526. if self.training:
  527. if targets is None:
  528. torch._assert(False, "targets should not be none when in training mode")
  529. else:
  530. # compute the losses
  531. losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
  532. else:
  533. # split outputs per level
  534. split_head_outputs: Dict[str, List[Tensor]] = {}
  535. for k in head_outputs:
  536. split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
  537. split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
  538. # compute the detections
  539. detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
  540. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  541. if torch.jit.is_scripting():
  542. if not self._has_warned:
  543. warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
  544. self._has_warned = True
  545. return losses, detections
  546. return self.eager_outputs(losses, detections)
  547. class FCOS_ResNet50_FPN_Weights(WeightsEnum):
  548. COCO_V1 = Weights(
  549. url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
  550. transforms=ObjectDetection,
  551. meta={
  552. "num_params": 32269600,
  553. "categories": _COCO_CATEGORIES,
  554. "min_size": (1, 1),
  555. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
  556. "_metrics": {
  557. "COCO-val2017": {
  558. "box_map": 39.2,
  559. }
  560. },
  561. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  562. },
  563. )
  564. DEFAULT = COCO_V1
  565. @handle_legacy_interface(
  566. weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
  567. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  568. )
  569. def fcos_resnet50_fpn(
  570. *,
  571. weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
  572. progress: bool = True,
  573. num_classes: Optional[int] = None,
  574. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  575. trainable_backbone_layers: Optional[int] = None,
  576. **kwargs: Any,
  577. ) -> FCOS:
  578. """
  579. Constructs a FCOS model with a ResNet-50-FPN backbone.
  580. .. betastatus:: detection module
  581. Reference: `FCOS: Fully Convolutional One-Stage Object Detection <https://arxiv.org/abs/1904.01355>`_.
  582. `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
  583. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  584. image, and should be in ``0-1`` range. Different images can have different sizes.
  585. The behavior of the model changes depending if it is in training or evaluation mode.
  586. During training, the model expects both the input tensors, as well as a targets (list of dictionary),
  587. containing:
  588. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  589. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  590. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  591. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  592. losses.
  593. During inference, the model requires only the input tensors, and returns the post-processed
  594. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  595. follows, where ``N`` is the number of detections:
  596. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  597. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  598. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  599. - scores (``Tensor[N]``): the scores of each detection
  600. For more details on the output, you may refer to :ref:`instance_seg_output`.
  601. Example:
  602. >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT)
  603. >>> model.eval()
  604. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  605. >>> predictions = model(x)
  606. Args:
  607. weights (:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`, optional): The
  608. pretrained weights to use. See
  609. :class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`
  610. below for more details, and possible values. By default, no
  611. pre-trained weights are used.
  612. progress (bool): If True, displays a progress bar of the download to stderr
  613. num_classes (int, optional): number of output classes of the model (including the background)
  614. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
  615. the backbone.
  616. trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting
  617. from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  618. trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
  619. **kwargs: parameters passed to the ``torchvision.models.detection.FCOS``
  620. base class. Please refer to the `source code
  621. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/fcos.py>`_
  622. for more details about this class.
  623. .. autoclass:: torchvision.models.detection.FCOS_ResNet50_FPN_Weights
  624. :members:
  625. """
  626. weights = FCOS_ResNet50_FPN_Weights.verify(weights)
  627. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  628. if weights is not None:
  629. weights_backbone = None
  630. num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
  631. elif num_classes is None:
  632. num_classes = 91
  633. is_trained = weights is not None or weights_backbone is not None
  634. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  635. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  636. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  637. backbone = _resnet_fpn_extractor(
  638. backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
  639. )
  640. model = FCOS(backbone, num_classes, **kwargs)
  641. if weights is not None:
  642. model.load_state_dict(weights.get_state_dict(progress=progress))
  643. return model
  644. # The dictionary below is internal implementation detail and will be removed in v0.15
  645. from .._utils import _ModelURLs
  646. model_urls = _ModelURLs(
  647. {
  648. "fcos_resnet50_fpn_coco": FCOS_ResNet50_FPN_Weights.COCO_V1.url,
  649. }
  650. )