show_pickle.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #!/usr/bin/env python3
  2. import sys
  3. import pickle
  4. import struct
  5. import pprint
  6. import zipfile
  7. import fnmatch
  8. from typing import Any, IO, BinaryIO, Union
  9. class FakeObject(object):
  10. def __init__(self, module, name, args):
  11. self.module = module
  12. self.name = name
  13. self.args = args
  14. # NOTE: We don't distinguish between state never set and state set to None.
  15. self.state = None
  16. def __repr__(self):
  17. state_str = "" if self.state is None else f"(state={self.state!r})"
  18. return f"{self.module}.{self.name}{self.args!r}{state_str}"
  19. def __setstate__(self, state):
  20. self.state = state
  21. @staticmethod
  22. def pp_format(printer, obj, stream, indent, allowance, context, level):
  23. if not obj.args and obj.state is None:
  24. stream.write(repr(obj))
  25. return
  26. if obj.state is None:
  27. stream.write(f"{obj.module}.{obj.name}")
  28. printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
  29. return
  30. if not obj.args:
  31. stream.write(f"{obj.module}.{obj.name}()(state=\n")
  32. indent += printer._indent_per_level
  33. stream.write(" " * indent)
  34. printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
  35. stream.write(")")
  36. return
  37. raise Exception("Need to implement")
  38. class FakeClass(object):
  39. def __init__(self, module, name):
  40. self.module = module
  41. self.name = name
  42. self.__new__ = self.fake_new # type: ignore[assignment]
  43. def __repr__(self):
  44. return f"{self.module}.{self.name}"
  45. def __call__(self, *args):
  46. return FakeObject(self.module, self.name, args)
  47. def fake_new(self, *args):
  48. return FakeObject(self.module, self.name, args[1:])
  49. class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined]
  50. def __init__(
  51. self,
  52. file,
  53. *,
  54. catch_invalid_utf8=False,
  55. **kwargs):
  56. super().__init__(file, **kwargs)
  57. self.catch_invalid_utf8 = catch_invalid_utf8
  58. def find_class(self, module, name):
  59. return FakeClass(module, name)
  60. def persistent_load(self, pid):
  61. return FakeObject("pers", "obj", (pid,))
  62. dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined]
  63. # Custom objects in TorchScript are able to return invalid UTF-8 strings
  64. # from their pickle (__getstate__) functions. Install a custom loader
  65. # for strings that catches the decode exception and replaces it with
  66. # a sentinel object.
  67. def load_binunicode(self):
  68. strlen, = struct.unpack("<I", self.read(4)) # type: ignore[attr-defined]
  69. if strlen > sys.maxsize:
  70. raise Exception("String too long.")
  71. str_bytes = self.read(strlen) # type: ignore[attr-defined]
  72. obj: Any
  73. try:
  74. obj = str(str_bytes, "utf-8", "surrogatepass")
  75. except UnicodeDecodeError as exn:
  76. if not self.catch_invalid_utf8:
  77. raise
  78. obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),))
  79. self.append(obj) # type: ignore[attr-defined]
  80. dispatch[pickle.BINUNICODE[0]] = load_binunicode # type: ignore[assignment]
  81. @classmethod
  82. def dump(cls, in_stream, out_stream):
  83. value = cls(in_stream).load()
  84. pprint.pprint(value, stream=out_stream)
  85. return value
  86. def main(argv, output_stream=None):
  87. if len(argv) != 2:
  88. # Don't spam stderr if not using stdout.
  89. if output_stream is not None:
  90. raise Exception("Pass argv of length 2.")
  91. sys.stderr.write("usage: show_pickle PICKLE_FILE\n")
  92. sys.stderr.write(" PICKLE_FILE can be any of:\n")
  93. sys.stderr.write(" path to a pickle file\n")
  94. sys.stderr.write(" file.zip@member.pkl\n")
  95. sys.stderr.write(" file.zip@*/pattern.*\n")
  96. sys.stderr.write(" (shell glob pattern for members)\n")
  97. sys.stderr.write(" (only first match will be shown)\n")
  98. return 2
  99. fname = argv[1]
  100. handle: Union[IO[bytes], BinaryIO]
  101. if "@" not in fname:
  102. with open(fname, "rb") as handle:
  103. DumpUnpickler.dump(handle, output_stream)
  104. else:
  105. zfname, mname = fname.split("@", 1)
  106. with zipfile.ZipFile(zfname) as zf:
  107. if "*" not in mname:
  108. with zf.open(mname) as handle:
  109. DumpUnpickler.dump(handle, output_stream)
  110. else:
  111. found = False
  112. for info in zf.infolist():
  113. if fnmatch.fnmatch(info.filename, mname):
  114. with zf.open(info) as handle:
  115. DumpUnpickler.dump(handle, output_stream)
  116. found = True
  117. break
  118. if not found:
  119. raise Exception(f"Could not find member matching {mname} in {zfname}")
  120. if __name__ == "__main__":
  121. # This hack works on every version of Python I've tested.
  122. # I've tested on the following versions:
  123. # 3.7.4
  124. if True:
  125. pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined]
  126. sys.exit(main(sys.argv))