graph.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import io
  2. import pickle
  3. from torch.utils.data import IterDataPipe, MapDataPipe
  4. from torch.utils.data._utils.serialization import DILL_AVAILABLE
  5. from typing import Any, Dict, Set, Tuple, Type, Union
  6. __all__ = ["traverse", ]
  7. DataPipe = Union[IterDataPipe, MapDataPipe]
  8. reduce_ex_hook = None
  9. def _stub_unpickler():
  10. return "STUB"
  11. # TODO(VitalyFedyunin): Make sure it works without dill module installed
  12. def _list_connected_datapipes(scan_obj, only_datapipe, cache):
  13. f = io.BytesIO()
  14. p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
  15. if DILL_AVAILABLE:
  16. from dill import Pickler as dill_Pickler
  17. d = dill_Pickler(f)
  18. else:
  19. d = None
  20. def stub_pickler(obj):
  21. return _stub_unpickler, ()
  22. captured_connections = []
  23. def getstate_hook(obj):
  24. state = {}
  25. for k, v in obj.__dict__.items():
  26. if isinstance(v, (IterDataPipe, MapDataPipe, tuple)):
  27. state[k] = v
  28. return state
  29. def reduce_hook(obj):
  30. if obj == scan_obj or obj in cache:
  31. raise NotImplementedError
  32. else:
  33. captured_connections.append(obj)
  34. return _stub_unpickler, ()
  35. datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
  36. try:
  37. for cls in datapipe_classes:
  38. cls.set_reduce_ex_hook(reduce_hook)
  39. if only_datapipe:
  40. cls.set_getstate_hook(getstate_hook)
  41. try:
  42. p.dump(scan_obj)
  43. except (pickle.PickleError, AttributeError, TypeError):
  44. if DILL_AVAILABLE:
  45. d.dump(scan_obj)
  46. else:
  47. raise
  48. finally:
  49. for cls in datapipe_classes:
  50. cls.set_reduce_ex_hook(None)
  51. if only_datapipe:
  52. cls.set_getstate_hook(None)
  53. if DILL_AVAILABLE:
  54. from dill import extend as dill_extend
  55. dill_extend(False) # Undo change to dispatch table
  56. return captured_connections
  57. def traverse(datapipe, only_datapipe=False):
  58. cache: Set[DataPipe] = set()
  59. return _traverse_helper(datapipe, only_datapipe, cache)
  60. # Add cache here to prevent infinite recursion on DataPipe
  61. def _traverse_helper(datapipe, only_datapipe, cache):
  62. if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
  63. raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe)))
  64. cache.add(datapipe)
  65. items = _list_connected_datapipes(datapipe, only_datapipe, cache)
  66. d: Dict[DataPipe, Any] = {datapipe: {}}
  67. for item in items:
  68. # Using cache.copy() here is to prevent recursion on a single path rather than global graph
  69. # Single DataPipe can present multiple times in different paths in graph
  70. d[datapipe].update(_traverse_helper(item, only_datapipe, cache.copy()))
  71. return d