graph.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import torch
  2. from typing import Callable, Any
  3. class saved_tensors_hooks():
  4. """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
  5. Use this context-manager to define how intermediary results of an operation
  6. should be packed before saving, and unpacked on retrieval.
  7. In that context, the ``pack_hook`` function will be called everytime an
  8. operation saves a tensor for backward (this includes intermediary results
  9. saved using
  10. :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
  11. also those recorded by a PyTorch-defined operation). The output of
  12. ``pack_hook`` is then stored in the computation graph instead of the
  13. original tensor.
  14. The ``unpack_hook`` is called when the saved tensor needs to be accessed,
  15. namely when executing :func:`torch.Tensor.backward()` or
  16. :func:`torch.autograd.grad()`. It takes as argument the *packed* object
  17. returned by ``pack_hook`` and should return a tensor which has the same
  18. content as the original tensor (passed as input to the corresponding
  19. ``pack_hook``).
  20. The hooks should have the following signatures:
  21. pack_hook(tensor: Tensor) -> Any
  22. unpack_hook(Any) -> Tensor
  23. where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
  24. In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
  25. of value, size, dtype and device.
  26. Example::
  27. >>> def pack_hook(x):
  28. ... print("Packing", x)
  29. ... return x
  30. >>>
  31. >>> def unpack_hook(x):
  32. ... print("Unpacking", x)
  33. ... return x
  34. >>>
  35. >>> a = torch.ones(5, requires_grad=True)
  36. >>> b = torch.ones(5, requires_grad=True) * 2
  37. >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
  38. ... y = a * b
  39. Packing tensor([1., 1., 1., 1., 1.])
  40. Packing tensor([2., 2., 2., 2., 2.])
  41. >>> y.sum().backward()
  42. Unpacking tensor([1., 1., 1., 1., 1.])
  43. Unpacking tensor([2., 2., 2., 2., 2.])
  44. .. warning ::
  45. Performing an inplace operation on the input to either hooks may lead
  46. to undefined behavior.
  47. .. warning ::
  48. Only one pair of hooks is allowed at a time. When recursively nesting this
  49. context-manager, only the inner-most pair of hooks will be applied.
  50. """
  51. def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]):
  52. self.pack_hook = pack_hook
  53. self.unpack_hook = unpack_hook
  54. def __enter__(self):
  55. torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
  56. def __exit__(self, *args: Any):
  57. torch._C._autograd._pop_saved_tensors_default_hooks()
  58. class save_on_cpu(saved_tensors_hooks):
  59. """Context-manager under which tensors saved by the forward pass will be
  60. stored on cpu, then retrieved for backward.
  61. When performing operations within this context manager, intermediary
  62. results saved in the graph during the forward pass will be moved to CPU,
  63. then copied back to the original device when needed for the backward pass.
  64. If the graph was already on CPU, no tensor copy is performed.
  65. Use this context-manager to trade compute for GPU memory usage (e.g.
  66. when your model doesn't fit in GPU memory during training).
  67. Args:
  68. pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
  69. during packing and copied to GPU asynchronously during unpacking.
  70. Defaults to ``False``.
  71. Also see :ref:`cuda-memory-pinning`.
  72. Example::
  73. >>> a = torch.randn(5, requires_grad=True, device="cuda")
  74. >>> b = torch.randn(5, requires_grad=True, device="cuda")
  75. >>> c = torch.randn(5, requires_grad=True, device="cuda")
  76. >>>
  77. >>> def f(a, b, c):
  78. ... prod_1 = a * b # a and b are saved on GPU
  79. ... with torch.autograd.graph.save_on_cpu():
  80. ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
  81. ... y = prod_2 * a # prod_2 and a are saved on GPU
  82. ... return y
  83. >>>
  84. >>> y = f(a, b, c)
  85. >>> del a, b, c # for illustration only
  86. >>> # the content of a, b, and prod_2 are still alive on GPU
  87. >>> # the content of prod_1 and c only live on CPU
  88. >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
  89. >>> # all intermediary tensors are released (deleted) after the call to backward
  90. """
  91. def __init__(self, pin_memory=False):
  92. def pack_to_cpu(tensor):
  93. if not pin_memory:
  94. return (tensor.device, tensor.cpu())
  95. packed = torch.empty(
  96. tensor.size(),
  97. dtype=tensor.dtype,
  98. layout=tensor.layout,
  99. pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
  100. packed.copy_(tensor)
  101. return (tensor.device, packed)
  102. def unpack_from_cpu(packed):
  103. device, tensor = packed
  104. return tensor.to(device, non_blocking=pin_memory)
  105. super().__init__(pack_to_cpu, unpack_from_cpu)