_deprecated.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. """This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this
  2. was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus,
  3. we don't internalize without warning, but still go through a deprecation cycle.
  4. """
  5. import functools
  6. import random
  7. import warnings
  8. from typing import Any, Callable, Dict, Optional, Tuple, Union
  9. import torch
  10. from . import _legacy
  11. __all__ = [
  12. "rand",
  13. "randn",
  14. "assert_allclose",
  15. "get_all_device_types",
  16. "make_non_contiguous",
  17. ]
  18. def warn_deprecated(instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]) -> Callable:
  19. def outer_wrapper(fn: Callable) -> Callable:
  20. name = fn.__name__
  21. head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. "
  22. @functools.wraps(fn)
  23. def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
  24. return_value = fn(*args, **kwargs)
  25. tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions
  26. msg = (head + tail).strip()
  27. warnings.warn(msg, FutureWarning)
  28. return return_value
  29. return inner_wrapper
  30. return outer_wrapper
  31. rand = warn_deprecated("Use torch.rand() instead.")(torch.rand)
  32. randn = warn_deprecated("Use torch.randn() instead.")(torch.randn)
  33. _DTYPE_PRECISIONS = {
  34. torch.float16: (1e-3, 1e-3),
  35. torch.float32: (1e-4, 1e-5),
  36. torch.float64: (1e-5, 1e-8),
  37. }
  38. def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
  39. actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
  40. expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
  41. return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
  42. @warn_deprecated(
  43. "Use torch.testing.assert_close() instead. "
  44. "For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
  45. )
  46. def assert_allclose(
  47. actual: Any,
  48. expected: Any,
  49. rtol: Optional[float] = None,
  50. atol: Optional[float] = None,
  51. equal_nan: bool = True,
  52. msg: str = "",
  53. ) -> None:
  54. if not isinstance(actual, torch.Tensor):
  55. actual = torch.tensor(actual)
  56. if not isinstance(expected, torch.Tensor):
  57. expected = torch.tensor(expected, dtype=actual.dtype)
  58. if rtol is None and atol is None:
  59. rtol, atol = _get_default_rtol_and_atol(actual, expected)
  60. torch.testing.assert_close(
  61. actual,
  62. expected,
  63. rtol=rtol,
  64. atol=atol,
  65. equal_nan=equal_nan,
  66. check_device=True,
  67. check_dtype=False,
  68. check_stride=False,
  69. msg=msg or None,
  70. )
  71. getter_instructions = (
  72. lambda name, args, kwargs, return_value: f"This call can be replaced with {return_value}." # noqa: E731
  73. )
  74. # Deprecate and expose all dtype getters
  75. for name in _legacy.__all_dtype_getters__:
  76. fn = getattr(_legacy, name)
  77. globals()[name] = warn_deprecated(getter_instructions)(fn)
  78. __all__.append(name)
  79. get_all_device_types = warn_deprecated(getter_instructions)(_legacy.get_all_device_types)
  80. @warn_deprecated(
  81. "Depending on the use case there a different replacement options:\n\n"
  82. "- If you are using `make_non_contiguous` in combination with a creation function to create a noncontiguous tensor "
  83. "with random values, use `torch.testing.make_tensor(..., noncontiguous=True)` instead.\n"
  84. "- If you are using `make_non_contiguous` with a specific tensor, you can replace this call with "
  85. "`torch.repeat_interleave(input, 2, dim=-1)[..., ::2]`.\n"
  86. "- If you are using `make_non_contiguous` in the PyTorch test suite, use "
  87. "`torch.testing._internal.common_utils.noncontiguous_like` instead."
  88. )
  89. def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
  90. if tensor.numel() <= 1: # can't make non-contiguous
  91. return tensor.clone()
  92. osize = list(tensor.size())
  93. # randomly inflate a few dimensions in osize
  94. for _ in range(2):
  95. dim = random.randint(0, len(osize) - 1)
  96. add = random.randint(4, 15)
  97. osize[dim] = osize[dim] + add
  98. # narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension,
  99. # (which will always happen with a 1-dimensional tensor), so let's make a new
  100. # right-most dimension and cut it off
  101. input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
  102. input = input.select(len(input.size()) - 1, random.randint(0, 1))
  103. # now extract the input of correct size from 'input'
  104. for i in range(len(osize)):
  105. if input.size(i) != tensor.size(i):
  106. bounds = random.randint(1, input.size(i) - tensor.size(i))
  107. input = input.narrow(i, bounds, tensor.size(i))
  108. input.copy_(tensor)
  109. # Use .data here to hide the view relation between input and other temporary Tensors
  110. return input.data