checkpoint.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import torch
  2. import warnings
  3. from typing import Any, Dict, Iterable, List, Optional, Tuple
  4. def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
  5. if isinstance(inputs, tuple):
  6. out = []
  7. for inp in inputs:
  8. if not isinstance(inp, torch.Tensor):
  9. out.append(inp)
  10. continue
  11. x = inp.detach()
  12. x.requires_grad = inp.requires_grad
  13. out.append(x)
  14. return tuple(out)
  15. else:
  16. raise RuntimeError(
  17. "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
  18. def check_backward_validity(inputs: Iterable[Any]) -> None:
  19. if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
  20. warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
  21. # We can't know if the run_fn will internally move some args to different devices,
  22. # which would require logic to preserve rng states for those devices as well.
  23. # We could paranoically stash and restore ALL the rng states for all visible devices,
  24. # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
  25. # the device of all Tensor args.
  26. #
  27. # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
  28. def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
  29. # This will not error out if "arg" is a CPU tensor or a non-tensor type because
  30. # the conditionals short-circuit.
  31. fwd_gpu_devices = list(set(arg.get_device() for arg in args
  32. if isinstance(arg, torch.Tensor) and arg.is_cuda))
  33. fwd_gpu_states = []
  34. for device in fwd_gpu_devices:
  35. with torch.cuda.device(device):
  36. fwd_gpu_states.append(torch.cuda.get_rng_state())
  37. return fwd_gpu_devices, fwd_gpu_states
  38. def set_device_states(devices, states) -> None:
  39. for device, state in zip(devices, states):
  40. with torch.cuda.device(device):
  41. torch.cuda.set_rng_state(state)
  42. class CheckpointFunction(torch.autograd.Function):
  43. @staticmethod
  44. def forward(ctx, run_function, preserve_rng_state, *args):
  45. check_backward_validity(args)
  46. ctx.run_function = run_function
  47. ctx.preserve_rng_state = preserve_rng_state
  48. # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
  49. ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
  50. "dtype": torch.get_autocast_gpu_dtype(),
  51. "cache_enabled": torch.is_autocast_cache_enabled()}
  52. ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
  53. "dtype": torch.get_autocast_cpu_dtype(),
  54. "cache_enabled": torch.is_autocast_cache_enabled()}
  55. if preserve_rng_state:
  56. ctx.fwd_cpu_state = torch.get_rng_state()
  57. # Don't eagerly initialize the cuda context by accident.
  58. # (If the user intends that the context is initialized later, within their
  59. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  60. # we have no way to anticipate this will happen before we run the function.)
  61. ctx.had_cuda_in_fwd = False
  62. if torch.cuda._initialized:
  63. ctx.had_cuda_in_fwd = True
  64. ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
  65. # Save non-tensor inputs in ctx, keep a placeholder None for tensors
  66. # to be filled out during the backward.
  67. ctx.inputs = []
  68. ctx.tensor_indices = []
  69. tensor_inputs = []
  70. for i, arg in enumerate(args):
  71. if torch.is_tensor(arg):
  72. tensor_inputs.append(arg)
  73. ctx.tensor_indices.append(i)
  74. ctx.inputs.append(None)
  75. else:
  76. ctx.inputs.append(arg)
  77. ctx.save_for_backward(*tensor_inputs)
  78. with torch.no_grad():
  79. outputs = run_function(*args)
  80. return outputs
  81. @staticmethod
  82. def backward(ctx, *args):
  83. if not torch.autograd._is_checkpoint_valid():
  84. raise RuntimeError(
  85. "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
  86. " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
  87. " argument.")
  88. # Copy the list to avoid modifying original list.
  89. inputs = list(ctx.inputs)
  90. tensor_indices = ctx.tensor_indices
  91. tensors = ctx.saved_tensors
  92. # Fill in inputs with appropriate saved tensors.
  93. for i, idx in enumerate(tensor_indices):
  94. inputs[idx] = tensors[i]
  95. # Stash the surrounding rng state, and mimic the state that was
  96. # present at this time during forward. Restore the surrounding state
  97. # when we're done.
  98. rng_devices = []
  99. if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
  100. rng_devices = ctx.fwd_gpu_devices
  101. with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
  102. if ctx.preserve_rng_state:
  103. torch.set_rng_state(ctx.fwd_cpu_state)
  104. if ctx.had_cuda_in_fwd:
  105. set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
  106. detached_inputs = detach_variable(tuple(inputs))
  107. with torch.enable_grad(), \
  108. torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  109. torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
  110. outputs = ctx.run_function(*detached_inputs)
  111. if isinstance(outputs, torch.Tensor):
  112. outputs = (outputs,)
  113. # run backward() with only tensor that requires grad
  114. outputs_with_grad = []
  115. args_with_grad = []
  116. for i in range(len(outputs)):
  117. if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
  118. outputs_with_grad.append(outputs[i])
  119. args_with_grad.append(args[i])
  120. if len(outputs_with_grad) == 0:
  121. raise RuntimeError(
  122. "none of output has requires_grad=True,"
  123. " this checkpoint() is not necessary")
  124. torch.autograd.backward(outputs_with_grad, args_with_grad)
  125. grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
  126. for inp in detached_inputs)
  127. return (None, None) + grads
  128. def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
  129. r"""Checkpoint a model or part of the model
  130. Checkpointing works by trading compute for memory. Rather than storing all
  131. intermediate activations of the entire computation graph for computing
  132. backward, the checkpointed part does **not** save intermediate activations,
  133. and instead recomputes them in backward pass. It can be applied on any part
  134. of a model.
  135. Specifically, in the forward pass, :attr:`function` will run in
  136. :func:`torch.no_grad` manner, i.e., not storing the intermediate
  137. activations. Instead, the forward pass saves the inputs tuple and the
  138. :attr:`function` parameter. In the backwards pass, the saved inputs and
  139. :attr:`function` is retrieved, and the forward pass is computed on
  140. :attr:`function` again, now tracking the intermediate activations, and then
  141. the gradients are calculated using these activation values.
  142. The output of :attr:`function` can contain non-Tensor values and gradient
  143. recording is only performed for the Tensor values. Note that if the output
  144. consists of nested structures (ex: custom objects, lists, dicts etc.)
  145. consisting of Tensors, these Tensors nested in custom structures will not
  146. be considered as part of autograd.
  147. .. warning::
  148. If :attr:`function` invocation during backward does anything different
  149. than the one during forward, e.g., due to some global variable, the
  150. checkpointed version won't be equivalent, and unfortunately it can't be
  151. detected.
  152. .. warning::
  153. If ``use_reentrant=True`` is specified, then if the checkpointed segment
  154. contains tensors detached from the computational graph by `detach()` or
  155. `torch.no_grad()`, the backward pass will raise an error. This is
  156. because `checkpoint` makes all the outputs require gradients which
  157. causes issues when a tensor is defined to have no gradient in the model.
  158. To circumvent this, detach the tensors outside of the `checkpoint`
  159. function. Note that the checkpointed segment can contain tensors
  160. detached from the computational graph if ``use_reentrant=False`` is
  161. specified.
  162. .. warning::
  163. If ``use_reentrant=True`` is specified, at least one of the inputs needs
  164. to have :code:`requires_grad=True` if grads are needed for model inputs,
  165. otherwise the checkpointed part of the model won't have gradients. At
  166. least one of the outputs needs to have :code:`requires_grad=True` as
  167. well. Note that this does not apply if ``use_reentrant=False`` is
  168. specified.
  169. .. warning::
  170. If ``use_reentrant=True`` is specified, checkpointing currently only
  171. supports :func:`torch.autograd.backward` and only if its `inputs`
  172. argument is not passed. :func:`torch.autograd.grad`
  173. is not supported. If ``use_reentrant=False`` is specified, checkpointing
  174. will work with :func:`torch.autograd.grad`.
  175. Args:
  176. function: describes what to run in the forward pass of the model or
  177. part of the model. It should also know how to handle the inputs
  178. passed as the tuple. For example, in LSTM, if user passes
  179. ``(activation, hidden)``, :attr:`function` should correctly use the
  180. first input as ``activation`` and the second input as ``hidden``
  181. preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
  182. the RNG state during each checkpoint.
  183. use_reentrant(bool, optional, default=True): Use checkpointing
  184. implementation that requires re-entrant autograd.
  185. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
  186. implementation that does not require re-entrant autograd. This
  187. allows ``checkpoint`` to support additional functionality, such as
  188. working as expected with ``torch.autograd.grad``. Note that future
  189. versions of PyTorch will default to ``use_reentrant=False``.
  190. args: tuple containing inputs to the :attr:`function`
  191. Returns:
  192. Output of running :attr:`function` on :attr:`*args`
  193. """
  194. # Hack to mix *args with **kwargs in a python 2.7-compliant way
  195. preserve = kwargs.pop('preserve_rng_state', True)
  196. if kwargs:
  197. raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
  198. if use_reentrant:
  199. return CheckpointFunction.apply(function, preserve, *args)
  200. else:
  201. return _checkpoint_without_reentrant(
  202. function,
  203. preserve,
  204. *args
  205. )
  206. def checkpoint_sequential(functions, segments, input, **kwargs):
  207. r"""A helper function for checkpointing sequential models.
  208. Sequential models execute a list of modules/functions in order
  209. (sequentially). Therefore, we can divide such a model in various segments
  210. and checkpoint each segment. All segments except the last will run in
  211. :func:`torch.no_grad` manner, i.e., not storing the intermediate
  212. activations. The inputs of each checkpointed segment will be saved for
  213. re-running the segment in the backward pass.
  214. See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
  215. .. warning::
  216. Checkpointing currently only supports :func:`torch.autograd.backward`
  217. and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
  218. is not supported.
  219. .. warning:
  220. At least one of the inputs needs to have :code:`requires_grad=True` if
  221. grads are needed for model inputs, otherwise the checkpointed part of the
  222. model won't have gradients.
  223. .. warning:
  224. Since PyTorch 1.4, it allows only one Tensor as the input and
  225. intermediate outputs, just like :class:`torch.nn.Sequential`.
  226. Args:
  227. functions: A :class:`torch.nn.Sequential` or the list of modules or
  228. functions (comprising the model) to run sequentially.
  229. segments: Number of chunks to create in the model
  230. input: A Tensor that is input to :attr:`functions`
  231. preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
  232. the RNG state during each checkpoint.
  233. Returns:
  234. Output of running :attr:`functions` sequentially on :attr:`*inputs`
  235. Example:
  236. >>> model = nn.Sequential(...)
  237. >>> input_var = checkpoint_sequential(model, chunks, input_var)
  238. """
  239. # Hack for keyword-only parameter in a python 2.7-compliant way
  240. preserve = kwargs.pop('preserve_rng_state', True)
  241. if kwargs:
  242. raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
  243. def run_function(start, end, functions):
  244. def forward(input):
  245. for j in range(start, end + 1):
  246. input = functions[j](input)
  247. return input
  248. return forward
  249. if isinstance(functions, torch.nn.Sequential):
  250. functions = list(functions.children())
  251. segment_size = len(functions) // segments
  252. # the last chunk has to be non-volatile
  253. end = -1
  254. for start in range(0, segment_size * (segments - 1), segment_size):
  255. end = start + segment_size - 1
  256. input = checkpoint(run_function(start, end, functions), input,
  257. preserve_rng_state=preserve)
  258. return run_function(end + 1, len(functions) - 1, functions)(input)
  259. def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args):
  260. """Checkpointining without re-entrant autograd
  261. Args:
  262. function: describes what to run in the forward pass of the model or
  263. part of the model. It should also know how to handle the inputs
  264. passed as the tuple. For example, in LSTM, if user passes
  265. ``(activation, hidden)``, :attr:`function` should correctly use the
  266. first input as ``activation`` and the second input as ``hidden``
  267. preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
  268. the RNG state during each checkpoint.
  269. *args: Arguments to pass in to the given ``function``.
  270. """
  271. had_autocast_in_fwd = torch.is_autocast_enabled()
  272. if preserve_rng_state:
  273. fwd_cpu_state = torch.get_rng_state()
  274. # Don't eagerly initialize the cuda context by accident.
  275. # (If the user intends that the context is initialized later, within their
  276. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  277. # we have no way to anticipate this will happen before we run the function.
  278. # If they do so, we raise an error.)
  279. had_cuda_in_fwd = False
  280. if torch.cuda._initialized:
  281. had_cuda_in_fwd = True
  282. fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
  283. storage: Dict[int, Optional[torch.Tensor]] = {}
  284. counter = 0
  285. def pack(x):
  286. nonlocal counter
  287. counter += 1
  288. # TODO(varal7): Instead of returning indices, we can return things metadata (such as
  289. # size, device, ...) to catch certain cases of undeterministic behavior of the forward
  290. return counter - 1
  291. def unpack(x):
  292. unpack_counter = 0
  293. if len(storage) == 0:
  294. def inner_pack(inner):
  295. nonlocal unpack_counter
  296. storage[unpack_counter] = inner
  297. unpack_counter += 1
  298. return None
  299. def inner_unpack(packed):
  300. raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
  301. # Stash the surrounding rng state, and mimic the state that was
  302. # present at this time during forward. Restore the surrounding state
  303. # when we're done.
  304. rng_devices = []
  305. if preserve_rng_state and had_cuda_in_fwd:
  306. rng_devices = fwd_gpu_devices
  307. with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
  308. if preserve_rng_state:
  309. torch.set_rng_state(fwd_cpu_state)
  310. if had_cuda_in_fwd:
  311. set_device_states(fwd_gpu_devices, fwd_gpu_states)
  312. with torch.enable_grad(), torch.cuda.amp.autocast(had_autocast_in_fwd):
  313. with torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
  314. _unused = function(*args)
  315. if x not in storage:
  316. raise RuntimeError(
  317. "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
  318. " recomputation being triggered in between, this is not currently supported. Please"
  319. " open an issue with details on your use case so that we can prioritize adding this."
  320. )
  321. return storage.pop(x)
  322. with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
  323. output = function(*args)
  324. if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
  325. # Cuda was not initialized before running the forward, so we didn't
  326. # stash the CUDA state.
  327. raise RuntimeError(
  328. "PyTorch's CUDA state was initialized in the forward pass "
  329. "of a Checkpoint, which is not allowed. Please open an issue "
  330. "if you need this feature.")
  331. return output