graphs.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. import gc
  2. import torch
  3. from ._utils import _dummy_type
  4. if not hasattr(torch._C, '_CudaStreamBase'):
  5. # Define dummy base classes
  6. torch._C.__dict__['_CUDAGraph'] = _dummy_type('_CUDAGraph')
  7. torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle')
  8. torch._C.__dict__['_cuda_isCurrentStreamCapturing'] = _dummy_type('_cuda_isCurrentStreamCapturing')
  9. from torch._C import _CUDAGraph # noqa: F401
  10. from torch._C import _graph_pool_handle
  11. from torch._C import _cuda_isCurrentStreamCapturing
  12. def is_current_stream_capturing():
  13. r"""
  14. Returns True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
  15. If a CUDA context does not exist on the current device, returns False without initializing the context.
  16. """
  17. return _cuda_isCurrentStreamCapturing()
  18. # Python shim helps Sphinx process docstrings more reliably.
  19. def graph_pool_handle():
  20. r"""
  21. Returns an opaque token representing the id of a graph memory pool.
  22. See :ref:`Graph memory management<graph-memory-management>`.
  23. .. warning::
  24. This API is in beta and may change in future releases.
  25. """
  26. return _graph_pool_handle()
  27. # Python shim helps Sphinx process docstrings more reliably.
  28. class CUDAGraph(torch._C._CUDAGraph):
  29. r"""
  30. Wrapper around a CUDA graph.
  31. .. warning::
  32. This API is in beta and may change in future releases.
  33. """
  34. def __new__(cls):
  35. return super(CUDAGraph, cls).__new__(cls)
  36. def __init__(self):
  37. super(CUDAGraph, self).__init__()
  38. def capture_begin(self, pool=None):
  39. r"""
  40. Begins capturing CUDA work on the current stream.
  41. Typically, you shouldn't call ``capture_begin`` yourself.
  42. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  43. which call ``capture_begin`` internally.
  44. Arguments:
  45. pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
  46. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
  47. with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
  48. """
  49. # I'm not sure if pybind11 converts a None arg to the default defined on the C++ side,
  50. # so I'm not taking any chances.
  51. if pool is None:
  52. super(CUDAGraph, self).capture_begin()
  53. else:
  54. super(CUDAGraph, self).capture_begin(pool)
  55. def capture_end(self):
  56. r"""
  57. Ends CUDA graph capture on the current stream.
  58. After ``capture_end``, ``replay`` may be called on this instance.
  59. Typically, you shouldn't call ``capture_end`` yourself.
  60. Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
  61. which call ``capture_end`` internally.
  62. """
  63. super(CUDAGraph, self).capture_end()
  64. def replay(self):
  65. r"""
  66. Replays the CUDA work captured by this graph.
  67. """
  68. super(CUDAGraph, self).replay()
  69. def reset(self):
  70. r"""
  71. Deletes the graph currently held by this instance.
  72. """
  73. super(CUDAGraph, self).reset()
  74. def pool(self):
  75. r"""
  76. Returns an opaque token representing the id of this graph's memory pool.
  77. This id can optionally be passed to another graph's ``capture_begin``,
  78. which hints the other graph may share the same memory pool.
  79. """
  80. return super(CUDAGraph, self).pool()
  81. class graph(object):
  82. r"""
  83. Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph`
  84. object for later replay.
  85. See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
  86. detailed use, and constraints.
  87. Arguments:
  88. cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
  89. pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
  90. :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
  91. may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
  92. stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
  93. If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
  94. .. note::
  95. For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
  96. used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
  97. .. warning::
  98. This API is in beta and may change in future releases.
  99. """
  100. default_capture_stream = None
  101. def __init__(self,
  102. cuda_graph,
  103. pool=None,
  104. stream=None):
  105. # Lazy-init of default_capture_stream helps avoid circular-import errors.
  106. # Not thread safe, but graphs already have the general (explicitly documented)
  107. # restriction that only one capture may be underway at a time in the process.
  108. if self.__class__.default_capture_stream is None:
  109. self.__class__.default_capture_stream = torch.cuda.Stream()
  110. self.pool = () if pool is None else (pool,)
  111. self.capture_stream = stream if stream is not None else self.__class__.default_capture_stream
  112. assert self.capture_stream is not None
  113. self.stream_ctx = torch.cuda.stream(self.capture_stream)
  114. self.cuda_graph = cuda_graph
  115. def __enter__(self):
  116. # Free as much memory as we can for the graph
  117. torch.cuda.synchronize()
  118. gc.collect()
  119. torch.cuda.empty_cache()
  120. # Stackoverflow seems comfortable with this pattern
  121. # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
  122. self.stream_ctx.__enter__()
  123. self.cuda_graph.capture_begin(*self.pool)
  124. def __exit__(self, exc_type, exc_value, traceback):
  125. self.cuda_graph.capture_end()
  126. self.stream_ctx.__exit__(exc_type, exc_value, traceback)
  127. # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
  128. def make_graphed_callables(callables, sample_args):
  129. r"""
  130. Accepts callables (functions or :class:`nn.Module<torch.nn.Module>`\ s)
  131. and returns graphed versions.
  132. Each graphed callable's forward pass runs its source callable's
  133. forward CUDA work as a CUDA graph inside a single autograd node.
  134. The graphed callable's forward pass also appends
  135. a backward node to the autograd graph. During backward, this node runs the
  136. callable's backward work as a CUDA graph.
  137. Therefore, each graphed callable should be a drop-in replacement for its source callable
  138. in an autograd-enabled training loop.
  139. See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
  140. If you pass a tuple of several callables, their captures will use the same memory pool.
  141. See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
  142. Arguments:
  143. callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
  144. See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
  145. is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
  146. they'll run in the live workload.
  147. sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
  148. If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
  149. If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
  150. .. note::
  151. The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
  152. that's expected for the corresponding real input in the training loop.
  153. .. warning::
  154. This API is in beta and may change in future releases.
  155. .. warning::
  156. ``sample_args`` for each callable must be a tuple of Tensors. Other types and keyword args
  157. are not allowed.
  158. .. warning::
  159. Returned callables do not support higher order differentiation (e.g., double backward).
  160. .. warning::
  161. In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
  162. may be trainable. Buffers must have ``requires_grad=False``.
  163. .. warning::
  164. After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
  165. you may not add or remove any of that Module's parameters or buffers.
  166. .. warning::
  167. :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
  168. registered on them at the time they are passed. However, registering hooks on modules *after* passing them
  169. through :func:`~torch.cuda.make_graphed_callables` is allowed.
  170. .. warning::
  171. When running a graphed callable, you must pass its arguments in the same order and format
  172. they appeared in that callable's ``sample_args``.
  173. .. warning::
  174. All Tensor outputs of graphed callables must require grad.
  175. """
  176. just_one_callable = False
  177. if not isinstance(callables, tuple):
  178. just_one_callable = True
  179. callables = (callables,)
  180. sample_args = (sample_args,)
  181. for c, args in zip(callables, sample_args):
  182. if isinstance(c, torch.nn.Module):
  183. assert len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0, \
  184. "Modules must not have hooks registered at the time they are passed. However, registering hooks " + \
  185. "on modules after passing them through make_graphed_callables is allowed."
  186. assert all(b.requires_grad is False for b in c.buffers()), "In any :class:`~torch.nn.Module` passed to " + \
  187. ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + \
  188. "``requires_grad=False``."
  189. assert all(isinstance(arg, torch.Tensor) for arg in args), "In the beta API, sample_args " + \
  190. "for each callable must be a tuple of Tensors. Other types and keyword args are not allowed."
  191. # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
  192. # passes to forward (ie, its sample_args) AND the module's parameter attributes.
  193. per_callable_len_user_args = [len(args) for args in sample_args]
  194. per_callable_module_params = [tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
  195. for c in callables]
  196. per_callable_static_input_surfaces = [sample_args[i] + per_callable_module_params[i]
  197. for i in range(len(callables))]
  198. fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  199. bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
  200. mempool = graph_pool_handle()
  201. # Warmup
  202. # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
  203. # from ending up in any captures.
  204. torch.cuda.synchronize()
  205. with torch.cuda.stream(torch.cuda.Stream()):
  206. for func, args, static_input_surface in zip(callables,
  207. sample_args,
  208. per_callable_static_input_surfaces):
  209. for _ in range(3):
  210. outputs = func(*args)
  211. outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
  212. grad_inputs = torch.autograd.grad(outputs=outputs,
  213. inputs=tuple(i for i in static_input_surface if i.requires_grad),
  214. grad_outputs=tuple(torch.empty_like(o) for o in outputs),
  215. only_inputs=True,
  216. allow_unused=False)
  217. del outputs, grad_inputs
  218. torch.cuda.synchronize()
  219. # All captures here share a mempool. To avoid replays corrupting each other's memory,
  220. # the safest approach is to capture all passes in the same order they'll run:
  221. # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
  222. # Capture forward graphs
  223. per_callable_static_outputs = []
  224. per_callable_output_was_tensor = []
  225. for func, args, fwd_graph in zip(callables,
  226. sample_args,
  227. fwd_graphs):
  228. with torch.cuda.graph(fwd_graph, pool=mempool):
  229. outputs = func(*args)
  230. # Assumes model output is a tensor or tuple of tensors
  231. if isinstance(outputs, torch.Tensor):
  232. per_callable_output_was_tensor.append(True)
  233. outputs = (outputs,)
  234. else:
  235. per_callable_output_was_tensor.append(False)
  236. per_callable_static_outputs.append(outputs)
  237. # Capture backward graphs in reverse order
  238. per_callable_static_grad_outputs = []
  239. per_callable_static_grad_inputs = []
  240. for static_input_surface, static_outputs, bwd_graph, module_params in \
  241. zip(reversed(per_callable_static_input_surfaces),
  242. reversed(per_callable_static_outputs),
  243. reversed(bwd_graphs),
  244. reversed(per_callable_module_params)):
  245. # For now, assumes all static_outputs require grad
  246. assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
  247. static_grad_outputs = tuple(torch.empty_like(o) for o in static_outputs)
  248. with torch.cuda.graph(bwd_graph, pool=mempool):
  249. grad_inputs = torch.autograd.grad(outputs=static_outputs,
  250. inputs=tuple(i for i in static_input_surface if i.requires_grad),
  251. grad_outputs=static_grad_outputs,
  252. only_inputs=True,
  253. allow_unused=False)
  254. # Constructs a tuple suitable for returning from Graphed.backward:
  255. # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
  256. # I couldn't think of a slick one-liner for this pattern.
  257. static_grad_inputs = []
  258. grad_idx = 0
  259. for arg in static_input_surface:
  260. if arg.requires_grad:
  261. static_grad_inputs.append(grad_inputs[grad_idx])
  262. grad_idx += 1
  263. else:
  264. static_grad_inputs.append(None) # type: ignore[arg-type]
  265. static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
  266. per_callable_static_grad_outputs.append(static_grad_outputs)
  267. per_callable_static_grad_inputs.append(static_grad_inputs)
  268. # Reverses the most recent two lists
  269. per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
  270. per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
  271. # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
  272. def make_graphed_autograd_function(fwd_graph,
  273. bwd_graph,
  274. module_params,
  275. len_user_args,
  276. output_was_tensor,
  277. static_input_surface,
  278. static_outputs,
  279. static_grad_outputs,
  280. static_grad_inputs):
  281. class Graphed(torch.autograd.Function):
  282. @staticmethod
  283. def forward(ctx, *inputs):
  284. # At this stage, only the user args may (potentially) be new tensors.
  285. for i in range(len_user_args):
  286. if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
  287. static_input_surface[i].copy_(inputs[i])
  288. fwd_graph.replay()
  289. assert isinstance(static_outputs, tuple)
  290. return tuple(o.detach() for o in static_outputs)
  291. @staticmethod
  292. @torch.autograd.function.once_differentiable
  293. def backward(ctx, *grads):
  294. for g, grad in zip(static_grad_outputs, grads):
  295. if g is None:
  296. assert grad is None
  297. else:
  298. # don't copy if autograd gods have been kind and the
  299. # incoming grad is already in the right place
  300. if g.data_ptr() != grad.data_ptr():
  301. g.copy_(grad)
  302. bwd_graph.replay()
  303. # Input args that didn't require grad expect a None gradient.
  304. assert isinstance(static_grad_inputs, tuple)
  305. return tuple(b.detach() if b is not None else b for b in static_grad_inputs)
  306. def functionalized(*user_args):
  307. # Runs the autograd function with inputs == all inputs to the graph that might require grad
  308. # (explicit user args + module parameters)
  309. # Assumes module params didn't change since capture.
  310. out = Graphed.apply(*(user_args + module_params))
  311. return out[0] if output_was_tensor else out
  312. return functionalized
  313. # Put together the final graphed callables
  314. ret = []
  315. for i, func in enumerate(callables):
  316. graphed = make_graphed_autograd_function(fwd_graphs[i],
  317. bwd_graphs[i],
  318. per_callable_module_params[i],
  319. per_callable_len_user_args[i],
  320. per_callable_output_was_tensor[i],
  321. per_callable_static_input_surfaces[i],
  322. per_callable_static_outputs[i],
  323. per_callable_static_grad_outputs[i],
  324. per_callable_static_grad_inputs[i])
  325. if isinstance(func, torch.nn.Module):
  326. def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
  327. def new_fwd(*user_args):
  328. # If the module's training-or-eval state matches what we graphed,
  329. # run the graph, otherwise run the original forward method
  330. if func.training == graph_training_state:
  331. return graphed(*user_args)
  332. else:
  333. return orig_fwd(*user_args)
  334. return new_fwd
  335. func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
  336. ret.append(func)
  337. else:
  338. ret.append(graphed)
  339. if just_one_callable:
  340. return ret[0]
  341. return tuple(ret)