internal.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import collections
  2. import copyreg
  3. import io
  4. import pickle
  5. import sys
  6. import threading
  7. import traceback
  8. from enum import Enum
  9. import torch
  10. import torch.distributed as dist
  11. from torch._C._distributed_rpc import _get_current_rpc_agent
  12. # Thread local tensor tables to store tensors while pickling torch.Tensor
  13. # objects
  14. _thread_local_tensor_tables = threading.local()
  15. _pickler = pickle.Pickler
  16. _unpickler = pickle.Unpickler
  17. class RPCExecMode(Enum):
  18. SYNC = "sync"
  19. ASYNC = "async"
  20. ASYNC_JIT = "async_jit"
  21. REMOTE = "remote"
  22. class _InternalRPCPickler:
  23. r"""
  24. This class provides serialize() and deserialize() interfaces to serialize
  25. data to be "binary string + tensor table" format
  26. So for RPC python UDF function and args, non tensor data will be serialized
  27. into regular binary string, tensor data will be put into thread local tensor
  28. tables, this serialization format is consistent with builtin operator and args
  29. using JIT pickler. This format will make tensor handling in C++ much easier,
  30. e.g. attach tensor to distributed autograd graph in C++
  31. """
  32. def __init__(self):
  33. # Ignore type error because dispatch_table is defined in third-party package
  34. self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
  35. self._dispatch_table[torch.Tensor] = self._tensor_reducer
  36. # Used for registering customized picklers.
  37. self._class_reducer_dict = {}
  38. def _register_reducer(self, obj_class, reducer):
  39. # For the same class, only register the reducer once.
  40. if obj_class not in self._class_reducer_dict:
  41. self._class_reducer_dict[obj_class] = reducer
  42. @classmethod
  43. def _tensor_receiver(cls, tensor_index):
  44. global _thread_local_tensor_tables
  45. return _thread_local_tensor_tables.recv_tables[tensor_index]
  46. def _tensor_reducer(self, tensor):
  47. global _thread_local_tensor_tables
  48. _thread_local_tensor_tables.send_tables.append(tensor)
  49. tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
  50. return (_InternalRPCPickler._tensor_receiver, (tensor_index,))
  51. @classmethod
  52. def _py_rref_receiver(cls, rref_fork_data):
  53. return dist.rpc.PyRRef._deserialize(rref_fork_data)
  54. def _py_rref_reducer(self, py_rref):
  55. rref_fork_data = py_rref._serialize()
  56. return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))
  57. def _rref_reducer(self, rref):
  58. return self._py_rref_reducer(rref)
  59. @classmethod
  60. def _script_module_receiver(cls, script_module_serialized):
  61. """
  62. Given a serialized representation of a ScriptModule created with torch.jit.save,
  63. loads and returns the ScriptModule.
  64. """
  65. f = io.BytesIO(script_module_serialized)
  66. m = torch.jit.load(f)
  67. return m
  68. def _script_module_reducer(self, script_module):
  69. """
  70. Serializes a ScriptModule.
  71. """
  72. f = io.BytesIO()
  73. torch.jit.save(script_module, f)
  74. return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),))
  75. def serialize(self, obj):
  76. r"""
  77. Serialize non tensor data into binary string, tensor data into
  78. tensor table
  79. """
  80. f = io.BytesIO()
  81. p = _pickler(f)
  82. p.dispatch_table = self._dispatch_table
  83. # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
  84. # user picklers could have different initialization function from _InternalRPCPickler,
  85. # but all the user picklers should call serialize() and use _rref_reducer to pickle rref
  86. # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
  87. # compiled yet, it is not good place to acces rpc.RRef inside _InternalRPCPickler constructor,
  88. # so puting rref's dispatch table here
  89. #
  90. # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
  91. # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
  92. # Ignore type error because dispatch_table is defined in third-party package
  93. p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index]
  94. # An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
  95. # Ignore type error because dispatch_table is defined in third-party package
  96. p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index]
  97. # Add dispatch pickling for ScriptModule or its subclass.
  98. if isinstance(obj, torch.jit.ScriptModule):
  99. # Ignore type error because dispatch_table is defined in third-party package
  100. p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index]
  101. # Install customized picklers.
  102. for class_name in self._class_reducer_dict.keys():
  103. p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index]
  104. # save _thread_local_tensor_tables.send_tables if it is in nested call
  105. global _thread_local_tensor_tables
  106. if hasattr(_thread_local_tensor_tables, "send_tables"):
  107. old_send_tables = _thread_local_tensor_tables.send_tables
  108. else:
  109. old_send_tables = None
  110. _thread_local_tensor_tables.send_tables = []
  111. p.dump(obj)
  112. # restore _thread_local_tensor_tables.send_tables if return
  113. # from nested call, otherwise clean up the table
  114. tensors = _thread_local_tensor_tables.send_tables
  115. if old_send_tables is not None:
  116. _thread_local_tensor_tables.send_tables = old_send_tables
  117. else:
  118. del _thread_local_tensor_tables.send_tables
  119. return (f.getvalue(), tensors)
  120. def deserialize(self, binary_data, tensor_table):
  121. r"""
  122. Deserialize binary string + tensor table to original obj
  123. """
  124. # save _thread_local_tensor_tables.recv_tables if it is in nested call
  125. global _thread_local_tensor_tables
  126. if hasattr(_thread_local_tensor_tables, "recv_tables"):
  127. old_recv_tables = _thread_local_tensor_tables.recv_tables
  128. else:
  129. old_recv_tables = None
  130. _thread_local_tensor_tables.recv_tables = tensor_table
  131. try:
  132. unpickler = _unpickler(io.BytesIO(binary_data))
  133. ret = unpickler.load()
  134. except AttributeError as e:
  135. # Occurs when function is not found on module/class during
  136. # unpickling.
  137. except_str = (
  138. str(e)
  139. + """ Default RPC pickler does not serialize
  140. function code. Ensure that UDFs are defined on both caller and
  141. callee modules."""
  142. )
  143. ret = AttributeError(except_str)
  144. # Ensure the stack trace gets preserved
  145. ret.__cause__ = e
  146. # restore _thread_local_tensor_tables.recv_tables if return
  147. # from nested call, otherwise clean up the table
  148. if old_recv_tables is not None:
  149. _thread_local_tensor_tables.recv_tables = old_recv_tables
  150. else:
  151. del _thread_local_tensor_tables.recv_tables
  152. return ret
  153. # Create _internal_rpc_pickler only once to initialize _dispatch_table only once
  154. _internal_rpc_pickler = _InternalRPCPickler()
  155. def serialize(obj):
  156. return _internal_rpc_pickler.serialize(obj)
  157. def deserialize(binary_data, tensor_table):
  158. return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
  159. def _run_function(python_udf):
  160. r"""
  161. This function is exclusively called from C++.
  162. See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
  163. Runs a Python UDF and returns its return value.
  164. Wraps any exception in ``RemoteException`` if the function raises.
  165. """
  166. try:
  167. if isinstance(python_udf, AttributeError):
  168. raise python_udf
  169. result = python_udf.func(*python_udf.args, **python_udf.kwargs)
  170. except Exception as e:
  171. # except str = exception info + traceback string
  172. except_str = (
  173. f"On {_get_current_rpc_agent().get_worker_info()}:\n"
  174. f"{repr(e)}\n{traceback.format_exc()}"
  175. )
  176. print(except_str, file=sys.stderr)
  177. result = RemoteException(except_str, type(e))
  178. return result
  179. def _handle_exception(result):
  180. if isinstance(result, RemoteException):
  181. raise result.exception_type(result.msg.encode("utf-8").decode("unicode_escape"))
  182. def _build_rpc_profiling_key(
  183. exec_type, func_name, current_worker_name, dst_worker_name
  184. ):
  185. """
  186. Builds the key that RPC calls are profiled with using the autograd profiler.
  187. This will be the name of the corresponding Event recorded in the profiler.
  188. Args:
  189. exec_type (RPCExecMode): Type of RPC/RRef call
  190. func_name (str): Name of function being profiled.
  191. current_worker_name (str): Name of current worker.
  192. dst_worker_name (str): Name of the destination worker.
  193. Returns:
  194. String representing profiling key
  195. """
  196. profile_key = "rpc_{rpc_type}#{func_name}({current_worker} -> {dst_worker})".format(
  197. rpc_type=exec_type.value,
  198. func_name=func_name,
  199. current_worker=current_worker_name,
  200. dst_worker=dst_worker_name,
  201. )
  202. return profile_key
  203. def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
  204. """
  205. This function should be called from RPC/RRef functions to create a
  206. RecordFunction object for profiling. This function also runs the before
  207. callbacks that start the profiling, though the user is responsible for
  208. running the appropriate callbacks when the function to be profiled finishes.
  209. Args:
  210. exec_type (RPCExecMode): Type of RPC/RRef call
  211. func_name (str): Name of function being profiled.
  212. current_worker_name (str): Name of current worker.
  213. dest_worker_name (str): Name of the destination worker.
  214. Returns:
  215. An instance of `torch.autograd._RecordFunction`.
  216. """
  217. assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
  218. profile_key = "rpc_{}#{}({} -> {})".format(
  219. exec_type.value, str(func_name), current_worker_name, dest_worker_name
  220. )
  221. rf = torch.autograd._RecordFunction() # type: ignore[attr-defined]
  222. torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined]
  223. return rf
  224. PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
  225. RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])