api.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import sys
  8. import uuid
  9. from dataclasses import dataclass, field
  10. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  11. import torch.distributed.elastic.rendezvous.registry as rdzv_registry
  12. from torch.distributed.elastic import events, metrics
  13. from torch.distributed.elastic.agent.server.api import WorkerSpec
  14. from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
  15. from torch.distributed.elastic.multiprocessing import SignalException, Std
  16. from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
  17. from torch.distributed.elastic.rendezvous import RendezvousParameters
  18. from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
  19. from torch.distributed.elastic.utils.logging import get_logger
  20. logger = get_logger()
  21. @dataclass
  22. class LaunchConfig:
  23. """
  24. Creates a rendezvous config.
  25. Args:
  26. min_nodes: Minimum amount of nodes that the user function will
  27. be launched on. Elastic agent ensures that the user
  28. function start only when the min_nodes amount enters
  29. the rendezvous.
  30. max_nodes: Maximum amount of nodes that the user function
  31. will be launched on.
  32. nproc_per_node: On each node the elastic agent will launch
  33. this amount of workers that will execute user
  34. defined function.
  35. rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
  36. rdzv_endpoint: The endpoint of the rdzv sync. storage.
  37. rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
  38. rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
  39. to be removed in future versions, see the note below. The default timeout is 900 seconds.
  40. run_id: The unique run id of the job (if not passed a unique one will be
  41. deduced from run environment - flow workflow id in flow - or auto generated).
  42. role: User defined role of the worker (defaults to "trainer").
  43. max_restarts: The maximum amount of restarts that elastic agent will conduct
  44. on workers before failure.
  45. monitor_interval: The interval in seconds that is used by the elastic_agent
  46. as a period of monitoring workers.
  47. start_method: The method is used by the elastic agent to start the
  48. workers (spawn, fork, forkserver).
  49. log_dir: base log directory where log files are written. If not set,
  50. one is created in a tmp dir but NOT removed on exit.
  51. redirects: configuration to redirect stdout/stderr to log files.
  52. Pass a single ``Std`` enum to redirect all workers,
  53. or a mapping keyed by local_rank to selectively redirect.
  54. tee: configuration to "tee" stdout/stderr to console + log file.
  55. metrics_cfg: configuration to initialize metrics.
  56. ..note:
  57. `rdzv_timeout` is a legacy argument that will be removed in future.
  58. Set the timeout via `rdzv_configs['timeout']`
  59. """
  60. min_nodes: int
  61. max_nodes: int
  62. nproc_per_node: int
  63. run_id: str = ""
  64. role: str = "default_role"
  65. rdzv_endpoint: str = ""
  66. rdzv_backend: str = "etcd"
  67. rdzv_configs: Dict[str, Any] = field(default_factory=dict)
  68. rdzv_timeout: int = -1
  69. max_restarts: int = 3
  70. monitor_interval: float = 30
  71. start_method: str = "spawn"
  72. log_dir: Optional[str] = None
  73. redirects: Union[Std, Dict[int, Std]] = Std.NONE
  74. tee: Union[Std, Dict[int, Std]] = Std.NONE
  75. metrics_cfg: Dict[str, str] = field(default_factory=dict)
  76. def __post_init__(self):
  77. default_timeout = 900
  78. if self.rdzv_timeout != -1:
  79. self.rdzv_configs["timeout"] = self.rdzv_timeout
  80. elif "timeout" not in self.rdzv_configs:
  81. self.rdzv_configs["timeout"] = default_timeout
  82. class elastic_launch:
  83. """
  84. Launches an torchelastic agent on the container that invoked the entrypoint.
  85. 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
  86. ``entrypoint`` can be a function or a command.
  87. 2. The return value is a map of each worker's output mapped
  88. by their respective global rank.
  89. Usage
  90. ::
  91. def worker_fn(foo):
  92. # ...
  93. def main():
  94. # entrypoint is a function.
  95. outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
  96. # return rank 0's output
  97. return outputs[0]
  98. # entrypoint is a command and ``script.py`` is the python module.
  99. ouptuts = elestic_launch(LaunchConfig, "script.py")(args)
  100. ouptuts = elestic_launch(LaunchConfig, "python")("script.py")
  101. """
  102. def __init__(
  103. self,
  104. config: LaunchConfig,
  105. entrypoint: Union[Callable, str, None],
  106. ):
  107. self._config = config
  108. self._entrypoint = entrypoint
  109. def __call__(self, *args):
  110. return launch_agent(self._config, self._entrypoint, list(args))
  111. def _get_entrypoint_name(
  112. entrypoint: Union[Callable, str, None], args: List[Any]
  113. ) -> str:
  114. """Retrive entrypoint name with the rule:
  115. 1. If entrypoint is a function, use ``entrypont.__qualname__``.
  116. 2. If entrypoint is a string, check its value:
  117. 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
  118. which does not start with hifen letter (for example, "-u" will be skipped).
  119. 2.2 otherwise, use ``entrypoint`` value.
  120. 3. Otherwise, return empty string.
  121. """
  122. if isinstance(entrypoint, Callable): # type: ignore[arg-type]
  123. return entrypoint.__name__ # type: ignore[union-attr]
  124. elif isinstance(entrypoint, str):
  125. if entrypoint == sys.executable:
  126. return next((arg for arg in args if arg[0] != "-"), "")
  127. else:
  128. return entrypoint
  129. else:
  130. return ""
  131. def _get_addr_and_port(
  132. rdzv_parameters: RendezvousParameters,
  133. ) -> Tuple[Optional[str], Optional[int]]:
  134. if rdzv_parameters.backend != "static":
  135. return (None, None)
  136. endpoint = rdzv_parameters.endpoint
  137. endpoint = endpoint.strip()
  138. if not endpoint:
  139. raise ValueError(
  140. "Endpoint is missing in endpoint. Try to add --master_addr and --master_port"
  141. )
  142. master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
  143. if master_port == -1:
  144. raise ValueError(
  145. f"port is missing in endpoint: {endpoint}. Try to specify --master_port"
  146. )
  147. return (master_addr, master_port)
  148. def launch_agent(
  149. config: LaunchConfig,
  150. entrypoint: Union[Callable, str, None],
  151. args: List[Any],
  152. ) -> Dict[int, Any]:
  153. if not config.run_id:
  154. run_id = str(uuid.uuid4().int)
  155. logger.warning(f"config has no run_id, generated a random run_id: {run_id}")
  156. config.run_id = run_id
  157. entrypoint_name = _get_entrypoint_name(entrypoint, args)
  158. logger.info(
  159. f"Starting elastic_operator with launch configs:\n"
  160. f" entrypoint : {entrypoint_name}\n"
  161. f" min_nodes : {config.min_nodes}\n"
  162. f" max_nodes : {config.max_nodes}\n"
  163. f" nproc_per_node : {config.nproc_per_node}\n"
  164. f" run_id : {config.run_id}\n"
  165. f" rdzv_backend : {config.rdzv_backend}\n"
  166. f" rdzv_endpoint : {config.rdzv_endpoint}\n"
  167. f" rdzv_configs : {config.rdzv_configs}\n"
  168. f" max_restarts : {config.max_restarts}\n"
  169. f" monitor_interval : {config.monitor_interval}\n"
  170. f" log_dir : {config.log_dir}\n"
  171. f" metrics_cfg : {config.metrics_cfg}\n"
  172. )
  173. rdzv_parameters = RendezvousParameters(
  174. backend=config.rdzv_backend,
  175. endpoint=config.rdzv_endpoint,
  176. run_id=config.run_id,
  177. min_nodes=config.min_nodes,
  178. max_nodes=config.max_nodes,
  179. **config.rdzv_configs,
  180. )
  181. master_addr, master_port = _get_addr_and_port(rdzv_parameters)
  182. spec = WorkerSpec(
  183. role=config.role,
  184. local_world_size=config.nproc_per_node,
  185. entrypoint=entrypoint,
  186. args=tuple(args),
  187. rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
  188. max_restarts=config.max_restarts,
  189. monitor_interval=config.monitor_interval,
  190. redirects=config.redirects,
  191. tee=config.tee,
  192. master_addr=master_addr,
  193. master_port=master_port,
  194. )
  195. agent = LocalElasticAgent(
  196. spec=spec, start_method=config.start_method, log_dir=config.log_dir
  197. )
  198. shutdown_rdzv = True
  199. try:
  200. metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
  201. result = agent.run()
  202. # records that agent.run() has succeeded NOT that workers have succeeded
  203. events.record(agent.get_event_succeeded())
  204. if result.is_failed():
  205. # ChildFailedError is treated specially by @record
  206. # if the error files for the failed children exist
  207. # @record will copy the first error (root cause)
  208. # to the error file of the launcher process.
  209. raise ChildFailedError(
  210. name=entrypoint_name,
  211. failures=result.failures,
  212. )
  213. return result.return_values
  214. except ChildFailedError:
  215. raise
  216. except SignalException:
  217. # when the agent dies with a signal do NOT shutdown the rdzv_handler
  218. # since this closes the rendezvous on this rdzv_id permanently and
  219. # prevents any additional scaling events
  220. shutdown_rdzv = False
  221. events.record(agent.get_event_failed())
  222. raise
  223. except Exception:
  224. events.record(agent.get_event_failed())
  225. raise
  226. finally:
  227. if shutdown_rdzv:
  228. spec.rdzv_handler.shutdown()