functional.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. import torch
  2. import torch.distributed as dist
  3. from torch.autograd import Function
  4. # The two imports below are not always available depending on the
  5. # USE_DISTRIBUTED compile flag. Make sure they raise import error
  6. # if we're trying to use them.
  7. from torch.distributed import group, ReduceOp
  8. def broadcast(tensor, src, group=group.WORLD):
  9. """
  10. Broadcasts the tensor to the whole group.
  11. ``tensor`` must have the same number of elements in all processes
  12. participating in the collective.
  13. Arguments:
  14. tensor (Tensor): Data to be sent if ``src`` is the rank of current
  15. process.
  16. src (int): Source rank.
  17. group (ProcessGroup, optional): The process group to work on.
  18. Returns:
  19. Tensor: Received tensor from the broadcast op.
  20. """
  21. return _Broadcast.apply(src, group, tensor)
  22. def gather(tensor, dst=0, group=group.WORLD):
  23. """
  24. Gathers a list of tensors in a single process.
  25. Arguments:
  26. tensor (Tensor): Input tensor.
  27. dst (int, optional): Destination rank (default is 0).
  28. group (ProcessGroup, optional): The process group to work on.
  29. Returns:
  30. tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
  31. """
  32. return _Gather.apply(dst, group, tensor)
  33. def scatter(tensors, src=0, group=group.WORLD):
  34. """
  35. Scatters a list of tensors to all processes in a group.
  36. Each process will receive exactly one tensor and store its data in the
  37. ``tensor`` argument.
  38. Arguments:
  39. tensors (list[Tensor]): List of tensors to scatter on the source rank.
  40. Receivers must pass ``None`.
  41. src (int, optional): Source rank (default is 0).
  42. group (ProcessGroup, optional): The process group to work on.
  43. Returns:
  44. Tensor: Output tensor from the scatter operation.
  45. """
  46. return _Scatter.apply(src, group, *tensors)
  47. def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD):
  48. """
  49. Reduces the tensor data across all machines.
  50. Only the process with rank ``dst`` is going to receive the final result.
  51. Arguments:
  52. tensor (Tensor): Input of the collective.
  53. dst (int): Destination rank.
  54. op (optional): One of the values from
  55. ``torch.distributed.ReduceOp``
  56. enum. Specifies an operation used for element-wise reductions.
  57. group (ProcessGroup, optional): The process group to work on.
  58. Returns:
  59. Tensor: Output of the collective.
  60. """
  61. return _Reduce.apply(dst, op, group, tensor)
  62. def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD):
  63. """
  64. Reduces, then scatters a list of tensors to all processes in a group.
  65. Arguments:
  66. output (Tensor): Output tensor.
  67. input_list (list[Tensor]): List of tensors to reduce and scatter.
  68. op (optional): One of the values from
  69. ``torch.distributed.ReduceOp``
  70. enum. Specifies an operation used for element-wise reductions.
  71. group (ProcessGroup, optional): The process group to work on.
  72. Returns:
  73. Tensor: Output of the collective.
  74. """
  75. return _Reduce_Scatter.apply(op, group, output, *input_list)
  76. def all_gather(tensor, group=group.WORLD):
  77. """
  78. Gathers tensors from the whole group in a list.
  79. Arguments:
  80. tensor (Tensor): Tensor to be broadcast from current process.
  81. group (ProcessGroup, optional): The process group to work on.
  82. Returns:
  83. tuple([Tensor]): Output of the collective.
  84. """
  85. return _AllGather.apply(group, tensor)
  86. def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD):
  87. """
  88. Each process scatters list of input tensors to all processes in a group and
  89. return gathered list of tensors in output list.
  90. Arguments:
  91. out_tensor_list (list[Tensor]): list of tensors to gather one per rank.
  92. input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
  93. group (ProcessGroup, optional): The process group to work on.
  94. Returns:
  95. tuple([Tensor]): Output of the collective.
  96. """
  97. return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
  98. def all_to_all_single(
  99. output,
  100. input,
  101. output_split_sizes=None,
  102. input_split_sizes=None,
  103. group=group.WORLD,
  104. ):
  105. """
  106. Each process splits input tensor and then scatters the split list
  107. to all processes in a group. Then concatenate the received tensors from all
  108. the processes in the group and return single output tensor.
  109. Arguments:
  110. output (Tensor): Gathered cancatenated output tensor.
  111. input (Tensor): Input tensor to scatter.
  112. output_split_sizes: (list[Int], optional): Output split sizes for dim 0
  113. if specified None or empty, dim 0 of ``output`` tensor must divide
  114. equally by ``world_size``.
  115. input_split_sizes: (list[Int], optional): Input split sizes for dim 0
  116. if specified None or empty, dim 0 of ``input`` tensor must divide
  117. equally by ``world_size``.
  118. Returns:
  119. Tensor: Output of the collective.
  120. """
  121. return _AlltoAllSingle.apply(
  122. group, output, output_split_sizes, input_split_sizes, input
  123. )
  124. def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD):
  125. """
  126. Reduces the tensor data across all machines in such a way that all get
  127. the final result.
  128. After the call the returned tensor is going to be bitwise
  129. identical in all processes.
  130. Arguments:
  131. tensor (Tensor): Input of the collective.
  132. op (optional): One of the values from
  133. ``torch.distributed.ReduceOp``
  134. enum. Specifies an operation used for element-wise reductions.
  135. group (ProcessGroup, optional): The process group to work on.
  136. Returns:
  137. Tensor: Output of the collective
  138. """
  139. return _AllReduce.apply(op, group, tensor)
  140. class _Broadcast(Function):
  141. @staticmethod
  142. def forward(ctx, src, group, tensor):
  143. ctx.src = src
  144. ctx.group = group
  145. ctx.rank = dist.get_rank()
  146. # torch.distributed makes all the calls in place
  147. # we allocate new tensors to avoid this
  148. tensor = tensor.clone()
  149. dist.broadcast(tensor, src, group=group)
  150. return tensor
  151. @staticmethod
  152. def backward(ctx, grad_output):
  153. gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output)
  154. if ctx.src != ctx.rank:
  155. gx.zero_()
  156. return (None, None, gx)
  157. class _Gather(Function):
  158. @staticmethod
  159. def forward(ctx, dst, group, tensor):
  160. ctx.dst = dst
  161. ctx.group = group
  162. # Need to create a list of tensors here to do the
  163. # aggregation, get it from the group size
  164. # tensor should be correctly sized for the method
  165. # gathering
  166. tensor_list = [
  167. torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group))
  168. ]
  169. tensor = tensor.contiguous()
  170. if dist.get_rank(group=group) == dst:
  171. dist.gather(tensor, tensor_list, dst, group=group)
  172. else:
  173. dist.gather(tensor, None, dst, group=group)
  174. return tuple(tensor_list)
  175. @staticmethod
  176. def backward(ctx, *grad_outputs):
  177. return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),)
  178. class _Scatter(Function):
  179. @staticmethod
  180. def forward(ctx, src, group, *tensors):
  181. ctx.src = src
  182. ctx.group = group
  183. assert all(t.size() == tensors[0].size() for t in tensors)
  184. output = torch.zeros_like(tensors[0])
  185. if dist.get_rank(group=group) == src:
  186. dist.scatter(output, list(tensors), src, group=group)
  187. else:
  188. dist.scatter(output, None, src, group=group)
  189. return output
  190. @staticmethod
  191. def backward(ctx, grad_output):
  192. return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
  193. class _Reduce(Function):
  194. @staticmethod
  195. def forward(ctx, src, op, group, tensor):
  196. ctx.src = src
  197. ctx.group = group
  198. tensor = tensor.clone()
  199. dist.reduce(tensor, src, op=op, group=group)
  200. return tensor
  201. @staticmethod
  202. def backward(ctx, grad_output):
  203. return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
  204. class _Reduce_Scatter(Function):
  205. @staticmethod
  206. def forward(ctx, op, group, tensor, *input_tensor_list):
  207. ctx.group = group
  208. input_tensor_list = tuple(t.contiguous() for t in input_tensor_list)
  209. dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
  210. return tensor
  211. @staticmethod
  212. def backward(ctx, grad_output):
  213. return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
  214. class _AllGather(Function):
  215. @staticmethod
  216. def forward(ctx, group, tensor):
  217. ctx.group = group
  218. out_tensor_list = [
  219. torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group))
  220. ]
  221. dist.all_gather(out_tensor_list, tensor.contiguous(), group=group)
  222. return tuple(out_tensor_list)
  223. @staticmethod
  224. def backward(ctx, *grad_outputs):
  225. if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
  226. rank = dist.get_rank()
  227. gx = torch.empty_like(grad_outputs[rank])
  228. _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs)
  229. else:
  230. # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum()
  231. # to emulate the ReduceScatter behavior
  232. tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
  233. gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
  234. gx = torch.sum(torch.stack(gxs), dim=0)
  235. return (None, gx)
  236. class _AlltoAll(Function):
  237. @staticmethod
  238. def forward(ctx, group, out_tensor_list, *tensors):
  239. ctx.group = group
  240. ctx.input_tensor_size_list = [
  241. tensors[i].size() for i in range(dist.get_world_size(group=group))
  242. ]
  243. my_rank = dist.get_rank(group=group)
  244. tensors = tuple(t.contiguous() for t in tensors)
  245. # Implement it on means of scatter/gather, send/recv async operations have issues
  246. if dist.get_backend(group=group) is dist.Backend.GLOO:
  247. for i in range(dist.get_world_size(group=group)):
  248. to_send = None
  249. if i == my_rank:
  250. to_send = list(tensors)
  251. dist.scatter(out_tensor_list[i], to_send, i, group=group)
  252. else:
  253. dist.all_to_all(
  254. out_tensor_list,
  255. list(tensors),
  256. group=group,
  257. )
  258. return tuple(out_tensor_list)
  259. @staticmethod
  260. def backward(ctx, *grad_outputs):
  261. tensor_list = [
  262. torch.empty(size, device=grad_outputs[0].device)
  263. for size in ctx.input_tensor_size_list
  264. ]
  265. return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
  266. class _AlltoAllSingle(Function):
  267. @staticmethod
  268. def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
  269. ctx.group = group
  270. ctx.input_size = input.size()
  271. ctx.output_split_sizes = input_split_sizes
  272. ctx.input_split_sizes = output_split_sizes
  273. dist.all_to_all_single(
  274. output,
  275. input,
  276. output_split_sizes=output_split_sizes,
  277. input_split_sizes=input_split_sizes,
  278. group=group,
  279. )
  280. return output
  281. @staticmethod
  282. def backward(ctx, grad_output):
  283. tensor = torch.empty(ctx.input_size, device=grad_output.device)
  284. return (None, None, None, None) + (
  285. _AlltoAllSingle.apply(
  286. ctx.group,
  287. tensor,
  288. ctx.output_split_sizes,
  289. ctx.input_split_sizes,
  290. grad_output.contiguous(),
  291. ),
  292. )
  293. class _AllReduce(Function):
  294. @staticmethod
  295. def forward(ctx, op, group, tensor):
  296. ctx.group = group
  297. ctx.op = op
  298. tensor = tensor.clone()
  299. dist.all_reduce(tensor, op=op, group=group)
  300. return tensor
  301. @staticmethod
  302. def backward(ctx, grad_output):
  303. return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)