wrappers.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import torch
  2. import torch._prims as prims
  3. from torch._prims.utils import (
  4. Number,
  5. NumberType,
  6. TensorLike,
  7. TensorLikeType,
  8. ELEMENTWISE_TYPE_PROMOTION_KIND,
  9. )
  10. import torch._prims.utils as utils
  11. from torch.utils._pytree import tree_flatten
  12. from typing import Callable, Sequence, Union
  13. import inspect
  14. from functools import wraps, reduce
  15. import operator
  16. import warnings
  17. from itertools import chain
  18. # TODO: implement ref.cast with an option to enforce safe casting
  19. def _maybe_convert_to_dtype(
  20. a: Union[TensorLikeType, NumberType, Sequence], dtype: torch.dtype
  21. ) -> Union[TensorLikeType, NumberType, Sequence]:
  22. if isinstance(a, TensorLike):
  23. if a.dtype != dtype:
  24. # NOTE: this is incorrect on the CPU
  25. # See https://github.com/pytorch/pytorch/issues/77553
  26. return prims.convert_element_type(a, dtype)
  27. return a
  28. if isinstance(a, Number):
  29. return utils.dtype_to_type(dtype)(a)
  30. if isinstance(a, Sequence):
  31. return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
  32. raise ValueError(
  33. "Received type {0} that is neither a tensor or a number!".format(type(a))
  34. )
  35. def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
  36. if not isinstance(a, Number):
  37. msg = "Found unknown type {0} when trying to convert scalars!".format(type(a))
  38. raise ValueError(msg)
  39. if not utils.is_weakly_lesser_type(type(a), typ):
  40. msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format(
  41. a, type(a), typ
  42. )
  43. raise ValueError(msg)
  44. return typ(a)
  45. def _annotation_has_type(*, typ, annotation):
  46. if hasattr(annotation, "__args__"):
  47. for a in annotation.__args__:
  48. if _annotation_has_type(typ=typ, annotation=a):
  49. return True
  50. return False
  51. return typ is annotation
  52. class elementwise_type_promotion_wrapper(object):
  53. """
  54. Adds elementwise type promotion to a Python reference implementation.
  55. Takes two kwargs, type_promoting_args and type_promotion_kind.
  56. type_promoting_args must be a string Sequence specifiying the argument names of all
  57. arguments that participate in type promotion (and should be type promoted). If the
  58. arg specifies a Sequence-type then every element of the Sequence will participate in
  59. type promotion.
  60. type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
  61. See its documentation for details.
  62. Other type promotion behavior, like validating the Python type of scalar arguments, must
  63. be handled separately.
  64. """
  65. def __init__(
  66. self,
  67. *,
  68. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  69. type_promoting_args: Sequence[str] = None,
  70. ):
  71. self.type_promoting_arg_names = type_promoting_args
  72. self.type_promotion_kind = type_promotion_kind
  73. def __call__(self, fn: Callable) -> Callable:
  74. sig = inspect.signature(fn)
  75. @wraps(fn)
  76. def _fn(*args, **kwargs):
  77. bound = sig.bind(*args, **kwargs)
  78. type_promoting_args = tuple(
  79. bound.arguments[x]
  80. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  81. if x in bound.arguments.keys()
  82. )
  83. flattened_type_promoting_args = tree_flatten(type_promoting_args)[0]
  84. compute_dtype, result_dtype = utils.elementwise_dtypes(
  85. *flattened_type_promoting_args,
  86. type_promotion_kind=self.type_promotion_kind,
  87. )
  88. promoted_args = {
  89. x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
  90. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  91. if x in bound.arguments.keys()
  92. }
  93. bound.arguments.update(promoted_args)
  94. result = fn(**bound.arguments)
  95. # FIXME?: assumes result is a single tensor
  96. assert isinstance(result, TensorLike)
  97. return _maybe_convert_to_dtype(result, result_dtype)
  98. _fn.__signature__ = sig # type: ignore[attr-defined]
  99. return _fn
  100. # TODO: handle tuples of tensors
  101. def _maybe_resize_out(out: TensorLikeType, shape):
  102. if out.numel() == 0:
  103. return prims.resize(out, shape)
  104. if out.numel() != reduce(operator.mul, shape, 1):
  105. msg = (
  106. "An output with one or more elements was resized since it had shape {0} "
  107. "which does not match the required output shape {1}. "
  108. "This behavior is deprecated, and in a future PyTorch release outputs will not "
  109. "be resized unless they have zero elements. "
  110. "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0).".format(
  111. str(out.shape), str(shape)
  112. )
  113. )
  114. warnings.warn(msg)
  115. return prims.resize(out, shape)
  116. return out
  117. def _safe_copy_out(*, copy_from: TensorLikeType, copy_to: TensorLikeType):
  118. # Checks same device
  119. if copy_from.device != copy_to.device:
  120. msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
  121. copy_from.device, copy_to.device
  122. )
  123. raise RuntimeError(msg)
  124. # Checks safe cast
  125. if not utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype):
  126. msg = "Attempting to cast from {0} to out tensor with dtype {1}, but this can't be cast because it is not safe!".format(
  127. copy_from.dtype, copy_to.dtype
  128. )
  129. raise RuntimeError(msg)
  130. return prims.copy_to(copy_to, copy_from)
  131. # FIXME: only supports out parameter that is literally called "out"
  132. def out_wrapper(fn: Callable) -> Callable:
  133. """
  134. Adds the out parameter to a Python reference.
  135. Note that this currently only supports operations that return a single tensor.
  136. """
  137. @wraps(fn)
  138. def _fn(*args, out=None, **kwargs):
  139. result = fn(*args, **kwargs)
  140. if out is not None:
  141. assert isinstance(out, TensorLike)
  142. out = _maybe_resize_out(out, result.shape)
  143. return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
  144. return out
  145. return result
  146. sig = inspect.signature(fn)
  147. out_param = inspect.Parameter(
  148. "out",
  149. kind=inspect.Parameter.KEYWORD_ONLY,
  150. default=None,
  151. annotation=TensorLikeType,
  152. )
  153. params = chain(sig.parameters.values(), (out_param,))
  154. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  155. parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
  156. )
  157. _fn.__annotations__ = fn.__annotations__
  158. _fn.__annotations__["out"] = TensorLikeType
  159. return _fn
  160. def out_wrapper_multi(*out_names):
  161. def go(fn: Callable) -> Callable:
  162. @wraps(fn)
  163. def _fn(*args, **kwargs):
  164. out_kwargs = {}
  165. has_out_kwargs = None
  166. for o in out_names:
  167. out_kwargs[o] = kwargs.pop(o, None)
  168. # Either all of the out kwargs are set or none of them
  169. if has_out_kwargs is None:
  170. has_out_kwargs = out_kwargs[o] is not None
  171. else:
  172. assert has_out_kwargs == (out_kwargs[o] is not None)
  173. result = fn(*args, **kwargs)
  174. assert isinstance(result, tuple)
  175. if has_out_kwargs:
  176. final_result = []
  177. for i, o in enumerate(out_names):
  178. out = out_kwargs[o]
  179. assert isinstance(out, TensorLike)
  180. out = _maybe_resize_out(out, result[i].shape)
  181. final_result.append(_safe_copy_out(copy_from=result[i], copy_to=out)) # type: ignore[arg-type]
  182. return tuple(final_result)
  183. return result
  184. sig = inspect.signature(fn)
  185. out_params = []
  186. for o in out_names:
  187. out_params.append(
  188. inspect.Parameter(
  189. o,
  190. kind=inspect.Parameter.KEYWORD_ONLY,
  191. default=None,
  192. annotation=TensorLikeType,
  193. )
  194. )
  195. params = chain(sig.parameters.values(), out_params)
  196. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  197. parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
  198. )
  199. _fn.__annotations__ = fn.__annotations__
  200. for o in out_names:
  201. _fn.__annotations__[o] = TensorLikeType
  202. return _fn
  203. return go