diou_loss.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from typing import Tuple
  2. import torch
  3. from ..utils import _log_api_usage_once
  4. from ._utils import _loss_inter_union, _upcast_non_float
  5. def distance_box_iou_loss(
  6. boxes1: torch.Tensor,
  7. boxes2: torch.Tensor,
  8. reduction: str = "none",
  9. eps: float = 1e-7,
  10. ) -> torch.Tensor:
  11. """
  12. Gradient-friendly IoU loss with an additional penalty that is non-zero when the
  13. distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
  14. boxes, the distance IoU is the same as the IoU loss.
  15. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
  16. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
  17. ``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
  18. same dimensions.
  19. Args:
  20. boxes1 (Tensor[N, 4]): first set of boxes
  21. boxes2 (Tensor[N, 4]): second set of boxes
  22. reduction (string, optional): Specifies the reduction to apply to the output:
  23. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
  24. applied to the output. ``'mean'``: The output will be averaged.
  25. ``'sum'``: The output will be summed. Default: ``'none'``
  26. eps (float, optional): small number to prevent division by zero. Default: 1e-7
  27. Returns:
  28. Tensor: Loss tensor with the reduction option applied.
  29. Reference:
  30. Zhaohui Zheng et. al: Distance Intersection over Union Loss:
  31. https://arxiv.org/abs/1911.08287
  32. """
  33. # Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
  34. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  35. _log_api_usage_once(distance_box_iou_loss)
  36. boxes1 = _upcast_non_float(boxes1)
  37. boxes2 = _upcast_non_float(boxes2)
  38. loss, _ = _diou_iou_loss(boxes1, boxes2, eps)
  39. if reduction == "mean":
  40. loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
  41. elif reduction == "sum":
  42. loss = loss.sum()
  43. return loss
  44. def _diou_iou_loss(
  45. boxes1: torch.Tensor,
  46. boxes2: torch.Tensor,
  47. eps: float = 1e-7,
  48. ) -> Tuple[torch.Tensor, torch.Tensor]:
  49. intsct, union = _loss_inter_union(boxes1, boxes2)
  50. iou = intsct / (union + eps)
  51. # smallest enclosing box
  52. x1, y1, x2, y2 = boxes1.unbind(dim=-1)
  53. x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
  54. xc1 = torch.min(x1, x1g)
  55. yc1 = torch.min(y1, y1g)
  56. xc2 = torch.max(x2, x2g)
  57. yc2 = torch.max(y2, y2g)
  58. # The diagonal distance of the smallest enclosing box squared
  59. diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
  60. # centers of boxes
  61. x_p = (x2 + x1) / 2
  62. y_p = (y2 + y1) / 2
  63. x_g = (x1g + x2g) / 2
  64. y_g = (y1g + y2g) / 2
  65. # The distance between boxes' centers squared.
  66. centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
  67. # The distance IoU is the IoU penalized by a normalized
  68. # distance between boxes' centers squared.
  69. loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
  70. return loss, iou