poolers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. import warnings
  2. from typing import Optional, List, Dict, Tuple, Union
  3. import torch
  4. import torch.fx
  5. import torchvision
  6. from torch import nn, Tensor
  7. from torchvision.ops.boxes import box_area
  8. from ..utils import _log_api_usage_once
  9. from .roi_align import roi_align
  10. # copying result_idx_in_level to a specific index in result[]
  11. # is not supported by ONNX tracing yet.
  12. # _onnx_merge_levels() is an implementation supported by ONNX
  13. # that merges the levels to the right indices
  14. @torch.jit.unused
  15. def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
  16. first_result = unmerged_results[0]
  17. dtype, device = first_result.dtype, first_result.device
  18. res = torch.zeros(
  19. (levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device
  20. )
  21. for level in range(len(unmerged_results)):
  22. index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
  23. index = index.expand(
  24. index.size(0),
  25. unmerged_results[level].size(1),
  26. unmerged_results[level].size(2),
  27. unmerged_results[level].size(3),
  28. )
  29. res = res.scatter(0, index, unmerged_results[level])
  30. return res
  31. # TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
  32. def initLevelMapper(
  33. k_min: int,
  34. k_max: int,
  35. canonical_scale: int = 224,
  36. canonical_level: int = 4,
  37. eps: float = 1e-6,
  38. ):
  39. return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
  40. class LevelMapper:
  41. """Determine which FPN level each RoI in a set of RoIs should map to based
  42. on the heuristic in the FPN paper.
  43. Args:
  44. k_min (int)
  45. k_max (int)
  46. canonical_scale (int)
  47. canonical_level (int)
  48. eps (float)
  49. """
  50. def __init__(
  51. self,
  52. k_min: int,
  53. k_max: int,
  54. canonical_scale: int = 224,
  55. canonical_level: int = 4,
  56. eps: float = 1e-6,
  57. ):
  58. self.k_min = k_min
  59. self.k_max = k_max
  60. self.s0 = canonical_scale
  61. self.lvl0 = canonical_level
  62. self.eps = eps
  63. def __call__(self, boxlists: List[Tensor]) -> Tensor:
  64. """
  65. Args:
  66. boxlists (list[BoxList])
  67. """
  68. # Compute level ids
  69. s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))
  70. # Eqn.(1) in FPN paper
  71. target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
  72. target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
  73. return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
  74. def _convert_to_roi_format(boxes: List[Tensor]) -> Tensor:
  75. concat_boxes = torch.cat(boxes, dim=0)
  76. device, dtype = concat_boxes.device, concat_boxes.dtype
  77. ids = torch.cat(
  78. [torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device) for i, b in enumerate(boxes)],
  79. dim=0,
  80. )
  81. rois = torch.cat([ids, concat_boxes], dim=1)
  82. return rois
  83. def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
  84. # assumption: the scale is of the form 2 ** (-k), with k integer
  85. size = feature.shape[-2:]
  86. possible_scales: List[float] = []
  87. for s1, s2 in zip(size, original_size):
  88. approx_scale = float(s1) / float(s2)
  89. scale = 2 ** float(torch.tensor(approx_scale).log2().round())
  90. possible_scales.append(scale)
  91. return possible_scales[0]
  92. @torch.fx.wrap
  93. def _setup_scales(
  94. features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int
  95. ) -> Tuple[List[float], LevelMapper]:
  96. if not image_shapes:
  97. raise ValueError("images list should not be empty")
  98. max_x = 0
  99. max_y = 0
  100. for shape in image_shapes:
  101. max_x = max(shape[0], max_x)
  102. max_y = max(shape[1], max_y)
  103. original_input_shape = (max_x, max_y)
  104. scales = [_infer_scale(feat, original_input_shape) for feat in features]
  105. # get the levels in the feature map by leveraging the fact that the network always
  106. # downsamples by a factor of 2 at each level.
  107. lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
  108. lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
  109. map_levels = initLevelMapper(
  110. int(lvl_min),
  111. int(lvl_max),
  112. canonical_scale=canonical_scale,
  113. canonical_level=canonical_level,
  114. )
  115. return scales, map_levels
  116. @torch.fx.wrap
  117. def _filter_input(x: Dict[str, Tensor], featmap_names: List[str]) -> List[Tensor]:
  118. x_filtered = []
  119. for k, v in x.items():
  120. if k in featmap_names:
  121. x_filtered.append(v)
  122. return x_filtered
  123. @torch.fx.wrap
  124. def _multiscale_roi_align(
  125. x_filtered: List[Tensor],
  126. boxes: List[Tensor],
  127. output_size: List[int],
  128. sampling_ratio: int,
  129. scales: Optional[List[float]],
  130. mapper: Optional[LevelMapper],
  131. ) -> Tensor:
  132. """
  133. Args:
  134. x_filtered (List[Tensor]): List of input tensors.
  135. boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
  136. (x1, y1, x2, y2) format and in the image reference size, not the feature map
  137. reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  138. output_size (Union[List[Tuple[int, int]], List[int]]): size of the output
  139. sampling_ratio (int): sampling ratio for ROIAlign
  140. scales (Optional[List[float]]): If None, scales will be automatically infered. Default value is None.
  141. mapper (Optional[LevelMapper]): If none, mapper will be automatically infered. Default value is None.
  142. Returns:
  143. result (Tensor)
  144. """
  145. if scales is None or mapper is None:
  146. raise ValueError("scales and mapper should not be None")
  147. num_levels = len(x_filtered)
  148. rois = _convert_to_roi_format(boxes)
  149. if num_levels == 1:
  150. return roi_align(
  151. x_filtered[0],
  152. rois,
  153. output_size=output_size,
  154. spatial_scale=scales[0],
  155. sampling_ratio=sampling_ratio,
  156. )
  157. levels = mapper(boxes)
  158. num_rois = len(rois)
  159. num_channels = x_filtered[0].shape[1]
  160. dtype, device = x_filtered[0].dtype, x_filtered[0].device
  161. result = torch.zeros(
  162. (
  163. num_rois,
  164. num_channels,
  165. )
  166. + output_size,
  167. dtype=dtype,
  168. device=device,
  169. )
  170. tracing_results = []
  171. for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
  172. idx_in_level = torch.where(levels == level)[0]
  173. rois_per_level = rois[idx_in_level]
  174. result_idx_in_level = roi_align(
  175. per_level_feature,
  176. rois_per_level,
  177. output_size=output_size,
  178. spatial_scale=scale,
  179. sampling_ratio=sampling_ratio,
  180. )
  181. if torchvision._is_tracing():
  182. tracing_results.append(result_idx_in_level.to(dtype))
  183. else:
  184. # result and result_idx_in_level's dtypes are based on dtypes of different
  185. # elements in x_filtered. x_filtered contains tensors output by different
  186. # layers. When autocast is active, it may choose different dtypes for
  187. # different layers' outputs. Therefore, we defensively match result's dtype
  188. # before copying elements from result_idx_in_level in the following op.
  189. # We need to cast manually (can't rely on autocast to cast for us) because
  190. # the op acts on result in-place, and autocast only affects out-of-place ops.
  191. result[idx_in_level] = result_idx_in_level.to(result.dtype)
  192. if torchvision._is_tracing():
  193. result = _onnx_merge_levels(levels, tracing_results)
  194. return result
  195. class MultiScaleRoIAlign(nn.Module):
  196. """
  197. Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
  198. It infers the scale of the pooling via the heuristics specified in eq. 1
  199. of the `Feature Pyramid Network paper <https://arxiv.org/abs/1612.03144>`_.
  200. They keyword-only parameters ``canonical_scale`` and ``canonical_level``
  201. correspond respectively to ``224`` and ``k0=4`` in eq. 1, and
  202. have the following meaning: ``canonical_level`` is the target level of the pyramid from
  203. which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``.
  204. Args:
  205. featmap_names (List[str]): the names of the feature maps that will be used
  206. for the pooling.
  207. output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
  208. sampling_ratio (int): sampling ratio for ROIAlign
  209. canonical_scale (int, optional): canonical_scale for LevelMapper
  210. canonical_level (int, optional): canonical_level for LevelMapper
  211. Examples::
  212. >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
  213. >>> i = OrderedDict()
  214. >>> i['feat1'] = torch.rand(1, 5, 64, 64)
  215. >>> i['feat2'] = torch.rand(1, 5, 32, 32) # this feature won't be used in the pooling
  216. >>> i['feat3'] = torch.rand(1, 5, 16, 16)
  217. >>> # create some random bounding boxes
  218. >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
  219. >>> # original image size, before computing the feature maps
  220. >>> image_sizes = [(512, 512)]
  221. >>> output = m(i, [boxes], image_sizes)
  222. >>> print(output.shape)
  223. >>> torch.Size([6, 5, 3, 3])
  224. """
  225. __annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]}
  226. def __init__(
  227. self,
  228. featmap_names: List[str],
  229. output_size: Union[int, Tuple[int], List[int]],
  230. sampling_ratio: int,
  231. *,
  232. canonical_scale: int = 224,
  233. canonical_level: int = 4,
  234. ):
  235. super().__init__()
  236. _log_api_usage_once(self)
  237. if isinstance(output_size, int):
  238. output_size = (output_size, output_size)
  239. self.featmap_names = featmap_names
  240. self.sampling_ratio = sampling_ratio
  241. self.output_size = tuple(output_size)
  242. self.scales = None
  243. self.map_levels = None
  244. self.canonical_scale = canonical_scale
  245. self.canonical_level = canonical_level
  246. def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
  247. warnings.warn("The 'convert_to_roi_format' method is deprecated since 0.12 and will be removed in 0.14.")
  248. return _convert_to_roi_format(boxes)
  249. def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
  250. warnings.warn("The 'infer_scale' method is deprecated since 0.12 and will be removed in 0.14.")
  251. return _infer_scale(feature, original_size)
  252. def setup_setup_scales(
  253. self,
  254. features: List[Tensor],
  255. image_shapes: List[Tuple[int, int]],
  256. ) -> None:
  257. warnings.warn("The 'setup_setup_scales' method is deprecated since 0.12 and will be removed in 0.14.")
  258. self.scales, self.map_levels = _setup_scales(features, image_shapes, self.canonical_scale, self.canonical_level)
  259. def forward(
  260. self,
  261. x: Dict[str, Tensor],
  262. boxes: List[Tensor],
  263. image_shapes: List[Tuple[int, int]],
  264. ) -> Tensor:
  265. """
  266. Args:
  267. x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
  268. all the same number of channels, but they can have different sizes.
  269. boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
  270. (x1, y1, x2, y2) format and in the image reference size, not the feature map
  271. reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  272. image_shapes (List[Tuple[height, width]]): the sizes of each image before they
  273. have been fed to a CNN to obtain feature maps. This allows us to infer the
  274. scale factor for each one of the levels to be pooled.
  275. Returns:
  276. result (Tensor)
  277. """
  278. x_filtered = _filter_input(x, self.featmap_names)
  279. if self.scales is None or self.map_levels is None:
  280. self.scales, self.map_levels = _setup_scales(
  281. x_filtered, image_shapes, self.canonical_scale, self.canonical_level
  282. )
  283. return _multiscale_roi_align(
  284. x_filtered,
  285. boxes,
  286. self.output_size,
  287. self.sampling_ratio,
  288. self.scales,
  289. self.map_levels,
  290. )
  291. def __repr__(self) -> str:
  292. return (
  293. f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
  294. f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})"
  295. )