_legacy.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """This module exist to be able to deprecate functions publicly without doing so internally. The deprecated
  2. public versions are defined in torch.testing._deprecated and exposed from torch.testing. The non-deprecated internal
  3. versions should be imported from torch.testing._internal
  4. """
  5. from typing import List
  6. import torch
  7. __all_dtype_getters__ = [
  8. "_validate_dtypes",
  9. "_dispatch_dtypes",
  10. "all_types",
  11. "all_types_and",
  12. "all_types_and_complex",
  13. "all_types_and_complex_and",
  14. "all_types_and_half",
  15. "complex_types",
  16. "empty_types",
  17. "floating_and_complex_types",
  18. "floating_and_complex_types_and",
  19. "floating_types",
  20. "floating_types_and",
  21. "double_types",
  22. "floating_types_and_half",
  23. "get_all_complex_dtypes",
  24. "get_all_dtypes",
  25. "get_all_fp_dtypes",
  26. "get_all_int_dtypes",
  27. "get_all_math_dtypes",
  28. "integral_types",
  29. "integral_types_and",
  30. ]
  31. __all__ = [
  32. *__all_dtype_getters__,
  33. "get_all_device_types",
  34. ]
  35. # Functions and classes for describing the dtypes a function supports
  36. # NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
  37. # Verifies each given dtype is a torch.dtype
  38. def _validate_dtypes(*dtypes):
  39. for dtype in dtypes:
  40. assert isinstance(dtype, torch.dtype)
  41. return dtypes
  42. # class for tuples corresponding to a PyTorch dispatch macro
  43. class _dispatch_dtypes(tuple):
  44. def __add__(self, other):
  45. assert isinstance(other, tuple)
  46. return _dispatch_dtypes(tuple.__add__(self, other))
  47. _empty_types = _dispatch_dtypes(())
  48. def empty_types():
  49. return _empty_types
  50. _floating_types = _dispatch_dtypes((torch.float32, torch.float64))
  51. def floating_types():
  52. return _floating_types
  53. _floating_types_and_half = _floating_types + (torch.half,)
  54. def floating_types_and_half():
  55. return _floating_types_and_half
  56. def floating_types_and(*dtypes):
  57. return _floating_types + _validate_dtypes(*dtypes)
  58. _floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
  59. def floating_and_complex_types():
  60. return _floating_and_complex_types
  61. def floating_and_complex_types_and(*dtypes):
  62. return _floating_and_complex_types + _validate_dtypes(*dtypes)
  63. _double_types = _dispatch_dtypes((torch.float64, torch.complex128))
  64. def double_types():
  65. return _double_types
  66. _integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
  67. def integral_types():
  68. return _integral_types
  69. def integral_types_and(*dtypes):
  70. return _integral_types + _validate_dtypes(*dtypes)
  71. _all_types = _floating_types + _integral_types
  72. def all_types():
  73. return _all_types
  74. def all_types_and(*dtypes):
  75. return _all_types + _validate_dtypes(*dtypes)
  76. _complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
  77. def complex_types():
  78. return _complex_types
  79. def complex_types_and(*dtypes):
  80. return _complex_types + _validate_dtypes(*dtypes)
  81. _all_types_and_complex = _all_types + _complex_types
  82. def all_types_and_complex():
  83. return _all_types_and_complex
  84. def all_types_and_complex_and(*dtypes):
  85. return _all_types_and_complex + _validate_dtypes(*dtypes)
  86. _all_types_and_half = _all_types + (torch.half,)
  87. def all_types_and_half():
  88. return _all_types_and_half
  89. # The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
  90. # See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
  91. def get_all_dtypes(include_half=True,
  92. include_bfloat16=True,
  93. include_bool=True,
  94. include_complex=True,
  95. include_complex32=False,
  96. include_qint=False,
  97. ) -> List[torch.dtype]:
  98. dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
  99. if include_bool:
  100. dtypes.append(torch.bool)
  101. if include_complex:
  102. dtypes += get_all_complex_dtypes(include_complex32)
  103. if include_qint:
  104. dtypes += get_all_qint_dtypes()
  105. return dtypes
  106. def get_all_math_dtypes(device) -> List[torch.dtype]:
  107. return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
  108. include_bfloat16=False) + get_all_complex_dtypes()
  109. def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
  110. return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
  111. def get_all_int_dtypes() -> List[torch.dtype]:
  112. return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
  113. def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
  114. dtypes = [torch.float32, torch.float64]
  115. if include_half:
  116. dtypes.append(torch.float16)
  117. if include_bfloat16:
  118. dtypes.append(torch.bfloat16)
  119. return dtypes
  120. def get_all_qint_dtypes() -> List[torch.dtype]:
  121. return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
  122. def get_all_device_types() -> List[str]:
  123. return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']