hooks.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import torch
  2. from collections import OrderedDict
  3. import weakref
  4. import warnings
  5. import functools
  6. from typing import Any
  7. class RemovableHandle(object):
  8. """A handle which provides the capability to remove a hook."""
  9. id: int
  10. next_id: int = 0
  11. def __init__(self, hooks_dict: Any) -> None:
  12. self.hooks_dict_ref = weakref.ref(hooks_dict)
  13. self.id = RemovableHandle.next_id
  14. RemovableHandle.next_id += 1
  15. def remove(self) -> None:
  16. hooks_dict = self.hooks_dict_ref()
  17. if hooks_dict is not None and self.id in hooks_dict:
  18. del hooks_dict[self.id]
  19. def __getstate__(self):
  20. return (self.hooks_dict_ref(), self.id)
  21. def __setstate__(self, state) -> None:
  22. if state[0] is None:
  23. # create a dead reference
  24. self.hooks_dict_ref = weakref.ref(OrderedDict())
  25. else:
  26. self.hooks_dict_ref = weakref.ref(state[0])
  27. self.id = state[1]
  28. RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
  29. def __enter__(self) -> 'RemovableHandle':
  30. return self
  31. def __exit__(self, type: Any, value: Any, tb: Any) -> None:
  32. self.remove()
  33. def unserializable_hook(f):
  34. """
  35. Decorator which marks a function as an unserializable hook.
  36. This suppresses warnings that would otherwise arise if you attempt
  37. to serialize a tensor that has a hook.
  38. """
  39. f.__torch_unserializable__ = True
  40. return f
  41. def warn_if_has_hooks(tensor):
  42. if tensor._backward_hooks:
  43. for k in tensor._backward_hooks:
  44. hook = tensor._backward_hooks[k]
  45. if not hasattr(k, "__torch_unserializable__"):
  46. warnings.warn("backward hook {} on tensor will not be "
  47. "serialized. If this is expected, you can "
  48. "decorate the function with @torch.utils.hooks.unserializable_hook "
  49. "to suppress this warning".format(repr(hook)))
  50. class BackwardHook(object):
  51. """
  52. A wrapper class to implement nn.Module backward hooks.
  53. It handles:
  54. - Ignoring non-Tensor inputs and replacing them by None before calling the user hook
  55. - Generating the proper Node to capture a set of Tensor's gradients
  56. - Linking the gradients captures for the outputs with the gradients captured for the input
  57. - Calling the user hook once both output and input gradients are available
  58. """
  59. def __init__(self, module, user_hooks):
  60. self.user_hooks = user_hooks
  61. self.module = module
  62. self.grad_outputs = None
  63. self.n_outputs = -1
  64. self.output_tensors_index = None
  65. self.n_inputs = -1
  66. self.input_tensors_index = None
  67. def _pack_with_none(self, indices, values, size):
  68. res = [None] * size
  69. for idx, val in zip(indices, values):
  70. res[idx] = val
  71. return tuple(res)
  72. def _unpack_none(self, indices, values):
  73. res = []
  74. for idx in indices:
  75. res.append(values[idx])
  76. return tuple(res)
  77. def _set_user_hook(self, grad_fn, user_hook):
  78. @functools.wraps(user_hook)
  79. def hook(grad_input, _):
  80. if self.grad_outputs is None:
  81. raise RuntimeError("Module backward hook for grad_input is called before "
  82. "the grad_output one. This happens because the gradient "
  83. "in your nn.Module flows to the Module's input without "
  84. "passing through the Module's output. Make sure that the "
  85. "output depends on the input and that the loss is computed "
  86. "based on the output.")
  87. grad_input = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
  88. res = user_hook(self.module, grad_input, self.grad_outputs)
  89. if res is None:
  90. return res
  91. if len(res) != len(grad_input):
  92. raise RuntimeError("Backward hook returned an invalid number of grad_input, "
  93. "got {}, but expected {}".format(len(res), len(grad_input)))
  94. return self._unpack_none(self.input_tensors_index, res)
  95. grad_fn.register_hook(hook)
  96. def _apply_on_tensors(self, fn, args):
  97. # Can be used to apply the given function to the tensors contained in the
  98. # args. Will return updated args and the tensors indices
  99. tensors_idx = []
  100. tensors = []
  101. requires_grad = False
  102. for i, arg in enumerate(args):
  103. if isinstance(arg, torch.Tensor):
  104. tensors_idx.append(i)
  105. tensors.append(arg)
  106. requires_grad |= arg.requires_grad
  107. if not (requires_grad and torch.is_grad_enabled()):
  108. return args, None
  109. new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
  110. if len(new_tensors) == 0:
  111. raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
  112. grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"]
  113. if len(grad_fns) == 0:
  114. raise RuntimeError("Error while setting up backward hooks. Please open "
  115. "an issue with a code sample to reproduce this.")
  116. fn(grad_fns[0])
  117. arg_list = list(args)
  118. for idx, val in zip(tensors_idx, new_tensors):
  119. arg_list[idx] = val
  120. return tuple(arg_list), tensors_idx
  121. def setup_input_hook(self, args):
  122. def fn(grad_fn):
  123. for hook in self.user_hooks:
  124. self._set_user_hook(grad_fn, hook)
  125. res, input_idx = self._apply_on_tensors(fn, args)
  126. self.n_inputs = len(args)
  127. self.input_tensors_index = input_idx
  128. return res
  129. def setup_output_hook(self, args):
  130. def fn(grad_fn):
  131. def hook(_, grad_output):
  132. self.grad_outputs = self._pack_with_none(self.output_tensors_index,
  133. grad_output,
  134. self.n_outputs)
  135. # Special case if no input required gradients, this hook should call the user
  136. # hook directly
  137. if self.input_tensors_index is None:
  138. grad_inputs = self._pack_with_none([], [], self.n_inputs)
  139. for user_hook in self.user_hooks:
  140. res = user_hook(self.module, grad_inputs, self.grad_outputs)
  141. if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)):
  142. raise RuntimeError("Backward hook for Modules where no input requires "
  143. "gradient should always return None or None for all gradients.")
  144. grad_fn.register_hook(hook)
  145. is_tuple = True
  146. if not isinstance(args, tuple):
  147. args = (args,)
  148. is_tuple = False
  149. res, output_idx = self._apply_on_tensors(fn, args)
  150. self.n_outputs = len(args)
  151. self.output_tensors_index = output_idx
  152. if not is_tuple:
  153. res = res[0]
  154. return res