_deploy.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import io
  2. import torch
  3. from torch.package._package_pickler import create_pickler
  4. from torch.package._package_unpickler import PackageUnpickler
  5. from torch.package import sys_importer, OrderedImporter, PackageImporter, Importer
  6. from torch.serialization import _maybe_decode_ascii
  7. def _save_storages(importer, obj):
  8. serialized_storages = []
  9. serialized_dtypes = []
  10. importer = importer if isinstance(importer, torch.package.PackageImporter) else None
  11. importers: Importer
  12. if importer is not None:
  13. importers = OrderedImporter(importer, sys_importer)
  14. else:
  15. importers = sys_importer
  16. def persistent_id(obj):
  17. if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
  18. if isinstance(obj, torch.storage._TypedStorage):
  19. # TODO: Once we decide to break serialization FC, we can
  20. # remove this case
  21. storage = obj._storage
  22. dtype = obj.dtype
  23. else:
  24. storage = obj
  25. dtype = torch.uint8
  26. serialized_storages.append(obj)
  27. serialized_dtypes.append(dtype)
  28. return ('storage', len(serialized_storages) - 1)
  29. if hasattr(obj, "__reduce_deploy__"):
  30. if _serialized_reduces.get(id(obj)) is None:
  31. _serialized_reduces[id(obj)] = (
  32. "reduce_deploy",
  33. id(obj),
  34. *obj.__reduce_deploy__(importers),
  35. )
  36. return _serialized_reduces[id(obj)]
  37. return None
  38. # Write the pickle data for `obj`
  39. data_buf = io.BytesIO()
  40. pickler = create_pickler(data_buf, importers)
  41. pickler.persistent_id = persistent_id
  42. pickler.dump(obj)
  43. data_value = data_buf.getvalue()
  44. return data_value, serialized_storages, serialized_dtypes, importer.zip_reader if importer else None
  45. def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
  46. def persistent_load(saved_id):
  47. assert isinstance(saved_id, tuple)
  48. typename = _maybe_decode_ascii(saved_id[0])
  49. data = saved_id[1:]
  50. if typename == 'storage':
  51. # TODO: Once we decide to break serialization FC, we can
  52. # stop wrapping with _TypedStorage
  53. storage = serialized_storages[data[0]]
  54. dtype = serialized_dtypes[data[0]]
  55. return torch.storage._TypedStorage(
  56. wrap_storage=storage._untyped(),
  57. dtype=dtype)
  58. if typename == 'reduce_deploy':
  59. reduce_id, func, args = data
  60. if reduce_id not in _loaded_reduces:
  61. _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
  62. return _loaded_reduces[reduce_id]
  63. return None
  64. importer: Importer
  65. if zip_reader is not None:
  66. importer = OrderedImporter(_get_package(zip_reader), sys_importer)
  67. else:
  68. importer = sys_importer
  69. unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
  70. unpickler.persistent_load = persistent_load # type: ignore[assignment]
  71. result = _deploy_objects[id] = unpickler.load()
  72. return result
  73. def _get_package(zip_reader):
  74. if zip_reader not in _raw_packages:
  75. _raw_packages[zip_reader] = PackageImporter(zip_reader)
  76. return _raw_packages[zip_reader]
  77. _raw_packages: dict = {}
  78. _deploy_objects: dict = {}
  79. _serialized_reduces: dict = {}
  80. _loaded_reduces: dict = {}