serialization.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057
  1. import difflib
  2. import os
  3. import io
  4. import shutil
  5. import struct
  6. import sys
  7. import torch
  8. import tarfile
  9. import tempfile
  10. import warnings
  11. from contextlib import closing, contextmanager
  12. from ._utils import _import_dotted_name
  13. from ._six import string_classes as _string_classes
  14. from torch._sources import get_source_lines_and_file
  15. from torch.types import Storage
  16. from torch.storage import _get_dtype_from_pickle_storage_type
  17. from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO
  18. import copyreg
  19. import pickle
  20. import pathlib
  21. DEFAULT_PROTOCOL = 2
  22. LONG_SIZE = struct.Struct('=l').size
  23. INT_SIZE = struct.Struct('=i').size
  24. SHORT_SIZE = struct.Struct('=h').size
  25. MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
  26. PROTOCOL_VERSION = 1001
  27. STORAGE_KEY_SEPARATOR = ','
  28. class SourceChangeWarning(Warning):
  29. pass
  30. @contextmanager
  31. def mkdtemp():
  32. path = tempfile.mkdtemp()
  33. yield path
  34. shutil.rmtree(path)
  35. _package_registry = []
  36. def _is_zipfile(f) -> bool:
  37. # This is a stricter implementation than zipfile.is_zipfile().
  38. # zipfile.is_zipfile() is True if the magic number appears anywhere in the
  39. # binary. Since we expect the files here to be generated by torch.save or
  40. # torch.jit.save, it's safe to only check the start bytes and avoid
  41. # collisions and assume the zip has only 1 file.
  42. # See bugs.python.org/issue28494.
  43. # Read the first 4 bytes of the file
  44. read_bytes = []
  45. start = f.tell()
  46. byte = f.read(1)
  47. while byte != "":
  48. read_bytes.append(byte)
  49. if len(read_bytes) == 4:
  50. break
  51. byte = f.read(1)
  52. f.seek(start)
  53. local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
  54. return read_bytes == local_header_magic_number
  55. def register_package(priority, tagger, deserializer):
  56. queue_elem = (priority, tagger, deserializer)
  57. _package_registry.append(queue_elem)
  58. _package_registry.sort()
  59. def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
  60. '''
  61. Check if a module's version satisfies requirements
  62. Usually, a module's version string will be like 'x.y.z', which would be represented
  63. as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
  64. string does not match the given tuple's format up to the length of the tuple, then
  65. error and exit or emit a warning.
  66. Args:
  67. module: the module to check the version of
  68. req_version_tuple: tuple (usually of ints) representing the required version
  69. error_if_malformed: whether we should exit if module version string is malformed
  70. Returns:
  71. requirement_is_met: bool
  72. '''
  73. try:
  74. version_strs = module.__version__.split('.')
  75. # Cast module version fields to match the types of the required version
  76. module_version = tuple(
  77. type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
  78. )
  79. requirement_is_met = module_version >= req_version_tuple
  80. except Exception as e:
  81. message = (
  82. "'%s' module version string is malformed '%s' and cannot be compared"
  83. " with tuple %s"
  84. ) % (
  85. module.__name__, module.__version__, str(req_version_tuple)
  86. )
  87. if error_if_malformed:
  88. raise RuntimeError(message) from e
  89. else:
  90. warnings.warn(message + ', but continuing assuming that requirement is met')
  91. requirement_is_met = True
  92. return requirement_is_met
  93. def _cpu_tag(obj):
  94. if obj.device.type == 'cpu':
  95. return 'cpu'
  96. def _cuda_tag(obj):
  97. if obj.device.type == 'cuda':
  98. return 'cuda:' + str(obj.device.index)
  99. def _cpu_deserialize(obj, location):
  100. if location == 'cpu':
  101. return obj
  102. def validate_cuda_device(location):
  103. device = torch.cuda._utils._get_device_index(location, True)
  104. if not torch.cuda.is_available():
  105. raise RuntimeError('Attempting to deserialize object on a CUDA '
  106. 'device but torch.cuda.is_available() is False. '
  107. 'If you are running on a CPU-only machine, '
  108. 'please use torch.load with map_location=torch.device(\'cpu\') '
  109. 'to map your storages to the CPU.')
  110. device_count = torch.cuda.device_count()
  111. if device >= device_count:
  112. raise RuntimeError('Attempting to deserialize object on CUDA device '
  113. f'{device} but torch.cuda.device_count() is {device_count}. Please use '
  114. 'torch.load with map_location to map your storages '
  115. 'to an existing device.')
  116. return device
  117. def _cuda_deserialize(obj, location):
  118. if location.startswith('cuda'):
  119. device = validate_cuda_device(location)
  120. if getattr(obj, "_torch_load_uninitialized", False):
  121. with torch.cuda.device(device):
  122. return torch._UntypedStorage(obj.nbytes(), device=torch.device(location))
  123. else:
  124. return obj.cuda(device)
  125. register_package(10, _cpu_tag, _cpu_deserialize)
  126. register_package(20, _cuda_tag, _cuda_deserialize)
  127. def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
  128. for _, tagger, _ in _package_registry:
  129. location = tagger(storage)
  130. if location:
  131. return location
  132. raise RuntimeError("don't know how to determine data location of "
  133. + torch.typename(storage))
  134. def default_restore_location(storage, location):
  135. for _, _, fn in _package_registry:
  136. result = fn(storage, location)
  137. if result is not None:
  138. return result
  139. raise RuntimeError("don't know how to restore data location of "
  140. + torch.typename(storage) + " (tagged with "
  141. + location + ")")
  142. def normalize_storage_type(storage_type):
  143. return getattr(torch, storage_type.__name__)
  144. def storage_to_tensor_type(storage):
  145. storage_type = type(storage)
  146. module = _import_dotted_name(storage_type.__module__)
  147. return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
  148. def _is_path(name_or_buffer):
  149. return isinstance(name_or_buffer, str) or \
  150. isinstance(name_or_buffer, pathlib.Path)
  151. class _opener(object):
  152. def __init__(self, file_like):
  153. self.file_like = file_like
  154. def __enter__(self):
  155. return self.file_like
  156. def __exit__(self, *args):
  157. pass
  158. class _open_file(_opener):
  159. def __init__(self, name, mode):
  160. super(_open_file, self).__init__(open(name, mode))
  161. def __exit__(self, *args):
  162. self.file_like.close()
  163. class _open_buffer_reader(_opener):
  164. def __init__(self, buffer):
  165. super(_open_buffer_reader, self).__init__(buffer)
  166. _check_seekable(buffer)
  167. class _open_buffer_writer(_opener):
  168. def __exit__(self, *args):
  169. self.file_like.flush()
  170. def _open_file_like(name_or_buffer, mode):
  171. if _is_path(name_or_buffer):
  172. return _open_file(name_or_buffer, mode)
  173. else:
  174. if 'w' in mode:
  175. return _open_buffer_writer(name_or_buffer)
  176. elif 'r' in mode:
  177. return _open_buffer_reader(name_or_buffer)
  178. else:
  179. raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
  180. class _open_zipfile_reader(_opener):
  181. def __init__(self, name_or_buffer) -> None:
  182. super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
  183. class _open_zipfile_writer_file(_opener):
  184. def __init__(self, name) -> None:
  185. super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
  186. def __exit__(self, *args) -> None:
  187. self.file_like.write_end_of_file()
  188. class _open_zipfile_writer_buffer(_opener):
  189. def __init__(self, buffer) -> None:
  190. self.buffer = buffer
  191. super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
  192. def __exit__(self, *args) -> None:
  193. self.file_like.write_end_of_file()
  194. self.buffer.flush()
  195. def _open_zipfile_writer(name_or_buffer):
  196. container: Type[_opener]
  197. if _is_path(name_or_buffer):
  198. container = _open_zipfile_writer_file
  199. else:
  200. container = _open_zipfile_writer_buffer
  201. return container(name_or_buffer)
  202. def _is_compressed_file(f) -> bool:
  203. compress_modules = ['gzip']
  204. try:
  205. return f.__module__ in compress_modules
  206. except AttributeError:
  207. return False
  208. def _should_read_directly(f):
  209. """
  210. Checks if f is a file that should be read directly. It should be read
  211. directly if it is backed by a real file (has a fileno) and is not a
  212. a compressed file (e.g. gzip)
  213. """
  214. if _is_compressed_file(f):
  215. return False
  216. try:
  217. return f.fileno() >= 0
  218. except io.UnsupportedOperation:
  219. return False
  220. except AttributeError:
  221. return False
  222. def _check_seekable(f) -> bool:
  223. def raise_err_msg(patterns, e):
  224. for p in patterns:
  225. if p in str(e):
  226. msg = (str(e) + ". You can only torch.load from a file that is seekable."
  227. + " Please pre-load the data into a buffer like io.BytesIO and"
  228. + " try to load from it instead.")
  229. raise type(e)(msg)
  230. raise e
  231. try:
  232. f.seek(f.tell())
  233. return True
  234. except (io.UnsupportedOperation, AttributeError) as e:
  235. raise_err_msg(["seek", "tell"], e)
  236. return False
  237. def _check_dill_version(pickle_module) -> None:
  238. '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
  239. If dill version is lower than 0.3.1, a ValueError is raised.
  240. Args:
  241. pickle_module: module used for pickling metadata and objects
  242. '''
  243. if pickle_module.__name__ == 'dill':
  244. required_dill_version = (0, 3, 1)
  245. if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
  246. raise ValueError((
  247. "'torch' supports dill >= %s, but you have dill %s."
  248. " Please upgrade dill or switch to 'pickle'"
  249. ) % (
  250. '.'.join([str(num) for num in required_dill_version]),
  251. pickle_module.__version__
  252. ))
  253. def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
  254. pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
  255. # Reference: https://github.com/pytorch/pytorch/issues/54354
  256. # The first line of this docstring overrides the one Sphinx generates for the
  257. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  258. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  259. """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
  260. Saves an object to a disk file.
  261. See also: :ref:`saving-loading-tensors`
  262. Args:
  263. obj: saved object
  264. f: a file-like object (has to implement write and flush) or a string or
  265. os.PathLike object containing a file name
  266. pickle_module: module used for pickling metadata and objects
  267. pickle_protocol: can be specified to override the default protocol
  268. .. note::
  269. A common PyTorch convention is to save tensors using .pt file extension.
  270. .. note::
  271. PyTorch preserves storage sharing across serialization. See
  272. :ref:`preserve-storage-sharing` for more details.
  273. .. note::
  274. The 1.6 release of PyTorch switched ``torch.save`` to use a new
  275. zipfile-based file format. ``torch.load`` still retains the ability to
  276. load files in the old format. If for any reason you want ``torch.save``
  277. to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
  278. Example:
  279. >>> # Save to file
  280. >>> x = torch.tensor([0, 1, 2, 3, 4])
  281. >>> torch.save(x, 'tensor.pt')
  282. >>> # Save to io.BytesIO buffer
  283. >>> buffer = io.BytesIO()
  284. >>> torch.save(x, buffer)
  285. """
  286. _check_dill_version(pickle_module)
  287. with _open_file_like(f, 'wb') as opened_file:
  288. if _use_new_zipfile_serialization:
  289. with _open_zipfile_writer(opened_file) as opened_zipfile:
  290. _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  291. return
  292. _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  293. def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
  294. import torch.nn as nn
  295. serialized_container_types = {}
  296. serialized_storages = {}
  297. # Since loading storages that view the same data with different dtypes is
  298. # not supported, we need to keep track of the dtype associated with each
  299. # storage data_ptr and throw an error if the dtype is ever different.
  300. # TODO: This feature could be added in the future
  301. storage_dtypes: Dict[int, torch.dtype] = {}
  302. def persistent_id(obj: Any) -> Optional[Tuple]:
  303. # FIXME: the docs say that persistent_id should only return a string
  304. # but torch store returns tuples. This works only in the binary protocol
  305. # see
  306. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  307. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  308. if isinstance(obj, type) and issubclass(obj, nn.Module):
  309. if obj in serialized_container_types:
  310. return None
  311. serialized_container_types[obj] = True
  312. source_file = source = None
  313. try:
  314. source_lines, _, source_file = get_source_lines_and_file(obj)
  315. source = ''.join(source_lines)
  316. except Exception: # saving the source is optional, so we can ignore any errors
  317. warnings.warn("Couldn't retrieve source code for container of "
  318. "type " + obj.__name__ + ". It won't be checked "
  319. "for correctness upon loading.")
  320. return ('module', obj, source_file, source)
  321. if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
  322. storage: torch._UntypedStorage
  323. if isinstance(obj, torch.storage._TypedStorage):
  324. # TODO: Once we decide to break serialization FC, this case
  325. # can be deleted
  326. storage = obj._storage
  327. storage_dtype = obj.dtype
  328. storage_type_str = obj.pickle_storage_type()
  329. storage_type = getattr(torch, storage_type_str)
  330. dtype = obj.dtype
  331. storage_numel = obj.size()
  332. elif isinstance(obj, torch._UntypedStorage):
  333. storage = obj
  334. storage_dtype = torch.uint8
  335. storage_type = normalize_storage_type(type(obj))
  336. dtype = torch.uint8
  337. storage_numel = storage.nbytes()
  338. else:
  339. raise TypeError(f'type not recognized: {type(obj)}')
  340. # If storage is allocated, ensure that any other saved storages
  341. # pointing to the same data all have the same dtype. If storage is
  342. # not allocated, don't perform this check
  343. if storage.data_ptr() != 0:
  344. if storage.data_ptr() in storage_dtypes:
  345. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  346. raise RuntimeError(
  347. 'Cannot save multiple tensors or storages that '
  348. 'view the same data as different types')
  349. else:
  350. storage_dtypes[storage.data_ptr()] = storage_dtype
  351. view_metadata: Optional[Tuple[str, int, int]]
  352. # Offset is always 0, but we keep it for backwards compatibility
  353. # with the old serialization format (which supported storage views)
  354. offset = 0
  355. storage_key = str(storage._cdata)
  356. location = location_tag(storage)
  357. # TODO: There's an issue here with FC. It might be impossible to
  358. # solve, but it's worth noting. Imagine we save a list `[storage,
  359. # tensor]`, where `tensor.storage()` is the same as `storage`, and
  360. # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
  361. # torch.float`. The storage will be serialized with element size
  362. # of 1, since we're choosing to serialize the first occurance of
  363. # a duplicate storage. Since this legacy serialization format saves
  364. # the numel of the storage, rather than nbytes directly, we'll be
  365. # effectively saving nbytes in this case. We'll be able to load it
  366. # and the tensor back up with no problems in _this_ and future
  367. # versions of pytorch, but in older versions, here's the problem:
  368. # the storage will be loaded up as a _UntypedStorage, and then the
  369. # FloatTensor will loaded and the _UntypedStorage will be assigned to
  370. # it. Since the storage dtype does not match the tensor dtype, this
  371. # will cause an error. If we reverse the list, like `[tensor,
  372. # storage]`, then we will save the `tensor.storage()` as a faked
  373. # `FloatStorage`, and the saved size will be the correct
  374. # dtype-specific numel count that old versions expect. `tensor`
  375. # will be able to load up properly in old versions, pointing to
  376. # a FloatStorage. However, `storage` is still being translated to
  377. # a _UntypedStorage, and it will try to resolve to the same
  378. # FloatStorage that `tensor` contains. This will also cause an
  379. # error. It doesn't seem like there's any way around this.
  380. # Probably, we just cannot maintain FC for the legacy format if the
  381. # saved list contains both a tensor and a storage that point to the
  382. # same data. We should still be able to maintain FC for lists of
  383. # just tensors, as long as all views share the same dtype as the
  384. # tensor they are viewing.
  385. if storage_key not in serialized_storages:
  386. serialized_storages[storage_key] = (storage, dtype)
  387. is_view = storage._cdata != storage._cdata
  388. if is_view:
  389. view_metadata = (str(storage._cdata), offset, storage.nbytes())
  390. else:
  391. view_metadata = None
  392. res = ('storage',
  393. storage_type,
  394. storage_key,
  395. location,
  396. storage_numel,
  397. view_metadata)
  398. return res
  399. return None
  400. sys_info = dict(
  401. protocol_version=PROTOCOL_VERSION,
  402. little_endian=sys.byteorder == 'little',
  403. type_sizes=dict(
  404. short=SHORT_SIZE,
  405. int=INT_SIZE,
  406. long=LONG_SIZE,
  407. ),
  408. )
  409. pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
  410. pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
  411. pickle_module.dump(sys_info, f, protocol=pickle_protocol)
  412. pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
  413. pickler.persistent_id = persistent_id
  414. pickler.dump(obj)
  415. serialized_storage_keys = sorted(serialized_storages.keys())
  416. pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
  417. f.flush()
  418. for key in serialized_storage_keys:
  419. storage, dtype = serialized_storages[key]
  420. storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
  421. def _save(obj, zip_file, pickle_module, pickle_protocol):
  422. serialized_storages = {}
  423. id_map: Dict[int, str] = {}
  424. # Since loading storages that view the same data with different dtypes is
  425. # not supported, we need to keep track of the dtype associated with each
  426. # storage data_ptr and throw an error if the dtype is ever different.
  427. # TODO: This feature could be added in the future
  428. storage_dtypes: Dict[int, torch.dtype] = {}
  429. def persistent_id(obj):
  430. # FIXME: the docs say that persistent_id should only return a string
  431. # but torch store returns tuples. This works only in the binary protocol
  432. # see
  433. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  434. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  435. if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
  436. if isinstance(obj, torch.storage._TypedStorage):
  437. # TODO: Once we decide to break serialization FC, this case
  438. # can be deleted
  439. storage = obj._storage
  440. storage_dtype = obj.dtype
  441. storage_type_str = obj.pickle_storage_type()
  442. storage_type = getattr(torch, storage_type_str)
  443. storage_numel = obj.size()
  444. else:
  445. storage = obj
  446. storage_dtype = torch.uint8
  447. storage_type = normalize_storage_type(type(obj))
  448. storage_numel = storage.nbytes()
  449. # If storage is allocated, ensure that any other saved storages
  450. # pointing to the same data all have the same dtype. If storage is
  451. # not allocated, don't perform this check
  452. if storage.data_ptr() != 0:
  453. if storage.data_ptr() in storage_dtypes:
  454. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  455. raise RuntimeError(
  456. 'Cannot save multiple tensors or storages that '
  457. 'view the same data as different types')
  458. else:
  459. storage_dtypes[storage.data_ptr()] = storage_dtype
  460. storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
  461. location = location_tag(storage)
  462. serialized_storages[storage_key] = storage
  463. return ('storage',
  464. storage_type,
  465. storage_key,
  466. location,
  467. storage_numel)
  468. return None
  469. # Write the pickle data for `obj`
  470. data_buf = io.BytesIO()
  471. pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
  472. pickler.persistent_id = persistent_id
  473. pickler.dump(obj)
  474. data_value = data_buf.getvalue()
  475. zip_file.write_record('data.pkl', data_value, len(data_value))
  476. # Write each tensor to a file named tensor/the_tensor_key in the zip archive
  477. for key in sorted(serialized_storages.keys()):
  478. name = f'data/{key}'
  479. storage = serialized_storages[key]
  480. # given that we copy things around anyway, we might use storage.cpu()
  481. # this means to that to get tensors serialized, you need to implement
  482. # .cpu() on the underlying Storage
  483. if storage.device.type != 'cpu':
  484. storage = storage.cpu()
  485. # Now that it is on the CPU we can directly copy it into the zip file
  486. num_bytes = storage.nbytes()
  487. zip_file.write_record(name, storage.data_ptr(), num_bytes)
  488. def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
  489. # Reference: https://github.com/pytorch/pytorch/issues/54354
  490. # The first line of this docstring overrides the one Sphinx generates for the
  491. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  492. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  493. """load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
  494. Loads an object saved with :func:`torch.save` from a file.
  495. :func:`torch.load` uses Python's unpickling facilities but treats storages,
  496. which underlie tensors, specially. They are first deserialized on the
  497. CPU and are then moved to the device they were saved from. If this fails
  498. (e.g. because the run time system doesn't have certain devices), an exception
  499. is raised. However, storages can be dynamically remapped to an alternative
  500. set of devices using the :attr:`map_location` argument.
  501. If :attr:`map_location` is a callable, it will be called once for each serialized
  502. storage with two arguments: storage and location. The storage argument
  503. will be the initial deserialization of the storage, residing on the CPU.
  504. Each serialized storage has a location tag associated with it which
  505. identifies the device it was saved from, and this tag is the second
  506. argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
  507. for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
  508. :attr:`map_location` should return either ``None`` or a storage. If
  509. :attr:`map_location` returns a storage, it will be used as the final deserialized
  510. object, already moved to the right device. Otherwise, :func:`torch.load` will
  511. fall back to the default behavior, as if :attr:`map_location` wasn't specified.
  512. If :attr:`map_location` is a :class:`torch.device` object or a string containing
  513. a device tag, it indicates the location where all tensors should be loaded.
  514. Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
  515. appearing in the file (keys), to ones that specify where to put the
  516. storages (values).
  517. User extensions can register their own location tags and tagging and
  518. deserialization methods using :func:`torch.serialization.register_package`.
  519. Args:
  520. f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
  521. or a string or os.PathLike object containing a file name
  522. map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
  523. locations
  524. pickle_module: module used for unpickling metadata and objects (has to
  525. match the :attr:`pickle_module` used to serialize file)
  526. pickle_load_args: (Python 3 only) optional keyword arguments passed over to
  527. :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
  528. :attr:`errors=...`.
  529. .. warning::
  530. :func:`torch.load()` uses ``pickle`` module implicitly, which is known to be insecure.
  531. It is possible to construct malicious pickle data which will execute arbitrary code
  532. during unpickling. Never load data that could have come from an untrusted
  533. source, or that could have been tampered with. **Only load data you trust**.
  534. .. note::
  535. When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
  536. will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
  537. and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
  538. .. note::
  539. By default, we decode byte strings as ``utf-8``. This is to avoid a common error
  540. case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
  541. when loading files saved by Python 2 in Python 3. If this default
  542. is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
  543. these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
  544. to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
  545. as byte arrays which can be decoded later with ``byte_array.decode(...)``.
  546. Example:
  547. >>> torch.load('tensors.pt')
  548. # Load all tensors onto the CPU
  549. >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
  550. # Load all tensors onto the CPU, using a function
  551. >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
  552. # Load all tensors onto GPU 1
  553. >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
  554. # Map tensors from GPU 1 to GPU 0
  555. >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
  556. # Load tensor from io.BytesIO object
  557. >>> with open('tensor.pt', 'rb') as f:
  558. ... buffer = io.BytesIO(f.read())
  559. >>> torch.load(buffer)
  560. # Load a module with 'ascii' encoding for unpickling
  561. >>> torch.load('module.pt', encoding='ascii')
  562. """
  563. _check_dill_version(pickle_module)
  564. if 'encoding' not in pickle_load_args.keys():
  565. pickle_load_args['encoding'] = 'utf-8'
  566. with _open_file_like(f, 'rb') as opened_file:
  567. if _is_zipfile(opened_file):
  568. # The zipfile reader is going to advance the current file position.
  569. # If we want to actually tail call to torch.jit.load, we need to
  570. # reset back to the original position.
  571. orig_position = opened_file.tell()
  572. with _open_zipfile_reader(opened_file) as opened_zipfile:
  573. if _is_torchscript_zip(opened_zipfile):
  574. warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
  575. " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
  576. " silence this warning)", UserWarning)
  577. opened_file.seek(orig_position)
  578. return torch.jit.load(opened_file)
  579. return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  580. return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  581. # Register pickling support for layout instances such as
  582. # torch.sparse_coo, etc
  583. def _get_layout(name):
  584. """Get layout extension object from its string representation.
  585. """
  586. cache = _get_layout.cache # type: ignore[attr-defined]
  587. if not cache:
  588. for v in torch.__dict__.values():
  589. if isinstance(v, torch.layout):
  590. cache[str(v)] = v
  591. return cache[name]
  592. # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
  593. _get_layout.cache = {} # type: ignore[attr-defined]
  594. copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
  595. def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
  596. deserialized_objects: Dict[int, Any] = {}
  597. restore_location = _get_restore_location(map_location)
  598. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  599. def find_class(self, mod_name, name):
  600. if type(name) is str and 'Storage' in name:
  601. try:
  602. return StorageType(name)
  603. except KeyError:
  604. pass
  605. return super().find_class(mod_name, name)
  606. def _check_container_source(container_type, source_file, original_source):
  607. try:
  608. current_source = ''.join(get_source_lines_and_file(container_type)[0])
  609. except Exception: # saving the source is optional, so we can ignore any errors
  610. warnings.warn("Couldn't retrieve source code for container of "
  611. "type " + container_type.__name__ + ". It won't be checked "
  612. "for correctness upon loading.")
  613. return
  614. if original_source != current_source:
  615. if container_type.dump_patches:
  616. file_name = container_type.__name__ + '.patch'
  617. diff = difflib.unified_diff(current_source.split('\n'),
  618. original_source.split('\n'),
  619. source_file,
  620. source_file, lineterm="")
  621. lines = '\n'.join(diff)
  622. try:
  623. with open(file_name, 'a+') as f:
  624. file_size = f.seek(0, 2)
  625. f.seek(0)
  626. if file_size == 0:
  627. f.write(lines)
  628. elif file_size != len(lines) or f.read() != lines:
  629. raise IOError
  630. msg = ("Saved a reverse patch to " + file_name + ". "
  631. "Run `patch -p0 < " + file_name + "` to revert your "
  632. "changes.")
  633. except IOError:
  634. msg = ("Tried to save a patch, but couldn't create a "
  635. "writable file " + file_name + ". Make sure it "
  636. "doesn't exist and your working directory is "
  637. "writable.")
  638. else:
  639. msg = ("you can retrieve the original source code by "
  640. "accessing the object's source attribute or set "
  641. "`torch.nn.Module.dump_patches = True` and use the "
  642. "patch tool to revert the changes.")
  643. msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
  644. warnings.warn(msg, SourceChangeWarning)
  645. def legacy_load(f):
  646. deserialized_objects: Dict[int, Any] = {}
  647. def persistent_load(saved_id):
  648. if isinstance(saved_id, tuple):
  649. # Ignore containers that don't have any sources saved
  650. if all(saved_id[1:]):
  651. _check_container_source(*saved_id)
  652. return saved_id[0]
  653. return deserialized_objects[int(saved_id)]
  654. with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
  655. mkdtemp() as tmpdir:
  656. tar.extract('storages', path=tmpdir)
  657. with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
  658. num_storages = pickle_module.load(f, **pickle_load_args)
  659. for i in range(num_storages):
  660. args = pickle_module.load(f, **pickle_load_args)
  661. key, location, storage_type = args
  662. dtype = storage_type.dtype
  663. obj = cast(Storage, torch._UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
  664. obj = restore_location(obj, location)
  665. # TODO: Once we decide to break serialization FC, we can
  666. # stop wrapping with _TypedStorage
  667. deserialized_objects[key] = torch.storage._TypedStorage(
  668. wrap_storage=obj,
  669. dtype=dtype)
  670. storage_views = pickle_module.load(f, **pickle_load_args)
  671. for target_cdata, root_cdata, offset, numel in storage_views:
  672. root = deserialized_objects[root_cdata]
  673. element_size = torch._utils._element_size(root.dtype)
  674. offset_bytes = offset * element_size
  675. # TODO: Once we decide to break serialization FC, we can
  676. # stop wrapping with _TypedStorage
  677. deserialized_objects[target_cdata] = torch.storage._TypedStorage(
  678. wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
  679. dtype=root.dtype)
  680. tar.extract('tensors', path=tmpdir)
  681. with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
  682. num_tensors = pickle_module.load(f, **pickle_load_args)
  683. for _ in range(num_tensors):
  684. args = pickle_module.load(f, **pickle_load_args)
  685. key, storage_id, original_tensor_type = args
  686. storage = deserialized_objects[storage_id]
  687. ndim, = struct.unpack('<i', f.read(4))
  688. # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
  689. f.read(4)
  690. numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
  691. stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
  692. storage_offset, = struct.unpack('<q', f.read(8))
  693. tensor = torch.tensor([], dtype=storage.dtype).set_(
  694. storage._storage, storage_offset, numel, stride)
  695. deserialized_objects[key] = tensor
  696. pickle_file = tar.extractfile('pickle')
  697. unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
  698. unpickler.persistent_load = persistent_load
  699. result = unpickler.load()
  700. return result
  701. deserialized_objects = {}
  702. def persistent_load(saved_id):
  703. assert isinstance(saved_id, tuple)
  704. typename = _maybe_decode_ascii(saved_id[0])
  705. data = saved_id[1:]
  706. if typename == 'module':
  707. # Ignore containers that don't have any sources saved
  708. if all(data[1:]):
  709. _check_container_source(*data)
  710. return data[0]
  711. elif typename == 'storage':
  712. storage_type, root_key, location, numel, view_metadata = data
  713. location = _maybe_decode_ascii(location)
  714. dtype = storage_type.dtype
  715. nbytes = numel * torch._utils._element_size(dtype)
  716. if root_key not in deserialized_objects:
  717. obj = cast(Storage, torch._UntypedStorage(nbytes))
  718. obj._torch_load_uninitialized = True
  719. # TODO: Once we decide to break serialization FC, we can
  720. # stop wrapping with _TypedStorage
  721. deserialized_objects[root_key] = torch.storage._TypedStorage(
  722. wrap_storage=restore_location(obj, location),
  723. dtype=dtype)
  724. typed_storage = deserialized_objects[root_key]
  725. if view_metadata is not None:
  726. view_key, offset, view_size = view_metadata
  727. offset_bytes = offset * torch._utils._element_size(dtype)
  728. view_size_bytes = view_size * torch._utils._element_size(dtype)
  729. if view_key not in deserialized_objects:
  730. # TODO: Once we decide to break serialization FC, we can
  731. # stop wrapping with _TypedStorage
  732. deserialized_objects[view_key] = torch.storage._TypedStorage(
  733. wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
  734. dtype=dtype)
  735. res = deserialized_objects[view_key]
  736. else:
  737. res = typed_storage
  738. return res
  739. else:
  740. raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
  741. _check_seekable(f)
  742. f_should_read_directly = _should_read_directly(f)
  743. if f_should_read_directly and f.tell() == 0:
  744. # legacy_load requires that f has fileno()
  745. # only if offset is zero we can attempt the legacy tar file loader
  746. try:
  747. return legacy_load(f)
  748. except tarfile.TarError:
  749. if _is_zipfile(f):
  750. # .zip is used for torch.jit.save and will throw an un-pickling error here
  751. raise RuntimeError(
  752. f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
  753. # if not a tarfile, reset file offset and proceed
  754. f.seek(0)
  755. if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
  756. raise RuntimeError(
  757. "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
  758. f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
  759. "functionality.")
  760. magic_number = pickle_module.load(f, **pickle_load_args)
  761. if magic_number != MAGIC_NUMBER:
  762. raise RuntimeError("Invalid magic number; corrupt file?")
  763. protocol_version = pickle_module.load(f, **pickle_load_args)
  764. if protocol_version != PROTOCOL_VERSION:
  765. raise RuntimeError("Invalid protocol version: %s" % protocol_version)
  766. _sys_info = pickle_module.load(f, **pickle_load_args)
  767. unpickler = UnpicklerWrapper(f, **pickle_load_args)
  768. unpickler.persistent_load = persistent_load
  769. result = unpickler.load()
  770. deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
  771. offset = f.tell() if f_should_read_directly else None
  772. for key in deserialized_storage_keys:
  773. assert key in deserialized_objects
  774. typed_storage = deserialized_objects[key]
  775. typed_storage._storage._set_from_file(
  776. f, offset, f_should_read_directly,
  777. torch._utils._element_size(typed_storage.dtype))
  778. if offset is not None:
  779. offset = f.tell()
  780. torch._utils._validate_loaded_sparse_tensors()
  781. return result
  782. def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
  783. # When using encoding='bytes' in Py3, some **internal** keys stored as
  784. # strings in Py2 are loaded as bytes. This function decodes them with
  785. # ascii encoding, one that Py3 uses by default.
  786. #
  787. # NOTE: This should only be used on internal keys (e.g., `typename` and
  788. # `location` in `persistent_load` below!
  789. if isinstance(bytes_str, bytes):
  790. return bytes_str.decode('ascii')
  791. return bytes_str
  792. def _get_restore_location(map_location):
  793. if map_location is None:
  794. restore_location = default_restore_location
  795. elif isinstance(map_location, dict):
  796. def restore_location(storage, location):
  797. location = map_location.get(location, location)
  798. return default_restore_location(storage, location)
  799. elif isinstance(map_location, _string_classes):
  800. def restore_location(storage, location):
  801. return default_restore_location(storage, map_location)
  802. elif isinstance(map_location, torch.device):
  803. def restore_location(storage, location):
  804. return default_restore_location(storage, str(map_location))
  805. else:
  806. def restore_location(storage, location):
  807. result = map_location(storage, location)
  808. if result is None:
  809. result = default_restore_location(storage, location)
  810. return result
  811. return restore_location
  812. class StorageType():
  813. def __init__(self, name):
  814. self.dtype = _get_dtype_from_pickle_storage_type(name)
  815. def __str__(self):
  816. return f'StorageType(dtype={self.dtype})'
  817. def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
  818. restore_location = _get_restore_location(map_location)
  819. loaded_storages = {}
  820. def load_tensor(dtype, numel, key, location):
  821. name = f'data/{key}'
  822. storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
  823. # TODO: Once we decide to break serialization FC, we can
  824. # stop wrapping with _TypedStorage
  825. loaded_storages[key] = torch.storage._TypedStorage(
  826. wrap_storage=restore_location(storage, location),
  827. dtype=dtype)
  828. def persistent_load(saved_id):
  829. assert isinstance(saved_id, tuple)
  830. typename = _maybe_decode_ascii(saved_id[0])
  831. data = saved_id[1:]
  832. assert typename == 'storage', \
  833. f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
  834. storage_type, key, location, numel = data
  835. if storage_type is torch._UntypedStorage:
  836. dtype = torch.uint8
  837. else:
  838. dtype = storage_type.dtype
  839. if key not in loaded_storages:
  840. nbytes = numel * torch._utils._element_size(dtype)
  841. load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  842. return loaded_storages[key]
  843. load_module_mapping: Dict[str, str] = {
  844. # See https://github.com/pytorch/pytorch/pull/51633
  845. 'torch.tensor': 'torch._tensor'
  846. }
  847. # Need to subclass Unpickler instead of directly monkey-patching the find_class method
  848. # because it's marked readonly in pickle.
  849. # The type: ignore is because mypy can't statically determine the type of this class.
  850. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  851. # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
  852. # Lets us override the imports that pickle uses when unpickling an object.
  853. # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
  854. def find_class(self, mod_name, name):
  855. if type(name) is str and 'Storage' in name:
  856. try:
  857. return StorageType(name)
  858. except KeyError:
  859. pass
  860. mod_name = load_module_mapping.get(mod_name, mod_name)
  861. return super().find_class(mod_name, name)
  862. # Load the data (which may in turn use `persistent_load` to load tensors)
  863. data_file = io.BytesIO(zip_file.get_record(pickle_file))
  864. unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
  865. unpickler.persistent_load = persistent_load
  866. result = unpickler.load()
  867. torch._utils._validate_loaded_sparse_tensors()
  868. return result
  869. def _is_torchscript_zip(zip_file):
  870. return 'constants.pkl' in zip_file.get_all_records()