utils.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import collections
  2. from itertools import repeat
  3. from typing import List, Dict, Any
  4. def _ntuple(n, name="parse"):
  5. def parse(x):
  6. if isinstance(x, collections.abc.Iterable):
  7. return tuple(x)
  8. return tuple(repeat(x, n))
  9. parse.__name__ = name
  10. return parse
  11. _single = _ntuple(1, "_single")
  12. _pair = _ntuple(2, "_pair")
  13. _triple = _ntuple(3, "_triple")
  14. _quadruple = _ntuple(4, "_quadruple")
  15. def _reverse_repeat_tuple(t, n):
  16. r"""Reverse the order of `t` and repeat each element for `n` times.
  17. This can be used to translate padding arg used by Conv and Pooling modules
  18. to the ones used by `F.pad`.
  19. """
  20. return tuple(x for x in reversed(t) for _ in range(n))
  21. def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
  22. if isinstance(out_size, int):
  23. return out_size
  24. if len(defaults) <= len(out_size):
  25. raise ValueError(
  26. "Input dimension should be at least {}".format(len(out_size) + 1)
  27. )
  28. return [
  29. v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
  30. ]
  31. def consume_prefix_in_state_dict_if_present(
  32. state_dict: Dict[str, Any], prefix: str
  33. ) -> None:
  34. r"""Strip the prefix in state_dict in place, if any.
  35. ..note::
  36. Given a `state_dict` from a DP/DDP model, a local model can load it by applying
  37. `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
  38. :meth:`torch.nn.Module.load_state_dict`.
  39. Args:
  40. state_dict (OrderedDict): a state-dict to be loaded to the model.
  41. prefix (str): prefix.
  42. """
  43. keys = sorted(state_dict.keys())
  44. for key in keys:
  45. if key.startswith(prefix):
  46. newkey = key[len(prefix) :]
  47. state_dict[newkey] = state_dict.pop(key)
  48. # also strip the prefix in metadata if any.
  49. if "_metadata" in state_dict:
  50. metadata = state_dict["_metadata"]
  51. for key in list(metadata.keys()):
  52. # for the metadata dict, the key can be:
  53. # '': for the DDP module, which we want to remove.
  54. # 'module': for the actual model.
  55. # 'module.xx.xx': for the rest.
  56. if len(key) == 0:
  57. continue
  58. newkey = key[len(prefix) :]
  59. metadata[newkey] = metadata.pop(key)