| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- """This module exist to be able to deprecate functions publicly without doing so internally. The deprecated
- public versions are defined in torch.testing._deprecated and exposed from torch.testing. The non-deprecated internal
- versions should be imported from torch.testing._internal
- """
- from typing import List
- import torch
- __all_dtype_getters__ = [
- "_validate_dtypes",
- "_dispatch_dtypes",
- "all_types",
- "all_types_and",
- "all_types_and_complex",
- "all_types_and_complex_and",
- "all_types_and_half",
- "complex_types",
- "empty_types",
- "floating_and_complex_types",
- "floating_and_complex_types_and",
- "floating_types",
- "floating_types_and",
- "double_types",
- "floating_types_and_half",
- "get_all_complex_dtypes",
- "get_all_dtypes",
- "get_all_fp_dtypes",
- "get_all_int_dtypes",
- "get_all_math_dtypes",
- "integral_types",
- "integral_types_and",
- ]
- __all__ = [
- *__all_dtype_getters__,
- "get_all_device_types",
- ]
- # Functions and classes for describing the dtypes a function supports
- # NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
- # Verifies each given dtype is a torch.dtype
- def _validate_dtypes(*dtypes):
- for dtype in dtypes:
- assert isinstance(dtype, torch.dtype)
- return dtypes
- # class for tuples corresponding to a PyTorch dispatch macro
- class _dispatch_dtypes(tuple):
- def __add__(self, other):
- assert isinstance(other, tuple)
- return _dispatch_dtypes(tuple.__add__(self, other))
- _empty_types = _dispatch_dtypes(())
- def empty_types():
- return _empty_types
- _floating_types = _dispatch_dtypes((torch.float32, torch.float64))
- def floating_types():
- return _floating_types
- _floating_types_and_half = _floating_types + (torch.half,)
- def floating_types_and_half():
- return _floating_types_and_half
- def floating_types_and(*dtypes):
- return _floating_types + _validate_dtypes(*dtypes)
- _floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
- def floating_and_complex_types():
- return _floating_and_complex_types
- def floating_and_complex_types_and(*dtypes):
- return _floating_and_complex_types + _validate_dtypes(*dtypes)
- _double_types = _dispatch_dtypes((torch.float64, torch.complex128))
- def double_types():
- return _double_types
- _integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
- def integral_types():
- return _integral_types
- def integral_types_and(*dtypes):
- return _integral_types + _validate_dtypes(*dtypes)
- _all_types = _floating_types + _integral_types
- def all_types():
- return _all_types
- def all_types_and(*dtypes):
- return _all_types + _validate_dtypes(*dtypes)
- _complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
- def complex_types():
- return _complex_types
- def complex_types_and(*dtypes):
- return _complex_types + _validate_dtypes(*dtypes)
- _all_types_and_complex = _all_types + _complex_types
- def all_types_and_complex():
- return _all_types_and_complex
- def all_types_and_complex_and(*dtypes):
- return _all_types_and_complex + _validate_dtypes(*dtypes)
- _all_types_and_half = _all_types + (torch.half,)
- def all_types_and_half():
- return _all_types_and_half
- # The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
- # See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
- def get_all_dtypes(include_half=True,
- include_bfloat16=True,
- include_bool=True,
- include_complex=True,
- include_complex32=False,
- include_qint=False,
- ) -> List[torch.dtype]:
- dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
- if include_bool:
- dtypes.append(torch.bool)
- if include_complex:
- dtypes += get_all_complex_dtypes(include_complex32)
- if include_qint:
- dtypes += get_all_qint_dtypes()
- return dtypes
- def get_all_math_dtypes(device) -> List[torch.dtype]:
- return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
- include_bfloat16=False) + get_all_complex_dtypes()
- def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
- return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128]
- def get_all_int_dtypes() -> List[torch.dtype]:
- return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
- def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
- dtypes = [torch.float32, torch.float64]
- if include_half:
- dtypes.append(torch.float16)
- if include_bfloat16:
- dtypes.append(torch.bfloat16)
- return dtypes
- def get_all_qint_dtypes() -> List[torch.dtype]:
- return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
- def get_all_device_types() -> List[str]:
- return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|