parameter_sharing.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from caffe2.python import scope
  2. import contextlib
  3. import logging
  4. logger = logging.getLogger(__name__)
  5. class ParameterSharingContext(object):
  6. """
  7. This class manages scope driven way of parameter sharing across different
  8. NameScopes.
  9. """
  10. def __init__(self):
  11. self._scope_overrides = {}
  12. self._contexts = []
  13. def _resolve_scope_overrides(self, candidate_scope):
  14. """
  15. Recursively resolves all scope overrides, i.e multiple steps of
  16. override can be used.
  17. For example, if one provides following scope overrides:
  18. {'scope_b': 'scope_a'} and within 'scope_b' - {'shared_child': ''},
  19. then name 'w' will get resolved to the following blobs depending on the
  20. namescope:
  21. a. 'scope_a' -> 'scope_a/w'
  22. b. 'scope_b' -> 'scope_a/w'
  23. c. 'scope_c' -> 'scope_c/w'
  24. d. 'scope_b/shared_child' -> 'scope_a/w'
  25. d. 'scope_b/unshared_child' -> 'scope_a/unshared_child/w'
  26. """
  27. best_scope = candidate_scope
  28. best_scope_idx = 0
  29. sub_scopes = candidate_scope.split(scope._NAMESCOPE_SEPARATOR)
  30. cur_scope = ''
  31. for idx, sub_scope in enumerate(sub_scopes):
  32. cur_scope = cur_scope + sub_scope + scope._NAMESCOPE_SEPARATOR
  33. if cur_scope in self._scope_overrides:
  34. best_scope = self._scope_overrides[cur_scope]
  35. best_scope_idx = idx
  36. if best_scope == candidate_scope:
  37. return candidate_scope
  38. else:
  39. return (self._resolve_scope_overrides(best_scope) +
  40. scope._NAMESCOPE_SEPARATOR.join(
  41. sub_scopes[best_scope_idx + 1:]))
  42. def get_parameter_name(self, name):
  43. candidate_scope = scope.CurrentNameScope()
  44. best_scope = self._resolve_scope_overrides(candidate_scope)
  45. if best_scope != candidate_scope:
  46. logger.info("Overwriting scope {0} with scope {1}".format(
  47. candidate_scope, best_scope))
  48. return best_scope + name
  49. def add_scope_overrides(self, shared_scopes):
  50. self._contexts.append(shared_scopes)
  51. self._scope_overrides.update(shared_scopes)
  52. def pop(self):
  53. assert len(self._contexts) > 0
  54. self._contexts.pop()
  55. self._scope_overrides = {}
  56. for x in self._contexts:
  57. self._scope_overrides.update(x)
  58. parameter_sharing_context = ParameterSharingContext()
  59. def _normalize_namescope(namescope):
  60. if namescope and namescope[-1] != scope._NAMESCOPE_SEPARATOR:
  61. return namescope + scope._NAMESCOPE_SEPARATOR
  62. else:
  63. return namescope
  64. @contextlib.contextmanager
  65. def ParameterSharing(shared_scopes):
  66. """
  67. Helper function for sharing scopes.
  68. All the parameters within the shared_scopes, will be remapped with the
  69. respect of CurrentNamescope()
  70. I.e. if one calls ParameterSharing with {'scope_b': 'scope_'a'}, from the
  71. scope 'some_global_scope', it'll effectively mean, that all parameters from
  72. 'some_global_scope/scope_b' will shared with the parameters from
  73. 'some_global_scope/scope_a'
  74. """
  75. assert isinstance(shared_scopes, dict)
  76. shared_scope_overrides = {}
  77. current_scope = scope.CurrentNameScope()
  78. for k, v in shared_scopes.items():
  79. assert not v.startswith(k), (
  80. "Illegal override for parameter sharing. {} is prefix of {}".
  81. format(k, v))
  82. k = current_scope + k
  83. v = current_scope + v
  84. # Normalize all the scopes, so scope_a and scope_a/ are equivalent
  85. k = _normalize_namescope(k)
  86. v = _normalize_namescope(v)
  87. shared_scope_overrides[k] = v
  88. try:
  89. parameter_sharing_context.add_scope_overrides(shared_scope_overrides)
  90. yield
  91. finally:
  92. parameter_sharing_context.pop()