| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- """This module exists since the `torch.testing` exposed a lot of stuff that shouldn't have been public. Although this
- was never documented anywhere, some other internal FB projects as well as downstream OSS projects might use this. Thus,
- we don't internalize without warning, but still go through a deprecation cycle.
- """
- import functools
- import random
- import warnings
- from typing import Any, Callable, Dict, Optional, Tuple, Union
- import torch
- from . import _legacy
- __all__ = [
- "rand",
- "randn",
- "assert_allclose",
- "get_all_device_types",
- "make_non_contiguous",
- ]
- def warn_deprecated(instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]) -> Callable:
- def outer_wrapper(fn: Callable) -> Callable:
- name = fn.__name__
- head = f"torch.testing.{name}() is deprecated since 1.12 and will be removed in 1.14. "
- @functools.wraps(fn)
- def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
- return_value = fn(*args, **kwargs)
- tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions
- msg = (head + tail).strip()
- warnings.warn(msg, FutureWarning)
- return return_value
- return inner_wrapper
- return outer_wrapper
- rand = warn_deprecated("Use torch.rand() instead.")(torch.rand)
- randn = warn_deprecated("Use torch.randn() instead.")(torch.randn)
- _DTYPE_PRECISIONS = {
- torch.float16: (1e-3, 1e-3),
- torch.float32: (1e-4, 1e-5),
- torch.float64: (1e-5, 1e-8),
- }
- def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
- actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
- expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
- return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
- @warn_deprecated(
- "Use torch.testing.assert_close() instead. "
- "For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844."
- )
- def assert_allclose(
- actual: Any,
- expected: Any,
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
- equal_nan: bool = True,
- msg: str = "",
- ) -> None:
- if not isinstance(actual, torch.Tensor):
- actual = torch.tensor(actual)
- if not isinstance(expected, torch.Tensor):
- expected = torch.tensor(expected, dtype=actual.dtype)
- if rtol is None and atol is None:
- rtol, atol = _get_default_rtol_and_atol(actual, expected)
- torch.testing.assert_close(
- actual,
- expected,
- rtol=rtol,
- atol=atol,
- equal_nan=equal_nan,
- check_device=True,
- check_dtype=False,
- check_stride=False,
- msg=msg or None,
- )
- getter_instructions = (
- lambda name, args, kwargs, return_value: f"This call can be replaced with {return_value}." # noqa: E731
- )
- # Deprecate and expose all dtype getters
- for name in _legacy.__all_dtype_getters__:
- fn = getattr(_legacy, name)
- globals()[name] = warn_deprecated(getter_instructions)(fn)
- __all__.append(name)
- get_all_device_types = warn_deprecated(getter_instructions)(_legacy.get_all_device_types)
- @warn_deprecated(
- "Depending on the use case there a different replacement options:\n\n"
- "- If you are using `make_non_contiguous` in combination with a creation function to create a noncontiguous tensor "
- "with random values, use `torch.testing.make_tensor(..., noncontiguous=True)` instead.\n"
- "- If you are using `make_non_contiguous` with a specific tensor, you can replace this call with "
- "`torch.repeat_interleave(input, 2, dim=-1)[..., ::2]`.\n"
- "- If you are using `make_non_contiguous` in the PyTorch test suite, use "
- "`torch.testing._internal.common_utils.noncontiguous_like` instead."
- )
- def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor:
- if tensor.numel() <= 1: # can't make non-contiguous
- return tensor.clone()
- osize = list(tensor.size())
- # randomly inflate a few dimensions in osize
- for _ in range(2):
- dim = random.randint(0, len(osize) - 1)
- add = random.randint(4, 15)
- osize[dim] = osize[dim] + add
- # narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension,
- # (which will always happen with a 1-dimensional tensor), so let's make a new
- # right-most dimension and cut it off
- input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
- input = input.select(len(input.size()) - 1, random.randint(0, 1))
- # now extract the input of correct size from 'input'
- for i in range(len(osize)):
- if input.size(i) != tensor.size(i):
- bounds = random.randint(1, input.size(i) - tensor.size(i))
- input = input.narrow(i, bounds, tensor.size(i))
- input.copy_(tensor)
- # Use .data here to hide the view relation between input and other temporary Tensors
- return input.data
|