functions.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import functools
  2. def async_execution(fn):
  3. r"""
  4. A decorator for a function indicating that the return value of the function
  5. is guaranteed to be a :class:`~torch.futures.Future` object and this
  6. function can run asynchronously on the RPC callee. More specifically, the
  7. callee extracts the :class:`~torch.futures.Future` returned by the wrapped
  8. function and installs subsequent processing steps as a callback to that
  9. :class:`~torch.futures.Future`. The installed callback will read the value
  10. from the :class:`~torch.futures.Future` when completed and send the
  11. value back as the RPC response. That also means the returned
  12. :class:`~torch.futures.Future` only exists on the callee side and is never
  13. sent through RPC. This decorator is useful when the wrapped function's
  14. (``fn``) execution needs to pause and resume due to, e.g., containing
  15. :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
  16. .. note:: To enable asynchronous execution, applications must pass the
  17. function object returned by this decorator to RPC APIs. If RPC detected
  18. attributes installed by this decorator, it knows that this function
  19. returns a ``Future`` object and will handle that accordingly.
  20. However, this does not mean this decorator has to be outmost one when
  21. defining a function. For example, when combined with ``@staticmethod``
  22. or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
  23. inner decorator to allow the target function be recognized as a static
  24. or class function. This target function can still execute asynchronously
  25. because, when accessed, the static or class method preserves attributes
  26. installed by ``@rpc.functions.async_execution``.
  27. Example::
  28. The returned :class:`~torch.futures.Future` object can come from
  29. :meth:`~torch.distributed.rpc.rpc_async`,
  30. :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
  31. constructor. The example below shows directly using the
  32. :class:`~torch.futures.Future` returned by
  33. :meth:`~torch.futures.Future.then`.
  34. >>> from torch.distributed import rpc
  35. >>>
  36. >>> # omitting setup and shutdown RPC
  37. >>>
  38. >>> # On all workers
  39. >>> @rpc.functions.async_execution
  40. >>> def async_add_chained(to, x, y, z):
  41. >>> # This function runs on "worker1" and returns immediately when
  42. >>> # the callback is installed through the `then(cb)` API. In the
  43. >>> # mean time, the `rpc_async` to "worker2" can run concurrently.
  44. >>> # When the return value of that `rpc_async` arrives at
  45. >>> # "worker1", "worker1" will run the lambda function accordingly
  46. >>> # and set the value for the previously returned `Future`, which
  47. >>> # will then trigger RPC to send the result back to "worker0".
  48. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  49. >>> lambda fut: fut.wait() + z
  50. >>> )
  51. >>>
  52. >>> # On worker0
  53. >>> ret = rpc.rpc_sync(
  54. >>> "worker1",
  55. >>> async_add_chained,
  56. >>> args=("worker2", torch.ones(2), 1, 1)
  57. >>> )
  58. >>> print(ret) # prints tensor([3., 3.])
  59. When combined with TorchScript decorators, this decorator must be the
  60. outmost one.
  61. >>> from torch import Tensor
  62. >>> from torch.futures import Future
  63. >>> from torch.distributed import rpc
  64. >>>
  65. >>> # omitting setup and shutdown RPC
  66. >>>
  67. >>> # On all workers
  68. >>> @torch.jit.script
  69. >>> def script_add(x: Tensor, y: Tensor) -> Tensor:
  70. >>> return x + y
  71. >>>
  72. >>> @rpc.functions.async_execution
  73. >>> @torch.jit.script
  74. >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
  75. >>> return rpc.rpc_async(to, script_add, (x, y))
  76. >>>
  77. >>> # On worker0
  78. >>> ret = rpc.rpc_sync(
  79. >>> "worker1",
  80. >>> async_add,
  81. >>> args=("worker2", torch.ones(2), 1)
  82. >>> )
  83. >>> print(ret) # prints tensor([2., 2.])
  84. When combined with static or class method, this decorator must be the
  85. inner one.
  86. >>> from torch.distributed import rpc
  87. >>>
  88. >>> # omitting setup and shutdown RPC
  89. >>>
  90. >>> # On all workers
  91. >>> class AsyncExecutionClass:
  92. >>>
  93. >>> @staticmethod
  94. >>> @rpc.functions.async_execution
  95. >>> def static_async_add(to, x, y, z):
  96. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  97. >>> lambda fut: fut.wait() + z
  98. >>> )
  99. >>>
  100. >>> @classmethod
  101. >>> @rpc.functions.async_execution
  102. >>> def class_async_add(cls, to, x, y, z):
  103. >>> ret_fut = torch.futures.Future()
  104. >>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
  105. >>> lambda fut: ret_fut.set_result(fut.wait() + z)
  106. >>> )
  107. >>> return ret_fut
  108. >>>
  109. >>> @rpc.functions.async_execution
  110. >>> def bound_async_add(self, to, x, y, z):
  111. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  112. >>> lambda fut: fut.wait() + z
  113. >>> )
  114. >>>
  115. >>> # On worker0
  116. >>> ret = rpc.rpc_sync(
  117. >>> "worker1",
  118. >>> AsyncExecutionClass.static_async_add,
  119. >>> args=("worker2", torch.ones(2), 1, 2)
  120. >>> )
  121. >>> print(ret) # prints tensor([4., 4.])
  122. >>>
  123. >>> ret = rpc.rpc_sync(
  124. >>> "worker1",
  125. >>> AsyncExecutionClass.class_async_add,
  126. >>> args=("worker2", torch.ones(2), 1, 2)
  127. >>> )
  128. >>> print(ret) # prints tensor([4., 4.])
  129. This decorator also works with RRef helpers, i.e., .
  130. :meth:`torch.distributed.rpc.RRef.rpc_sync`,
  131. :meth:`torch.distributed.rpc.RRef.rpc_async`, and
  132. :meth:`torch.distributed.rpc.RRef.remote`.
  133. >>> from torch.distributed import rpc
  134. >>>
  135. >>> # reuse the AsyncExecutionClass class above
  136. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  137. >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
  138. >>> print(ret) # prints tensor([4., 4.])
  139. >>>
  140. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  141. >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
  142. >>> print(ret) # prints tensor([4., 4.])
  143. >>>
  144. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  145. >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
  146. >>> print(ret) # prints tensor([4., 4.])
  147. """
  148. @functools.wraps(fn)
  149. def wrapper(*args, **kwargs):
  150. return fn(*args, **kwargs)
  151. # Can't declare and use attributes of function objects (mypy#2087)
  152. wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
  153. return wrapper