| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582 |
- import torch
- from typing import Optional, List, DefaultDict, Any
- import warnings
- from collections import defaultdict
- import sys
- import traceback
- def _type(self, dtype=None, non_blocking=False, **kwargs):
- """Returns the type if `dtype` is not provided, else casts this object to
- the specified type.
- If this is already of the correct type, no copy is performed and the
- original object is returned.
- Args:
- dtype (type or string): The desired type
- non_blocking (bool): If ``True``, and the source is in pinned memory
- and destination is on the GPU or vice versa, the copy is performed
- asynchronously with respect to the host. Otherwise, the argument
- has no effect.
- **kwargs: For compatibility, may contain the key ``async`` in place of
- the ``non_blocking`` argument. The ``async`` arg is deprecated.
- """
- non_blocking = _get_async_or_non_blocking('type', non_blocking, kwargs)
- if dtype is None:
- return self.__module__ + '.' + self.__class__.__name__
- if isinstance(dtype, str):
- dtype = _import_dotted_name(dtype)
- if dtype == type(self):
- return self
- if self.is_sparse:
- if not dtype.is_sparse:
- raise RuntimeError("Cannot cast sparse tensor to dense tensor")
- new_module_name = dtype.__module__.replace('.sparse', '')
- new_values_type_name = new_module_name + '.' + dtype.__name__
- new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
- new_indices_type_name = new_module_name + '.LongTensor'
- new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking)
- return dtype(new_indices, new_values, self.size())
- if dtype.is_sparse:
- raise RuntimeError("Cannot cast dense tensor to sparse tensor")
- return dtype(self.size()).copy_(self, non_blocking)
- def _cuda(self, device=None, non_blocking=False, **kwargs):
- """Returns a copy of this object in CUDA memory.
- If this object is already in CUDA memory and on the correct device, then
- no copy is performed and the original object is returned.
- Args:
- device (int): The destination GPU id. Defaults to the current device.
- non_blocking (bool): If ``True`` and the source is in pinned memory,
- the copy will be asynchronous with respect to the host. Otherwise,
- the argument has no effect.
- **kwargs: For compatibility, may contain the key ``async`` in place of
- the ``non_blocking`` argument.
- """
- non_blocking = _get_async_or_non_blocking('cuda', non_blocking, kwargs)
- if self.is_cuda:
- if device is None:
- device = torch.cuda.current_device()
- if self.get_device() == device:
- return self
- else:
- if device is None:
- device = -1
- with torch.cuda.device(device):
- if self.is_sparse:
- new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
- indices = torch.Tensor._indices(self).cuda(device, non_blocking)
- values = torch.Tensor._values(self).cuda(device, non_blocking)
- return new_type(indices, values, self.size())
- else:
- return torch._UntypedStorage(self.size(), device=torch.device('cuda')).copy_(self, non_blocking)
- def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
- if not kwargs:
- return non_blocking
- if len(kwargs) != 1 or 'async' not in kwargs:
- message = "{}() got an unexpected keyword argument '{}'"
- argument = list(kwargs.keys()).pop()
- raise TypeError(message.format(function_name, argument))
- warnings.warn("'async' is deprecated; use 'non_blocking'")
- return kwargs['async']
- # Note [Don't serialize hooks]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # Since time immemorial, we have serialized the backward hooks associated with
- # variables. This kind of half-worked--Python can pickle global functions
- # (but not closures!)--but there were problems.
- #
- # - It's fragile. If you serialize a backward hook into a saved
- # model, and then you rename the function associated with the hook,
- # now your saved model is broken and you can't load it anymore.
- #
- # - It's not actually used. The standard recommendation is to
- # serialize the *state_dict* of a model, not the model itself
- # (since this is more stable to code changes affecting the model
- # serialization), and the state dict saves "data" only, thus
- # stripping the the backward hooks. In some cases, hooks are
- # essential to the well-functioning of a model (e.g., DDP),
- # but DDP already manages readding the hooks!
- #
- # - We didn't serialize them in many cases. Prior to #10220, we
- # were dropping backward hooks in ForkingPickler. We "fixed" this
- # to be convenient with other serialization sites, but lack of
- # serializing backward hooks wasn't actually the root cause of
- # the bug.
- #
- # With these cases in mind, we have decided that a better strategy
- # is to just NOT serialize hooks at all.
- #
- # Since this is a BC-breaking change, we should warn when we previously
- # serialized a hook, but no longer do so. This will be done by adding a special
- # sentinel property to hooks will be used to suppress this warning. If a hook
- # has the property _torch_serialize_ignore, we will not emit a warning if we
- # attempt to serialize a Tensor with this hook attached to it.
- #
- # By the way, when _backward_hooks is skipped, we must give an EMPTY
- # OrderedDict(), if you pass a None you'll run afoul #12219.
- # TODO: Once we decide to break serialization FC, `storage` no longer needs to
- # be a _TypedStorage
- def _rebuild_tensor(storage, storage_offset, size, stride):
- # first construct a tensor with the correct dtype/device
- t = torch.tensor([], dtype=storage.dtype, device=storage._untyped().device)
- return t.set_(storage._untyped(), storage_offset, size, stride)
- def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
- tensor = _rebuild_tensor(storage, storage_offset, size, stride)
- tensor.requires_grad = requires_grad
- # NB: This line exists only for backwards compatibility; the
- # general expectation is that backward_hooks is an empty
- # OrderedDict. See Note [Don't serialize hooks]
- tensor._backward_hooks = backward_hooks
- return tensor
- _sparse_tensors_to_validate: List["torch.Tensor"] = []
- # In _legacy_load() in serialization.py we unpickle storages after the sparse
- # tensors have been already unpickled. Those storages contain data necessary for
- # validating sparse tensors: indices and values. That's why sparse tensors are
- # first unpickled without any validation, and then this function is called just
- # before _legacy_load() returns, so that all the sparse tensors can be validated
- # in bulk.
- #
- # The same procedure must be followed by _load() in serialization.py because due
- # to Pickler semantics, we have to use the same (non-validating) function for
- # unpickling sparse tensors, regardless of the caller.
- def _validate_loaded_sparse_tensors():
- try:
- for t in _sparse_tensors_to_validate:
- if t.is_sparse:
- torch._validate_sparse_coo_tensor_args(t._indices(), t._values(),
- t.size())
- elif t.is_sparse_csr:
- # TODO: Validation currently involves an expensive traversal
- # on CPU, which may include a device transfer.
- torch._validate_sparse_csr_tensor_args(t.crow_indices(), t.col_indices(),
- t.values(), t.size())
- else:
- raise NotImplementedError(
- '_validate_loaded_sparse_tensors for layout `%s`' % (t.layout))
- finally:
- _sparse_tensors_to_validate.clear()
- def _rebuild_sparse_tensor(layout, data):
- if layout == torch.sparse_coo:
- indices, values, size = data
- result = torch._sparse_coo_tensor_unsafe(indices, values, size)
- _sparse_tensors_to_validate.append(result)
- return result
- raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
- def _rebuild_sparse_csr_tensor(layout, data):
- if layout == torch.sparse_csr:
- crow_indices, col_indices, values, size = data
- result = torch._sparse_csr_tensor_unsafe(crow_indices, col_indices, values, size)
- _sparse_tensors_to_validate.append(result)
- return result
- raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
- def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
- tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
- tensor.requires_grad = requires_grad
- return tensor
- # Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
- _rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
- def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
- return torch.empty_strided(size, stride, dtype=dtype, device='meta', requires_grad=requires_grad)
- def _rebuild_wrapper_subclass(cls, dtype, size, stride, storage_offset, layout, device, requires_grad):
- return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
- cls, size, strides=stride, storage_offset=storage_offset, layout=layout,
- device=device, requires_grad=requires_grad)
- # TODO: Once we decide to break serialization FC, `storage` no longer needs to
- # be a _TypedStorage
- def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
- qscheme = quantizer_params[0]
- if qscheme == torch.per_tensor_affine:
- _, scale, zero_point = quantizer_params
- tensor = torch._empty_affine_quantized(size, scale=scale, zero_point=zero_point, dtype=storage.dtype)
- elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
- _, scales, zero_points, axis = quantizer_params
- if type(scales) is list and type(zero_points) is list:
- if qscheme == torch.per_channel_affine:
- scales = torch.tensor(scales, dtype=torch.double)
- zero_points = torch.tensor(zero_points, dtype=torch.long)
- else:
- scales = torch.tensor(scales, dtype=torch.float)
- zero_points = torch.tensor(zero_points, dtype=torch.float)
- tensor = torch._empty_per_channel_affine_quantized(
- size, scales=scales, zero_points=zero_points, axis=axis, dtype=storage.dtype)
- else:
- raise RuntimeError("Can't deserialize quantized tensor with qscheme {}".format(qscheme))
- tensor.set_(storage, storage_offset, size, stride)
- tensor.requires_grad = requires_grad
- # NB: This line exists only for backwards compatibility; the
- # general expectation is that backward_hooks is an empty
- # OrderedDict. See Note [Don't serialize hooks]
- tensor._backward_hooks = backward_hooks
- return tensor
- def _rebuild_parameter(data, requires_grad, backward_hooks):
- param = torch.nn.Parameter(data, requires_grad)
- # NB: This line exists only for backwards compatibility; the
- # general expectation is that backward_hooks is an empty
- # OrderedDict. See Note [Don't serialize hooks]
- param._backward_hooks = backward_hooks
- return param
- def _import_dotted_name(name):
- components = name.split('.')
- obj = __import__(components[0])
- for component in components[1:]:
- obj = getattr(obj, component)
- return obj
- # Taken from python 3.5 docs
- def _accumulate(iterable, fn=lambda x, y: x + y):
- 'Return running totals'
- # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
- # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
- it = iter(iterable)
- try:
- total = next(it)
- except StopIteration:
- return
- yield total
- for element in it:
- total = fn(total, element)
- yield total
- def _flatten_dense_tensors(tensors):
- """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
- same dense type.
- Since inputs are dense, the resulting tensor will be a concatenated 1D
- buffer. Element-wise operation on this buffer will be equivalent to
- operating individually.
- Args:
- tensors (Iterable[Tensor]): dense tensors to flatten.
- Returns:
- A contiguous 1D buffer containing input tensors.
- """
- return torch._C._nn.flatten_dense_tensors(tensors)
- def _flatten_sparse_tensors(tensors):
- """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
- one of values. Assume tensors are of same sparse type.
- Args:
- tensors (Iterable[Tensor]): sparse tensors to flatten.
- Returns:
- A tuple of two contiguous 1D buffers, one containing input tensors'
- indices and the other containing the values.
- """
- flat_indices = torch._C._nn.flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors])
- flat_values = torch._C._nn.flatten_dense_tensors([torch.Tensor._values(t) for t in tensors])
- return flat_indices, flat_values
- def _unflatten_dense_tensors(flat, tensors):
- """View a flat buffer using the sizes of tensors. Assume that tensors are of
- same dense type, and that flat is given by _flatten_dense_tensors.
- Args:
- flat (Tensor): flattened dense tensors to unflatten.
- tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
- unflatten flat.
- Returns:
- Unflattened dense tensors with sizes same as tensors and values from
- flat.
- """
- return torch._C._nn.unflatten_dense_tensors(flat, tensors)
- def _unflatten_sparse_tensors(flat, tensors):
- """View flat buffer (containing indices and values) using the sizes of
- tensors. Assume that tensors are of same sparse type, and that flat is given
- by _flatten_sparse_tensors.
- Args:
- flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
- tensors to unflatten.
- tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
- unflatten flat.
- Returns:
- Unflattened sparse tensors with sizes same as tensors and values from
- flat.
- """
- flat_indices, flat_values = flat
- indices = torch._C._nn.unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors])
- values = torch._C._nn.unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors])
- outputs = []
- for t, i, v in zip(tensors, indices, values):
- outputs.append(t.new(i, v, t.size()))
- return tuple(outputs)
- def _reorder_tensors_as(tensors, ordered_tensors):
- """Assume that tensors are of same order as ordered_tensors within their
- types, e.g., from _take_tensors. Reorder them to be of same order as
- ordered_tensors.
- Args:
- tensors (Iterable[Tensor]): tensors to be reordered. They should be of
- the same order as ordered_tensors within their own types.
- ordered_tensors (Iterable[Tensor]): tensors whose order will be the
- reference.
- Returns:
- Ordered tuple of tensors with contents from tensors and order of
- ordered_tensors.
- """
- type_dict = defaultdict(list)
- for tensor in tensors:
- type_dict[tensor.type()].append(tensor)
- type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
- return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
- def _take_tensors(tensors, size_limit):
- """Group tensors into chunks. This generator yields a chunk at each time,
- each containing tensors of same type up to certain byte limit in total size.
- Args:
- tensors (Sequence): A sequence of tensors to be separated into chunks.
- size_limit (int): The limit of each chunk in bytes.
- Yields:
- Blocks of tensors of same type and within size_limit. The yielded
- tensors are only ordered as the original sequence within its types.
- """
- buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
- for tensor in tensors:
- t = tensor.type()
- if tensor.is_sparse:
- indices = torch.Tensor._indices(tensor)
- values = torch.Tensor._values(tensor)
- size = indices.numel() * indices.element_size() + values.numel() * values.element_size()
- else:
- size = tensor.numel() * tensor.element_size()
- buf_and_size = buf_dict[t]
- if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
- yield buf_and_size[0]
- buf_and_size = buf_dict[t] = [[], 0]
- buf_and_size[0].append(tensor)
- buf_and_size[1] += size
- for buf, _ in buf_dict.values():
- if len(buf) > 0:
- yield buf
- # annotation decorator to get annotations in a way that is compatible
- # with both Python 2 and 3
- def annotate(ret, **kwargs):
- def dec(fun):
- fun.__annotations__ = dict(kwargs)
- fun.__annotations__['return'] = ret
- return fun
- return dec
- # NOTE [ Python Traceback Reference Cycle Problem ]
- #
- # When using sys.exc_info(), it is important to **not** store the exc_info[2],
- # which is the traceback, because otherwise you will run into the traceback
- # reference cycle problem, i.e., the traceback holding reference to the frame,
- # and the frame (which holds reference to all the object in its temporary scope)
- # holding reference the traceback.
- class KeyErrorMessage(str):
- r"""str subclass that returns itself in repr"""
- def __repr__(self):
- return self
- class ExceptionWrapper(object):
- r"""Wraps an exception plus traceback to communicate across threads"""
- def __init__(self, exc_info=None, where="in background"):
- # It is important that we don't store exc_info, see
- # NOTE [ Python Traceback Reference Cycle Problem ]
- if exc_info is None:
- exc_info = sys.exc_info()
- self.exc_type = exc_info[0]
- self.exc_msg = "".join(traceback.format_exception(*exc_info))
- self.where = where
- def reraise(self):
- r"""Reraises the wrapped exception in the current thread"""
- # Format a message such as: "Caught ValueError in DataLoader worker
- # process 2. Original Traceback:", followed by the traceback.
- msg = "Caught {} {}.\nOriginal {}".format(
- self.exc_type.__name__, self.where, self.exc_msg)
- if self.exc_type == KeyError:
- # KeyError calls repr() on its argument (usually a dict key). This
- # makes stack traces unreadable. It will not be changed in Python
- # (https://bugs.python.org/issue2651), so we work around it.
- msg = KeyErrorMessage(msg)
- elif getattr(self.exc_type, "message", None):
- # Some exceptions have first argument as non-str but explicitly
- # have message field
- raise self.exc_type(message=msg)
- try:
- exception = self.exc_type(msg)
- except TypeError:
- # If the exception takes multiple arguments, don't try to
- # instantiate since we don't know how to
- raise RuntimeError(msg) from None
- raise exception
- def _get_available_device_type():
- if torch.cuda.is_available():
- return "cuda"
- # add more available device types here
- return None
- def _get_device_attr(get_member):
- device_type = _get_available_device_type()
- if device_type and device_type.lower() == "cuda":
- return get_member(torch.cuda)
- # add more available device types here
- return None
- def _get_current_device_index():
- # current device index
- return _get_device_attr(lambda m: m.current_device())
- def _get_all_device_indices():
- # all device index
- return _get_device_attr(lambda m: list(range(m.device_count())))
- def _get_devices_properties(device_ids):
- # all device properties
- return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
- def get_current_device_index() -> int:
- r"""Checks if there are CUDA devices available and
- returns the device index of the current default CUDA device.
- Returns -1 in case there are no CUDA devices available.
- Arguments: ``None``
- """
- if torch.cuda.device_count() > 0:
- return torch.cuda.current_device()
- return -1
- def _get_device_index(device: Any, optional: bool = False, allow_cpu: bool = False) -> int:
- r"""Gets the device index from :attr:`device`, which can be a torch.device
- object, a Python integer, or ``None``.
- If :attr:`device` is a torch.device object, returns the device index if it
- has index. Note that for a device without a specified index,
- i.e., ``torch.device('xxx')``, this will return the current default
- device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
- CPU devices will be accepted and ``-1`` will be returned in this case.
- If :attr:`device` is a Python integer, it is returned as is.
- If :attr:`device` is ``None``, this will return the current default
- device of the supported runtime platform if :attr:`optional` is ``True``.
- i.e., the current default CUDA device will be returned if CUDA runtime is supported.
- """
- if isinstance(device, str):
- device = torch.device(device)
- device_idx: Optional[int] = None
- if isinstance(device, torch.device):
- if not allow_cpu and device.type == 'cpu':
- raise ValueError('Expected a non cpu device, but got: {}'.format(device))
- device_idx = -1 if device.type == 'cpu' else device.index
- if isinstance(device, int):
- device_idx = device
- if device_idx is None:
- if optional:
- # The eager API _get_current_device_index uses `lambda` functions which are
- # not supported in JIT and hence not scriptable. The JIT equivalent API to get
- # the current device index is `get_current_device_index()` which can
- # be scripted. We use is_scripting to check the mode we are in and call the
- # appropriate API.
- if torch.jit.is_scripting():
- device_idx = get_current_device_index()
- else:
- device_idx = _get_current_device_index()
- else:
- raise ValueError('Expected a torch.device with a specified index '
- 'or an integer, but got:{}'.format(device))
- return device_idx
- def _handle_complex(tensor):
- """
- Returns a real view of a tensor if complex dtype else just the tensor
- need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
- """
- return torch.view_as_real(tensor) if not isinstance(tensor,
- torch.nn.UninitializedParameter) and tensor.is_complex() else tensor
- def _element_size(dtype):
- """
- Returns the element size for a dtype, in bytes
- """
- if not isinstance(dtype, torch.dtype):
- raise RuntimeError(f'expected torch.dtype, but got {type(dtype)}')
- if dtype.is_complex:
- return torch.finfo(dtype).bits >> 2
- elif dtype.is_floating_point:
- return torch.finfo(dtype).bits >> 3
- elif dtype == torch.bool:
- # NOTE: torch.bool is not supported in torch.iinfo()
- return 1
- else:
- return torch.iinfo(dtype).bits >> 3
- class _ClassPropertyDescriptor:
- def __init__(self, fget, fset=None):
- self.fget = fget
- def __get__(self, instance, owner=None):
- if owner is None:
- owner = type(instance)
- return self.fget.__get__(instance, owner)()
- def classproperty(func):
- if not isinstance(func, (classmethod, staticmethod)):
- func = classmethod(func)
- return _ClassPropertyDescriptor(func)
|