modifier_context.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # @package modifier_context
  2. # Module caffe2.python.modifier_context
  3. DEFAULT_MODIFIER = 'DEFAULT'
  4. class ModifierContext(object):
  5. """
  6. provide context to allow param_info to have different modifiers
  7. """
  8. def __init__(self):
  9. self._modifiers = {}
  10. self._modifiers_list = []
  11. def _rebuild_modifiers(self):
  12. self._modifiers = {}
  13. for m in self._modifiers_list:
  14. self._modifiers.update(m)
  15. def _has_modifier(self, name):
  16. return name in self._modifiers
  17. def _get_modifier(self, name):
  18. return self._modifiers.get(name)
  19. def push_modifiers(self, modifiers):
  20. # modifier override is allowed
  21. self._modifiers_list.append(modifiers)
  22. self._modifiers.update(modifiers)
  23. def pop_modifiers(self):
  24. assert len(self._modifiers_list) > 0
  25. self._modifiers_list.pop()
  26. self._rebuild_modifiers()
  27. class UseModifierBase(object):
  28. '''
  29. context class to allow setting the current context.
  30. Example usage with layer:
  31. modifiers = {'modifier1': modifier1, 'modifier2': modifier2}
  32. with Modifiers(modifiers):
  33. modifier = ModifierContext.current().get_modifier('modifier1')
  34. layer(modifier=modifier)
  35. '''
  36. def __init__(self, modifier_or_dict):
  37. if isinstance(modifier_or_dict, dict):
  38. self._modifiers = modifier_or_dict
  39. else:
  40. self._modifiers = {DEFAULT_MODIFIER: modifier_or_dict}
  41. def _context_class(self):
  42. raise NotImplementedError
  43. def __enter__(self):
  44. self._context_class().current().push_modifiers(self._modifiers)
  45. return self
  46. def __exit__(self, type, value, traceback):
  47. self._context_class().current().pop_modifiers()