_comparison.py 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362
  1. import abc
  2. import cmath
  3. import collections.abc
  4. import contextlib
  5. from typing import NoReturn, Callable, Sequence, List, Union, Optional, Type, Tuple, Any, Collection
  6. import torch
  7. try:
  8. import numpy as np
  9. NUMPY_AVAILABLE = True
  10. except ModuleNotFoundError:
  11. NUMPY_AVAILABLE = False
  12. class ErrorMeta(Exception):
  13. """Internal testing exception that makes that carries error meta data."""
  14. def __init__(self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()) -> None:
  15. super().__init__(
  16. "If you are a user and see this message during normal operation "
  17. "please file an issue at https://github.com/pytorch/pytorch/issues. "
  18. "If you are a developer and working on the comparison functions, please `raise ErrorMeta().to_error()` "
  19. "for user facing errors."
  20. )
  21. self.type = type
  22. self.msg = msg
  23. self.id = id
  24. def to_error(self, msg: Optional[Union[str, Callable[[str], str]]] = None) -> Exception:
  25. if not isinstance(msg, str):
  26. generated_msg = self.msg
  27. if self.id:
  28. generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
  29. msg = msg(generated_msg) if callable(msg) else generated_msg
  30. return self.type(msg)
  31. # Some analysis of tolerance by logging tests from test_torch.py can be found in
  32. # https://github.com/pytorch/pytorch/pull/32538.
  33. # {dtype: (rtol, atol)}
  34. _DTYPE_PRECISIONS = {
  35. torch.float16: (0.001, 1e-5),
  36. torch.bfloat16: (0.016, 1e-5),
  37. torch.float32: (1.3e-6, 1e-5),
  38. torch.float64: (1e-7, 1e-7),
  39. torch.complex32: (0.001, 1e-5),
  40. torch.complex64: (1.3e-6, 1e-5),
  41. torch.complex128: (1e-7, 1e-7),
  42. }
  43. # The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
  44. # their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
  45. _DTYPE_PRECISIONS.update(
  46. {
  47. dtype: _DTYPE_PRECISIONS[torch.float32]
  48. for dtype in (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32)
  49. }
  50. )
  51. def default_tolerances(*inputs: Union[torch.Tensor, torch.dtype]) -> Tuple[float, float]:
  52. """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
  53. See :func:`assert_close` for a table of the default tolerance for each dtype.
  54. Returns:
  55. (Tuple[float, float]): Loosest tolerances of all input dtypes.
  56. """
  57. dtypes = []
  58. for input in inputs:
  59. if isinstance(input, torch.Tensor):
  60. dtypes.append(input.dtype)
  61. elif isinstance(input, torch.dtype):
  62. dtypes.append(input)
  63. else:
  64. raise TypeError(f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead.")
  65. rtols, atols = zip(*[_DTYPE_PRECISIONS.get(dtype, (0.0, 0.0)) for dtype in dtypes])
  66. return max(rtols), max(atols)
  67. def get_tolerances(
  68. *inputs: Union[torch.Tensor, torch.dtype], rtol: Optional[float], atol: Optional[float], id: Tuple[Any, ...] = ()
  69. ) -> Tuple[float, float]:
  70. """Gets absolute and relative to be used for numeric comparisons.
  71. If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of
  72. :func:`default_tolerances` is used.
  73. Raises:
  74. ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified.
  75. Returns:
  76. (Tuple[float, float]): Valid absolute and relative tolerances.
  77. """
  78. if (rtol is None) ^ (atol is None):
  79. # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
  80. # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
  81. raise ErrorMeta(
  82. ValueError,
  83. f"Both 'rtol' and 'atol' must be either specified or omitted, "
  84. f"but got no {'rtol' if rtol is None else 'atol'}.",
  85. id=id,
  86. )
  87. elif rtol is not None and atol is not None:
  88. return rtol, atol
  89. else:
  90. return default_tolerances(*inputs)
  91. def _make_mismatch_msg(
  92. *,
  93. default_identifier: str,
  94. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  95. extra: Optional[str] = None,
  96. abs_diff: float,
  97. abs_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
  98. atol: float,
  99. rel_diff: float,
  100. rel_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
  101. rtol: float,
  102. ) -> str:
  103. """Makes a mismatch error message for numeric values.
  104. Args:
  105. default_identifier (str): Default description of the compared values, e.g. "Tensor-likes".
  106. identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides
  107. ``default_identifier``. Can be passed as callable in which case it will be called with
  108. ``default_identifier`` to create the description at runtime.
  109. extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
  110. abs_diff (float): Absolute difference.
  111. abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference.
  112. atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are
  113. ``> 0``.
  114. rel_diff (float): Relative difference.
  115. rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference.
  116. rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are
  117. ``> 0``.
  118. """
  119. equality = rtol == 0 and atol == 0
  120. def make_diff_msg(*, type: str, diff: float, idx: Optional[Union[int, Tuple[int, ...]]], tol: float) -> str:
  121. if idx is None:
  122. msg = f"{type.title()} difference: {diff}"
  123. else:
  124. msg = f"Greatest {type} difference: {diff} at index {idx}"
  125. if not equality:
  126. msg += f" (up to {tol} allowed)"
  127. return msg + "\n"
  128. if identifier is None:
  129. identifier = default_identifier
  130. elif callable(identifier):
  131. identifier = identifier(default_identifier)
  132. msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n"
  133. if extra:
  134. msg += f"{extra.strip()}\n"
  135. msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol)
  136. msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol)
  137. return msg.strip()
  138. def make_scalar_mismatch_msg(
  139. actual: Union[int, float, complex],
  140. expected: Union[int, float, complex],
  141. *,
  142. rtol: float,
  143. atol: float,
  144. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  145. ) -> str:
  146. """Makes a mismatch error message for scalars.
  147. Args:
  148. actual (Union[int, float, complex]): Actual scalar.
  149. expected (Union[int, float, complex]): Expected scalar.
  150. rtol (float): Relative tolerance.
  151. atol (float): Absolute tolerance.
  152. identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed
  153. as callable in which case it will be called by the default value to create the description at runtime.
  154. Defaults to "Scalars".
  155. """
  156. abs_diff = abs(actual - expected)
  157. rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
  158. return _make_mismatch_msg(
  159. default_identifier="Scalars",
  160. identifier=identifier,
  161. abs_diff=abs_diff,
  162. atol=atol,
  163. rel_diff=rel_diff,
  164. rtol=rtol,
  165. )
  166. def make_tensor_mismatch_msg(
  167. actual: torch.Tensor,
  168. expected: torch.Tensor,
  169. mismatches: torch.Tensor,
  170. *,
  171. rtol: float,
  172. atol: float,
  173. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  174. ):
  175. """Makes a mismatch error message for tensors.
  176. Args:
  177. actual (torch.Tensor): Actual tensor.
  178. expected (torch.Tensor): Expected tensor.
  179. mismatches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the
  180. location of mismatches.
  181. rtol (float): Relative tolerance.
  182. atol (float): Absolute tolerance.
  183. identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed
  184. as callable in which case it will be called by the default value to create the description at runtime.
  185. Defaults to "Tensor-likes".
  186. """
  187. def unravel_flat_index(flat_index: int) -> Tuple[int, ...]:
  188. if not mismatches.shape:
  189. return ()
  190. inverse_index = []
  191. for size in mismatches.shape[::-1]:
  192. div, mod = divmod(flat_index, size)
  193. flat_index = div
  194. inverse_index.append(mod)
  195. return tuple(inverse_index[::-1])
  196. number_of_elements = mismatches.numel()
  197. total_mismatches = torch.sum(mismatches).item()
  198. extra = (
  199. f"Mismatched elements: {total_mismatches} / {number_of_elements} "
  200. f"({total_mismatches / number_of_elements:.1%})"
  201. )
  202. a_flat = actual.flatten()
  203. b_flat = expected.flatten()
  204. matches_flat = ~mismatches.flatten()
  205. abs_diff = torch.abs(a_flat - b_flat)
  206. # Ensure that only mismatches are used for the max_abs_diff computation
  207. abs_diff[matches_flat] = 0
  208. max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)
  209. rel_diff = abs_diff / torch.abs(b_flat)
  210. # Ensure that only mismatches are used for the max_rel_diff computation
  211. rel_diff[matches_flat] = 0
  212. max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0)
  213. return _make_mismatch_msg(
  214. default_identifier="Tensor-likes",
  215. identifier=identifier,
  216. extra=extra,
  217. abs_diff=max_abs_diff.item(),
  218. abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)),
  219. atol=atol,
  220. rel_diff=max_rel_diff.item(),
  221. rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)),
  222. rtol=rtol,
  223. )
  224. class UnsupportedInputs(Exception): # noqa: B903
  225. """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs."""
  226. class Pair(abc.ABC):
  227. """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`.
  228. Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison.
  229. Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the
  230. super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to
  231. handle the inputs and the next pair type will be tried.
  232. All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can
  233. be used to automatically handle overwriting the message with a user supplied one and id handling.
  234. """
  235. def __init__(
  236. self,
  237. actual: Any,
  238. expected: Any,
  239. *,
  240. id: Tuple[Any, ...] = (),
  241. **unknown_parameters: Any,
  242. ) -> None:
  243. self.actual = actual
  244. self.expected = expected
  245. self.id = id
  246. self._unknown_parameters = unknown_parameters
  247. @staticmethod
  248. def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]):
  249. """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise."""
  250. if not all(isinstance(input, cls) for input in inputs):
  251. raise UnsupportedInputs()
  252. def _make_error_meta(self, type: Type[Exception], msg: str) -> ErrorMeta:
  253. """Makes an :class:`ErrorMeta` from a given exception type and message and the stored id.
  254. .. warning::
  255. Since this method uses instance attributes of :class:`Pair`, it should not be used before the
  256. ``super().__init__(...)`` call in the constructor.
  257. """
  258. return ErrorMeta(type, msg, id=self.id)
  259. @abc.abstractmethod
  260. def compare(self) -> None:
  261. """Compares the inputs and returns an :class`ErrorMeta` in case they mismatch."""
  262. def extra_repr(self) -> Sequence[Union[str, Tuple[str, Any]]]:
  263. """Returns extra information that will be included in the representation.
  264. Should be overwritten by all subclasses that use additional options. The representation of the object will only
  265. be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of
  266. key-value-pairs or attribute names.
  267. """
  268. return []
  269. def __repr__(self) -> str:
  270. head = f"{type(self).__name__}("
  271. tail = ")"
  272. body = [
  273. f" {name}={value!s},"
  274. for name, value in [
  275. ("id", self.id),
  276. ("actual", self.actual),
  277. ("expected", self.expected),
  278. *[(extra, getattr(self, extra)) if isinstance(extra, str) else extra for extra in self.extra_repr()],
  279. ]
  280. ]
  281. return "\n".join((head, *body, *tail))
  282. class ObjectPair(Pair):
  283. """Pair for any type of inputs that will be compared with the `==` operator.
  284. .. note::
  285. Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs
  286. couldn't handle the inputs.
  287. """
  288. def compare(self) -> None:
  289. try:
  290. equal = self.actual == self.expected
  291. except Exception as error:
  292. raise self._make_error_meta(
  293. ValueError, f"{self.actual} == {self.expected} failed with:\n{error}."
  294. ) from error
  295. if not equal:
  296. raise self._make_error_meta(AssertionError, f"{self.actual} != {self.expected}")
  297. class NonePair(Pair):
  298. """Pair for ``None`` inputs."""
  299. def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None:
  300. if not (actual is None or expected is None):
  301. raise UnsupportedInputs()
  302. super().__init__(actual, expected, **other_parameters)
  303. def compare(self) -> None:
  304. if not (self.actual is None and self.expected is None):
  305. raise self._make_error_meta(AssertionError, f"None mismatch: {self.actual} is not {self.expected}")
  306. class BooleanPair(Pair):
  307. """Pair for :class:`bool` inputs.
  308. .. note::
  309. If ``numpy`` is available, also handles :class:`numpy.bool_` inputs.
  310. """
  311. def __init__(self, actual: Any, expected: Any, *, id: Tuple[Any, ...], **other_parameters: Any) -> None:
  312. actual, expected = self._process_inputs(actual, expected, id=id)
  313. super().__init__(actual, expected, **other_parameters)
  314. @property
  315. def _supported_types(self) -> Tuple[Type, ...]:
  316. cls: List[Type] = [bool]
  317. if NUMPY_AVAILABLE:
  318. cls.append(np.bool_)
  319. return tuple(cls)
  320. def _process_inputs(self, actual: Any, expected: Any, *, id: Tuple[Any, ...]) -> Tuple[bool, bool]:
  321. self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
  322. actual, expected = [self._to_bool(bool_like, id=id) for bool_like in (actual, expected)]
  323. return actual, expected
  324. def _to_bool(self, bool_like: Any, *, id: Tuple[Any, ...]) -> bool:
  325. if isinstance(bool_like, bool):
  326. return bool_like
  327. elif isinstance(bool_like, np.bool_):
  328. return bool_like.item()
  329. else:
  330. raise ErrorMeta(TypeError, f"Unknown boolean type {type(bool_like)}.", id=id)
  331. def compare(self) -> None:
  332. if self.actual is not self.expected:
  333. raise self._make_error_meta(AssertionError, f"Booleans mismatch: {self.actual} is not {self.expected}")
  334. class NumberPair(Pair):
  335. """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs.
  336. .. note::
  337. If ``numpy`` is available, also handles :class:`numpy.number` inputs.
  338. Kwargs:
  339. rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
  340. values based on the type are selected with the below table.
  341. atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
  342. values based on the type are selected with the below table.
  343. equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
  344. check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``.
  345. The following table displays correspondence between Python number type and the ``torch.dtype``'s. See
  346. :func:`assert_close` for the corresponding tolerances.
  347. +------------------+-------------------------------+
  348. | ``type`` | corresponding ``torch.dtype`` |
  349. +==================+===============================+
  350. | :class:`int` | :attr:`~torch.int64` |
  351. +------------------+-------------------------------+
  352. | :class:`float` | :attr:`~torch.float64` |
  353. +------------------+-------------------------------+
  354. | :class:`complex` | :attr:`~torch.complex64` |
  355. +------------------+-------------------------------+
  356. """
  357. _TYPE_TO_DTYPE = {
  358. int: torch.int64,
  359. float: torch.float64,
  360. complex: torch.complex128,
  361. }
  362. _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys())
  363. def __init__(
  364. self,
  365. actual: Any,
  366. expected: Any,
  367. *,
  368. id: Tuple[Any, ...] = (),
  369. rtol: Optional[float] = None,
  370. atol: Optional[float] = None,
  371. equal_nan: bool = False,
  372. check_dtype: bool = False,
  373. **other_parameters: Any,
  374. ) -> None:
  375. actual, expected = self._process_inputs(actual, expected, id=id)
  376. super().__init__(actual, expected, id=id, **other_parameters)
  377. self.rtol, self.atol = get_tolerances(
  378. *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)], rtol=rtol, atol=atol, id=id
  379. )
  380. self.equal_nan = equal_nan
  381. self.check_dtype = check_dtype
  382. @property
  383. def _supported_types(self) -> Tuple[Type, ...]:
  384. cls = list(self._NUMBER_TYPES)
  385. if NUMPY_AVAILABLE:
  386. cls.append(np.number)
  387. return tuple(cls)
  388. def _process_inputs(
  389. self, actual: Any, expected: Any, *, id: Tuple[Any, ...]
  390. ) -> Tuple[Union[int, float, complex], Union[int, float, complex]]:
  391. self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
  392. actual, expected = [self._to_number(number_like, id=id) for number_like in (actual, expected)]
  393. return actual, expected
  394. def _to_number(self, number_like: Any, *, id: Tuple[Any, ...]) -> Union[int, float, complex]:
  395. if NUMPY_AVAILABLE and isinstance(number_like, np.number):
  396. return number_like.item()
  397. elif isinstance(number_like, self._NUMBER_TYPES):
  398. return number_like
  399. else:
  400. raise ErrorMeta(TypeError, f"Unknown number type {type(number_like)}.", id=id)
  401. def compare(self) -> None:
  402. if self.check_dtype and type(self.actual) is not type(self.expected):
  403. raise self._make_error_meta(
  404. AssertionError,
  405. f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.",
  406. )
  407. if self.actual == self.expected:
  408. return
  409. if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected):
  410. return
  411. abs_diff = abs(self.actual - self.expected)
  412. tolerance = self.atol + self.rtol * abs(self.expected)
  413. if cmath.isfinite(abs_diff) and abs_diff <= tolerance:
  414. return
  415. raise self._make_error_meta(
  416. AssertionError, make_scalar_mismatch_msg(self.actual, self.expected, rtol=self.rtol, atol=self.atol)
  417. )
  418. def extra_repr(self) -> Sequence[str]:
  419. return (
  420. "rtol",
  421. "atol",
  422. "equal_nan",
  423. "check_dtype",
  424. )
  425. class TensorLikePair(Pair):
  426. """Pair for :class:`torch.Tensor`-like inputs.
  427. Kwargs:
  428. allow_subclasses (bool):
  429. rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
  430. values based on the type are selected. See :func:assert_close: for details.
  431. atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
  432. values based on the type are selected. See :func:assert_close: for details.
  433. equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
  434. check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
  435. :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
  436. :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
  437. check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
  438. check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
  439. :func:`torch.promote_types`) before being compared.
  440. check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
  441. check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
  442. compared.
  443. check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
  444. check_is_coalesced (bool): If ``True`` (default) and corresponding tensors are sparse COO, checks that both
  445. ``actual`` and ``expected`` are either coalesced or uncoalesced. If this check is disabled, tensors are
  446. :meth:`~torch.Tensor.coalesce`'ed before being compared.
  447. """
  448. def __init__(
  449. self,
  450. actual: Any,
  451. expected: Any,
  452. *,
  453. id: Tuple[Any, ...] = (),
  454. allow_subclasses: bool = True,
  455. rtol: Optional[float] = None,
  456. atol: Optional[float] = None,
  457. equal_nan: bool = False,
  458. check_device: bool = True,
  459. check_dtype: bool = True,
  460. check_layout: bool = True,
  461. check_stride: bool = False,
  462. check_is_coalesced: bool = True,
  463. **other_parameters: Any,
  464. ):
  465. actual, expected = self._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)
  466. super().__init__(actual, expected, id=id, **other_parameters)
  467. self.rtol, self.atol = get_tolerances(actual, expected, rtol=rtol, atol=atol, id=self.id)
  468. self.equal_nan = equal_nan
  469. self.check_device = check_device
  470. self.check_dtype = check_dtype
  471. self.check_layout = check_layout
  472. self.check_stride = check_stride
  473. self.check_is_coalesced = check_is_coalesced
  474. def _process_inputs(
  475. self, actual: Any, expected: Any, *, id: Tuple[Any, ...], allow_subclasses: bool
  476. ) -> Tuple[torch.Tensor, torch.Tensor]:
  477. directly_related = isinstance(actual, type(expected)) or isinstance(expected, type(actual))
  478. if not directly_related:
  479. raise UnsupportedInputs()
  480. if not allow_subclasses and type(actual) is not type(expected):
  481. raise UnsupportedInputs()
  482. actual, expected = [self._to_tensor(input) for input in (actual, expected)]
  483. for tensor in (actual, expected):
  484. self._check_supported(tensor, id=id)
  485. return actual, expected
  486. def _to_tensor(self, tensor_like: Any) -> torch.Tensor:
  487. if isinstance(tensor_like, torch.Tensor):
  488. return tensor_like
  489. try:
  490. return torch.as_tensor(tensor_like)
  491. except Exception:
  492. raise UnsupportedInputs()
  493. def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
  494. if tensor.layout not in {torch.strided,
  495. torch.sparse_coo,
  496. torch.sparse_csr,
  497. torch.sparse_csc,
  498. torch.sparse_bsr,
  499. torch.sparse_bsc}:
  500. raise ErrorMeta(ValueError, f"Unsupported tensor layout {tensor.layout}", id=id)
  501. def compare(self) -> None:
  502. actual, expected = self.actual, self.expected
  503. self._compare_attributes(actual, expected)
  504. if any(input.device.type == "meta" for input in (actual, expected)):
  505. return
  506. actual, expected = self._equalize_attributes(actual, expected)
  507. self._compare_values(actual, expected)
  508. def _compare_attributes(
  509. self,
  510. actual: torch.Tensor,
  511. expected: torch.Tensor,
  512. ) -> None:
  513. """Checks if the attributes of two tensors match.
  514. Always checks
  515. - the :attr:`~torch.Tensor.shape`,
  516. - whether both inputs are quantized or not,
  517. - and if they use the same quantization scheme.
  518. Checks for
  519. - :attr:`~torch.Tensor.layout`,
  520. - :meth:`~torch.Tensor.stride`,
  521. - :attr:`~torch.Tensor.device`, and
  522. - :attr:`~torch.Tensor.dtype`
  523. are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair.
  524. """
  525. def raise_mismatch_error(attribute_name: str, actual_value: Any, expected_value: Any) -> NoReturn:
  526. raise self._make_error_meta(
  527. AssertionError,
  528. f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.",
  529. )
  530. if actual.shape != expected.shape:
  531. raise_mismatch_error("shape", actual.shape, expected.shape)
  532. if actual.is_quantized != expected.is_quantized:
  533. raise_mismatch_error("is_quantized", actual.is_quantized, expected.is_quantized)
  534. elif actual.is_quantized and actual.qscheme() != expected.qscheme():
  535. raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme())
  536. if actual.layout != expected.layout:
  537. if self.check_layout:
  538. raise_mismatch_error("layout", actual.layout, expected.layout)
  539. elif actual.layout == torch.strided and self.check_stride and actual.stride() != expected.stride():
  540. raise_mismatch_error("stride()", actual.stride(), expected.stride())
  541. if self.check_device and actual.device != expected.device:
  542. raise_mismatch_error("device", actual.device, expected.device)
  543. if self.check_dtype and actual.dtype != expected.dtype:
  544. raise_mismatch_error("dtype", actual.dtype, expected.dtype)
  545. def _equalize_attributes(self, actual: torch.Tensor, expected: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  546. """Equalizes some attributes of two tensors for value comparison.
  547. If ``actual`` and ``expected`` are ...
  548. - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
  549. - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to
  550. :func:`torch.promote_types`).
  551. - ... not of the same ``layout``, they are converted to strided tensors.
  552. Args:
  553. actual (Tensor): Actual tensor.
  554. expected (Tensor): Expected tensor.
  555. Returns:
  556. (Tuple[Tensor, Tensor]): Equalized tensors.
  557. """
  558. if actual.device != expected.device:
  559. actual = actual.cpu()
  560. expected = expected.cpu()
  561. if actual.dtype != expected.dtype:
  562. dtype = torch.promote_types(actual.dtype, expected.dtype)
  563. actual = actual.to(dtype)
  564. expected = expected.to(dtype)
  565. if actual.layout != expected.layout:
  566. # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
  567. actual = actual.to_dense() if actual.layout != torch.strided else actual
  568. expected = expected.to_dense() if expected.layout != torch.strided else expected
  569. return actual, expected
  570. def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None:
  571. if actual.is_quantized:
  572. compare_fn = self._compare_quantized_values
  573. elif actual.is_sparse:
  574. compare_fn = self._compare_sparse_coo_values
  575. elif actual.layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
  576. compare_fn = self._compare_sparse_compressed_values
  577. else:
  578. compare_fn = self._compare_regular_values_close
  579. compare_fn(actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan)
  580. def _compare_quantized_values(
  581. self, actual: torch.Tensor, expected: torch.Tensor, *, rtol: float, atol: float, equal_nan: bool
  582. ) -> None:
  583. """Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness.
  584. .. note::
  585. A detailed discussion about why only the dequantized variant is checked for closeness rather than checking
  586. the individual quantization parameters for closeness and the integer representation for equality can be
  587. found in https://github.com/pytorch/pytorch/issues/68548.
  588. """
  589. return self._compare_regular_values_close(
  590. actual.dequantize(),
  591. expected.dequantize(),
  592. rtol=rtol,
  593. atol=atol,
  594. equal_nan=equal_nan,
  595. identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}",
  596. )
  597. def _compare_sparse_coo_values(
  598. self, actual: torch.Tensor, expected: torch.Tensor, *, rtol: float, atol: float, equal_nan: bool
  599. ) -> None:
  600. """Compares sparse COO tensors by comparing
  601. - the number of sparse dimensions,
  602. - the number of non-zero elements (nnz) for equality,
  603. - the indices for equality, and
  604. - the values for closeness.
  605. """
  606. if actual.sparse_dim() != expected.sparse_dim():
  607. raise self._make_error_meta(
  608. AssertionError,
  609. (
  610. f"The number of sparse dimensions in sparse COO tensors does not match: "
  611. f"{actual.sparse_dim()} != {expected.sparse_dim()}"
  612. ),
  613. )
  614. if actual._nnz() != expected._nnz():
  615. raise self._make_error_meta(
  616. AssertionError,
  617. (
  618. f"The number of specified values in sparse COO tensors does not match: "
  619. f"{actual._nnz()} != {expected._nnz()}"
  620. ),
  621. )
  622. self._compare_regular_values_equal(
  623. actual._indices(),
  624. expected._indices(),
  625. identifier="Sparse COO indices",
  626. )
  627. self._compare_regular_values_close(
  628. actual._values(),
  629. expected._values(),
  630. rtol=rtol,
  631. atol=atol,
  632. equal_nan=equal_nan,
  633. identifier="Sparse COO values",
  634. )
  635. def _compare_sparse_compressed_values(
  636. self, actual: torch.Tensor, expected: torch.Tensor, *, rtol: float, atol: float, equal_nan: bool
  637. ) -> None:
  638. """Compares sparse compressed tensors by comparing
  639. - the number of non-zero elements (nnz) for equality,
  640. - the plain indices for equality,
  641. - the compressed indices for equality, and
  642. - the values for closeness.
  643. """
  644. format_name, compressed_indices_method, plain_indices_method = {
  645. torch.sparse_csr: ('CSR', torch.Tensor.crow_indices, torch.Tensor.col_indices),
  646. torch.sparse_csc: ('CSC', torch.Tensor.ccol_indices, torch.Tensor.row_indices),
  647. torch.sparse_bsr: ('BSR', torch.Tensor.crow_indices, torch.Tensor.col_indices),
  648. torch.sparse_bsc: ('BSC', torch.Tensor.ccol_indices, torch.Tensor.row_indices),
  649. }[actual.layout]
  650. if actual._nnz() != expected._nnz():
  651. raise self._make_error_meta(
  652. AssertionError,
  653. (
  654. f"The number of specified values in sparse {format_name} tensors does not match: "
  655. f"{actual._nnz()} != {expected._nnz()}"
  656. ),
  657. )
  658. self._compare_regular_values_equal(
  659. compressed_indices_method(actual),
  660. compressed_indices_method(expected),
  661. identifier=f"Sparse {format_name} {compressed_indices_method.__name__}",
  662. )
  663. self._compare_regular_values_equal(
  664. plain_indices_method(actual),
  665. plain_indices_method(expected),
  666. identifier=f"Sparse {format_name} {plain_indices_method.__name__}",
  667. )
  668. self._compare_regular_values_close(
  669. actual.values(),
  670. expected.values(),
  671. rtol=rtol,
  672. atol=atol,
  673. equal_nan=equal_nan,
  674. identifier=f"Sparse {format_name} values",
  675. )
  676. def _compare_regular_values_equal(
  677. self,
  678. actual: torch.Tensor,
  679. expected: torch.Tensor,
  680. *,
  681. equal_nan: bool = False,
  682. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  683. ) -> None:
  684. """Checks if the values of two tensors are equal."""
  685. self._compare_regular_values_close(actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier)
  686. def _compare_regular_values_close(
  687. self,
  688. actual: torch.Tensor,
  689. expected: torch.Tensor,
  690. *,
  691. rtol: float,
  692. atol: float,
  693. equal_nan: bool,
  694. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  695. ) -> None:
  696. """Checks if the values of two tensors are close up to a desired tolerance."""
  697. actual, expected = self._promote_for_comparison(actual, expected)
  698. matches = torch.isclose(actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)
  699. if torch.all(matches):
  700. return
  701. if actual.shape == torch.Size([]):
  702. msg = make_scalar_mismatch_msg(actual.item(), expected.item(), rtol=rtol, atol=atol, identifier=identifier)
  703. else:
  704. msg = make_tensor_mismatch_msg(actual, expected, ~matches, rtol=rtol, atol=atol, identifier=identifier)
  705. raise self._make_error_meta(AssertionError, msg)
  706. def _promote_for_comparison(
  707. self, actual: torch.Tensor, expected: torch.Tensor
  708. ) -> Tuple[torch.Tensor, torch.Tensor]:
  709. """Promotes the inputs to the comparison dtype based on the input dtype.
  710. Returns:
  711. Inputs promoted to the highest precision dtype of the same dtype category. :class:`torch.bool` is treated
  712. as integral dtype.
  713. """
  714. # This is called after self._equalize_attributes() and thus `actual` and `expected` already have the same dtype.
  715. if actual.dtype.is_complex:
  716. dtype = torch.complex128
  717. elif actual.dtype.is_floating_point:
  718. dtype = torch.float64
  719. else:
  720. dtype = torch.int64
  721. return actual.to(dtype), expected.to(dtype)
  722. def extra_repr(self) -> Sequence[str]:
  723. return (
  724. "rtol",
  725. "atol",
  726. "equal_nan",
  727. "check_device",
  728. "check_dtype",
  729. "check_layout",
  730. "check_stride",
  731. "check_is_coalesced",
  732. )
  733. def originate_pairs(
  734. actual: Any,
  735. expected: Any,
  736. *,
  737. pair_types: Sequence[Type[Pair]],
  738. sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
  739. mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
  740. id: Tuple[Any, ...] = (),
  741. **options: Any,
  742. ) -> List[Pair]:
  743. """Originates pairs from the individual inputs.
  744. ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
  745. :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them.
  746. Args:
  747. actual (Any): Actual input.
  748. expected (Any): Expected input.
  749. pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs.
  750. First successful pair will be used.
  751. sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
  752. mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
  753. id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message.
  754. **options (Any): Options passed to each pair during construction.
  755. Raises:
  756. ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their
  757. length does not match.
  758. ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of
  759. keys do not match.
  760. ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs.
  761. ErrorMeta: With any expected exception that happens during the construction of a pair.
  762. Returns:
  763. (List[Pair]): Originated pairs.
  764. """
  765. # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
  766. # "a" == "a"[0][0]...
  767. if (
  768. isinstance(actual, sequence_types)
  769. and not isinstance(actual, str)
  770. and isinstance(expected, sequence_types)
  771. and not isinstance(expected, str)
  772. ):
  773. actual_len = len(actual)
  774. expected_len = len(expected)
  775. if actual_len != expected_len:
  776. raise ErrorMeta(
  777. AssertionError, f"The length of the sequences mismatch: {actual_len} != {expected_len}", id=id
  778. )
  779. pairs = []
  780. for idx in range(actual_len):
  781. pairs.extend(
  782. originate_pairs(
  783. actual[idx],
  784. expected[idx],
  785. pair_types=pair_types,
  786. sequence_types=sequence_types,
  787. mapping_types=mapping_types,
  788. id=(*id, idx),
  789. **options,
  790. )
  791. )
  792. return pairs
  793. elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types):
  794. actual_keys = set(actual.keys())
  795. expected_keys = set(expected.keys())
  796. if actual_keys != expected_keys:
  797. missing_keys = expected_keys - actual_keys
  798. additional_keys = actual_keys - expected_keys
  799. raise ErrorMeta(
  800. AssertionError,
  801. (
  802. f"The keys of the mappings do not match:\n"
  803. f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
  804. f"Additional keys in the actual mapping: {sorted(additional_keys)}"
  805. ),
  806. id=id,
  807. )
  808. keys: Collection = actual_keys
  809. # Since the origination aborts after the first failure, we try to be deterministic
  810. with contextlib.suppress(Exception):
  811. keys = sorted(keys)
  812. pairs = []
  813. for key in keys:
  814. pairs.extend(
  815. originate_pairs(
  816. actual[key],
  817. expected[key],
  818. pair_types=pair_types,
  819. sequence_types=sequence_types,
  820. mapping_types=mapping_types,
  821. id=(*id, key),
  822. **options,
  823. )
  824. )
  825. return pairs
  826. else:
  827. for pair_type in pair_types:
  828. try:
  829. return [pair_type(actual, expected, id=id, **options)]
  830. # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
  831. # inputs. Thus, we try the next pair type.
  832. except UnsupportedInputs:
  833. continue
  834. # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This
  835. # is only in a separate branch, because the one below would also except it.
  836. except ErrorMeta:
  837. raise
  838. # Raising any other exception during origination is unexpected and will give some extra information about
  839. # what happened. If applicable, the exception should be expected in the future.
  840. except Exception as error:
  841. raise RuntimeError(
  842. f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n"
  843. f"{type(actual).__name__}(): {actual}\n\n"
  844. f"and\n\n"
  845. f"{type(expected).__name__}(): {expected}\n\n"
  846. f"resulted in the unexpected exception above. "
  847. f"If you are a user and see this message during normal operation "
  848. "please file an issue at https://github.com/pytorch/pytorch/issues. "
  849. "If you are a developer and working on the comparison functions, "
  850. "please except the previous error and raise an expressive `ErrorMeta` instead."
  851. ) from error
  852. else:
  853. raise ErrorMeta(
  854. TypeError,
  855. f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.",
  856. id=id,
  857. )
  858. def assert_equal(
  859. actual: Any,
  860. expected: Any,
  861. *,
  862. pair_types: Sequence[Type[Pair]] = (ObjectPair,),
  863. sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
  864. mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
  865. msg: Optional[Union[str, Callable[[str], str]]] = None,
  866. **options: Any,
  867. ) -> None:
  868. """Asserts that inputs are equal.
  869. ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
  870. :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them.
  871. Args:
  872. actual (Any): Actual input.
  873. expected (Any): Expected input.
  874. pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the
  875. inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`.
  876. sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
  877. mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
  878. **options (Any): Options passed to each pair during construction.
  879. """
  880. # Hide this function from `pytest`'s traceback
  881. __tracebackhide__ = True
  882. # TODO: the Tensor compare uses bunch of operations which is currently not
  883. # supported by MPS. We will remove this move to CPU after all the
  884. # support is added. https://github.com/pytorch/pytorch/issues/77144
  885. if isinstance(actual, torch.Tensor) and (actual.is_mps):
  886. actual = actual.to('cpu')
  887. if isinstance(expected, torch.Tensor) and (expected.is_mps):
  888. expected = expected.to('cpu')
  889. try:
  890. pairs = originate_pairs(
  891. actual,
  892. expected,
  893. pair_types=pair_types,
  894. sequence_types=sequence_types,
  895. mapping_types=mapping_types,
  896. **options,
  897. )
  898. except ErrorMeta as error_meta:
  899. # Explicitly raising from None to hide the internal traceback
  900. raise error_meta.to_error() from None
  901. error_metas: List[ErrorMeta] = []
  902. for pair in pairs:
  903. try:
  904. pair.compare()
  905. except ErrorMeta as error_meta:
  906. error_metas.append(error_meta)
  907. # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information
  908. # about what happened. If applicable, the exception should be expected in the future.
  909. except Exception as error:
  910. raise RuntimeError(
  911. f"Comparing\n\n"
  912. f"{pair}\n\n"
  913. f"resulted in the unexpected exception above. "
  914. f"If you are a user and see this message during normal operation "
  915. "please file an issue at https://github.com/pytorch/pytorch/issues. "
  916. "If you are a developer and working on the comparison functions, "
  917. "please except the previous error and raise an expressive `ErrorMeta` instead."
  918. ) from error
  919. if not error_metas:
  920. return
  921. # TODO: compose all metas into one AssertionError
  922. raise error_metas[0].to_error(msg)
  923. def assert_close(
  924. actual: Any,
  925. expected: Any,
  926. *,
  927. allow_subclasses: bool = True,
  928. rtol: Optional[float] = None,
  929. atol: Optional[float] = None,
  930. equal_nan: bool = False,
  931. check_device: bool = True,
  932. check_dtype: bool = True,
  933. check_layout: bool = True,
  934. check_stride: bool = False,
  935. msg: Optional[Union[str, Callable[[str], str]]] = None,
  936. ):
  937. r"""Asserts that ``actual`` and ``expected`` are close.
  938. If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if
  939. .. math::
  940. \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert
  941. Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are
  942. only considered equal to each other if ``equal_nan`` is ``True``.
  943. In addition, they are only considered close if they have the same
  944. - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``),
  945. - ``dtype`` (if ``check_dtype`` is ``True``),
  946. - ``layout`` (if ``check_layout`` is ``True``), and
  947. - stride (if ``check_stride`` is ``True``).
  948. If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed.
  949. If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are
  950. checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR,
  951. or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively,
  952. are always checked for equality whereas the values are checked for closeness according to the definition above.
  953. If ``actual`` and ``expected`` are quantized, they are considered close if they have the same
  954. :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the
  955. definition above.
  956. ``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which
  957. :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types
  958. have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s
  959. or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all
  960. their elements are considered close according to the above definition.
  961. .. note::
  962. Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e.
  963. :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus,
  964. Python scalars of different types can be checked, but require ``check_dtype=False``.
  965. Args:
  966. actual (Any): Actual input.
  967. expected (Any): Expected input.
  968. allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types
  969. are allowed. Otherwise type equality is required.
  970. rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
  971. values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
  972. atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
  973. values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
  974. equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal.
  975. check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
  976. :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
  977. :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
  978. check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
  979. check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
  980. :func:`torch.promote_types`) before being compared.
  981. check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
  982. check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
  983. compared.
  984. check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
  985. msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during
  986. the comparison. Can also passed as callable in which case it will be called with the generated message and
  987. should return the new message.
  988. Raises:
  989. ValueError: If no :class:`torch.Tensor` can be constructed from an input.
  990. ValueError: If only ``rtol`` or ``atol`` is specified.
  991. NotImplementedError: If a tensor is a meta tensor. This is a temporary restriction and will be relaxed in the
  992. future.
  993. AssertionError: If corresponding inputs are not Python scalars and are not directly related.
  994. AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have
  995. different types.
  996. AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
  997. AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
  998. AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
  999. AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same
  1000. :attr:`~torch.Tensor.layout`.
  1001. AssertionError: If only one of corresponding tensors is quantized.
  1002. AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s.
  1003. AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same
  1004. :attr:`~torch.Tensor.device`.
  1005. AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``.
  1006. AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride.
  1007. AssertionError: If the values of corresponding tensors are not close according to the definition above.
  1008. The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching
  1009. ``dtype``'s, the maximum of both tolerances is used.
  1010. +---------------------------+------------+----------+
  1011. | ``dtype`` | ``rtol`` | ``atol`` |
  1012. +===========================+============+==========+
  1013. | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` |
  1014. +---------------------------+------------+----------+
  1015. | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` |
  1016. +---------------------------+------------+----------+
  1017. | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` |
  1018. +---------------------------+------------+----------+
  1019. | :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
  1020. +---------------------------+------------+----------+
  1021. | :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` |
  1022. +---------------------------+------------+----------+
  1023. | :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
  1024. +---------------------------+------------+----------+
  1025. | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
  1026. +---------------------------+------------+----------+
  1027. | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` |
  1028. +---------------------------+------------+----------+
  1029. | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` |
  1030. +---------------------------+------------+----------+
  1031. | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` |
  1032. +---------------------------+------------+----------+
  1033. | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` |
  1034. +---------------------------+------------+----------+
  1035. | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` |
  1036. +---------------------------+------------+----------+
  1037. | other | ``0.0`` | ``0.0`` |
  1038. +---------------------------+------------+----------+
  1039. .. note::
  1040. :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged
  1041. to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might
  1042. define an ``assert_equal`` that uses zero tolrances for every ``dtype`` by default:
  1043. >>> import functools
  1044. >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
  1045. >>> assert_equal(1e-9, 1e-10)
  1046. Traceback (most recent call last):
  1047. ...
  1048. AssertionError: Scalars are not equal!
  1049. <BLANKLINE>
  1050. Absolute difference: 9.000000000000001e-10
  1051. Relative difference: 9.0
  1052. Examples:
  1053. >>> # tensor to tensor comparison
  1054. >>> expected = torch.tensor([1e0, 1e-1, 1e-2])
  1055. >>> actual = torch.acos(torch.cos(expected))
  1056. >>> torch.testing.assert_close(actual, expected)
  1057. >>> # scalar to scalar comparison
  1058. >>> import math
  1059. >>> expected = math.sqrt(2.0)
  1060. >>> actual = 2.0 / math.sqrt(2.0)
  1061. >>> torch.testing.assert_close(actual, expected)
  1062. >>> # numpy array to numpy array comparison
  1063. >>> import numpy as np
  1064. >>> expected = np.array([1e0, 1e-1, 1e-2])
  1065. >>> actual = np.arccos(np.cos(expected))
  1066. >>> torch.testing.assert_close(actual, expected)
  1067. >>> # sequence to sequence comparison
  1068. >>> import numpy as np
  1069. >>> # The types of the sequences do not have to match. They only have to have the same
  1070. >>> # length and their elements have to match.
  1071. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
  1072. >>> actual = tuple(expected)
  1073. >>> torch.testing.assert_close(actual, expected)
  1074. >>> # mapping to mapping comparison
  1075. >>> from collections import OrderedDict
  1076. >>> import numpy as np
  1077. >>> foo = torch.tensor(1.0)
  1078. >>> bar = 2.0
  1079. >>> baz = np.array(3.0)
  1080. >>> # The types and a possible ordering of mappings do not have to match. They only
  1081. >>> # have to have the same set of keys and their elements have to match.
  1082. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
  1083. >>> actual = {"baz": baz, "bar": bar, "foo": foo}
  1084. >>> torch.testing.assert_close(actual, expected)
  1085. >>> expected = torch.tensor([1.0, 2.0, 3.0])
  1086. >>> actual = expected.clone()
  1087. >>> # By default, directly related instances can be compared
  1088. >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
  1089. >>> # This check can be made more strict with allow_subclasses=False
  1090. >>> torch.testing.assert_close(
  1091. ... torch.nn.Parameter(actual), expected, allow_subclasses=False
  1092. ... )
  1093. Traceback (most recent call last):
  1094. ...
  1095. TypeError: No comparison pair was able to handle inputs of type
  1096. <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
  1097. >>> # If the inputs are not directly related, they are never considered close
  1098. >>> torch.testing.assert_close(actual.numpy(), expected)
  1099. Traceback (most recent call last):
  1100. ...
  1101. TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
  1102. and <class 'torch.Tensor'>.
  1103. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of
  1104. >>> # their type if check_dtype=False.
  1105. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
  1106. >>> # NaN != NaN by default.
  1107. >>> expected = torch.tensor(float("Nan"))
  1108. >>> actual = expected.clone()
  1109. >>> torch.testing.assert_close(actual, expected)
  1110. Traceback (most recent call last):
  1111. ...
  1112. AssertionError: Scalars are not close!
  1113. <BLANKLINE>
  1114. Absolute difference: nan (up to 1e-05 allowed)
  1115. Relative difference: nan (up to 1.3e-06 allowed)
  1116. >>> torch.testing.assert_close(actual, expected, equal_nan=True)
  1117. >>> expected = torch.tensor([1.0, 2.0, 3.0])
  1118. >>> actual = torch.tensor([1.0, 4.0, 5.0])
  1119. >>> # The default error message can be overwritten.
  1120. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
  1121. Traceback (most recent call last):
  1122. ...
  1123. AssertionError: Argh, the tensors are not close!
  1124. >>> # If msg is a callable, it can be used to augment the generated message with
  1125. >>> # extra information
  1126. >>> torch.testing.assert_close(
  1127. ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
  1128. ... )
  1129. Traceback (most recent call last):
  1130. ...
  1131. AssertionError: Header
  1132. <BLANKLINE>
  1133. Tensor-likes are not close!
  1134. <BLANKLINE>
  1135. Mismatched elements: 2 / 3 (66.7%)
  1136. Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
  1137. Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)
  1138. <BLANKLINE>
  1139. Footer
  1140. """
  1141. # Hide this function from `pytest`'s traceback
  1142. __tracebackhide__ = True
  1143. assert_equal(
  1144. actual,
  1145. expected,
  1146. pair_types=(
  1147. NonePair,
  1148. BooleanPair,
  1149. NumberPair,
  1150. TensorLikePair,
  1151. ),
  1152. allow_subclasses=allow_subclasses,
  1153. rtol=rtol,
  1154. atol=atol,
  1155. equal_nan=equal_nan,
  1156. check_device=check_device,
  1157. check_dtype=check_dtype,
  1158. check_layout=check_layout,
  1159. check_stride=check_stride,
  1160. msg=msg,
  1161. )