anomaly_mode.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch
  2. import warnings
  3. from typing import Any
  4. __all__ = ["detect_anomaly", "set_detect_anomaly"]
  5. class detect_anomaly(object):
  6. r"""Context-manager that enable anomaly detection for the autograd engine.
  7. This does two things:
  8. - Running the forward pass with detection enabled will allow the backward
  9. pass to print the traceback of the forward operation that created the failing
  10. backward function.
  11. - Any backward computation that generate "nan" value will raise an error.
  12. .. warning::
  13. This mode should be enabled only for debugging as the different tests
  14. will slow down your program execution.
  15. Example:
  16. >>> import torch
  17. >>> from torch import autograd
  18. >>> class MyFunc(autograd.Function):
  19. ... @staticmethod
  20. ... def forward(ctx, inp):
  21. ... return inp.clone()
  22. ... @staticmethod
  23. ... def backward(ctx, gO):
  24. ... # Error during the backward pass
  25. ... raise RuntimeError("Some error in backward")
  26. ... return gO.clone()
  27. >>> def run_fn(a):
  28. ... out = MyFunc.apply(a)
  29. ... return out.sum()
  30. >>> inp = torch.rand(10, 10, requires_grad=True)
  31. >>> out = run_fn(inp)
  32. >>> out.backward()
  33. Traceback (most recent call last):
  34. File "<stdin>", line 1, in <module>
  35. File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
  36. torch.autograd.backward(self, gradient, retain_graph, create_graph)
  37. File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
  38. allow_unreachable=True) # allow_unreachable flag
  39. File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
  40. return self._forward_cls.backward(self, *args)
  41. File "<stdin>", line 8, in backward
  42. RuntimeError: Some error in backward
  43. >>> with autograd.detect_anomaly():
  44. ... inp = torch.rand(10, 10, requires_grad=True)
  45. ... out = run_fn(inp)
  46. ... out.backward()
  47. Traceback of forward call that caused the error:
  48. File "tmp.py", line 53, in <module>
  49. out = run_fn(inp)
  50. File "tmp.py", line 44, in run_fn
  51. out = MyFunc.apply(a)
  52. Traceback (most recent call last):
  53. File "<stdin>", line 4, in <module>
  54. File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
  55. torch.autograd.backward(self, gradient, retain_graph, create_graph)
  56. File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
  57. allow_unreachable=True) # allow_unreachable flag
  58. File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
  59. return self._forward_cls.backward(self, *args)
  60. File "<stdin>", line 8, in backward
  61. RuntimeError: Some error in backward
  62. """
  63. def __init__(self) -> None:
  64. self.prev = torch.is_anomaly_enabled()
  65. warnings.warn('Anomaly Detection has been enabled. '
  66. 'This mode will increase the runtime '
  67. 'and should only be enabled for debugging.', stacklevel=2)
  68. def __enter__(self) -> None:
  69. torch.set_anomaly_enabled(True)
  70. def __exit__(self, *args: Any) -> None:
  71. torch.set_anomaly_enabled(self.prev)
  72. class set_detect_anomaly(object):
  73. r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
  74. ``set_detect_anomaly`` will enable or disable the autograd anomaly detection
  75. based on its argument :attr:`mode`.
  76. It can be used as a context-manager or as a function.
  77. See ``detect_anomaly`` above for details of the anomaly detection behaviour.
  78. Args:
  79. mode (bool): Flag whether to enable anomaly detection (``True``),
  80. or disable (``False``).
  81. """
  82. def __init__(self, mode: bool) -> None:
  83. self.prev = torch.is_anomaly_enabled()
  84. torch.set_anomaly_enabled(mode)
  85. def __enter__(self) -> None:
  86. pass
  87. def __exit__(self, *args: Any) -> None:
  88. torch.set_anomaly_enabled(self.prev)