immutable_collections.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from typing import Any, Dict, Tuple, List
  2. from ._compatibility import compatibility
  3. from torch.utils._pytree import Context, _register_pytree_node
  4. _help_mutation = """\
  5. If you are attempting to modify the kwargs or args of a torch.fx.Node object,
  6. instead create a new copy of it and assign the copy to the node:
  7. new_args = ... # copy and mutate args
  8. node.args = new_args
  9. """
  10. def _no_mutation(self, *args, **kwargs):
  11. raise NotImplementedError(f"'{type(self).__name__}' object does not support mutation. {_help_mutation}")
  12. def _create_immutable_container(base, mutable_functions):
  13. container = type('immutable_' + base.__name__, (base,), {})
  14. for attr in mutable_functions:
  15. setattr(container, attr, _no_mutation)
  16. return container
  17. immutable_list = _create_immutable_container(list,
  18. ['__delitem__', '__iadd__', '__imul__', '__setitem__', 'append',
  19. 'clear', 'extend', 'insert', 'pop', 'remove'])
  20. immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),))
  21. compatibility(is_backward_compatible=True)(immutable_list)
  22. immutable_dict = _create_immutable_container(dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update'])
  23. immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),))
  24. compatibility(is_backward_compatible=True)(immutable_dict)
  25. # Register immutable collections for PyTree operations
  26. def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
  27. return list(d.values()), list(d.keys())
  28. def _immutable_dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
  29. return immutable_dict({key: value for key, value in zip(context, values)})
  30. def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
  31. return d, None
  32. def _immutable_list_unflatten(values: List[Any], context: Context) -> List[Any]:
  33. return immutable_list(values)
  34. _register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
  35. _register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)