storage.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849
  1. import io
  2. import torch
  3. from ._utils import _type, _cuda
  4. from torch.types import Storage
  5. from typing import Any, TypeVar, Type, Union, cast
  6. import copy
  7. import collections
  8. from functools import lru_cache
  9. try:
  10. import numpy as np
  11. HAS_NUMPY = True
  12. except ModuleNotFoundError:
  13. np = None # type: ignore[assignment]
  14. T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]')
  15. class _StorageBase(object):
  16. _cdata: Any
  17. is_sparse: bool = False
  18. is_sparse_csr: bool = False
  19. device: torch.device
  20. def __init__(self, *args, **kwargs): ... # noqa: E704
  21. def __len__(self) -> int: ... # noqa: E704
  22. def __getitem__(self, idx): ... # noqa: E704
  23. def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
  24. def nbytes(self) -> int: ... # noqa: E704
  25. def size(self) -> int:
  26. return self.nbytes()
  27. def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
  28. def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
  29. def element_size(self) -> int: ... # noqa: E704
  30. def get_device(self) -> int: ... # noqa: E704
  31. def data_ptr(self) -> int: ... # noqa: E704
  32. # Defined in torch/csrc/generic/StorageSharing.cpp
  33. def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
  34. def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
  35. @classmethod
  36. def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
  37. @classmethod
  38. def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
  39. @classmethod
  40. def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
  41. @classmethod
  42. def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
  43. @classmethod
  44. def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
  45. @classmethod
  46. def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
  47. def _shared_decref(self) -> T: ... # noqa: E704
  48. def _write_file(self, *args, **kwargs): ... # noqa: E704
  49. def resize_(self, size: int): ... # noqa: E704
  50. def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
  51. def is_pinned(self) -> bool: ... # noqa: E704
  52. def _set_from_file(self, *args, **kwargs): ... # noqa: E704
  53. def _set_cdata(self, *args, **kwargs): ... # noqa: E704
  54. def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
  55. def is_shared(self) -> bool: ... # noqa: E704
  56. @classmethod
  57. def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
  58. def _shared_incref(self, *args, **kwargs): ... # noqa: E704
  59. @classmethod
  60. def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
  61. @property
  62. def is_cuda(self): ... # noqa: E704
  63. def __str__(self):
  64. data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
  65. return data_str + (
  66. f'\n[{torch.typename(self)}(device={self.device}) '
  67. f'of size {len(self)}]')
  68. def __repr__(self):
  69. return str(self)
  70. def __iter__(self):
  71. return iter(map(lambda i: self[i], range(self.size())))
  72. def __copy__(self):
  73. return self.clone()
  74. def __deepcopy__(self, memo):
  75. memo = memo.setdefault('torch', {})
  76. if self._cdata in memo:
  77. return memo[self._cdata]
  78. new_storage = self.clone()
  79. memo[self._cdata] = new_storage
  80. return new_storage
  81. def __reduce__(self):
  82. b = io.BytesIO()
  83. torch.save(self, b, _use_new_zipfile_serialization=False)
  84. return (_load_from_bytes, (b.getvalue(),))
  85. def __sizeof__(self):
  86. return super(_StorageBase, self).__sizeof__() + self.size()
  87. def clone(self):
  88. """Returns a copy of this storage"""
  89. return type(self)(self.nbytes(), device=self.device).copy_(self)
  90. def tolist(self):
  91. """Returns a list containing the elements of this storage"""
  92. return list(self)
  93. def cpu(self):
  94. """Returns a CPU copy of this storage if it's not already on the CPU"""
  95. if self.device.type != 'cpu':
  96. return torch._UntypedStorage(self.size()).copy_(self, False)
  97. else:
  98. return self
  99. def _to(self, dtype):
  100. if not isinstance(dtype, torch.dtype):
  101. raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
  102. storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage()
  103. if storage.data_ptr() == self.data_ptr():
  104. storage = storage.clone()
  105. return storage
  106. def double(self):
  107. """Casts this storage to double type"""
  108. return self._to(torch.double)
  109. def float(self):
  110. """Casts this storage to float type"""
  111. return self._to(torch.float)
  112. def half(self):
  113. """Casts this storage to half type"""
  114. return self._to(torch.half)
  115. def long(self):
  116. """Casts this storage to long type"""
  117. return self._to(torch.long)
  118. def int(self):
  119. """Casts this storage to int type"""
  120. return self._to(torch.int)
  121. def short(self):
  122. """Casts this storage to short type"""
  123. return self._to(torch.short)
  124. def char(self):
  125. """Casts this storage to char type"""
  126. return self._to(torch.int8)
  127. def byte(self):
  128. """Casts this storage to byte type"""
  129. return self._to(torch.uint8)
  130. def bool(self):
  131. """Casts this storage to bool type"""
  132. return self._to(torch.bool)
  133. def bfloat16(self):
  134. """Casts this storage to bfloat16 type"""
  135. return self._to(torch.bfloat16)
  136. def complex_double(self):
  137. """Casts this storage to complex double type"""
  138. return self._to(torch.cdouble)
  139. def complex_float(self):
  140. """Casts this storage to complex float type"""
  141. return self._to(torch.cfloat)
  142. def pin_memory(self):
  143. """Copies the storage to pinned memory, if it's not already pinned."""
  144. if self.is_cuda:
  145. raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
  146. import torch.cuda
  147. allocator = torch.cuda.memory._host_allocator() # type: ignore[attr-defined]
  148. return type(self)(self.size(), allocator=allocator).copy_(self)
  149. def share_memory_(self):
  150. """Moves the storage to shared memory.
  151. This is a no-op for storages already in shared memory and for CUDA
  152. storages, which do not need to be moved for sharing across processes.
  153. Storages in shared memory cannot be resized.
  154. Returns: self
  155. """
  156. from torch.multiprocessing import get_sharing_strategy
  157. if self.is_cuda:
  158. pass # CUDA doesn't use POSIX shared memory
  159. elif get_sharing_strategy() == 'file_system':
  160. self._share_filename_cpu_()
  161. else:
  162. self._share_fd_cpu_()
  163. return self
  164. @classmethod
  165. def _new_shared(cls, size, *, device='cpu'):
  166. """Creates a new storage in shared memory with the same data type"""
  167. from torch.multiprocessing import get_sharing_strategy
  168. device = torch.device(device)
  169. if device.type == 'cuda':
  170. return cls(size, device=device)
  171. elif get_sharing_strategy() == 'file_system':
  172. return cls._new_using_filename_cpu(size)
  173. else:
  174. return cls._new_using_fd_cpu(size)
  175. def _untyped(self):
  176. return self
  177. class _UntypedStorage(torch._C.StorageBase, _StorageBase):
  178. pass
  179. @property
  180. def is_cuda(self):
  181. return self.device.type == 'cuda'
  182. def _load_from_bytes(b):
  183. return torch.load(io.BytesIO(b))
  184. _StorageBase.type = _type # type: ignore[assignment]
  185. _StorageBase.cuda = _cuda # type: ignore[assignment]
  186. @lru_cache(maxsize=None)
  187. def _dtype_to_storage_type_map():
  188. # NOTE: We should no longer add dtypes to this map. This map
  189. # is only used for BC/FC with older PyTorch versions. Going forward,
  190. # new dtypes of _TypedStorage should not translate to a legacy
  191. # <type>Storage class. Instead, new dtypes of _TypedStorage should
  192. # be serialized as an _UntypedStorage paired with a torch.dtype
  193. return {
  194. torch.double: 'DoubleStorage',
  195. torch.float: 'FloatStorage',
  196. torch.half: 'HalfStorage',
  197. torch.long: 'LongStorage',
  198. torch.int: 'IntStorage',
  199. torch.int16: 'ShortStorage',
  200. torch.int8: 'CharStorage',
  201. torch.uint8: 'ByteStorage',
  202. torch.bool: 'BoolStorage',
  203. torch.bfloat16: 'BFloat16Storage',
  204. torch.cdouble: 'ComplexDoubleStorage',
  205. torch.cfloat: 'ComplexFloatStorage',
  206. torch.qint8: 'QInt8Storage',
  207. torch.qint32: 'QInt32Storage',
  208. torch.quint8: 'QUInt8Storage',
  209. torch.quint4x2: 'QUInt4x2Storage',
  210. torch.quint2x4: 'QUInt2x4Storage',
  211. }
  212. @lru_cache(maxsize=None)
  213. def _storage_type_to_dtype_map():
  214. dtype_map = {
  215. val: key for key, val in _dtype_to_storage_type_map().items()}
  216. return dtype_map
  217. def _get_storage_from_sequence(sequence, dtype, device):
  218. if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  219. interpret_dtypes = {
  220. torch.quint8: torch.uint8,
  221. torch.quint4x2: torch.uint8,
  222. torch.quint2x4: torch.uint8,
  223. torch.qint32: torch.int32,
  224. torch.qint8: torch.int8
  225. }
  226. tmp_tensor = torch.tensor(
  227. sequence,
  228. dtype=interpret_dtypes[dtype],
  229. device=device)
  230. else:
  231. tmp_tensor = torch.tensor(
  232. sequence,
  233. dtype=dtype,
  234. device=device)
  235. return tmp_tensor.storage()._untyped()
  236. def _isint(x):
  237. if HAS_NUMPY:
  238. return isinstance(x, (int, np.integer))
  239. else:
  240. return isinstance(x, int)
  241. class _TypedStorage:
  242. is_sparse = False
  243. dtype: torch.dtype
  244. def fill_(self, value):
  245. self[0:len(self)] = value
  246. return self
  247. def __new__(cls, *args, wrap_storage=None, dtype=None, device=None):
  248. if cls == torch.storage._LegacyStorage:
  249. raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
  250. if cls == _TypedStorage:
  251. return super().__new__(cls)
  252. else:
  253. arg_error_msg = (
  254. f'{cls}.__new__ received an invalid combination '
  255. f'of arguments. Expected one of:\n'
  256. ' * no arguments\n'
  257. ' * (int size)\n'
  258. ' * (Sequence data)\n'
  259. ' * (*, _UntypedStorage wrap_storage)')
  260. if device is not None:
  261. raise RuntimeError(
  262. arg_error_msg +
  263. "\nKeyword argument 'device' cannot be specified")
  264. if dtype is not None:
  265. raise RuntimeError(
  266. arg_error_msg +
  267. "\nKeyword argument 'dtype' cannot be specified")
  268. if wrap_storage is None:
  269. if len(args) > 1:
  270. raise RuntimeError(
  271. arg_error_msg +
  272. "\nToo many positional arguments")
  273. if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence):
  274. raise TypeError(
  275. arg_error_msg +
  276. f"\nArgument type not recognized: {type(args[0])}")
  277. return _TypedStorage(
  278. *args,
  279. dtype=cls.dtype,
  280. device='cuda' if eval(cls.__module__) is torch.cuda else 'cpu')
  281. else:
  282. if len(args) != 0:
  283. raise RuntimeError(
  284. arg_error_msg +
  285. "\nNo positional arguments should be given when using "
  286. "'wrap_storage'")
  287. if not isinstance(wrap_storage, torch._UntypedStorage):
  288. raise TypeError(
  289. arg_error_msg +
  290. f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
  291. cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
  292. if wrap_storage.device.type != cls_device:
  293. raise RuntimeError(
  294. arg_error_msg +
  295. f"\nDevice of 'wrap_storage' must be {cls_device}"
  296. f", but got {wrap_storage.device.type}")
  297. return _TypedStorage(
  298. *args,
  299. wrap_storage=wrap_storage,
  300. dtype=cls.dtype)
  301. def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
  302. arg_error_msg = (
  303. '_TypedStorage.__init__ received an invalid combination '
  304. 'of arguments. Expected one of:\n'
  305. ' * (*, torch.device device, torch.dtype dtype)\n'
  306. ' * (int size, *, torch.device device, torch.dtype dtype)\n'
  307. ' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
  308. ' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)')
  309. if wrap_storage is not None:
  310. if len(args) != 0:
  311. raise RuntimeError(
  312. arg_error_msg +
  313. "\nNo positional arguments should be given when using "
  314. "'wrap_storage'")
  315. if dtype is None:
  316. raise RuntimeError(
  317. arg_error_msg +
  318. "\nArgument 'dtype' must be specified")
  319. if not isinstance(dtype, torch.dtype):
  320. raise TypeError(
  321. arg_error_msg +
  322. f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}")
  323. if device is not None:
  324. raise RuntimeError(
  325. arg_error_msg +
  326. "\nArgument 'device' should not be specified when 'wrap_storage' is given")
  327. self.dtype = dtype
  328. if not isinstance(wrap_storage, torch._UntypedStorage):
  329. raise TypeError(
  330. arg_error_msg +
  331. f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
  332. self._storage = wrap_storage
  333. else:
  334. self.dtype = torch.get_default_dtype() if dtype is None else dtype
  335. device = torch.device('cpu' if device is None else device)
  336. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  337. if device.type == 'cuda':
  338. raise RuntimeError("Cannot create CUDA storage with quantized dtype")
  339. if len(args) == 0:
  340. self._storage = torch._UntypedStorage(device=device)
  341. elif len(args) == 1:
  342. if _isint(args[0]):
  343. self._storage = torch._UntypedStorage(int(args[0]) * self.element_size(), device=device)
  344. elif isinstance(args[0], collections.abc.Sequence):
  345. self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
  346. else:
  347. raise TypeError(
  348. arg_error_msg +
  349. f"\nArgument type not recognized: {type(args[0])}")
  350. else:
  351. raise RuntimeError(
  352. arg_error_msg +
  353. "\nToo many positional arguments")
  354. @property
  355. def is_cuda(self):
  356. return self.device.type == 'cuda'
  357. def _untyped(self):
  358. return self._storage
  359. def _new_wrapped_storage(self, untyped_storage):
  360. assert type(untyped_storage) == torch._UntypedStorage
  361. if type(self) == _TypedStorage:
  362. return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
  363. else:
  364. return type(self)(wrap_storage=untyped_storage)
  365. def __len__(self):
  366. return self._storage.nbytes() // self.element_size()
  367. def _maybe_wrap_index(self, idx, is_stop=False):
  368. if idx is None:
  369. if is_stop:
  370. return self.size()
  371. else:
  372. return 0
  373. else:
  374. if type(idx) != int:
  375. raise TypeError(
  376. f"can't index a {type(self)} with {type(idx)}")
  377. if is_stop:
  378. if (idx > self.size()) or (idx < -self.size()):
  379. raise IndexError(
  380. f'index {idx} out of range for storage of size {self.size()}')
  381. if idx > 0:
  382. return idx
  383. else:
  384. return idx % self.size()
  385. else:
  386. if (idx >= self.size()) or (idx < -self.size()):
  387. raise IndexError(
  388. f'index {idx} out of range for storage of size {self.size()}')
  389. return idx % self.size()
  390. def __setitem__(self, idx, value):
  391. if not isinstance(idx, (int, slice)):
  392. raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
  393. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  394. interpret_dtypes = {
  395. torch.quint8: torch.uint8,
  396. torch.quint4x2: torch.uint8,
  397. torch.quint2x4: torch.uint8,
  398. torch.qint32: torch.int32,
  399. torch.qint8: torch.int8
  400. }
  401. tmp_dtype = interpret_dtypes[self.dtype]
  402. tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage(
  403. wrap_storage=self._storage,
  404. dtype=tmp_dtype))
  405. else:
  406. tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
  407. tmp_tensor[idx] = value
  408. def __getitem__(self, idx):
  409. # NOTE: Before _TypedStorage existed, indexing with a slice used to be
  410. # possible for <type>Storage objects. However, it would return
  411. # a storage view, which would be a hassle to implement in _TypedStorage,
  412. # so it was disabled
  413. if isinstance(idx, slice):
  414. raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__')
  415. elif not isinstance(idx, int):
  416. raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
  417. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  418. interpret_dtypes = {
  419. torch.quint8: torch.uint8,
  420. torch.quint4x2: torch.uint8,
  421. torch.quint2x4: torch.uint8,
  422. torch.qint32: torch.int32,
  423. torch.qint8: torch.int8
  424. }
  425. return _TypedStorage(
  426. wrap_storage=self._storage,
  427. dtype=interpret_dtypes[self.dtype])[idx]
  428. idx_wrapped = self._maybe_wrap_index(idx)
  429. tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
  430. return tmp_tensor[idx_wrapped].item()
  431. def copy_(self, source: T, non_blocking: bool = None):
  432. self._storage.copy_(source._untyped(), non_blocking)
  433. return self
  434. def nbytes(self):
  435. return self._storage.nbytes()
  436. def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
  437. if dtype is None:
  438. legacy_class = self._get_legacy_storage_class()
  439. if legacy_class is not None:
  440. return legacy_class.__module__ + '.' + legacy_class.__name__
  441. return '.'.join([self.__module__, type(self).__name__])
  442. else:
  443. return self._storage.type(dtype, non_blocking)
  444. def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
  445. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  446. raise RuntimeError("Cannot create CUDA storage with quantized dtype")
  447. cuda_storage: torch._UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
  448. return self._new_wrapped_storage(cuda_storage)
  449. def element_size(self):
  450. return torch._utils._element_size(self.dtype)
  451. def get_device(self) -> int:
  452. return self._storage.get_device()
  453. def __str__(self):
  454. data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
  455. return data_str + (
  456. f'\n[{torch.typename(self)}(dtype={self.dtype}, '
  457. f'device={self.device}) of size {len(self)}]')
  458. def __repr__(self):
  459. return str(self)
  460. def __iter__(self):
  461. return iter(map(lambda i: self[i], range(self.size())))
  462. def __copy__(self):
  463. return self._new_wrapped_storage(copy.copy(self._storage))
  464. def __deepcopy__(self, memo):
  465. return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))
  466. def __sizeof__(self):
  467. return super(_TypedStorage, self).__sizeof__() + self.nbytes()
  468. def clone(self):
  469. """Returns a copy of this storage"""
  470. return self._new_wrapped_storage(self._storage.clone())
  471. def tolist(self):
  472. """Returns a list containing the elements of this storage"""
  473. return list(self)
  474. def cpu(self):
  475. """Returns a CPU copy of this storage if it's not already on the CPU"""
  476. return self._new_wrapped_storage(self._storage.cpu())
  477. def pin_memory(self):
  478. """Coppies the storage to pinned memory, if it's not already pinned."""
  479. return self._new_wrapped_storage(self._storage.pin_memory())
  480. def share_memory_(self):
  481. """Moves the storage to shared memory.
  482. This is a no-op for storages already in shared memory and for CUDA
  483. storages, which do not need to be moved for sharing across processes.
  484. Storages in shared memory cannot be resized.
  485. Returns: self
  486. """
  487. self._storage.share_memory_()
  488. return self
  489. def _new_shared(self, size, *, device=None):
  490. """Creates a new storage in shared memory with the same data type"""
  491. if device is None:
  492. device = 'cpu'
  493. device = torch.device(device)
  494. untyped_storage = torch._UntypedStorage._new_shared(size * self.element_size(), device=device)
  495. return _TypedStorage(
  496. wrap_storage=untyped_storage,
  497. dtype=self.dtype)
  498. @property
  499. def _cdata(self):
  500. return self._storage._cdata
  501. @property
  502. def device(self):
  503. return self._storage.device
  504. def size(self):
  505. return len(self)
  506. def pickle_storage_type(self):
  507. try:
  508. return _dtype_to_storage_type_map()[self.dtype]
  509. except KeyError:
  510. raise KeyError(f'dtype {self.dtype} is not recognized')
  511. def __reduce__(self):
  512. b = io.BytesIO()
  513. torch.save(self, b, _use_new_zipfile_serialization=False)
  514. return (_load_from_bytes, (b.getvalue(),))
  515. def data_ptr(self):
  516. return self._storage.data_ptr()
  517. def resize_(self, size):
  518. self._storage.resize_(size * self.element_size())
  519. @classmethod
  520. def _free_weak_ref(cls, *args, **kwargs):
  521. return _UntypedStorage._free_weak_ref(*args, **kwargs)
  522. def _weak_ref(self, *args, **kwargs):
  523. return self._storage._weak_ref(*args, **kwargs)
  524. @classmethod
  525. def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
  526. if cls == _TypedStorage:
  527. dtype = torch.get_default_dtype() if dtype is None else dtype
  528. device = torch.device('cpu' if device is None else device)
  529. if device.type != 'cpu':
  530. raise RuntimeError(f'_TypedStorage.from_buffer: Not available for device {device.type}')
  531. untyped_storage: torch._UntypedStorage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
  532. else:
  533. if dtype is not None or len(args) == 5:
  534. raise RuntimeError((
  535. "from_buffer: 'dtype' can only be specified in "
  536. "_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
  537. if device is not None:
  538. raise RuntimeError((
  539. "from_buffer: 'device' can only be specified in "
  540. "_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
  541. dtype = cls.dtype
  542. untyped_storage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
  543. return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
  544. def _to(self, dtype):
  545. if not isinstance(dtype, torch.dtype):
  546. raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
  547. storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage()
  548. if storage.data_ptr() == self.data_ptr():
  549. storage = storage.clone()
  550. return storage
  551. def double(self):
  552. """Casts this storage to double type"""
  553. return self._to(torch.double)
  554. def float(self):
  555. """Casts this storage to float type"""
  556. return self._to(torch.float)
  557. def half(self):
  558. """Casts this storage to half type"""
  559. return self._to(torch.half)
  560. def long(self):
  561. """Casts this storage to long type"""
  562. return self._to(torch.long)
  563. def int(self):
  564. """Casts this storage to int type"""
  565. return self._to(torch.int)
  566. def short(self):
  567. """Casts this storage to short type"""
  568. return self._to(torch.short)
  569. def char(self):
  570. """Casts this storage to char type"""
  571. return self._to(torch.int8)
  572. def byte(self):
  573. """Casts this storage to byte type"""
  574. return self._to(torch.uint8)
  575. def bool(self):
  576. """Casts this storage to bool type"""
  577. return self._to(torch.bool)
  578. def bfloat16(self):
  579. """Casts this storage to bfloat16 type"""
  580. return self._to(torch.bfloat16)
  581. def complex_double(self):
  582. """Casts this storage to complex double type"""
  583. return self._to(torch.cdouble)
  584. def complex_float(self):
  585. """Casts this storage to complex float type"""
  586. return self._to(torch.cfloat)
  587. @classmethod
  588. def from_file(cls, filename, shared, size):
  589. """
  590. from_file(filename, shared=False, size=0) -> Storage
  591. If `shared` is `True`, then memory is shared between all processes.
  592. All changes are written to the file. If `shared` is `False`, then the changes on
  593. the storage do not affect the file.
  594. `size` is the number of elements in the storage. If `shared` is `False`,
  595. then the file must contain at least `size * sizeof(Type)` bytes
  596. (`Type` is the type of storage). If `shared` is `True` the file will be
  597. created if needed.
  598. Args:
  599. filename (str): file name to map
  600. shared (bool): whether to share memory
  601. size (int): number of elements in the storage
  602. """
  603. if cls == _TypedStorage:
  604. raise RuntimeError('from_file can only be called on derived classes')
  605. untyped_storage = eval(cls.__module__)._UntypedStorage.from_file(
  606. filename,
  607. shared,
  608. size * torch._utils._element_size(cls.dtype))
  609. storage = cls(wrap_storage=untyped_storage)
  610. return storage
  611. @classmethod
  612. def _expired(cls, *args, **kwargs):
  613. return eval(cls.__module__)._UntypedStorage._expired(*args, **kwargs)
  614. def is_pinned(self):
  615. return self._storage.is_pinned()
  616. def _write_file(self, *args, **kwargs):
  617. return self._storage._write_file(*args, **kwargs)
  618. def _set_from_file(self, *args, **kwargs):
  619. return self._storage._set_from_file(*args, **kwargs)
  620. def _set_cdata(self, *args, **kwargs):
  621. return self._storage._set_cdata(*args, **kwargs)
  622. def _share_cuda_(self, *args, **kwargs):
  623. return self._storage._share_cuda_(*args, **kwargs)
  624. def is_shared(self):
  625. return self._storage.is_shared()
  626. @classmethod
  627. def _new_shared_cuda(cls, *args, **kwargs):
  628. return torch._UntypedStorage._new_shared_cuda(*args, **kwargs)
  629. def _share_filename_cpu_(self, *args, **kwargs):
  630. manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs)
  631. return manager_handle, storage_handle, size // self.element_size()
  632. def _shared_decref(self):
  633. self._storage._shared_decref()
  634. return self
  635. @classmethod
  636. def _release_ipc_counter(cls, *args, device=None, **kwargs):
  637. return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
  638. def _shared_incref(self, *args, **kwargs):
  639. return self._storage._shared_incref(*args, **kwargs)
  640. def _share_fd_cpu_(self, *args, **kwargs):
  641. fd, size = self._storage._share_fd_cpu_(*args, **kwargs)
  642. return fd, size // self.element_size()
  643. def _get_legacy_storage_class(self):
  644. if self.dtype not in _dtype_to_storage_type_map():
  645. return None
  646. storage_name = _dtype_to_storage_type_map()[self.dtype]
  647. if self.device.type not in ['cpu', 'cuda']:
  648. return None
  649. module = 'torch.' if self.device.type == 'cpu' else 'torch.cuda.'
  650. try:
  651. return eval(module + storage_name)
  652. except AttributeError:
  653. return None
  654. _TypedStorage.type.__doc__ = _type.__doc__
  655. _TypedStorage.cuda.__doc__ = _cuda.__doc__
  656. class _LegacyStorageMeta(type):
  657. dtype: torch.dtype
  658. def __instancecheck__(cls, instance):
  659. if type(instance) == _TypedStorage:
  660. cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
  661. return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
  662. return False
  663. class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
  664. @classmethod
  665. def _new_shared(cls, size):
  666. """Creates a new storage in shared memory with the same data type"""
  667. untyped_storage = torch._UntypedStorage._new_shared(size * cls().element_size())
  668. return cls(wrap_storage=untyped_storage)
  669. @classmethod
  670. def _release_ipc_counter(cls, *args, **kwargs):
  671. return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
  672. @classmethod
  673. def _new_shared_filename(cls, manager, obj, size):
  674. bytes_size = size * torch._utils._element_size(cls.dtype)
  675. return cls(wrap_storage=torch._UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
  676. def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
  677. try:
  678. return _storage_type_to_dtype_map()[pickle_storage_type]
  679. except KeyError:
  680. raise KeyError(
  681. f'pickle storage type "{pickle_storage_type}" is not recognized')