utils.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083
  1. from __future__ import annotations
  2. from typing import Any, Union, Sequence, Optional, Callable, Dict, Tuple, List
  3. from enum import Enum
  4. from functools import reduce, cmp_to_key
  5. import operator
  6. import torch
  7. # nvFuser imports are conditional on being compiled with CUDA
  8. if hasattr(torch._C, "_nvfuser"):
  9. from torch._C._nvfuser import DataType # type: ignore[import]
  10. _torch_dtype_to_nvfuser_dtype_map = {
  11. torch.cdouble: DataType.ComplexDouble,
  12. torch.cfloat: DataType.ComplexFloat,
  13. torch.double: DataType.Double,
  14. torch.float: DataType.Float,
  15. torch.half: DataType.Half,
  16. torch.bfloat16: DataType.BFloat16,
  17. torch.long: DataType.Int,
  18. torch.int: DataType.Int32,
  19. torch.bool: DataType.Bool,
  20. }
  21. else:
  22. _torch_dtype_to_nvfuser_dtype_map = {}
  23. def getnvFuserDtype(dtype: torch.dtype):
  24. """
  25. Translates from torch.dtype to nvFuser's DataType enum
  26. """
  27. return _torch_dtype_to_nvfuser_dtype_map[dtype]
  28. ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
  29. StrideType = Union[List[int], Tuple[int, ...]]
  30. DimsType = Union[int, List[int], Tuple[int, ...]]
  31. DimsSequenceType = Union[List[int], Tuple[int, ...]]
  32. NumberType = Union[bool, int, float, complex]
  33. Number = (bool, int, float, complex)
  34. class TensorMeta(torch.Tensor):
  35. """
  36. Model tensor metadata. Not a stock meta tensor because device is modeled
  37. as the original device (not meta device), also we have different behavior
  38. for some high level Python bindings
  39. """
  40. # Note: this will be an fx Node if it's ever
  41. # populated, but some Meta-internal jobs don't include fx
  42. node: Optional[Any]
  43. tname: str
  44. @staticmethod
  45. def __new__(
  46. cls,
  47. tensorlike: Optional[Union[TensorMeta, NumberType, torch.Tensor]] = None,
  48. *,
  49. shape: Optional[ShapeType] = None,
  50. strides: Optional[StrideType] = None,
  51. dtype: Optional[torch.dtype] = None,
  52. device: Optional[Union[torch.device, str]] = None,
  53. ):
  54. if isinstance(tensorlike, Number):
  55. assert not shape and (shape is None or isinstance(shape, Sequence))
  56. assert not strides and (strides is None or isinstance(strides, Sequence))
  57. inferred_shape: Tuple[int, ...] = ()
  58. inferred_strides: Tuple[int, ...] = ()
  59. inferred_dtype = type_to_dtype(type(tensorlike))
  60. inferred_device = torch.device("cpu")
  61. # TODO: This looks wrong, a number that is wrapped into a tensor
  62. # needs to behave differently than a scalar tensor for type
  63. # promotion purposes
  64. elif tensorlike is not None:
  65. assert isinstance(tensorlike, (TensorMeta, torch.Tensor))
  66. inferred_shape = tuple(tensorlike.shape)
  67. inferred_strides = tuple(tensorlike.stride())
  68. inferred_dtype = tensorlike.dtype
  69. inferred_device = tensorlike.device
  70. else:
  71. # If no tensorlike "example" is given then all metadata
  72. # must be provided explicitly
  73. assert shape is not None
  74. assert strides is not None
  75. assert dtype is not None
  76. assert device is not None
  77. shape = inferred_shape if shape is None else tuple(shape)
  78. strides = inferred_strides if strides is None else tuple(strides)
  79. dtype = inferred_dtype if dtype is None else dtype
  80. device = inferred_device if device is None else device
  81. if isinstance(device, str):
  82. device = torch.device(device)
  83. r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
  84. cls,
  85. shape,
  86. strides=strides,
  87. storage_offset=0, # TODO: this is inaccurate
  88. dtype=dtype,
  89. device=device,
  90. requires_grad=False,
  91. )
  92. r.tname = ""
  93. r.node = None
  94. return r
  95. @classmethod
  96. def __torch_function__(
  97. cls,
  98. func: Callable,
  99. types: Sequence,
  100. args: Sequence[Any] = (),
  101. kwargs: Optional[Dict] = None,
  102. ):
  103. if kwargs is None:
  104. kwargs = {}
  105. if func in {
  106. torch.Tensor.ndim.__get__, # type: ignore[attr-defined]
  107. torch.Tensor.numel,
  108. torch.Tensor.stride,
  109. torch.Tensor.dtype.__get__, # type: ignore[attr-defined]
  110. torch.Tensor.shape.__get__, # type: ignore[attr-defined]
  111. torch.Tensor.device.__get__, # type: ignore[attr-defined]
  112. }:
  113. return super().__torch_function__(func, types, args, kwargs)
  114. if not hasattr(func, "meta"):
  115. raise ValueError(f"Callable {func} has no meta function!")
  116. return func.meta(*args, **kwargs) # type: ignore[attr-defined]
  117. @classmethod
  118. def __torch_dispatch__(
  119. cls,
  120. func,
  121. types,
  122. args=(),
  123. kwargs=None,
  124. ):
  125. raise RuntimeError("this should be unreachable")
  126. # TODO: fx uses dunder repr to print objects in code
  127. def __repr__(self):
  128. return self.tname
  129. # return f"TensorMeta(dtype={self.dtype}, device={self.device}, shape={self.shape}, strides={self.stride()})"
  130. def __format__(self, format_spec):
  131. return self.tname
  132. TensorLikeType = Union[torch.Tensor, TensorMeta]
  133. TensorLike = (torch.Tensor, TensorMeta)
  134. TensorSequenceType = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
  135. # TODO: look at using torch.testing.assert_close instead with an option
  136. # to just compare metadata
  137. def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType):
  138. """
  139. Checks that two tensor likes have the same shape,
  140. dtype and device.
  141. In the future this will validate additional metadata, like
  142. strides.
  143. """
  144. assert isinstance(a, TensorLike)
  145. assert isinstance(b, TensorLike)
  146. for x, y in zip(a.shape, b.shape):
  147. if x != y:
  148. msg = "Shapes {0} and {1} are not equal!".format(a.shape, b.shape)
  149. raise AssertionError(msg)
  150. if a.dtype != b.dtype:
  151. msg = "Dtypes {0} and {1} are not equal!".format(a.dtype, b.dtype)
  152. raise AssertionError(msg)
  153. if a.device != b.device:
  154. # Handles special cuda:0 vs cuda case
  155. # TODO: we should review why this happens and see about fixing it
  156. if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and (
  157. str(b.device) == "cuda:0" or str(b.device) == "cuda"
  158. ):
  159. pass
  160. else:
  161. msg = "Devices {0} and {1} are not equal!".format(a.device, b.device)
  162. raise AssertionError(msg)
  163. same_strides, idx = check_significant_strides(a, b)
  164. if not same_strides:
  165. msg = "Stride mismatch! Strides are {0} and {1} (mismatched at {2})!".format(
  166. a.stride(), b.stride(), idx
  167. )
  168. raise RuntimeError(msg)
  169. def check_significant_strides(
  170. a: TensorLikeType, b: TensorLikeType
  171. ) -> Tuple[bool, Optional[int]]:
  172. # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
  173. # See https://github.com/pytorch/pytorch/issues/77553
  174. # Only compares strides that are "meaningful" -- strides for dimensions with length > 1
  175. # and for tensors with more than one element
  176. if (a.device.type == "cuda" or b.device.type == "cuda") and a.numel() > 0:
  177. for idx in range(a.ndim):
  178. if a.stride()[idx] != b.stride()[idx] and a.shape[idx] > 1:
  179. return False, idx
  180. return True, None
  181. def is_contiguous(a: TensorLikeType) -> bool:
  182. """
  183. Tests whether a tensor is contiguous or not.
  184. Tensors are contiguous when they have no elements,
  185. or when they have "nested" strides.
  186. """
  187. if a.numel() == 0:
  188. return True
  189. expected_stride = 1
  190. for x, y in reversed(tuple(zip(a.shape, a.stride()))):
  191. # Skips checking strides when a dimension has length 1
  192. if x == 1:
  193. continue
  194. if y != expected_stride:
  195. return False
  196. expected_stride = expected_stride * x
  197. return True
  198. # NOTE: Based on the implementation in TensorIterator.cpp, but note that
  199. # the note [Computing output strides] is incorrect, because it
  200. # says that strides will be preserved even if they are not
  201. # "non overlapping and dense", but this is incorrect. The
  202. # output of elementwise operations are always given
  203. # non overlapping and dense strides.
  204. # This is also INCORRECT because it does not model TensorIterator's
  205. # short-circuit, which can cause different strides.
  206. def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
  207. """
  208. Computes the output strides for elementwise operations.
  209. """
  210. if len(tensors) == 0:
  211. msg = "Can't compute elementwise output strides for zero tensors!"
  212. raise ValueError(msg)
  213. check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
  214. # Filters the tensors to actual tensors
  215. all_tensors = all(isinstance(a, TensorLike) for a in tensors)
  216. tensors = tuple(
  217. a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
  218. )
  219. # Short-circuits for CPU scalar case
  220. if len(tensors) == 0:
  221. return ()
  222. # Short-circuits for shapes with zero or one dimensions
  223. # TODO: are these necessary?
  224. ndim = tensors[0].ndim
  225. if ndim == 0:
  226. return ()
  227. if ndim == 1:
  228. return (1,)
  229. shape = tensors[0].shape
  230. def _cmp(idx_a, idx_b):
  231. for tensor in tensors:
  232. stride_a = tensor.stride()[idx_a]
  233. stride_b = tensor.stride()[idx_b]
  234. if stride_a == 0 or stride_b == 0:
  235. continue
  236. if stride_a < stride_b:
  237. return -1
  238. if stride_a > stride_b:
  239. return 1
  240. # stride_a == stride_b
  241. if shape[idx_a] > shape[idx_b]:
  242. return 1
  243. # NOTE: this case is missing in the C++ impl
  244. if shape[idx_a] < shape[idx_b]:
  245. return -1
  246. # Note: this case is hit if all strides are zero,
  247. # or all strides are equal and all dimensions have the same length
  248. return 0
  249. perm = tuple(range(ndim))
  250. perm = tuple(sorted(perm, key=cmp_to_key(_cmp), reverse=True))
  251. permuted_shape = [-1] * ndim
  252. for idx, x in enumerate(perm):
  253. permuted_shape[idx] = shape[x]
  254. new_strides = make_contiguous_strides_for(permuted_shape)
  255. # print(f"new_strides is {new_strides}")
  256. # print(f"shape is {shape}")
  257. # print(f"permuted_shape is {permuted_shape}")
  258. permuted_strides = [-1] * ndim
  259. for idx, x in enumerate(perm):
  260. permuted_strides[x] = new_strides[idx]
  261. return tuple(permuted_strides)
  262. #
  263. # Common helper functions
  264. #
  265. def validate_dim_length(length: int):
  266. """
  267. Validates that an object represents a valid
  268. dimension length.
  269. """
  270. assert isinstance(length, int)
  271. assert length >= 0
  272. def validate_shape(shape: ShapeType):
  273. """
  274. Validates that a sequence represents a valid shape.
  275. """
  276. assert isinstance(shape, Sequence)
  277. for l in shape:
  278. validate_dim_length(l)
  279. def validate_strides(strides: StrideType):
  280. """
  281. Verifies the object specifies valid strides.
  282. """
  283. assert isinstance(strides, Sequence)
  284. for stride in strides:
  285. assert stride >= 0
  286. def validate_idx(rank: int, idx: int):
  287. """
  288. Validates that idx is a valid index for the given shape.
  289. Assumes the index is already canonicalized.
  290. """
  291. assert isinstance(idx, int)
  292. assert isinstance(rank, int)
  293. assert idx >= 0 and idx < rank or idx == 0
  294. def validate_dimension_indices(rank: int, indices: DimsSequenceType):
  295. for idx in indices:
  296. validate_idx(rank, idx)
  297. def validate_exclusive_idx(rank: int, ex_idx: int):
  298. """
  299. Validates that ex_idx is a valid exclusive index
  300. for the given shape.
  301. """
  302. assert isinstance(ex_idx, int)
  303. assert isinstance(rank, int)
  304. assert ex_idx > 0 and ex_idx <= rank
  305. # "Wraps" a dim (up to one time) for the given rank, allowing
  306. # dims to be specified using negative indices
  307. def canonicalize_dim(rank: int, idx: int) -> int:
  308. # TODO: add a comment for why this is
  309. _rank = rank if rank != 0 else 1
  310. if idx >= 0 and idx < _rank:
  311. return idx
  312. if idx < 0:
  313. _idx = idx + _rank
  314. else:
  315. _idx = idx
  316. if _idx < 0 or _idx > _rank:
  317. msg = "Received out of bounds index {0} for tensor of rank {1}!".format(
  318. idx, rank
  319. )
  320. raise ValueError(msg)
  321. return _idx
  322. # Takes a dimension or sequence of dimensions and "wraps" them,
  323. # mapping negative offsets to positive ones
  324. def canonicalize_dims(rank: int, indices: DimsType) -> DimsType:
  325. if isinstance(indices, int):
  326. return canonicalize_dim(rank, indices)
  327. return tuple(canonicalize_dim(rank, x) for x in indices)
  328. def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
  329. """
  330. Validates that perm is a permutation of length rank.
  331. """
  332. if not isinstance(perm, Sequence):
  333. return False
  334. if not (tuple(sorted(perm)) == tuple(range(0, rank))):
  335. return False
  336. return True
  337. def is_same_shape(a: Sequence, b: Sequence) -> bool:
  338. """
  339. Compares two shapes a and b, returning True if they are the same
  340. (their ranks and corresponding lengths match) and False otherwise.
  341. """
  342. return tuple(a) == tuple(b)
  343. def is_cpu_scalar_tensor(a: Any) -> bool:
  344. return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu"
  345. def check_same_device(*args, allow_cpu_scalar_tensors):
  346. """
  347. Checks that all Tensors in args have the same device.
  348. Raises a RuntimeError when:
  349. - args contains an object whose type is not Tensor or Number
  350. - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True
  351. """
  352. # Short-circuits if all (one or fewer) arguments are trivially on the same device
  353. if len(args) <= 1:
  354. return
  355. # Note: cannot initialize device to the first arg's device (it may not have one)
  356. device = None
  357. for arg in args:
  358. if isinstance(arg, Number):
  359. continue
  360. elif isinstance(arg, TensorLike):
  361. if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
  362. continue
  363. if device is None:
  364. device = arg.device
  365. if device != arg.device:
  366. msg = (
  367. "Tensor on device "
  368. + str(arg.device)
  369. + " is not on the expected device "
  370. + str(device)
  371. + "!"
  372. )
  373. raise RuntimeError(msg)
  374. else:
  375. msg = (
  376. "Unexpected type when checking for same device, " + str(type(arg)) + "!"
  377. )
  378. raise RuntimeError(msg)
  379. # Asserts if any of the following are true:
  380. # - a non-scalar or non-Tensor is given
  381. # - the shape of any tensors is distinct
  382. def check_same_shape(*args, allow_cpu_scalar_tensors):
  383. """
  384. Checks that all Tensors in args have the same shape.
  385. Raises a RuntimeError when:
  386. - args contains an object whose type is not Tensor or Number
  387. - two Tensor objects in args have different devices
  388. """
  389. shape = None
  390. for arg in args:
  391. if isinstance(arg, Number):
  392. continue
  393. elif isinstance(arg, TensorLike):
  394. if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
  395. continue
  396. if shape is None:
  397. shape = arg.shape
  398. if not is_same_shape(shape, arg.shape):
  399. msg = "Shape {0} is not the expected shape {1}!".format(
  400. arg.shape, shape
  401. )
  402. raise RuntimeError(msg)
  403. else:
  404. msg = (
  405. "Unexpected type when checking for same shape, " + str(type(arg)) + "!"
  406. )
  407. raise RuntimeError(msg)
  408. _integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
  409. _float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64)
  410. _complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)
  411. def is_boolean_dtype(dtype: torch.dtype) -> bool:
  412. assert isinstance(dtype, torch.dtype)
  413. return dtype is torch.bool
  414. def is_integer_dtype(dtype: torch.dtype) -> bool:
  415. assert isinstance(dtype, torch.dtype)
  416. return dtype in _integer_dtypes
  417. def is_float_dtype(dtype: torch.dtype) -> bool:
  418. assert isinstance(dtype, torch.dtype)
  419. return dtype in _float_dtypes
  420. def is_complex_dtype(dtype: torch.dtype) -> bool:
  421. assert isinstance(dtype, torch.dtype)
  422. return dtype in _complex_dtypes
  423. _complex_to_real_dtype_map = {
  424. torch.complex128: torch.float64,
  425. torch.complex64: torch.float32,
  426. torch.complex32: torch.float16,
  427. }
  428. _real_to_complex_dtype_map = {
  429. torch.float16: torch.complex32,
  430. torch.bfloat16: torch.complex64,
  431. torch.float32: torch.complex64,
  432. torch.float64: torch.complex128,
  433. }
  434. def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype:
  435. return _complex_to_real_dtype_map[dtype]
  436. def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype:
  437. return _real_to_complex_dtype_map[dtype]
  438. def dtype_to_type(dtype: torch.dtype) -> type:
  439. """
  440. Computes the corresponding Python type (AKA "type kind") for the
  441. given dtype.
  442. """
  443. assert isinstance(dtype, torch.dtype)
  444. if dtype is torch.bool:
  445. return bool
  446. if dtype in _integer_dtypes:
  447. return int
  448. if dtype in _float_dtypes:
  449. return float
  450. if dtype in _complex_dtypes:
  451. return complex
  452. raise ValueError("Invalid dtype!")
  453. _type_to_dtype_map = {
  454. bool: torch.bool,
  455. int: torch.int64,
  456. float: torch.float64,
  457. complex: torch.complex128,
  458. }
  459. def type_to_dtype(typ: type) -> torch.dtype:
  460. """
  461. Computes the corresponding dtype for a Number type.
  462. """
  463. return _type_to_dtype_map[typ]
  464. _ordered_types = (bool, int, float, complex)
  465. def get_higher_type(a: type, b: type) -> type:
  466. """
  467. Returns the higher of the two given Number types.
  468. The types are ordered bool -> int -> float -> complex.
  469. """
  470. # Type checking
  471. assert a in _ordered_types
  472. assert b in _ordered_types
  473. if a is b:
  474. return a
  475. for typ in _ordered_types:
  476. if a is typ:
  477. return b
  478. if b is typ:
  479. return a
  480. raise ValueError("Unknown Python scalar type!")
  481. # Returns the higher of two torch datatypes a and b or, if the two
  482. # are not ordered relative to each other, the next
  483. # higher datatype
  484. def get_higher_dtype(
  485. a: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
  486. b: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
  487. ) -> Optional[torch.dtype]:
  488. """
  489. Computes the "lowest" datatype that is weakly
  490. "higher" than both a and b.
  491. """
  492. # Type checking
  493. assert a is None or isinstance(a, (torch.dtype, TensorLike, Number))
  494. assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
  495. def _extract_dtype(
  496. x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
  497. ) -> Optional[torch.dtype]:
  498. if x is None:
  499. return None
  500. if isinstance(x, torch.dtype):
  501. return x
  502. if isinstance(x, TensorLike):
  503. return x.dtype
  504. if isinstance(x, Number):
  505. return type_to_dtype(type(x))
  506. raise RuntimeError("Unexpected type given to _extract_dtype!")
  507. a, b = _extract_dtype(a), _extract_dtype(b)
  508. if a is b:
  509. return a
  510. if a is None:
  511. return b
  512. if b is None:
  513. return a
  514. ordered_datatypes = (
  515. (torch.bool,),
  516. (torch.uint8, torch.int8),
  517. (torch.int16,),
  518. (torch.int32,),
  519. (torch.int64,),
  520. (torch.float16, torch.bfloat16),
  521. (torch.float32,),
  522. (torch.float64,),
  523. (torch.complex32,),
  524. (torch.complex64,),
  525. (torch.complex128,),
  526. )
  527. for idx, dtypes in enumerate(ordered_datatypes):
  528. if a in dtypes and b in dtypes:
  529. return ordered_datatypes[idx + 1][0]
  530. if a in dtypes:
  531. return b
  532. if b in dtypes:
  533. return a
  534. raise RuntimeError("Unexpected termination!")
  535. # TODO: maybe unify with can_cast_to?
  536. def is_weakly_lesser_type(a: type, b: type) -> bool:
  537. """
  538. Compares two types, a and b, returning True if a is weakly "less" than b.
  539. The comparison is determined by the following type ordering: bool, int, float, complex.
  540. """
  541. ordered_types = (
  542. bool,
  543. int,
  544. float,
  545. complex,
  546. )
  547. assert a in ordered_types
  548. assert b in ordered_types
  549. for typ in ordered_types:
  550. if a == typ:
  551. return True
  552. if b == typ:
  553. return False
  554. raise RuntimeError("Unexpected termination!")
  555. def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
  556. for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype):
  557. if fn(cast_to):
  558. return True
  559. if fn(cast_from):
  560. return False
  561. raise ValueError("Received unknown dtypes {0}, {1}!".format(cast_to, cast_from))
  562. def check_same_dtype(*args):
  563. """
  564. Checks that all Tensors in args have the same device and that all Numbers have the
  565. same corresponding Python type.
  566. Raises a RuntimeError when:
  567. - args contains an object whose type is not Tensor or Number
  568. - two Tensors objects in args have different dtypes
  569. - two Number objects in args have different types
  570. - there are Tensors and Numbers in args, and one of those Tensors corresponding
  571. Python types is different from the type of one of those Numbers
  572. """
  573. full_dtype = None
  574. scalar_type = None
  575. for arg in args:
  576. if isinstance(arg, Number):
  577. # Scalar type checking is disabled (and may be removed in the future)
  578. continue
  579. # if scalar_type is None:
  580. # scalar_type = type(arg)
  581. # if scalar_type is not type(arg):
  582. # msg = (
  583. # "Scalar of type "
  584. # + str(type(arg))
  585. # + " is not the expected type of "
  586. # + str(scalar_type)
  587. # + "!"
  588. # )
  589. # raise RuntimeError(msg)
  590. elif isinstance(arg, TensorLike):
  591. if full_dtype is None:
  592. full_dtype = arg.dtype
  593. if scalar_type is None:
  594. scalar_type = dtype_to_type(arg.dtype)
  595. if full_dtype is not arg.dtype:
  596. msg = (
  597. "Tensor with dtype "
  598. + str(arg.dtype)
  599. + " is not the expected dtype of "
  600. + str(full_dtype)
  601. + "!"
  602. )
  603. raise RuntimeError(msg)
  604. arg_type = dtype_to_type(arg.dtype)
  605. if arg_type is not scalar_type:
  606. msg = (
  607. "Tensor with corresponding Python type "
  608. + str(arg_type)
  609. + " is not the expected type of "
  610. + str(scalar_type)
  611. + "!"
  612. )
  613. raise RuntimeError(msg)
  614. else:
  615. msg = (
  616. "Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
  617. )
  618. raise RuntimeError(msg)
  619. # Maps datatypes to their computation types for elementwise operations
  620. _computation_dtype_map = {
  621. torch.bfloat16: torch.float32,
  622. torch.float16: torch.float32,
  623. torch.complex32: torch.complex64,
  624. }
  625. def _get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
  626. return _computation_dtype_map.get(dtype, dtype)
  627. class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
  628. DEFAULT = (0,)
  629. NO_OPMATH = (1,)
  630. INT_TO_FLOAT = (2,)
  631. ALWAYS_BOOL = (3,)
  632. COMPLEX_TO_FLOAT = (4,)
  633. BOOL_TO_LONG = (5,)
  634. # TODO: document type promotion kinds
  635. def elementwise_dtypes(
  636. *_args,
  637. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  638. ) -> Tuple[torch.dtype, torch.dtype]:
  639. """
  640. Computes the computation and result dtypes for elementwise type promotion
  641. on the given arguments and with the given elementwise type promotion kind.
  642. Note that not all inputs to an elementwise operation necessarily participate in type promotion.
  643. For example, the "alpha" parameter of torch.add does not participate in type promotion,
  644. although it may be cast to the Python type corresponding to the computation dtype that
  645. the type promotion algorithm determines.
  646. Default elementwise type promotion, which all other type promotion kinds tweak (see below),
  647. first decides which of four ordered types to use:
  648. bool -> integer -> floating point -> complex
  649. The selected type is the "lowest" type in the above list such that all number arguments
  650. have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
  651. type for their dtype.
  652. Once the type is determined, the particular result dtype is found. The dtypes are
  653. partially ordered as follows:
  654. bool -> uint8, int8 -> int16 -> int32 -> int64 ->
  655. float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128
  656. The result dtype is selected by:
  657. - if no tensor's dtype has the same corresponding type as the one selected,
  658. then the result dtype is the (default) dtype corresponding to the selected type
  659. (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
  660. - if the result type is complex then the dtype is:
  661. - the default complex dtype if there are no floating point or complex tensors
  662. - if there are floating point or complex tensors with one or more dimensions, then
  663. the complex dtype corresponding to the highest corresponding complex dtype among those tensors
  664. (for example, double + cfloat -> cdouble)
  665. - if there are only floating point or complex tensors with zero dimensions, then
  666. the complex dtype corresponding to the highest corresponding complex dtype among those tensors
  667. - if the first two cases do not apply, the result dtype is the highest dtype among
  668. all tensors with one or more dimensions of the output type, and if there are no such
  669. tensors then it's the highest dtype among all tensors with zero dimensions of the output type
  670. (for example, long + half -> half, even if the half tensor has zero dimensions)
  671. The "corresponding complex dtypes" are:
  672. float16 -> complex32
  673. bfloat16 -> complex64
  674. float32 -> complex64
  675. float64 -> complex128
  676. complex32 -> complex32
  677. complex64 -> complex64
  678. complex128 -> complex128
  679. The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation
  680. dtype by mapping low precision floating point and complex dtypes as follows:
  681. float16 -> float32
  682. bfloat16 -> float32
  683. complex32 -> complex64
  684. This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the
  685. computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels
  686. which perform no mathematical operations on their tensors (see below for examples).
  687. The INT_TO_FLOAT type promotion kind maps boolean and integer maps result dtypes to the default floating point dtype,
  688. and computation dtypes to the appropriate op math dtype.
  689. The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this
  690. mapping:
  691. complex32 -> float16
  692. complex64 -> float32
  693. complex128 -> float64
  694. Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does.
  695. The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long.
  696. The ALWAYS_BOOL type promotion kind always sets the result dtype to bool.
  697. Example operators for each type promotion option:
  698. DEFAULT : add
  699. NO_OPMATH : where, nextafter, cat
  700. INT_TO_FLOAT : sin
  701. COMPLEX_TO_FLOAT : abs
  702. BOOL_TO_LONG : pow
  703. ALWAYS_BOOL : eq
  704. """
  705. args = tuple(x for x in _args if x is not None)
  706. highest_type: type = bool
  707. for x in args:
  708. if not isinstance(x, (Number, TensorLike)):
  709. msg = (
  710. "Unexpected type {0} when computing elementwise type promotion!".format(
  711. str(type(x))
  712. )
  713. )
  714. raise ValueError(msg)
  715. if isinstance(x, Number):
  716. highest_type = get_higher_type(highest_type, type(x))
  717. else:
  718. # x is a TensorLike
  719. highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))
  720. result_dtype = None
  721. def _find_highest_dtype_filtered(
  722. args, filter, *, float_as_complex=False
  723. ) -> Optional[torch.dtype]:
  724. zero_dim_tensor_dtype = None
  725. one_plus_dim_tensor_dtype = None
  726. for x in args:
  727. if isinstance(x, TensorLike) and filter(x.dtype):
  728. _dtype = x.dtype
  729. if float_as_complex and is_float_dtype(_dtype):
  730. _dtype = corresponding_complex_dtype(_dtype)
  731. if x.ndim == 0:
  732. zero_dim_tensor_dtype = get_higher_dtype(
  733. zero_dim_tensor_dtype, _dtype
  734. )
  735. else:
  736. # x.ndim > 0
  737. one_plus_dim_tensor_dtype = get_higher_dtype(
  738. one_plus_dim_tensor_dtype, _dtype
  739. )
  740. # Prefers dtype of tensors with one or more dimensions
  741. if one_plus_dim_tensor_dtype is not None:
  742. return one_plus_dim_tensor_dtype
  743. return zero_dim_tensor_dtype
  744. if highest_type is float:
  745. result_dtype = _find_highest_dtype_filtered(args, is_float_dtype)
  746. result_dtype = (
  747. torch.get_default_dtype() if result_dtype is None else result_dtype
  748. )
  749. elif highest_type is complex:
  750. result_dtype = _find_highest_dtype_filtered(
  751. args,
  752. lambda x: is_float_dtype(x) or is_complex_dtype(x),
  753. float_as_complex=True,
  754. )
  755. if result_dtype is None:
  756. result_dtype = corresponding_complex_dtype(torch.get_default_dtype())
  757. elif highest_type is int:
  758. result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype)
  759. result_dtype = torch.long if result_dtype is None else result_dtype
  760. else:
  761. # highest_type is bool
  762. result_dtype = torch.bool
  763. if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
  764. return _get_computation_dtype(result_dtype), result_dtype
  765. elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH:
  766. return result_dtype, result_dtype
  767. elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
  768. if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype):
  769. result_dtype = torch.get_default_dtype()
  770. return _get_computation_dtype(result_dtype), result_dtype
  771. elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
  772. # NOTE: computation can still occur in a complex dtype
  773. computation_dtype = _get_computation_dtype(result_dtype)
  774. if is_complex_dtype(result_dtype):
  775. result_dtype = corresponding_real_dtype(result_dtype)
  776. return computation_dtype, result_dtype
  777. elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
  778. if is_boolean_dtype(result_dtype):
  779. return torch.long, torch.long
  780. return _get_computation_dtype(result_dtype), result_dtype
  781. elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
  782. return _get_computation_dtype(result_dtype), torch.bool
  783. else:
  784. raise ValueError(
  785. "Unknown type promotion kind {0}".format(str(type_promotion_kind))
  786. )
  787. def wrap_device(d: Union[str, torch.device]) -> torch.device:
  788. """
  789. Wraps strings into torch.device objects.
  790. Given torch.device objects are returned unmodified.
  791. """
  792. assert isinstance(d, (str, torch.device))
  793. if isinstance(d, str):
  794. return torch.device(d)
  795. return d
  796. def make_contiguous_strides_for(shape: ShapeType) -> Tuple[int, ...]:
  797. validate_shape(shape)
  798. if not shape:
  799. return ()
  800. multiplier = 1
  801. strides = []
  802. for l in reversed(shape):
  803. if l != 0:
  804. strides.append(multiplier)
  805. multiplier = l * multiplier
  806. else:
  807. strides.append(multiplier)
  808. result = tuple(reversed(strides))
  809. return result
  810. def compute_reduction_output_shape(
  811. shape: ShapeType, dimensions: Sequence
  812. ) -> Tuple[int, ...]:
  813. for idx in dimensions:
  814. validate_idx(len(shape), idx)
  815. new_shape = []
  816. for idx in range(len(shape)):
  817. if idx in dimensions:
  818. continue
  819. new_shape.append(shape[idx])
  820. return tuple(new_shape)
  821. def validate_no_repeating_dims(dims: Sequence):
  822. if len(dims) != len(set(dims)):
  823. raise RuntimeError("duplicate value in the list of dims")
  824. def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
  825. if dims is None:
  826. return tuple(range(len(shape)))
  827. dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
  828. validate_no_repeating_dims(dims)
  829. return dims
  830. def check_in_bounds_for_storage(
  831. a: torch._TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
  832. ):
  833. """
  834. Determines if the given shape, strides, and offset are valid for the given storage.
  835. """
  836. # Short-circuits if the shape has no elements
  837. if reduce(operator.mul, shape) == 0:
  838. return
  839. length = a.size() - storage_offset
  840. max_offset = 0
  841. for x, y in zip(shape, strides):
  842. max_offset = max_offset + (x - 1) * y
  843. if max_offset >= length:
  844. required_length = max_offset + storage_offset
  845. msg = (
  846. "Can't view a storage of size {0} with an offset of {1}, shape of {2}, and strides of {3}, "
  847. "which requires a storage of size {4}".format(
  848. a.size(), storage_offset, str(shape), str(strides), required_length
  849. )
  850. )
  851. raise ValueError(msg)