ssd.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. import warnings
  2. from collections import OrderedDict
  3. from typing import Any, Dict, List, Optional, Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn, Tensor
  7. from ...ops import boxes as box_ops
  8. from ...transforms._presets import ObjectDetection
  9. from ...utils import _log_api_usage_once
  10. from .._api import WeightsEnum, Weights
  11. from .._meta import _COCO_CATEGORIES
  12. from .._utils import handle_legacy_interface, _ovewrite_value_param
  13. from ..vgg import VGG, VGG16_Weights, vgg16
  14. from . import _utils as det_utils
  15. from .anchor_utils import DefaultBoxGenerator
  16. from .backbone_utils import _validate_trainable_layers
  17. from .transform import GeneralizedRCNNTransform
  18. __all__ = [
  19. "SSD300_VGG16_Weights",
  20. "ssd300_vgg16",
  21. ]
  22. class SSD300_VGG16_Weights(WeightsEnum):
  23. COCO_V1 = Weights(
  24. url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
  25. transforms=ObjectDetection,
  26. meta={
  27. "num_params": 35641826,
  28. "categories": _COCO_CATEGORIES,
  29. "min_size": (1, 1),
  30. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
  31. "_metrics": {
  32. "COCO-val2017": {
  33. "box_map": 25.1,
  34. }
  35. },
  36. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  37. },
  38. )
  39. DEFAULT = COCO_V1
  40. def _xavier_init(conv: nn.Module):
  41. for layer in conv.modules():
  42. if isinstance(layer, nn.Conv2d):
  43. torch.nn.init.xavier_uniform_(layer.weight)
  44. if layer.bias is not None:
  45. torch.nn.init.constant_(layer.bias, 0.0)
  46. class SSDHead(nn.Module):
  47. def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
  48. super().__init__()
  49. self.classification_head = SSDClassificationHead(in_channels, num_anchors, num_classes)
  50. self.regression_head = SSDRegressionHead(in_channels, num_anchors)
  51. def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
  52. return {
  53. "bbox_regression": self.regression_head(x),
  54. "cls_logits": self.classification_head(x),
  55. }
  56. class SSDScoringHead(nn.Module):
  57. def __init__(self, module_list: nn.ModuleList, num_columns: int):
  58. super().__init__()
  59. self.module_list = module_list
  60. self.num_columns = num_columns
  61. def _get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor:
  62. """
  63. This is equivalent to self.module_list[idx](x),
  64. but torchscript doesn't support this yet
  65. """
  66. num_blocks = len(self.module_list)
  67. if idx < 0:
  68. idx += num_blocks
  69. out = x
  70. for i, module in enumerate(self.module_list):
  71. if i == idx:
  72. out = module(x)
  73. return out
  74. def forward(self, x: List[Tensor]) -> Tensor:
  75. all_results = []
  76. for i, features in enumerate(x):
  77. results = self._get_result_from_module_list(features, i)
  78. # Permute output from (N, A * K, H, W) to (N, HWA, K).
  79. N, _, H, W = results.shape
  80. results = results.view(N, -1, self.num_columns, H, W)
  81. results = results.permute(0, 3, 4, 1, 2)
  82. results = results.reshape(N, -1, self.num_columns) # Size=(N, HWA, K)
  83. all_results.append(results)
  84. return torch.cat(all_results, dim=1)
  85. class SSDClassificationHead(SSDScoringHead):
  86. def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
  87. cls_logits = nn.ModuleList()
  88. for channels, anchors in zip(in_channels, num_anchors):
  89. cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
  90. _xavier_init(cls_logits)
  91. super().__init__(cls_logits, num_classes)
  92. class SSDRegressionHead(SSDScoringHead):
  93. def __init__(self, in_channels: List[int], num_anchors: List[int]):
  94. bbox_reg = nn.ModuleList()
  95. for channels, anchors in zip(in_channels, num_anchors):
  96. bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
  97. _xavier_init(bbox_reg)
  98. super().__init__(bbox_reg, 4)
  99. class SSD(nn.Module):
  100. """
  101. Implements SSD architecture from `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
  102. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  103. image, and should be in 0-1 range. Different images can have different sizes but they will be resized
  104. to a fixed size before passing it to the backbone.
  105. The behavior of the model changes depending if it is in training or evaluation mode.
  106. During training, the model expects both the input tensors, as well as a targets (list of dictionary),
  107. containing:
  108. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  109. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  110. - labels (Int64Tensor[N]): the class label for each ground-truth box
  111. The model returns a Dict[Tensor] during training, containing the classification and regression
  112. losses.
  113. During inference, the model requires only the input tensors, and returns the post-processed
  114. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  115. follows, where ``N`` is the number of detections:
  116. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  117. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  118. - labels (Int64Tensor[N]): the predicted labels for each detection
  119. - scores (Tensor[N]): the scores for each detection
  120. Args:
  121. backbone (nn.Module): the network used to compute the features for the model.
  122. It should contain an out_channels attribute with the list of the output channels of
  123. each feature map. The backbone should return a single Tensor or an OrderedDict[Tensor].
  124. anchor_generator (DefaultBoxGenerator): module that generates the default boxes for a
  125. set of feature maps.
  126. size (Tuple[int, int]): the width and height to which images will be rescaled before feeding them
  127. to the backbone.
  128. num_classes (int): number of output classes of the model (including the background).
  129. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  130. They are generally the mean values of the dataset on which the backbone has been trained
  131. on
  132. image_std (Tuple[float, float, float]): std values used for input normalization.
  133. They are generally the std values of the dataset on which the backbone has been trained on
  134. head (nn.Module, optional): Module run on top of the backbone features. Defaults to a module containing
  135. a classification and regression module.
  136. score_thresh (float): Score threshold used for postprocessing the detections.
  137. nms_thresh (float): NMS threshold used for postprocessing the detections.
  138. detections_per_img (int): Number of best detections to keep after NMS.
  139. iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  140. considered as positive during training.
  141. topk_candidates (int): Number of best detections to keep before NMS.
  142. positive_fraction (float): a number between 0 and 1 which indicates the proportion of positive
  143. proposals used during the training of the classification head. It is used to estimate the negative to
  144. positive ratio.
  145. """
  146. __annotations__ = {
  147. "box_coder": det_utils.BoxCoder,
  148. "proposal_matcher": det_utils.Matcher,
  149. }
  150. def __init__(
  151. self,
  152. backbone: nn.Module,
  153. anchor_generator: DefaultBoxGenerator,
  154. size: Tuple[int, int],
  155. num_classes: int,
  156. image_mean: Optional[List[float]] = None,
  157. image_std: Optional[List[float]] = None,
  158. head: Optional[nn.Module] = None,
  159. score_thresh: float = 0.01,
  160. nms_thresh: float = 0.45,
  161. detections_per_img: int = 200,
  162. iou_thresh: float = 0.5,
  163. topk_candidates: int = 400,
  164. positive_fraction: float = 0.25,
  165. **kwargs: Any,
  166. ):
  167. super().__init__()
  168. _log_api_usage_once(self)
  169. self.backbone = backbone
  170. self.anchor_generator = anchor_generator
  171. self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
  172. if head is None:
  173. if hasattr(backbone, "out_channels"):
  174. out_channels = backbone.out_channels
  175. else:
  176. out_channels = det_utils.retrieve_out_channels(backbone, size)
  177. if len(out_channels) != len(anchor_generator.aspect_ratios):
  178. raise ValueError(
  179. f"The length of the output channels from the backbone ({len(out_channels)}) do not match the length of the anchor generator aspect ratios ({len(anchor_generator.aspect_ratios)})"
  180. )
  181. num_anchors = self.anchor_generator.num_anchors_per_location()
  182. head = SSDHead(out_channels, num_anchors, num_classes)
  183. self.head = head
  184. self.proposal_matcher = det_utils.SSDMatcher(iou_thresh)
  185. if image_mean is None:
  186. image_mean = [0.485, 0.456, 0.406]
  187. if image_std is None:
  188. image_std = [0.229, 0.224, 0.225]
  189. self.transform = GeneralizedRCNNTransform(
  190. min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
  191. )
  192. self.score_thresh = score_thresh
  193. self.nms_thresh = nms_thresh
  194. self.detections_per_img = detections_per_img
  195. self.topk_candidates = topk_candidates
  196. self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
  197. # used only on torchscript mode
  198. self._has_warned = False
  199. @torch.jit.unused
  200. def eager_outputs(
  201. self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
  202. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  203. if self.training:
  204. return losses
  205. return detections
  206. def compute_loss(
  207. self,
  208. targets: List[Dict[str, Tensor]],
  209. head_outputs: Dict[str, Tensor],
  210. anchors: List[Tensor],
  211. matched_idxs: List[Tensor],
  212. ) -> Dict[str, Tensor]:
  213. bbox_regression = head_outputs["bbox_regression"]
  214. cls_logits = head_outputs["cls_logits"]
  215. # Match original targets with default boxes
  216. num_foreground = 0
  217. bbox_loss = []
  218. cls_targets = []
  219. for (
  220. targets_per_image,
  221. bbox_regression_per_image,
  222. cls_logits_per_image,
  223. anchors_per_image,
  224. matched_idxs_per_image,
  225. ) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs):
  226. # produce the matching between boxes and targets
  227. foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
  228. foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image]
  229. num_foreground += foreground_matched_idxs_per_image.numel()
  230. # Calculate regression loss
  231. matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image]
  232. bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
  233. anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
  234. target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
  235. bbox_loss.append(
  236. torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
  237. )
  238. # Estimate ground truth for class targets
  239. gt_classes_target = torch.zeros(
  240. (cls_logits_per_image.size(0),),
  241. dtype=targets_per_image["labels"].dtype,
  242. device=targets_per_image["labels"].device,
  243. )
  244. gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][
  245. foreground_matched_idxs_per_image
  246. ]
  247. cls_targets.append(gt_classes_target)
  248. bbox_loss = torch.stack(bbox_loss)
  249. cls_targets = torch.stack(cls_targets)
  250. # Calculate classification loss
  251. num_classes = cls_logits.size(-1)
  252. cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view(
  253. cls_targets.size()
  254. )
  255. # Hard Negative Sampling
  256. foreground_idxs = cls_targets > 0
  257. num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True)
  258. # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio
  259. negative_loss = cls_loss.clone()
  260. negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample
  261. values, idx = negative_loss.sort(1, descending=True)
  262. # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values))
  263. background_idxs = idx.sort(1)[1] < num_negative
  264. N = max(1, num_foreground)
  265. return {
  266. "bbox_regression": bbox_loss.sum() / N,
  267. "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N,
  268. }
  269. def forward(
  270. self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
  271. ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
  272. if self.training:
  273. if targets is None:
  274. torch._assert(False, "targets should not be none when in training mode")
  275. else:
  276. for target in targets:
  277. boxes = target["boxes"]
  278. if isinstance(boxes, torch.Tensor):
  279. torch._assert(
  280. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  281. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  282. )
  283. else:
  284. torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
  285. # get the original image sizes
  286. original_image_sizes: List[Tuple[int, int]] = []
  287. for img in images:
  288. val = img.shape[-2:]
  289. torch._assert(
  290. len(val) == 2,
  291. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  292. )
  293. original_image_sizes.append((val[0], val[1]))
  294. # transform the input
  295. images, targets = self.transform(images, targets)
  296. # Check for degenerate boxes
  297. if targets is not None:
  298. for target_idx, target in enumerate(targets):
  299. boxes = target["boxes"]
  300. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  301. if degenerate_boxes.any():
  302. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  303. degen_bb: List[float] = boxes[bb_idx].tolist()
  304. torch._assert(
  305. False,
  306. "All bounding boxes should have positive height and width."
  307. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  308. )
  309. # get the features from the backbone
  310. features = self.backbone(images.tensors)
  311. if isinstance(features, torch.Tensor):
  312. features = OrderedDict([("0", features)])
  313. features = list(features.values())
  314. # compute the ssd heads outputs using the features
  315. head_outputs = self.head(features)
  316. # create the set of anchors
  317. anchors = self.anchor_generator(images, features)
  318. losses = {}
  319. detections: List[Dict[str, Tensor]] = []
  320. if self.training:
  321. matched_idxs = []
  322. if targets is None:
  323. torch._assert(False, "targets should not be none when in training mode")
  324. else:
  325. for anchors_per_image, targets_per_image in zip(anchors, targets):
  326. if targets_per_image["boxes"].numel() == 0:
  327. matched_idxs.append(
  328. torch.full(
  329. (anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device
  330. )
  331. )
  332. continue
  333. match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
  334. matched_idxs.append(self.proposal_matcher(match_quality_matrix))
  335. losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs)
  336. else:
  337. detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
  338. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  339. if torch.jit.is_scripting():
  340. if not self._has_warned:
  341. warnings.warn("SSD always returns a (Losses, Detections) tuple in scripting")
  342. self._has_warned = True
  343. return losses, detections
  344. return self.eager_outputs(losses, detections)
  345. def postprocess_detections(
  346. self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]]
  347. ) -> List[Dict[str, Tensor]]:
  348. bbox_regression = head_outputs["bbox_regression"]
  349. pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
  350. num_classes = pred_scores.size(-1)
  351. device = pred_scores.device
  352. detections: List[Dict[str, Tensor]] = []
  353. for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_shapes):
  354. boxes = self.box_coder.decode_single(boxes, anchors)
  355. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  356. image_boxes = []
  357. image_scores = []
  358. image_labels = []
  359. for label in range(1, num_classes):
  360. score = scores[:, label]
  361. keep_idxs = score > self.score_thresh
  362. score = score[keep_idxs]
  363. box = boxes[keep_idxs]
  364. # keep only topk scoring predictions
  365. num_topk = det_utils._topk_min(score, self.topk_candidates, 0)
  366. score, idxs = score.topk(num_topk)
  367. box = box[idxs]
  368. image_boxes.append(box)
  369. image_scores.append(score)
  370. image_labels.append(torch.full_like(score, fill_value=label, dtype=torch.int64, device=device))
  371. image_boxes = torch.cat(image_boxes, dim=0)
  372. image_scores = torch.cat(image_scores, dim=0)
  373. image_labels = torch.cat(image_labels, dim=0)
  374. # non-maximum suppression
  375. keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
  376. keep = keep[: self.detections_per_img]
  377. detections.append(
  378. {
  379. "boxes": image_boxes[keep],
  380. "scores": image_scores[keep],
  381. "labels": image_labels[keep],
  382. }
  383. )
  384. return detections
  385. class SSDFeatureExtractorVGG(nn.Module):
  386. def __init__(self, backbone: nn.Module, highres: bool):
  387. super().__init__()
  388. _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d))
  389. # Patch ceil_mode for maxpool3 to get the same WxH output sizes as the paper
  390. backbone[maxpool3_pos].ceil_mode = True
  391. # parameters used for L2 regularization + rescaling
  392. self.scale_weight = nn.Parameter(torch.ones(512) * 20)
  393. # Multiple Feature maps - page 4, Fig 2 of SSD paper
  394. self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3
  395. # SSD300 case - page 4, Fig 2 of SSD paper
  396. extra = nn.ModuleList(
  397. [
  398. nn.Sequential(
  399. nn.Conv2d(1024, 256, kernel_size=1),
  400. nn.ReLU(inplace=True),
  401. nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
  402. nn.ReLU(inplace=True),
  403. ),
  404. nn.Sequential(
  405. nn.Conv2d(512, 128, kernel_size=1),
  406. nn.ReLU(inplace=True),
  407. nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
  408. nn.ReLU(inplace=True),
  409. ),
  410. nn.Sequential(
  411. nn.Conv2d(256, 128, kernel_size=1),
  412. nn.ReLU(inplace=True),
  413. nn.Conv2d(128, 256, kernel_size=3), # conv10_2
  414. nn.ReLU(inplace=True),
  415. ),
  416. nn.Sequential(
  417. nn.Conv2d(256, 128, kernel_size=1),
  418. nn.ReLU(inplace=True),
  419. nn.Conv2d(128, 256, kernel_size=3), # conv11_2
  420. nn.ReLU(inplace=True),
  421. ),
  422. ]
  423. )
  424. if highres:
  425. # Additional layers for the SSD512 case. See page 11, footernote 5.
  426. extra.append(
  427. nn.Sequential(
  428. nn.Conv2d(256, 128, kernel_size=1),
  429. nn.ReLU(inplace=True),
  430. nn.Conv2d(128, 256, kernel_size=4), # conv12_2
  431. nn.ReLU(inplace=True),
  432. )
  433. )
  434. _xavier_init(extra)
  435. fc = nn.Sequential(
  436. nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False), # add modified maxpool5
  437. nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
  438. nn.ReLU(inplace=True),
  439. nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7
  440. nn.ReLU(inplace=True),
  441. )
  442. _xavier_init(fc)
  443. extra.insert(
  444. 0,
  445. nn.Sequential(
  446. *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5
  447. fc,
  448. ),
  449. )
  450. self.extra = extra
  451. def forward(self, x: Tensor) -> Dict[str, Tensor]:
  452. # L2 regularization + Rescaling of 1st block's feature map
  453. x = self.features(x)
  454. rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
  455. output = [rescaled]
  456. # Calculating Feature maps for the rest blocks
  457. for block in self.extra:
  458. x = block(x)
  459. output.append(x)
  460. return OrderedDict([(str(i), v) for i, v in enumerate(output)])
  461. def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int):
  462. backbone = backbone.features
  463. # Gather the indices of maxpools. These are the locations of output blocks.
  464. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1]
  465. num_stages = len(stage_indices)
  466. # find the index of the layer from which we wont freeze
  467. torch._assert(
  468. 0 <= trainable_layers <= num_stages,
  469. f"trainable_layers should be in the range [0, {num_stages}]. Instead got {trainable_layers}",
  470. )
  471. freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
  472. for b in backbone[:freeze_before]:
  473. for parameter in b.parameters():
  474. parameter.requires_grad_(False)
  475. return SSDFeatureExtractorVGG(backbone, highres)
  476. @handle_legacy_interface(
  477. weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
  478. weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
  479. )
  480. def ssd300_vgg16(
  481. *,
  482. weights: Optional[SSD300_VGG16_Weights] = None,
  483. progress: bool = True,
  484. num_classes: Optional[int] = None,
  485. weights_backbone: Optional[VGG16_Weights] = VGG16_Weights.IMAGENET1K_FEATURES,
  486. trainable_backbone_layers: Optional[int] = None,
  487. **kwargs: Any,
  488. ) -> SSD:
  489. """The SSD300 model is based on the `SSD: Single Shot MultiBox Detector
  490. <https://arxiv.org/abs/1512.02325>`_ paper.
  491. .. betastatus:: detection module
  492. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  493. image, and should be in 0-1 range. Different images can have different sizes but they will be resized
  494. to a fixed size before passing it to the backbone.
  495. The behavior of the model changes depending if it is in training or evaluation mode.
  496. During training, the model expects both the input tensors, as well as a targets (list of dictionary),
  497. containing:
  498. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  499. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  500. - labels (Int64Tensor[N]): the class label for each ground-truth box
  501. The model returns a Dict[Tensor] during training, containing the classification and regression
  502. losses.
  503. During inference, the model requires only the input tensors, and returns the post-processed
  504. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  505. follows, where ``N`` is the number of detections:
  506. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  507. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  508. - labels (Int64Tensor[N]): the predicted labels for each detection
  509. - scores (Tensor[N]): the scores for each detection
  510. Example:
  511. >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT)
  512. >>> model.eval()
  513. >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)]
  514. >>> predictions = model(x)
  515. Args:
  516. weights (:class:`~torchvision.models.detection.SSD300_VGG16_Weights`, optional): The pretrained
  517. weights to use. See
  518. :class:`~torchvision.models.detection.SSD300_VGG16_Weights`
  519. below for more details, and possible values. By default, no
  520. pre-trained weights are used.
  521. progress (bool, optional): If True, displays a progress bar of the download to stderr
  522. Default is True.
  523. num_classes (int, optional): number of output classes of the model (including the background)
  524. weights_backbone (:class:`~torchvision.models.VGG16_Weights`, optional): The pretrained weights for the
  525. backbone
  526. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  527. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  528. passed (the default) this value is set to 4.
  529. **kwargs: parameters passed to the ``torchvision.models.detection.SSD``
  530. base class. Please refer to the `source code
  531. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/ssd.py>`_
  532. for more details about this class.
  533. .. autoclass:: torchvision.models.detection.SSD300_VGG16_Weights
  534. :members:
  535. """
  536. weights = SSD300_VGG16_Weights.verify(weights)
  537. weights_backbone = VGG16_Weights.verify(weights_backbone)
  538. if "size" in kwargs:
  539. warnings.warn("The size of the model is already fixed; ignoring the parameter.")
  540. if weights is not None:
  541. weights_backbone = None
  542. num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
  543. elif num_classes is None:
  544. num_classes = 91
  545. trainable_backbone_layers = _validate_trainable_layers(
  546. weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
  547. )
  548. # Use custom backbones more appropriate for SSD
  549. backbone = vgg16(weights=weights_backbone, progress=progress)
  550. backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
  551. anchor_generator = DefaultBoxGenerator(
  552. [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
  553. scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
  554. steps=[8, 16, 32, 64, 100, 300],
  555. )
  556. defaults = {
  557. # Rescale the input in a way compatible to the backbone
  558. "image_mean": [0.48235, 0.45882, 0.40784],
  559. "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor
  560. }
  561. kwargs: Any = {**defaults, **kwargs}
  562. model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
  563. if weights is not None:
  564. model.load_state_dict(weights.get_state_dict(progress=progress))
  565. return model
  566. # The dictionary below is internal implementation detail and will be removed in v0.15
  567. from .._utils import _ModelURLs
  568. model_urls = _ModelURLs(
  569. {
  570. "ssd300_vgg16_coco": SSD300_VGG16_Weights.COCO_V1.url,
  571. }
  572. )
  573. backbone_urls = _ModelURLs(
  574. {
  575. # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses
  576. # the same input standardization method as the paper.
  577. # Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
  578. # Only the `features` weights have proper values, those on the `classifier` module are filled with nans.
  579. "vgg16_features": VGG16_Weights.IMAGENET1K_FEATURES.url,
  580. }
  581. )