api.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. __all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync",
  2. "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"]
  3. import collections
  4. import contextlib
  5. import functools
  6. import inspect
  7. import logging
  8. import threading
  9. from typing import Dict, Generic, TypeVar, Set, Any
  10. import torch
  11. from torch.futures import Future
  12. from torch._C._distributed_rpc import (
  13. PyRRef,
  14. RemoteProfilerManager,
  15. WorkerInfo,
  16. TensorPipeAgent,
  17. get_rpc_timeout,
  18. _cleanup_python_rpc_handler,
  19. _delete_all_user_and_unforked_owner_rrefs,
  20. _destroy_rref_context,
  21. _get_current_rpc_agent,
  22. _invoke_remote_builtin,
  23. _invoke_remote_python_udf,
  24. _invoke_remote_torchscript,
  25. _invoke_rpc_builtin,
  26. _invoke_rpc_python_udf,
  27. _invoke_rpc_torchscript,
  28. _is_current_rpc_agent_set,
  29. _reset_current_rpc_agent,
  30. _set_and_start_rpc_agent,
  31. )
  32. from .internal import (
  33. PythonUDF,
  34. RPCExecMode,
  35. _internal_rpc_pickler,
  36. _build_rpc_profiling_key,
  37. )
  38. from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
  39. from ._utils import _group_membership_management, _update_group_membership
  40. logger = logging.getLogger(__name__)
  41. # NB: Ignoring RRef leaks during shutdown. Without this, applications have to
  42. # make sure there is no references to any RRef in the application code and
  43. # Python GC has done its job to delete those RRefs. This is could result in bad
  44. # debugging experiences especially when for large applications. Therefore, by
  45. # default, we are going to ignore RRef leaks during shutdown. This is usually
  46. # fine as shutdown means applications have done training and no longer care
  47. # about states.
  48. #
  49. # To enable RRef leak checking, set this _ignore_rref_leak to False
  50. _ignore_rref_leak = True
  51. _default_pickler = _internal_rpc_pickler
  52. @contextlib.contextmanager
  53. def _use_rpc_pickler(rpc_pickler):
  54. r"""
  55. rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler
  56. """
  57. global _default_pickler
  58. _default_pickler = rpc_pickler
  59. try:
  60. yield
  61. finally:
  62. _default_pickler = _internal_rpc_pickler
  63. def _require_initialized(func):
  64. @functools.wraps(func)
  65. def wrapper(*args, **kwargs):
  66. if not _is_current_rpc_agent_set():
  67. raise RuntimeError(
  68. "RPC has not been initialized. Call "
  69. "torch.distributed.rpc.init_rpc first."
  70. )
  71. return func(*args, **kwargs)
  72. return wrapper
  73. class AllGatherStates(object):
  74. def __init__(self):
  75. # Each `gathered_objects` is an empty dict at beginning.
  76. # The leader worker is elected as the first worker in a sorted worker
  77. # name list. Whenever there is a worker entering `_all_gather()`, it
  78. # runs `_gather_to_leader()` on the leader to add its own name and
  79. # data obj to this dict. The leader also adds itself's name to the dict
  80. # on calling `_all_gather()`.
  81. # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader
  82. # will broadcast the gathered dict to all follower workers and set their
  83. # `gathered_objects` field and the `proceed_signal` field.
  84. self.gathered_objects = {}
  85. # All workers wait on this signal until it receives all gathered
  86. # objects.
  87. self.proceed_signal = threading.Event()
  88. # States used by `def _all_gather()`.
  89. # `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer.
  90. _ALL_WORKER_NAMES: Set[Any] = set()
  91. _all_gather_dict_lock = threading.RLock()
  92. _all_gather_sequence_id: Dict[str, int] = {}
  93. _all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates)
  94. def _init_rpc_states(agent):
  95. worker_infos = agent.get_worker_infos()
  96. global _ALL_WORKER_NAMES
  97. _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
  98. # NB: backend implementation might have already set the rpc_agent.
  99. if not _is_current_rpc_agent_set():
  100. _set_and_start_rpc_agent(agent)
  101. def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
  102. with _all_gather_dict_lock:
  103. if not worker_names:
  104. worker_names = _ALL_WORKER_NAMES
  105. assert (
  106. worker_name in worker_names
  107. ), f"{worker_name} is not expected by leader."
  108. states = _all_gather_sequence_id_to_states[sequence_id]
  109. assert (
  110. worker_name not in states.gathered_objects
  111. ), f"{worker_name} reported intent sequence id {sequence_id} twice. "
  112. states.gathered_objects[worker_name] = obj
  113. if worker_names == set(states.gathered_objects.keys()):
  114. states.proceed_signal.set()
  115. def _broadcast_to_followers(sequence_id, objects_map):
  116. with _all_gather_dict_lock:
  117. states = _all_gather_sequence_id_to_states[sequence_id]
  118. assert (
  119. not states.proceed_signal.is_set()
  120. ), "Termination signal sequence id {} got set twice.".format(sequence_id)
  121. states.gathered_objects = objects_map
  122. states.proceed_signal.set()
  123. _thread_local_var = threading.local()
  124. @contextlib.contextmanager
  125. def _wait_all():
  126. r"""
  127. A context manager that collects all futures returned by ``rpc_async`` and
  128. waits them on the context manager's exit; relieving the user of needing
  129. to explicitly call wait.
  130. Example::
  131. >>> # On worker 0:
  132. >>> import torch
  133. >>> import torch.distributed.rpc as rpc
  134. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  135. >>> with rpc._wait_all():
  136. >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
  137. >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
  138. >>> #fut_1 and fut_2 are waited on
  139. """
  140. _thread_local_var.future_list = []
  141. try:
  142. yield
  143. finally:
  144. try:
  145. torch.futures.wait_all(_thread_local_var.future_list)
  146. finally:
  147. del _thread_local_var.future_list
  148. @_require_initialized
  149. def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
  150. r"""
  151. This is similar to torch.distributed.all_gather(), but is using RPC. It
  152. picks the worker with the smallest name (alphabetic order) as the leader.
  153. Then all followers send their data ``obj`` to the leader. After the leader
  154. has received all, it will broadcast the results back to all followers. This
  155. function blocks until all workers have received the gathered results.
  156. """
  157. if not worker_names:
  158. assert (
  159. _ALL_WORKER_NAMES is not None
  160. ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
  161. worker_names = _ALL_WORKER_NAMES
  162. leader_name = sorted(worker_names)[0]
  163. self_name = _get_current_rpc_agent().get_worker_info().name
  164. with _all_gather_dict_lock:
  165. concat_names = "".join(sorted(worker_names))
  166. sequence_num = _all_gather_sequence_id.get(concat_names, 0)
  167. _all_gather_sequence_id[concat_names] = sequence_num + 1
  168. sequence_id = concat_names + str(sequence_num)
  169. is_leader = leader_name == self_name
  170. if timeout == UNSET_RPC_TIMEOUT:
  171. # Timeout is specified by agent for RPC calls
  172. rpc_timeout = get_rpc_timeout()
  173. # No timeout for signal
  174. signal_timeout = None
  175. elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
  176. # No timeout for RPC
  177. rpc_timeout = timeout
  178. # No timeout for signal
  179. signal_timeout = None
  180. else:
  181. # Signal and RPC timeout use the same timeout
  182. signal_timeout = rpc_timeout = timeout
  183. # Phase 1: Followers send it's object to the leader
  184. if is_leader:
  185. _gather_to_leader(sequence_id, self_name, obj, worker_names)
  186. else:
  187. rpc_sync(
  188. leader_name,
  189. _gather_to_leader,
  190. args=(sequence_id, self_name, obj, worker_names),
  191. timeout=rpc_timeout,
  192. )
  193. with _all_gather_dict_lock:
  194. states = _all_gather_sequence_id_to_states[sequence_id]
  195. # Timeout is either set by function parameter or None (which is indefinite)
  196. states.proceed_signal.wait(timeout=signal_timeout)
  197. # Phase 2: Leader broadcast gathered results to all followers
  198. # Leader's signal is the first to be unblocked, after receiving all
  199. # followers' data objects.
  200. if is_leader:
  201. worker_name_to_response_future_dict = dict()
  202. for follower_name in worker_names - {leader_name}:
  203. fut = rpc_async(
  204. follower_name,
  205. _broadcast_to_followers,
  206. args=(sequence_id, states.gathered_objects),
  207. timeout=rpc_timeout
  208. )
  209. worker_name_to_response_future_dict[follower_name] = fut
  210. errors = []
  211. for follower_name, fut in worker_name_to_response_future_dict.items():
  212. try:
  213. fut.wait()
  214. except RuntimeError as ex:
  215. errors.append((follower_name, ex))
  216. if errors:
  217. raise RuntimeError(
  218. f"Followers {[e[0] for e in errors]} timed out in _all_gather "
  219. f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
  220. )
  221. # Clean up for the states using the sequence_id
  222. with _all_gather_dict_lock:
  223. states = _all_gather_sequence_id_to_states.pop(sequence_id)
  224. return states.gathered_objects
  225. @_require_initialized
  226. def _barrier(worker_names):
  227. r"""
  228. Synchronizes local and remote RPC processes.
  229. This will block until all local and remote RPC processes specified under worker_names
  230. reach this method to wait for all outstanding work to complete.
  231. Args:
  232. worker_names (List[str]): The set of workers to synchronize.
  233. """
  234. try:
  235. _all_gather(None, set(worker_names))
  236. except RuntimeError as ex:
  237. logger.error(
  238. f"Failed to complete barrier, got error {ex}"
  239. )
  240. @_require_initialized
  241. def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
  242. r"""
  243. Block until all local and remote RPC processes reach this method and wait
  244. for all outstanding work to complete. Every RPC process must call this
  245. method before exit to perform a graceful shutdown. This should be used to
  246. terminate the RPC framework, and there is no guarantee that the RPC
  247. framework will work after this method returns.
  248. """
  249. try:
  250. _all_gather(None, timeout=timeout)
  251. except RuntimeError as ex:
  252. logger.error(
  253. f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}"
  254. )
  255. raise ex
  256. @_require_initialized
  257. def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
  258. r"""
  259. Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
  260. stops the local agent from accepting outstanding requests, and shuts
  261. down the RPC framework by terminating all RPC threads. If ``graceful=True``,
  262. this will block until all local and remote RPC processes reach this method
  263. and wait for all outstanding work to complete. Otherwise, if
  264. ``graceful=False``, this is a local shutdown, and it does not wait for other
  265. RPC processes to reach this method.
  266. .. warning::
  267. For :class:`~torch.futures.Future` objects returned by
  268. :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
  269. be called after ``shutdown()``.
  270. Args:
  271. graceful (bool): Whether to do a graceful shutdown or not. If True,
  272. this will 1) wait until there is no pending system
  273. messages for ``UserRRefs`` and delete them; 2) block
  274. until all local and remote RPC processes have reached
  275. this method and wait for all outstanding work to
  276. complete.
  277. Example::
  278. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  279. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  280. API for more details. For example,
  281. >>> export MASTER_ADDR=localhost
  282. >>> export MASTER_PORT=5678
  283. Then run the following code in two different processes:
  284. >>> # On worker 0:
  285. >>> import torch
  286. >>> import torch.distributed.rpc as rpc
  287. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  288. >>> # do some work
  289. >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
  290. >>> # ready to shutdown
  291. >>> rpc.shutdown()
  292. >>> # On worker 1:
  293. >>> import torch.distributed.rpc as rpc
  294. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  295. >>> # wait for worker 0 to finish work, and then shutdown.
  296. >>> rpc.shutdown()
  297. """
  298. if graceful:
  299. try:
  300. agent = _get_current_rpc_agent()
  301. if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
  302. _wait_all_workers(timeout)
  303. _delete_all_user_and_unforked_owner_rrefs()
  304. agent.join(shutdown=True, timeout=timeout)
  305. else:
  306. # This is a dynamic group so we need to grab the token for the operation
  307. my_worker_info = agent.get_worker_info()
  308. my_name = my_worker_info.name
  309. with _group_membership_management(agent.store, my_name, False):
  310. all_worker_infos = agent.get_worker_infos()
  311. for worker in all_worker_infos:
  312. if worker.name != my_name:
  313. rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
  314. agent.join(shutdown=True, timeout=timeout)
  315. finally:
  316. # In case of errors, continue to complete the local shutdown.
  317. _finalize_shutdown()
  318. else:
  319. _finalize_shutdown()
  320. def _finalize_shutdown():
  321. try:
  322. # This raises a `TORCH_CHECK()` exception on RRef leak detected.
  323. _destroy_rref_context(_ignore_rref_leak)
  324. finally:
  325. _get_current_rpc_agent().shutdown()
  326. # clean up python rpc handler in shutdown(), see comments in
  327. # PythonRpcHandler::cleanup(), call it in python API because the
  328. # cleanup() function has python dependency, it assumes python
  329. # interpreter exists.
  330. # No matter if RRef leak exception is raised, this clean-up code
  331. # must run to avoid destruction segfault in Python 3.5.
  332. #
  333. # future.wait() should not be called after shutdown().
  334. # pythonRpcHandler is cleaned up in shutdown(), after
  335. # shutdown(), python objects returned from rpc python call can not be
  336. # resolved.
  337. _cleanup_python_rpc_handler()
  338. _reset_current_rpc_agent()
  339. @_require_initialized
  340. def get_worker_info(worker_name=None):
  341. r"""
  342. Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
  343. Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
  344. expensive string on every invocation.
  345. Args:
  346. worker_name (str): the string name of a worker. If ``None``, return the
  347. the id of the current worker. (default ``None``)
  348. Returns:
  349. :class:`~torch.distributed.rpc.WorkerInfo` instance for the given
  350. ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
  351. current worker if ``worker_name`` is ``None``.
  352. """
  353. if worker_name is not None:
  354. return _get_current_rpc_agent().get_worker_info(worker_name)
  355. else:
  356. return _get_current_rpc_agent().get_worker_info()
  357. def _to_worker_info(to):
  358. if isinstance(to, WorkerInfo):
  359. return to
  360. elif isinstance(to, str) or isinstance(to, int):
  361. return get_worker_info(to)
  362. else:
  363. raise ValueError("Cannot get WorkerInfo from name {}".format(to))
  364. def _rref_typeof_on_owner(rref, blocking=True):
  365. rref_type = type(rref.local_value())
  366. if blocking:
  367. return rref_type
  368. else:
  369. # Wrap result into a completed Future. This is so that if blocking=`False`
  370. # is specified, we return a future regardless of if this call is on user
  371. # or owner.
  372. future = Future[type]()
  373. future.set_result(rref_type)
  374. return future
  375. def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
  376. fut = rpc_async(
  377. rref.owner(),
  378. _rref_typeof_on_owner,
  379. args=(rref,),
  380. timeout=timeout
  381. )
  382. if blocking:
  383. return fut.wait()
  384. else:
  385. return fut
  386. T = TypeVar("T")
  387. GenericWithOneTypeVar = Generic[T]
  388. try:
  389. # Combine the implementation class and the type class.
  390. class RRef(PyRRef, Generic[T]):
  391. pass
  392. except TypeError:
  393. # TypeError: metaclass conflict: the metaclass of a derived class
  394. # must be a (non-strict) subclass of the metaclasses of all its bases
  395. # Mypy doesn't understand __class__ (mypy bug #4177)
  396. class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
  397. pass
  398. # Combine the implementation class and the type class.
  399. # Types for classes expecting a certain generic parameter (mypy bug #7791)
  400. class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
  401. pass
  402. # Install docstrings from `PyRRef` to `RRef`.
  403. #
  404. # This is for the fact that pybind11 generates the parameter
  405. # `self` as type `rpc.PyRRef`, so a `:inherited-members:`
  406. # under `.. autoclass:: RRef` does not work.
  407. # we have to do the following process to replacee `rpc.PyRRef` with `rpc.RRef`.
  408. #
  409. def method_factory(method_name, docstring):
  410. def method(self, *args, **kwargs):
  411. return getattr(super(RRef, self), method_name)(*args, **kwargs)
  412. if method.__doc__:
  413. method.__doc__ = docstring
  414. return method
  415. for method_name, method in inspect.getmembers(PyRRef):
  416. # Ignore magic methods, except "__str__".
  417. if method_name.startswith("_") and method_name != "__str__":
  418. continue
  419. # Get pybind11 generated docstring.
  420. # It's like,
  421. """
  422. to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object
  423. Blocking call that copies the value of the RRef from the owner
  424. to the local node and returns it. If the current node is the
  425. owner, returns a reference to the local value.
  426. """
  427. docstring = getattr(method, "__doc__", None)
  428. assert docstring is not None, "RRef user-facing methods should all have docstrings."
  429. # Do surgery on pybind11 generated docstrings.
  430. docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef")
  431. # Attach user-facing RRef method with modified docstring.
  432. new_method = method_factory(method_name, docstring)
  433. setattr(RRef, method_name, new_method)
  434. @_require_initialized
  435. def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
  436. r"""
  437. Make a remote call to run ``func`` on worker ``to`` and return an
  438. :class:`~torch.distributed.rpc.RRef` to the result value immediately.
  439. Worker ``to`` will be the owner of the returned
  440. :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is
  441. a user. The owner manages the global reference count of its
  442. :class:`~torch.distributed.rpc.RRef`, and the owner
  443. :class:`~torch.distributed.rpc.RRef` is only destructed when globally there
  444. are no living references to it.
  445. Args:
  446. to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
  447. func (callable): a callable function, such as Python callables, builtin
  448. operators (e.g. :meth:`~torch.add`) and annotated
  449. TorchScript functions.
  450. args (tuple): the argument tuple for the ``func`` invocation.
  451. kwargs (dict): is a dictionary of keyword arguments for the ``func``
  452. invocation.
  453. timeout (float, optional): timeout in seconds for this remote call. If the
  454. creation of this
  455. :class:`~torch.distributed.rpc.RRef` on worker
  456. ``to`` is not successfully processed on this
  457. worker within this timeout, then the next time
  458. there is an attempt to use the RRef (such as
  459. ``to_here()``), a timeout will be raised
  460. indicating this failure. A value of 0 indicates
  461. an infinite timeout, i.e. a timeout error will
  462. never be raised. If not provided, the default
  463. value set during initialization or with
  464. ``_set_rpc_timeout`` is used.
  465. Returns:
  466. A user :class:`~torch.distributed.rpc.RRef` instance to the result
  467. value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here`
  468. to retrieve the result value locally.
  469. .. warning ::
  470. The ``remote`` API does not copy storages of argument tensors until
  471. sending them over the wire, which could be done by a different thread
  472. depending on the RPC backend type. The caller should make sure that the
  473. contents of those tensors stay intact until the returned RRef is
  474. confirmed by the owner, which can be checked using the
  475. :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API.
  476. .. warning ::
  477. Errors such as timeouts for the ``remote`` API are handled on a
  478. best-effort basis. This means that when remote calls initiated by
  479. ``remote`` fail, such as with a timeout error, we take a best-effort
  480. approach to error handling. This means that errors are handled and set
  481. on the resulting RRef on an asynchronous basis. If the RRef has not been
  482. used by the application before this handling (such as ``to_here`` or
  483. fork call), then future uses of the ``RRef`` will appropriately raise
  484. errors. However, it is possible that the user application will use the
  485. ``RRef`` before the errors are handled. In this case, errors may not be
  486. raised as they have not yet been handled.
  487. Example::
  488. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  489. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  490. API for more details. For example,
  491. >>> export MASTER_ADDR=localhost
  492. >>> export MASTER_PORT=5678
  493. Then run the following code in two different processes:
  494. >>> # On worker 0:
  495. >>> import torch
  496. >>> import torch.distributed.rpc as rpc
  497. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  498. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  499. >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  500. >>> x = rref1.to_here() + rref2.to_here()
  501. >>> rpc.shutdown()
  502. >>> # On worker 1:
  503. >>> import torch.distributed.rpc as rpc
  504. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  505. >>> rpc.shutdown()
  506. Below is an example of running a TorchScript function using RPC.
  507. >>> # On both workers:
  508. >>> @torch.jit.script
  509. >>> def my_script_add(t1, t2):
  510. >>> return torch.add(t1, t2)
  511. >>> # On worker 0:
  512. >>> import torch.distributed.rpc as rpc
  513. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  514. >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))
  515. >>> rref.to_here()
  516. >>> rpc.shutdown()
  517. >>> # On worker 1:
  518. >>> import torch.distributed.rpc as rpc
  519. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  520. >>> rpc.shutdown()
  521. """
  522. torch._C._log_api_usage_once("torch.distributed.rpc_remote")
  523. qualified_name = torch.jit._builtins._find_builtin(func)
  524. dst_worker_info = _to_worker_info(to)
  525. should_profile = _get_should_profile()
  526. ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info)
  527. with ctx_manager as rf:
  528. args = args if args else ()
  529. kwargs = kwargs if kwargs else {}
  530. is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
  531. if is_async_exec:
  532. wrapped = func._wrapped_async_rpc_function
  533. if isinstance(wrapped, torch.jit.ScriptFunction):
  534. func = wrapped
  535. if qualified_name is not None:
  536. rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs)
  537. elif isinstance(func, torch.jit.ScriptFunction):
  538. rref = _invoke_remote_torchscript(
  539. dst_worker_info.name,
  540. torch._jit_internal._qualified_name(func),
  541. timeout,
  542. is_async_exec,
  543. *args,
  544. **kwargs,
  545. )
  546. else:
  547. (pickled_python_udf, tensors) = _default_pickler.serialize(
  548. PythonUDF(func, args, kwargs)
  549. )
  550. rref = _invoke_remote_python_udf(
  551. dst_worker_info,
  552. pickled_python_udf,
  553. tensors,
  554. timeout,
  555. is_async_exec
  556. )
  557. # attach profiling information
  558. if should_profile:
  559. assert torch.autograd._profiler_enabled()
  560. assert rf is not None
  561. fut = rf._call_end_callbacks_on_future(rref._get_future())
  562. rref._set_profiling_future(fut)
  563. return rref
  564. def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RPC_TIMEOUT):
  565. if not callable(func):
  566. raise TypeError("function should be callable.")
  567. qualified_name = torch.jit._builtins._find_builtin(func)
  568. dst_worker_info = _to_worker_info(to)
  569. should_profile = _get_should_profile()
  570. ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info)
  571. with ctx_manager as rf:
  572. args = args if args else ()
  573. kwargs = kwargs if kwargs else {}
  574. is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
  575. if is_async_exec:
  576. wrapped = func._wrapped_async_rpc_function
  577. if isinstance(wrapped, torch.jit.ScriptFunction):
  578. func = wrapped
  579. if qualified_name is not None:
  580. fut = _invoke_rpc_builtin(
  581. dst_worker_info,
  582. qualified_name,
  583. rpc_timeout,
  584. *args,
  585. **kwargs
  586. )
  587. elif isinstance(func, torch.jit.ScriptFunction):
  588. fut = _invoke_rpc_torchscript(
  589. dst_worker_info.name,
  590. torch._jit_internal._qualified_name(func),
  591. args,
  592. kwargs,
  593. rpc_timeout,
  594. is_async_exec
  595. )
  596. else:
  597. (pickled_python_udf, tensors) = _default_pickler.serialize(
  598. PythonUDF(func, args, kwargs)
  599. )
  600. fut = _invoke_rpc_python_udf(
  601. dst_worker_info,
  602. pickled_python_udf,
  603. tensors,
  604. rpc_timeout,
  605. is_async_exec
  606. )
  607. if should_profile:
  608. assert torch.autograd._profiler_enabled()
  609. assert rf is not None
  610. # Schedule profiling callbacks to run when the future completes.
  611. # This returns a future that is completed when the original future
  612. # completes and the profiling callbacks have been completed as well,
  613. # to guarantee that fut.wait() completes the profiling. This new
  614. # future will contain the same value as the original future.
  615. fut = rf._call_end_callbacks_on_future(fut)
  616. return fut
  617. @_require_initialized
  618. def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
  619. r"""
  620. Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
  621. messages are sent and received in parallel to execution of Python code. This
  622. method is thread-safe.
  623. Args:
  624. to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
  625. func (callable): a callable function, such as Python callables, builtin
  626. operators (e.g. :meth:`~torch.add`) and annotated
  627. TorchScript functions.
  628. args (tuple): the argument tuple for the ``func`` invocation.
  629. kwargs (dict): is a dictionary of keyword arguments for the ``func``
  630. invocation.
  631. timeout (float, optional): timeout in seconds to use for this RPC. If
  632. the RPC does not complete in this amount of
  633. time, an exception indicating it has
  634. timed out will be raised. A value of 0
  635. indicates an infinite timeout, i.e. a timeout
  636. error will never be raised. If not provided,
  637. the default value set during initialization
  638. or with ``_set_rpc_timeout`` is used.
  639. Returns:
  640. Returns the result of running ``func`` with ``args`` and ``kwargs``.
  641. Example::
  642. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  643. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  644. API for more details. For example,
  645. >>> export MASTER_ADDR=localhost
  646. >>> export MASTER_PORT=5678
  647. Then run the following code in two different processes:
  648. >>> # On worker 0:
  649. >>> import torch
  650. >>> import torch.distributed.rpc as rpc
  651. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  652. >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
  653. >>> rpc.shutdown()
  654. >>> # On worker 1:
  655. >>> import torch.distributed.rpc as rpc
  656. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  657. >>> rpc.shutdown()
  658. Below is an example of running a TorchScript function using RPC.
  659. >>> # On both workers:
  660. >>> @torch.jit.script
  661. >>> def my_script_add(t1, t2):
  662. >>> return torch.add(t1, t2)
  663. >>> # On worker 0:
  664. >>> import torch.distributed.rpc as rpc
  665. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  666. >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
  667. >>> rpc.shutdown()
  668. >>> # On worker 1:
  669. >>> import torch.distributed.rpc as rpc
  670. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  671. >>> rpc.shutdown()
  672. """
  673. torch._C._log_api_usage_once("torch.distributed.rpc_sync")
  674. fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
  675. return fut.wait()
  676. @_require_initialized
  677. def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
  678. r"""
  679. Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
  680. messages are sent and received in parallel to execution of Python code. This
  681. method is thread-safe. This method will immediately return a
  682. :class:`~torch.futures.Future` that can be awaited on.
  683. Args:
  684. to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
  685. func (callable): a callable function, such as Python callables, builtin
  686. operators (e.g. :meth:`~torch.add`) and annotated
  687. TorchScript functions.
  688. args (tuple): the argument tuple for the ``func`` invocation.
  689. kwargs (dict): is a dictionary of keyword arguments for the ``func``
  690. invocation.
  691. timeout (float, optional): timeout in seconds to use for this RPC. If
  692. the RPC does not complete in this amount of
  693. time, an exception indicating it has
  694. timed out will be raised. A value of 0
  695. indicates an infinite timeout, i.e. a timeout
  696. error will never be raised. If not provided,
  697. the default value set during initialization
  698. or with ``_set_rpc_timeout`` is used.
  699. Returns:
  700. Returns a :class:`~torch.futures.Future` object that can be waited
  701. on. When completed, the return value of ``func`` on ``args`` and
  702. ``kwargs`` can be retrieved from the :class:`~torch.futures.Future`
  703. object.
  704. .. warning ::
  705. Using GPU tensors as arguments or return values of ``func`` is not
  706. supported since we don't support sending GPU tensors over the wire. You
  707. need to explicitly copy GPU tensors to CPU before using them as
  708. arguments or return values of ``func``.
  709. .. warning ::
  710. The ``rpc_async`` API does not copy storages of argument tensors until
  711. sending them over the wire, which could be done by a different thread
  712. depending on the RPC backend type. The caller should make sure that the
  713. contents of those tensors stay intact until the returned
  714. :class:`~torch.futures.Future` completes.
  715. Example::
  716. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  717. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  718. API for more details. For example,
  719. >>> export MASTER_ADDR=localhost
  720. >>> export MASTER_PORT=5678
  721. Then run the following code in two different processes:
  722. >>> # On worker 0:
  723. >>> import torch
  724. >>> import torch.distributed.rpc as rpc
  725. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  726. >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
  727. >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
  728. >>> result = fut1.wait() + fut2.wait()
  729. >>> rpc.shutdown()
  730. >>> # On worker 1:
  731. >>> import torch.distributed.rpc as rpc
  732. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  733. >>> rpc.shutdown()
  734. Below is an example of running a TorchScript function using RPC.
  735. >>> # On both workers:
  736. >>> @torch.jit.script
  737. >>> def my_script_add(t1, t2):
  738. >>> return torch.add(t1, t2)
  739. >>> # On worker 0:
  740. >>> import torch.distributed.rpc as rpc
  741. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  742. >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
  743. >>> ret = fut.wait()
  744. >>> rpc.shutdown()
  745. >>> # On worker 1:
  746. >>> import torch.distributed.rpc as rpc
  747. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  748. >>> rpc.shutdown()
  749. """
  750. torch._C._log_api_usage_once("torch.distributed.rpc_async")
  751. fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
  752. if hasattr(_thread_local_var, "future_list"):
  753. _thread_local_var.future_list.append(fut)
  754. return fut
  755. def _get_should_profile():
  756. # Legacy profiler should be enabled. RPC profiling is not supported with
  757. # Kineto profiler.
  758. ActiveProfilerType = torch._C._autograd.ActiveProfilerType
  759. return (
  760. torch.autograd._profiler_enabled() and
  761. torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
  762. )
  763. def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
  764. ctx_manager = contextlib.suppress()
  765. if should_profile:
  766. # Create appropriate string representation based on type of func
  767. # (builtin, script, python)
  768. if qualified_name is None:
  769. func_name = (
  770. torch._jit_internal._qualified_name(func)
  771. if isinstance(func, torch.jit.ScriptFunction)
  772. else func.__qualname__
  773. )
  774. else:
  775. func_name = qualified_name
  776. # Build RPC profiling key.
  777. rpc_profiling_key = _build_rpc_profiling_key(
  778. rpc_type,
  779. func_name,
  780. get_worker_info().name,
  781. dst_worker_info.name,
  782. )
  783. RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
  784. # Mypy doesn't support re-def of a variable not in the same block (#1174)
  785. ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
  786. return ctx_manager