streams.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import ctypes
  2. import torch
  3. from ._utils import _dummy_type
  4. if not hasattr(torch._C, '_CudaStreamBase'):
  5. # Define dummy base classes
  6. torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase')
  7. torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase')
  8. class Stream(torch._C._CudaStreamBase):
  9. r"""Wrapper around a CUDA stream.
  10. A CUDA stream is a linear sequence of execution that belongs to a specific
  11. device, independent from other streams. See :ref:`cuda-semantics` for
  12. details.
  13. Args:
  14. device(torch.device or int, optional): a device on which to allocate
  15. the stream. If :attr:`device` is ``None`` (default) or a negative
  16. integer, this will use the current device.
  17. priority(int, optional): priority of the stream. Can be either
  18. -1 (high priority) or 0 (low priority). By default, streams have
  19. priority 0.
  20. .. note:: Although CUDA versions >= 11 support more than two levels of
  21. priorities, in PyTorch, we only support two levels of priorities.
  22. """
  23. def __new__(cls, device=None, priority=0, **kwargs):
  24. with torch.cuda.device(device):
  25. return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
  26. def wait_event(self, event):
  27. r"""Makes all future work submitted to the stream wait for an event.
  28. Args:
  29. event (torch.cuda.Event): an event to wait for.
  30. .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
  31. `CUDA Stream documentation`_ for more info.
  32. This function returns without waiting for :attr:`event`: only future
  33. operations are affected.
  34. .. _CUDA Stream documentation:
  35. https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
  36. """
  37. event.wait(self)
  38. def wait_stream(self, stream):
  39. r"""Synchronizes with another stream.
  40. All future work submitted to this stream will wait until all kernels
  41. submitted to a given stream at the time of call complete.
  42. Args:
  43. stream (Stream): a stream to synchronize.
  44. .. note:: This function returns without waiting for currently enqueued
  45. kernels in :attr:`stream`: only future operations are affected.
  46. """
  47. self.wait_event(stream.record_event())
  48. def record_event(self, event=None):
  49. r"""Records an event.
  50. Args:
  51. event (torch.cuda.Event, optional): event to record. If not given, a new one
  52. will be allocated.
  53. Returns:
  54. Recorded event.
  55. """
  56. if event is None:
  57. event = Event()
  58. event.record(self)
  59. return event
  60. def query(self):
  61. r"""Checks if all the work submitted has been completed.
  62. Returns:
  63. A boolean indicating if all kernels in this stream are completed."""
  64. return super(Stream, self).query()
  65. def synchronize(self):
  66. r"""Wait for all the kernels in this stream to complete.
  67. .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
  68. `CUDA Stream documentation`_ for more info.
  69. """
  70. super(Stream, self).synchronize()
  71. @property
  72. def _as_parameter_(self):
  73. return ctypes.c_void_p(self.cuda_stream)
  74. def __eq__(self, o):
  75. if isinstance(o, Stream):
  76. return super(Stream, self).__eq__(o)
  77. return False
  78. def __hash__(self):
  79. return hash((self.cuda_stream, self.device))
  80. def __repr__(self):
  81. return ('<torch.cuda.Stream device={0} cuda_stream={1:#x}>'
  82. .format(self.device, self.cuda_stream))
  83. class ExternalStream(Stream):
  84. r"""Wrapper around an externally allocated CUDA stream.
  85. This class is used to wrap streams allocated in other libraries in order
  86. to facilitate data exchange and multi-library interactions.
  87. .. note:: This class doesn't manage the stream life-cycle, it is the user
  88. responsibility to keep the referenced stream alive while this class is
  89. being used.
  90. Args:
  91. stream_ptr(int): Integer representation of the `cudaStream_t` value.
  92. allocated externally.
  93. device(torch.device or int, optional): the device where the stream
  94. was originally allocated. if device is specified incorrectly,
  95. subsequent launches using this stream may fail.
  96. """
  97. def __new__(cls, stream_ptr, device=None, **kwargs):
  98. with torch.cuda.device(device):
  99. return super(Stream, cls).__new__(cls, stream_ptr=stream_ptr, **kwargs)
  100. class Event(torch._C._CudaEventBase):
  101. r"""Wrapper around a CUDA event.
  102. CUDA events are synchronization markers that can be used to monitor the
  103. device's progress, to accurately measure timing, and to synchronize CUDA
  104. streams.
  105. The underlying CUDA events are lazily initialized when the event is first
  106. recorded or exported to another process. After creation, only streams on the
  107. same device may record the event. However, streams on any device can wait on
  108. the event.
  109. Args:
  110. enable_timing (bool, optional): indicates if the event should measure time
  111. (default: ``False``)
  112. blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
  113. interprocess (bool): if ``True``, the event can be shared between processes
  114. (default: ``False``)
  115. .. _CUDA Event Documentation:
  116. https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
  117. """
  118. def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
  119. return super(Event, cls).__new__(
  120. cls,
  121. enable_timing=enable_timing, blocking=blocking, interprocess=interprocess)
  122. @classmethod
  123. def from_ipc_handle(cls, device, handle):
  124. r"""Reconstruct an event from an IPC handle on the given device."""
  125. return super(Event, cls).from_ipc_handle(device, handle)
  126. def record(self, stream=None):
  127. r"""Records the event in a given stream.
  128. Uses ``torch.cuda.current_stream()`` if no stream is specified. The
  129. stream's device must match the event's device."""
  130. if stream is None:
  131. stream = torch.cuda.current_stream()
  132. super(Event, self).record(stream)
  133. def wait(self, stream=None):
  134. r"""Makes all future work submitted to the given stream wait for this
  135. event.
  136. Use ``torch.cuda.current_stream()`` if no stream is specified.
  137. .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
  138. `CUDA Event documentation`_ for more info.
  139. """
  140. if stream is None:
  141. stream = torch.cuda.current_stream()
  142. super(Event, self).wait(stream)
  143. def query(self):
  144. r"""Checks if all work currently captured by event has completed.
  145. Returns:
  146. A boolean indicating if all work currently captured by event has
  147. completed.
  148. """
  149. return super(Event, self).query()
  150. def elapsed_time(self, end_event):
  151. r"""Returns the time elapsed in milliseconds after the event was
  152. recorded and before the end_event was recorded.
  153. """
  154. return super(Event, self).elapsed_time(end_event)
  155. def synchronize(self):
  156. r"""Waits for the event to complete.
  157. Waits until the completion of all work currently captured in this event.
  158. This prevents the CPU thread from proceeding until the event completes.
  159. .. note:: This is a wrapper around ``cudaEventSynchronize()``: see
  160. `CUDA Event documentation`_ for more info.
  161. """
  162. super(Event, self).synchronize()
  163. def ipc_handle(self):
  164. r"""Returns an IPC handle of this event. If not recorded yet, the event
  165. will use the current device. """
  166. return super(Event, self).ipc_handle()
  167. @property
  168. def _as_parameter_(self):
  169. return ctypes.c_void_p(self.cuda_event)
  170. def __repr__(self):
  171. if self.cuda_event:
  172. return '<torch.cuda.Event {0:#x}>'.format(self._as_parameter_.value)
  173. else:
  174. return '<torch.cuda.Event uninitialized>'